use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CondBranch {
True,
False,
}
#[derive(Debug)]
struct CondBackward<T: Float> {
#[allow(dead_code)]
branch: CondBranch,
branch_outputs: Vec<Tensor<T>>,
#[allow(dead_code)]
operands: Vec<Tensor<T>>,
output_index: usize,
}
impl<T: Float> GradFn<T> for CondBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let _branch_out = &self.branch_outputs[self.output_index];
Ok(vec![Some(grad_output.clone())])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.branch_outputs[self.output_index]]
}
fn name(&self) -> &'static str {
"CondBackward"
}
}
pub fn cond<T, TF, FF>(
pred: &Tensor<T>,
true_fn: TF,
false_fn: FF,
operands: &[Tensor<T>],
) -> FerrotorchResult<Vec<Tensor<T>>>
where
T: Float,
TF: FnOnce(&[Tensor<T>]) -> Vec<Tensor<T>>,
FF: FnOnce(&[Tensor<T>]) -> Vec<Tensor<T>>,
{
if pred.numel() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"cond: pred must be a scalar tensor (1 element), got shape {:?} ({} elements)",
pred.shape(),
pred.numel()
),
});
}
let pred_val = pred.data()?[0];
let threshold = T::from(0.5).unwrap();
let take_true = pred_val > threshold;
let (branch, branch_outputs) = if take_true {
(CondBranch::True, true_fn(operands))
} else {
(CondBranch::False, false_fn(operands))
};
let any_requires_grad = is_grad_enabled() && operands.iter().any(|op| op.requires_grad());
if !any_requires_grad {
return Ok(branch_outputs);
}
let operands_vec: Vec<Tensor<T>> = operands.to_vec();
let mut result = Vec::with_capacity(branch_outputs.len());
for (i, out) in branch_outputs.iter().enumerate() {
let data = out.data_vec()?;
let shape = out.shape().to_vec();
let grad_fn = Arc::new(CondBackward {
branch,
branch_outputs: branch_outputs.clone(),
operands: operands_vec.clone(),
output_index: i,
});
let wrapped = Tensor::from_operation(TensorStorage::cpu(data), shape, grad_fn)?;
result.push(wrapped);
}
Ok(result)
}
pub fn validate_cond_branches<T: Float>(
true_outputs: &[Tensor<T>],
false_outputs: &[Tensor<T>],
) -> FerrotorchResult<()> {
if true_outputs.len() != false_outputs.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"cond: true branch returns {} tensors but false branch returns {}",
true_outputs.len(),
false_outputs.len()
),
});
}
for (i, (t, f)) in true_outputs.iter().zip(false_outputs.iter()).enumerate() {
if t.shape() != f.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cond: output[{i}] shape mismatch: true branch {:?} vs false branch {:?}",
t.shape(),
f.shape()
),
});
}
}
Ok(())
}
#[derive(Debug)]
struct ScanBackward<T: Float> {
carries: Vec<Tensor<T>>,
#[allow(dead_code)]
xs: Vec<Tensor<T>>,
outputs: Vec<Tensor<T>>,
output_index: OutputKind,
}
#[derive(Debug, Clone, Copy)]
enum OutputKind {
FinalCarry,
StepOutput(usize),
}
impl<T: Float> GradFn<T> for ScanBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
match self.output_index {
OutputKind::FinalCarry => {
let _last_carry = self.carries.last().unwrap();
Ok(vec![Some(grad_output.clone())])
}
OutputKind::StepOutput(_i) => {
Ok(vec![Some(grad_output.clone())])
}
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
match self.output_index {
OutputKind::FinalCarry => {
vec![self.carries.last().unwrap()]
}
OutputKind::StepOutput(i) => {
vec![&self.outputs[i]]
}
}
}
fn name(&self) -> &'static str {
"ScanBackward"
}
}
pub fn scan<T, F>(
fn_step: F,
init: &Tensor<T>,
xs: &[Tensor<T>],
) -> FerrotorchResult<(Tensor<T>, Vec<Tensor<T>>)>
where
T: Float,
F: Fn(&Tensor<T>, &Tensor<T>) -> (Tensor<T>, Tensor<T>),
{
if xs.is_empty() {
return Ok((init.clone(), Vec::new()));
}
let mut carries: Vec<Tensor<T>> = Vec::with_capacity(xs.len() + 1);
carries.push(init.clone());
let mut outputs: Vec<Tensor<T>> = Vec::with_capacity(xs.len());
let mut current_carry = init.clone();
for x in xs {
let (new_carry, output) = fn_step(¤t_carry, x);
carries.push(new_carry.clone());
outputs.push(output);
current_carry = new_carry;
}
let any_requires_grad = is_grad_enabled()
&& (init.requires_grad()
|| xs.iter().any(|x| x.requires_grad())
|| carries.iter().any(|c| c.requires_grad())
|| outputs.iter().any(|o| o.requires_grad()));
if !any_requires_grad {
return Ok((current_carry, outputs));
}
let final_carry_data = current_carry.data_vec()?;
let final_carry_shape = current_carry.shape().to_vec();
let final_carry_wrapped = Tensor::from_operation(
TensorStorage::cpu(final_carry_data),
final_carry_shape,
Arc::new(ScanBackward {
carries: carries.clone(),
xs: xs.to_vec(),
outputs: outputs.clone(),
output_index: OutputKind::FinalCarry,
}),
)?;
let mut wrapped_outputs = Vec::with_capacity(outputs.len());
for (i, out) in outputs.iter().enumerate() {
let out_data = out.data_vec()?;
let out_shape = out.shape().to_vec();
let wrapped = Tensor::from_operation(
TensorStorage::cpu(out_data),
out_shape,
Arc::new(ScanBackward {
carries: carries.clone(),
xs: xs.to_vec(),
outputs: outputs.clone(),
output_index: OutputKind::StepOutput(i),
}),
)?;
wrapped_outputs.push(wrapped);
}
Ok((final_carry_wrapped, wrapped_outputs))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::{full, ones, zeros};
#[test]
fn test_cond_true_branch() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![], false).unwrap();
let x = ones::<f32>(&[3]).unwrap();
let result = cond(
&pred,
|ops| {
let data = ops[0].data().unwrap();
let doubled: Vec<f32> = data.iter().map(|&v| v * 2.0).collect();
vec![
Tensor::from_storage(
TensorStorage::cpu(doubled),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
]
},
|ops| {
let data = ops[0].data().unwrap();
let tripled: Vec<f32> = data.iter().map(|&v| v * 3.0).collect();
vec![
Tensor::from_storage(
TensorStorage::cpu(tripled),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
]
},
&[x],
)
.unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data().unwrap();
assert_eq!(data, &[2.0, 2.0, 2.0]);
}
#[test]
fn test_cond_false_branch() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.0]), vec![], false).unwrap();
let x = ones::<f32>(&[3]).unwrap();
let result = cond(
&pred,
|ops| {
let data = ops[0].data().unwrap();
let doubled: Vec<f32> = data.iter().map(|&v| v * 2.0).collect();
vec![
Tensor::from_storage(
TensorStorage::cpu(doubled),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
]
},
|ops| {
let data = ops[0].data().unwrap();
let tripled: Vec<f32> = data.iter().map(|&v| v * 3.0).collect();
vec![
Tensor::from_storage(
TensorStorage::cpu(tripled),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
]
},
&[x],
)
.unwrap();
assert_eq!(result.len(), 1);
let data = result[0].data().unwrap();
assert_eq!(data, &[3.0, 3.0, 3.0]);
}
#[test]
fn test_cond_threshold_boundary() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.5]), vec![], false).unwrap();
let x = ones::<f32>(&[2]).unwrap();
let result = cond(
&pred,
|_| vec![full::<f32>(&[2], 10.0).unwrap()],
|_| vec![full::<f32>(&[2], 20.0).unwrap()],
&[x],
)
.unwrap();
let data = result[0].data().unwrap();
assert_eq!(data, &[20.0, 20.0]); }
#[test]
fn test_cond_just_above_threshold() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.51]), vec![], false).unwrap();
let x = ones::<f32>(&[2]).unwrap();
let result = cond(
&pred,
|_| vec![full::<f32>(&[2], 10.0).unwrap()],
|_| vec![full::<f32>(&[2], 20.0).unwrap()],
&[x],
)
.unwrap();
let data = result[0].data().unwrap();
assert_eq!(data, &[10.0, 10.0]); }
#[test]
fn test_cond_non_scalar_pred_error() {
let pred = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 0.0]), vec![2], false)
.unwrap();
let x = ones::<f32>(&[3]).unwrap();
let result = cond(
&pred,
|_| vec![zeros::<f32>(&[3]).unwrap()],
|_| vec![ones::<f32>(&[3]).unwrap()],
&[x],
);
assert!(result.is_err());
}
#[test]
fn test_cond_multiple_outputs() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![], false).unwrap();
let x = ones::<f32>(&[2]).unwrap();
let result = cond(
&pred,
|ops| {
let d = ops[0].data().unwrap();
vec![
Tensor::from_storage(
TensorStorage::cpu(d.iter().map(|&v| v * 2.0).collect()),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
Tensor::from_storage(
TensorStorage::cpu(d.iter().map(|&v| v * 3.0).collect()),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
]
},
|_| vec![zeros::<f32>(&[2]).unwrap(), zeros::<f32>(&[2]).unwrap()],
&[x],
)
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].data().unwrap(), &[2.0, 2.0]);
assert_eq!(result[1].data().unwrap(), &[3.0, 3.0]);
}
#[test]
fn test_cond_empty_operands() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![], false).unwrap();
let result = cond(
&pred,
|_| vec![full::<f32>(&[3], 42.0).unwrap()],
|_| vec![full::<f32>(&[3], 0.0).unwrap()],
&[],
)
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].data().unwrap(), &[42.0, 42.0, 42.0]);
}
#[test]
fn test_cond_with_requires_grad() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![], false).unwrap();
let x = Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0, 3.0]), vec![3], true)
.unwrap();
let result = cond(
&pred,
|ops| {
let data = ops[0].data().unwrap();
let doubled: Vec<f32> = data.iter().map(|&v| v * 2.0).collect();
vec![
Tensor::from_storage(
TensorStorage::cpu(doubled),
ops[0].shape().to_vec(),
false,
)
.unwrap(),
]
},
|_| vec![zeros::<f32>(&[3]).unwrap()],
&[x],
)
.unwrap();
assert!(result[0].requires_grad());
assert_eq!(result[0].data().unwrap(), &[2.0, 4.0, 6.0]);
}
#[test]
fn test_cond_scalar_pred_single_element() {
let pred =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0]), vec![1], false).unwrap();
let x = ones::<f32>(&[2]).unwrap();
let result = cond(
&pred,
|_| vec![full::<f32>(&[2], 5.0).unwrap()],
|_| vec![full::<f32>(&[2], 0.0).unwrap()],
&[x],
)
.unwrap();
assert_eq!(result[0].data().unwrap(), &[5.0, 5.0]);
}
#[test]
fn test_validate_cond_branches_matching() {
let a = vec![ones::<f32>(&[3, 4]).unwrap(), zeros::<f32>(&[2]).unwrap()];
let b = vec![zeros::<f32>(&[3, 4]).unwrap(), ones::<f32>(&[2]).unwrap()];
assert!(validate_cond_branches(&a, &b).is_ok());
}
#[test]
fn test_validate_cond_branches_count_mismatch() {
let a = vec![ones::<f32>(&[3]).unwrap()];
let b = vec![ones::<f32>(&[3]).unwrap(), ones::<f32>(&[3]).unwrap()];
assert!(validate_cond_branches(&a, &b).is_err());
}
#[test]
fn test_validate_cond_branches_shape_mismatch() {
let a = vec![ones::<f32>(&[3]).unwrap()];
let b = vec![ones::<f32>(&[4]).unwrap()];
assert!(validate_cond_branches(&a, &b).is_err());
}
#[test]
fn test_scan_empty_sequence() {
let init = full::<f32>(&[2], 1.0).unwrap();
let xs: &[Tensor<f32>] = &[];
let (final_carry, outputs) =
scan(|carry, _x| (carry.clone(), carry.clone()), &init, xs).unwrap();
assert_eq!(final_carry.shape(), &[2]);
assert_eq!(final_carry.data().unwrap(), &[1.0, 1.0]);
assert!(outputs.is_empty());
}
#[test]
fn test_scan_cumulative_sum() {
let init = zeros::<f32>(&[1]).unwrap();
let xs: Vec<Tensor<f32>> = vec![
full::<f32>(&[1], 1.0).unwrap(),
full::<f32>(&[1], 2.0).unwrap(),
full::<f32>(&[1], 3.0).unwrap(),
];
let (final_carry, outputs) = scan(
|carry, x| {
let c_data = carry.data().unwrap();
let x_data = x.data().unwrap();
let sum_val = c_data[0] + x_data[0];
let new_carry =
Tensor::from_storage(TensorStorage::cpu(vec![sum_val]), vec![1], false)
.unwrap();
let output = new_carry.clone();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert_eq!(final_carry.data().unwrap(), &[6.0]);
assert_eq!(outputs.len(), 3);
assert_eq!(outputs[0].data().unwrap(), &[1.0]);
assert_eq!(outputs[1].data().unwrap(), &[3.0]);
assert_eq!(outputs[2].data().unwrap(), &[6.0]);
}
#[test]
fn test_scan_single_step() {
let init = full::<f32>(&[2], 10.0).unwrap();
let xs = vec![full::<f32>(&[2], 5.0).unwrap()];
let (final_carry, outputs) = scan(
|carry, x| {
let c = carry.data().unwrap();
let xd = x.data().unwrap();
let new_data: Vec<f32> = c.iter().zip(xd.iter()).map(|(&a, &b)| a + b).collect();
let new_carry = Tensor::from_storage(
TensorStorage::cpu(new_data.clone()),
carry.shape().to_vec(),
false,
)
.unwrap();
let output = Tensor::from_storage(
TensorStorage::cpu(new_data),
carry.shape().to_vec(),
false,
)
.unwrap();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert_eq!(final_carry.data().unwrap(), &[15.0, 15.0]);
assert_eq!(outputs.len(), 1);
assert_eq!(outputs[0].data().unwrap(), &[15.0, 15.0]);
}
#[test]
fn test_scan_carry_shape_preserved() {
let init = zeros::<f32>(&[3, 4]).unwrap();
let xs = vec![ones::<f32>(&[3, 4]).unwrap(), ones::<f32>(&[3, 4]).unwrap()];
let (final_carry, outputs) = scan(
|carry, x| {
let c = carry.data().unwrap();
let xd = x.data().unwrap();
let new_data: Vec<f32> = c.iter().zip(xd.iter()).map(|(&a, &b)| a + b).collect();
let new_carry = Tensor::from_storage(
TensorStorage::cpu(new_data.clone()),
carry.shape().to_vec(),
false,
)
.unwrap();
let output = Tensor::from_storage(
TensorStorage::cpu(new_data),
carry.shape().to_vec(),
false,
)
.unwrap();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert_eq!(final_carry.shape(), &[3, 4]);
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].shape(), &[3, 4]);
assert_eq!(outputs[1].shape(), &[3, 4]);
}
#[test]
fn test_scan_different_carry_and_output_shapes() {
let init = zeros::<f32>(&[2]).unwrap();
let xs = vec![
full::<f32>(&[2], 1.0).unwrap(),
full::<f32>(&[2], 2.0).unwrap(),
];
let (final_carry, outputs) = scan(
|carry, x| {
let c = carry.data().unwrap();
let xd = x.data().unwrap();
let new_data: Vec<f32> = c.iter().zip(xd.iter()).map(|(&a, &b)| a + b).collect();
let new_carry = Tensor::from_storage(
TensorStorage::cpu(new_data.clone()),
carry.shape().to_vec(),
false,
)
.unwrap();
let sum: f32 = new_data.iter().sum();
let output =
Tensor::from_storage(TensorStorage::cpu(vec![sum]), vec![1], false).unwrap();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert_eq!(final_carry.shape(), &[2]);
assert_eq!(final_carry.data().unwrap(), &[3.0, 3.0]);
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].shape(), &[1]);
assert_eq!(outputs[0].data().unwrap(), &[2.0]); assert_eq!(outputs[1].shape(), &[1]);
assert_eq!(outputs[1].data().unwrap(), &[6.0]); }
#[test]
fn test_scan_multiplicative_accumulation() {
let init = full::<f32>(&[1], 1.0).unwrap();
let xs = vec![
full::<f32>(&[1], 2.0).unwrap(),
full::<f32>(&[1], 3.0).unwrap(),
full::<f32>(&[1], 4.0).unwrap(),
];
let (final_carry, outputs) = scan(
|carry, x| {
let c = carry.data().unwrap();
let xd = x.data().unwrap();
let product = c[0] * xd[0];
let new_carry =
Tensor::from_storage(TensorStorage::cpu(vec![product]), vec![1], false)
.unwrap();
let output = new_carry.clone();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert_eq!(final_carry.data().unwrap(), &[24.0]);
assert_eq!(outputs.len(), 3);
assert_eq!(outputs[0].data().unwrap(), &[2.0]);
assert_eq!(outputs[1].data().unwrap(), &[6.0]);
assert_eq!(outputs[2].data().unwrap(), &[24.0]);
}
#[test]
fn test_scan_with_requires_grad() {
let init =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.0]), vec![1], true).unwrap();
let xs = vec![
full::<f32>(&[1], 1.0).unwrap(),
full::<f32>(&[1], 2.0).unwrap(),
];
let (final_carry, outputs) = scan(
|carry, x| {
let c = carry.data().unwrap();
let xd = x.data().unwrap();
let sum = c[0] + xd[0];
let new_carry =
Tensor::from_storage(TensorStorage::cpu(vec![sum]), vec![1], false).unwrap();
let output = new_carry.clone();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert!(final_carry.requires_grad());
assert_eq!(final_carry.data().unwrap(), &[3.0]);
assert_eq!(outputs.len(), 2);
assert!(outputs[0].requires_grad());
assert!(outputs[1].requires_grad());
}
#[test]
fn test_scan_ema_filter() {
let alpha = 0.3f32;
let init = zeros::<f32>(&[1]).unwrap();
let xs = vec![
full::<f32>(&[1], 1.0).unwrap(),
full::<f32>(&[1], 1.0).unwrap(),
full::<f32>(&[1], 1.0).unwrap(),
full::<f32>(&[1], 1.0).unwrap(),
];
let (final_carry, outputs) = scan(
move |carry, x| {
let c = carry.data().unwrap();
let xd = x.data().unwrap();
let ema = alpha * xd[0] + (1.0 - alpha) * c[0];
let new_carry =
Tensor::from_storage(TensorStorage::cpu(vec![ema]), vec![1], false).unwrap();
let output = new_carry.clone();
(new_carry, output)
},
&init,
&xs,
)
.unwrap();
assert_eq!(outputs.len(), 4);
let eps = 1e-5;
assert!((outputs[0].data().unwrap()[0] - 0.3).abs() < eps);
assert!((outputs[1].data().unwrap()[0] - 0.51).abs() < eps);
assert!((outputs[2].data().unwrap()[0] - 0.657).abs() < eps);
assert!((outputs[3].data().unwrap()[0] - 0.7599).abs() < eps);
assert!((final_carry.data().unwrap()[0] - 0.7599).abs() < eps);
}
}