use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{
creation::{rand, randn},
Tensor,
};
pub fn encode_architecture(
operations: &[usize],
connections: &Tensor,
num_ops: usize,
) -> TorshResult<Tensor> {
let num_layers = operations.len();
let mut op_encoding = Vec::with_capacity(num_layers * num_ops);
for &op_idx in operations {
let mut one_hot = vec![0.0f32; num_ops];
if op_idx < num_ops {
one_hot[op_idx] = 1.0;
}
op_encoding.extend(one_hot);
}
let op_tensor = Tensor::from_data(
op_encoding,
vec![num_layers, num_ops],
torsh_core::device::DeviceType::Cpu,
)?;
let connections_flat = connections.view(&[-1])?;
let op_view = op_tensor.view(&[-1])?;
let encoding = Tensor::cat(&[&op_view, &connections_flat], 0)?;
Ok(encoding)
}
pub fn decode_architecture(
encoding: &Tensor,
num_layers: usize,
num_ops: usize,
) -> TorshResult<(Vec<usize>, Tensor)> {
let encoding_data = encoding.data()?;
let ops_size = num_layers * num_ops;
let connections_size = num_layers * num_layers;
if encoding_data.len() != ops_size + connections_size {
return Err(TorshError::invalid_argument_with_context(
"Encoding size doesn't match expected dimensions",
"decode_architecture",
));
}
let mut operations = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let start_idx = layer * num_ops;
let end_idx = start_idx + num_ops;
let layer_ops = &encoding_data[start_idx..end_idx];
let op_idx = layer_ops
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
operations.push(op_idx);
}
let connections_data = &encoding_data[ops_size..];
let connections = Tensor::from_data(
connections_data.to_vec(),
vec![num_layers, num_layers],
torsh_core::device::DeviceType::Cpu,
)?;
Ok((operations, connections))
}
pub fn darts_operation(_x: &Tensor, alpha: &Tensor, operations: &[Tensor]) -> TorshResult<Tensor> {
if operations.is_empty() {
return Err(TorshError::invalid_argument_with_context(
"Operations list cannot be empty",
"darts_operation",
));
}
let alpha_softmax = alpha.softmax(0)?;
let alpha_data = alpha_softmax.data()?;
let mut result = operations[0].mul_scalar(*alpha_data.get(0).unwrap_or(&0.0))?;
for (i, op_output) in operations.iter().enumerate().skip(1) {
let weight = *alpha_data.get(i).unwrap_or(&0.0);
let weighted_op = op_output.mul_scalar(weight)?;
result = result.add(&weighted_op)?;
}
Ok(result)
}
pub fn predict_architecture_performance(
encoding: &Tensor,
predictor_weights: &Tensor,
) -> TorshResult<Tensor> {
let prediction = predictor_weights.matmul(encoding)?;
prediction.sigmoid()
}
pub fn mutate_architecture(
operations: &[usize],
connections: &Tensor,
mutation_rate: f32,
num_ops: usize,
) -> TorshResult<(Vec<usize>, Tensor)> {
let mut mutated_ops = operations.to_vec();
let mut _mutated_connections = connections.clone();
for op in &mut mutated_ops {
let mutate_data = rand(&[1])?.data()?;
if *mutate_data.get(0).unwrap_or(&1.0) < mutation_rate {
let new_op_data = rand(&[1])?.data()?;
*op = (*new_op_data.get(0).unwrap_or(&0.5) * num_ops as f32) as usize % num_ops;
}
}
let noise = randn(connections.shape().dims())?;
let noisy_connections = connections.add(&noise.mul_scalar(mutation_rate)?)?;
let mutated_connections = noisy_connections.sigmoid()?;
Ok((mutated_ops, mutated_connections))
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::{ones, randn};
#[test]
fn test_encode_decode_architecture() -> TorshResult<()> {
let operations = vec![0, 1, 2, 1]; let connections = ones(&[4, 4])?;
let num_ops = 3;
let encoding = encode_architecture(&operations, &connections, num_ops)?;
let (decoded_ops, decoded_connections) = decode_architecture(&encoding, 4, num_ops)?;
assert_eq!(operations, decoded_ops);
assert_eq!(
connections.shape().dims(),
decoded_connections.shape().dims()
);
Ok(())
}
#[test]
fn test_darts_operation() -> TorshResult<()> {
let x = randn(&[2, 4])?;
let alpha = randn(&[3])?;
let op1 = x.mul_scalar(1.0)?;
let op2 = x.mul_scalar(2.0)?;
let op3 = x.mul_scalar(0.5)?;
let operations = vec![op1, op2, op3];
let result = darts_operation(&x, &alpha, &operations)?;
assert_eq!(x.shape().dims(), result.shape().dims());
Ok(())
}
#[test]
fn test_architecture_mutation() -> TorshResult<()> {
let operations = vec![0, 1, 2];
let connections = ones(&[3, 3])?;
let num_ops = 4;
let (mutated_ops, mutated_connections) =
mutate_architecture(&operations, &connections, 0.5, num_ops)?;
assert_eq!(operations.len(), mutated_ops.len());
assert_eq!(
connections.shape().dims(),
mutated_connections.shape().dims()
);
for &op in &mutated_ops {
assert!(op < num_ops);
}
Ok(())
}
}