use crate::tensor::Tensor;
use crate::buffer::Buffer;
use crate::dtypes::DType;
use crate::errors::{EtensorError, EtensorResult};
use crate::autograd::tape;
use crate::autograd::gradients::Gradients;
pub fn backward(root: &Tensor) -> EtensorResult<Gradients> {
if !root.requires_grad {
return Err(EtensorError::AutogradError(
"Cannot call backward() on a tensor where requires_grad is false.".to_string(),
));
}
let mut grads = Gradients::new();
if root.dtype != DType::F32 || !root.device.is_cpu() {
return Err(EtensorError::AutogradError(
"Phase 3 backward engine currently only supports CPU Float32 seeding.".to_string(),
));
}
let num_elements = root.shape.num_elements();
let seed_buffer = Buffer::from_f32_vec(vec![1.0; num_elements]);
grads.insert(root.id, seed_buffer)?;
let actions = tape::take();
for action in actions.into_iter().rev() {
action.backward(&mut grads)?;
}
Ok(grads)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, TensorId};
use crate::shape::Shape;
use crate::device::Device;
use crate::autograd::tape::{TapeAction, record};
struct MockPassThroughBackward {
input_id: TensorId,
output_id: TensorId,
}
impl TapeAction for MockPassThroughBackward {
fn backward(&self, grads: &mut Gradients) -> EtensorResult<()> {
let dy_buffer = grads.get(&self.output_id).unwrap().clone();
grads.insert(self.input_id, dy_buffer)?;
Ok(())
}
fn name(&self) -> String {
"MockPassThrough".to_string()
}
}
#[test]
fn test_requires_grad_guard() {
let shape = Shape::new(vec![1]);
let data = Buffer::from_f32_vec(vec![42.0]);
let t = Tensor::new(data, shape, Device::Cpu, DType::F32, false);
let result = backward(&t);
assert!(result.is_err());
if let Err(EtensorError::AutogradError(msg)) = result {
assert!(msg.contains("requires_grad is false"));
} else {
panic!("Engine bypassed the requires_grad safety guard!");
}
}
#[test]
fn test_engine_seeding_and_execution() {
let _ = tape::take();
let shape = Shape::new(vec![1]);
let input_tensor = Tensor::new(
Buffer::from_f32_vec(vec![5.0]), shape.clone(), Device::Cpu, DType::F32, true
);
let output_tensor = Tensor::new(
Buffer::from_f32_vec(vec![5.0]), shape, Device::Cpu, DType::F32, true
);
record(Box::new(MockPassThroughBackward {
input_id: input_tensor.id,
output_id: output_tensor.id,
}));
let grads = backward(&output_tensor).unwrap();
let dy = grads.get(&output_tensor.id).unwrap().as_f32_slice().unwrap();
assert_eq!(dy[0], 1.0);
let dx = grads.get(&input_tensor.id).unwrap().as_f32_slice().unwrap();
assert_eq!(dx[0], 1.0);
assert_eq!(tape::take().len(), 0);
}
}