burn-tensor 0.1.0

This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.
use super::{BackwardRecordedOps, ForwardRecordedOps, RecordedOpsParentRef, UnaryOpsNodeState};
use crate::{
    graph::{
        converter::Forward2BackwardGraphConverter,
        node::{BackwardNodeRef, BackwardNodeState, ForwardNodeRef},
    },
    tensor::ops::Zeros,
};
use std::{ops::Add, sync::Arc};

pub trait UnaryOps<In, Out>: std::fmt::Debug + Send + Sync {
    fn partial(&self, state: &UnaryOpsNodeState<In, Out>) -> In;
}

#[derive(new, Debug)]
pub struct ForwardUnaryRecordedOps<In, Ops> {
    input: ForwardNodeRef<In>,
    ops: Arc<Ops>,
}

#[derive(new, Debug)]
pub struct BackwareUnaryRecordedOps<In, Ops> {
    input: BackwardNodeRef<In>,
    ops: Arc<Ops>,
}

impl<In, Out, Ops> ForwardRecordedOps<Out> for ForwardUnaryRecordedOps<In, Ops>
where
    In: Clone + Zeros<In> + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
    Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
    Ops: UnaryOps<In, Out> + std::fmt::Debug + 'static,
{
    fn to_backward(
        &self,
        graph: &mut Forward2BackwardGraphConverter,
    ) -> super::BackwardRecordedOpsRef<Out> {
        let input = graph.from(&self.input);
        let ops = self.ops.clone();

        Arc::new(BackwareUnaryRecordedOps::new(input, ops))
    }
}

impl<In, Out, Ops> BackwardRecordedOps<Out> for BackwareUnaryRecordedOps<In, Ops>
where
    In: Clone + Zeros<In> + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
    Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
    Ops: UnaryOps<In, Out> + std::fmt::Debug + 'static,
{
    fn backward_step(&self, state: &BackwardNodeState<Out>) {
        let state = UnaryOpsNodeState::new(&self.input.state, &state);
        let partial = self.ops.partial(&state);
        self.input.state.update_grad(partial);
    }
    fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
        vec![self.input.clone()]
    }
}