dendritic_autodiff/
node.rs1use dendritic_ndarray::ndarray::NDArray;
2use std::rc::Rc;
3use std::cell::{RefCell};
4
5
6pub trait Node {
8 fn forward(&mut self);
9 fn backward(&mut self, upstream_gradient: NDArray<f64>);
10 fn value(&self) -> NDArray<f64>;
11 fn grad(&self) -> NDArray<f64>;
12 fn set_grad(&mut self, upstream_gradient: NDArray<f64>);
13}
14
15
16#[derive(Debug, Clone, Default)]
18pub struct Value<T> {
19 pub value: Rc<RefCell<T>>,
20 pub gradient: Rc<RefCell<T>>,
21}
22
23impl<T: Clone> Value<T> {
24
25 pub fn new(value: &T) -> Value<T> {
27
28 Value {
29 value: Rc::new(RefCell::new(value.clone())),
30 gradient: Rc::new(RefCell::new(value.clone()))
31 }
32 }
33
34 pub fn val(&self) -> T {
36 self.value.borrow().clone()
37 }
38
39 pub fn grad(&self) -> T {
41 self.gradient.borrow().clone()
42 }
43
44 pub fn set_val(&mut self, val: &T) {
46 self.value.replace(val.clone());
47 }
48
49 pub fn set_grad(&mut self, value: &T) {
51 self.gradient.replace(value.clone());
52 }
53
54}
55
56
57impl Node for Value<NDArray<f64>> {
58
59 fn forward(&mut self) {}
61
62 fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
64 self.gradient.replace(upstream_gradient);
65 }
66
67 fn backward(&mut self, upstream_gradient: NDArray<f64>) {
69 self.gradient.replace(upstream_gradient);
70 }
71
72 fn value(&self) -> NDArray<f64> {
74 self.value.borrow().clone()
75 }
76
77 fn grad(&self) -> NDArray<f64> {
79 self.gradient.borrow().clone()
80 }
81
82}