use crate::shapes::Rank0;
use crate::tensor::*;
pub trait Backward<E, D: Storage<E>>: HasErr {
fn backward(self) -> Gradients<E, D> {
self.try_backward().unwrap()
}
fn try_backward(self) -> Result<Gradients<E, D>, Self::Err>;
}
impl<E: 'static + Clone, D: OneFillStorage<E>> Backward<E, D>
for Tensor<Rank0, E, D, OwnedTape<E, D>>
{
fn try_backward(self) -> Result<Gradients<E, D>, Self::Err> {
let (t, mut tape) = self.split_tape();
let t_ghost = t.ghost();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&t_ghost)?;
t.device.try_fill_with_ones(grads.get_mut(&t_ghost))
});
let mut grads = tape.execute()?;
grads.drop_non_leafs();
Ok(grads)
}
}