burn-tensor 0.1.0

This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.
use crate::{
    graph::node::{BackwardNode, BackwardNodeRef, ForwardNodeRef},
    tensor::ops::Zeros,
};
use std::{any::Any, collections::HashMap, sync::Arc};

pub struct Forward2BackwardGraphConverter {
    state: HashMap<String, Box<dyn Any>>,
}

impl Forward2BackwardGraphConverter {
    pub fn new() -> Self {
        Self {
            state: HashMap::new(),
        }
    }
    pub fn from<T: Clone + 'static + Zeros<T>>(
        &mut self,
        node: &ForwardNodeRef<T>,
    ) -> BackwardNodeRef<T> {
        match self.state.get(&node.id) {
            Some(node) => {
                let node: &BackwardNodeRef<T> = node.downcast_ref().unwrap();
                return node.clone();
            }
            None => {}
        };

        let node = Arc::new(BackwardNode::from_node(node, self));
        self.state.insert(node.id.clone(), Box::new(node.clone()));
        node
    }
}