dendritic_autodiff/
node.rs

1use dendritic_ndarray::ndarray::NDArray;
2use std::rc::Rc; 
3use std::cell::{RefCell}; 
4
5
6/// Methods for each value in computation graph
7pub trait Node {
8    fn forward(&mut self); 
9    fn backward(&mut self, upstream_gradient: NDArray<f64>); 
10    fn value(&self) -> NDArray<f64>;
11    fn grad(&self) -> NDArray<f64>;
12    fn set_grad(&mut self, upstream_gradient: NDArray<f64>);
13}
14
15
16/// Value node for computation graph
17#[derive(Debug, Clone, Default)]
18pub struct Value<T> {
19    pub value: Rc<RefCell<T>>,
20    pub gradient: Rc<RefCell<T>>,
21}
22
23impl<T: Clone> Value<T> {
24
25    /// Create new instance of value for comptuation graph
26    pub fn new(value: &T) -> Value<T> {
27        
28        Value {
29            value: Rc::new(RefCell::new(value.clone())),
30            gradient: Rc::new(RefCell::new(value.clone()))
31        }
32    }
33
34    /// Get value associated with structure
35    pub fn val(&self) -> T {
36        self.value.borrow().clone()
37    }
38
39    /// Get gradient of value
40    pub fn grad(&self) -> T {
41        self.gradient.borrow().clone()
42    }
43
44    /// Set value associated with structure
45    pub fn set_val(&mut self, val: &T) {
46        self.value.replace(val.clone());
47    }
48
49    /// Set gradient of value in computation graph
50    pub fn set_grad(&mut self, value: &T) {
51        self.gradient.replace(value.clone());
52    }
53
54}
55
56
57impl Node for Value<NDArray<f64>> {
58
59    /// Forward operation for a value
60    fn forward(&mut self) {} 
61
62    /// Set gradient from upstream for value
63    fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
64        self.gradient.replace(upstream_gradient);
65    } 
66
67    /// Set gradient from upstream in backward pass
68    fn backward(&mut self, upstream_gradient: NDArray<f64>) {
69        self.gradient.replace(upstream_gradient);        
70    } 
71
72    /// Retrieve value from node in computation graph
73    fn value(&self) -> NDArray<f64> { 
74        self.value.borrow().clone()
75    }
76
77    /// Retrieve gradient from node in computation graph
78    fn grad(&self) -> NDArray<f64> { 
79        self.gradient.borrow().clone()
80    }
81
82}