dendritic_autodiff/
ops.rs

1use dendritic_ndarray::ndarray::NDArray;
2use dendritic_ndarray::ops::*;
3use std::cell::{RefCell, RefMut}; 
4use crate::node::{Node, Value}; 
5
6pub struct Dot<RHS, LHS> {
7    pub rhs: RefCell<RHS>,
8    pub lhs: RefCell<LHS>,
9    pub output: RefCell<Value<NDArray<f64>>>,
10    pub gradient: RefCell<Value<NDArray<f64>>>
11}
12
13
14impl<RHS, LHS> Dot<RHS, LHS>
15where
16    RHS: Node,
17    LHS: Node,
18{
19
20    /// Create new instance of dot product operation in computation graph
21    pub fn new(rhs: RHS, lhs: LHS) -> Dot<RHS, LHS> {
22
23        let op_result = rhs.value().dot(lhs.value().clone()).unwrap();
24        let op_value = Value::new(&op_result);
25
26        Dot {
27            rhs: RefCell::new(rhs),
28            lhs: RefCell::new(lhs),
29            output: RefCell::new(op_value.clone()),
30            gradient: RefCell::new(op_value)
31        }
32    }
33
34    /// Get right hand side value of dot product operation
35    pub fn rhs(&self) -> RefMut<dyn Node> {
36        self.rhs.borrow_mut()
37    }
38
39    /// Get left hand side value of dot product operation
40    pub fn lhs(&self) -> RefMut<dyn Node> {
41        self.lhs.borrow_mut()
42    }
43
44}
45
46
47impl <RHS, LHS>Node for Dot<RHS, LHS>
48where
49    RHS: Node,
50    LHS: Node,    
51{
52
53    /// Perform forward pass of dot product
54    fn forward(&mut self) {
55
56        let rhs = self.rhs().value();
57        let lhs = self.lhs().value();
58
59        self.rhs().forward();
60        self.lhs().forward();
61
62        let result = rhs.dot(lhs).unwrap();
63        self.output = Value::new(&result).into(); 
64    } 
65
66    /// Perform backward pass of dot product
67    fn backward(&mut self, upstream_gradient: NDArray<f64>) {
68
69        self.gradient = Value::new(&upstream_gradient).into();
70
71        let rhs_t = self.rhs().value().transpose().unwrap();
72        let lhs_t = self.lhs().value().transpose().unwrap();   
73
74        let rhs_grad = rhs_t.dot(upstream_gradient.clone()).unwrap();
75        let lhs_grad = upstream_gradient.dot(lhs_t).unwrap();
76
77        self.rhs().backward(rhs_grad);
78        self.lhs().backward(lhs_grad);
79
80    }
81
82
83    /// Get output value of dot product operation
84    fn value(&self) -> NDArray<f64> {
85        self.output.borrow().val().clone()
86    }
87
88    /// Get gradient of dot product operation
89    fn grad(&self) -> NDArray<f64> {
90        self.gradient.borrow().val().clone()
91    }
92
93    /// Set gradient of dot product operation
94    fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
95        self.gradient = Value::new(&upstream_gradient).into();
96    } 
97}
98
99
100pub struct ScaleAdd<RHS, LHS> 
101where
102    RHS: Node,
103    LHS: Node,
104{
105    pub rhs: RefCell<RHS>,
106    pub lhs: RefCell<LHS>,
107    pub output: RefCell<Value<NDArray<f64>>>,
108    pub gradient: RefCell<Value<NDArray<f64>>>
109}
110
111
112
113impl<RHS, LHS> ScaleAdd<RHS, LHS> 
114where
115    RHS: Node,
116    LHS: Node,
117{
118
119    /// Create new instance of elememtwise add operation
120    pub fn new(rhs: RHS, lhs: LHS) -> Self {
121
122        let scalar_vec = lhs.value();
123        let op_result = rhs.value().scale_add(scalar_vec).unwrap();
124        let op_value = Value::new(&op_result);
125
126        ScaleAdd {
127            rhs: RefCell::new(rhs),
128            lhs: RefCell::new(lhs),
129            output: RefCell::new(op_value.clone()),
130            gradient: RefCell::new(op_value)
131        }
132    }
133
134    /// Retrieve right hand side value of elementwise add operation
135    pub fn rhs(&self) -> RefMut<dyn Node> {
136        self.rhs.borrow_mut()
137    }
138
139    /// Retrieve left hand side value of elementwise add operation
140    pub fn lhs(&self) -> RefMut<dyn Node> {
141        self.lhs.borrow_mut()
142    }
143
144}
145
146
147
148impl<LHS, RHS> Node for ScaleAdd<RHS, LHS> 
149where
150    RHS: Node,
151    LHS: Node,
152{
153
154    /// Perform forward pass of elementwise add operation
155    fn forward(&mut self) {
156
157        self.rhs().forward();
158        self.lhs().forward();
159
160        let scalar_vec = self.lhs().value();
161        let op_result = self.rhs().value().scale_add(scalar_vec).unwrap();
162        self.output = Value::new(&op_result).into(); 
163    } 
164
165    /// Perform backward pass of elementwise add operation
166    fn backward(&mut self, upstream_gradient: NDArray<f64>) {
167        self.gradient = Value::new(&upstream_gradient).into();
168        self.lhs().backward(upstream_gradient.clone());
169        self.rhs().backward(upstream_gradient);
170    }
171
172    /// Get output value of elementwise add operation
173    fn value(&self) -> NDArray<f64> {
174        self.output.borrow().val().clone()
175    }
176
177    /// Get gradient of elementwise add operation
178    fn grad(&self) -> NDArray<f64> {
179        self.gradient.borrow().val().clone()
180    }
181
182    /// Set gradient of elementwise add operation
183    fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
184        self.gradient = Value::new(&upstream_gradient).into();
185    } 
186}
187
188
189pub struct Regularization<RHS, LHS> 
190where
191    RHS: Node,
192    LHS: Node,
193{
194    pub rhs: RefCell<RHS>,
195    pub lhs: RefCell<LHS>,
196    pub output: RefCell<Value<NDArray<f64>>>,
197    pub gradient: RefCell<Value<NDArray<f64>>>,
198    pub learning_rate: f64
199}
200
201
202
203impl<RHS, LHS> Regularization<RHS, LHS> 
204where
205    RHS: Node,
206    LHS: Node,
207{
208
209    /// Create new instance of regularization operation
210    pub fn new(rhs: RHS, lhs: LHS, learning_rate: f64) -> Self {
211
212        let weights = rhs.value();
213        let w_square = weights.square().unwrap();
214        let w_sum = w_square.sum().unwrap();
215        let op_result = lhs.value().mult(w_sum).unwrap();
216        let op_value = Value::new(&op_result);
217
218        Regularization {
219            rhs: RefCell::new(rhs),
220            lhs: RefCell::new(lhs),
221            output: RefCell::new(op_value.clone()),
222            gradient: RefCell::new(op_value),
223            learning_rate: learning_rate
224        }
225    }
226
227    /// Get right hand side value of regularization operation
228    pub fn rhs(&self) -> RefMut<dyn Node> {
229        self.rhs.borrow_mut()
230    }
231
232    /// Get left hand side value of regularization operation
233    pub fn lhs(&self) -> RefMut<dyn Node> {
234        self.lhs.borrow_mut()
235    }
236
237}
238
239
240impl<LHS, RHS> Node for Regularization<RHS, LHS> 
241where
242    RHS: Node,
243    LHS: Node,
244{
245
246    /// Perform forward pass of regularization operation
247    fn forward(&mut self) {
248
249        self.rhs().forward();
250        self.lhs().forward();
251
252        let weights = self.rhs().value();
253        let w_square = weights.square().unwrap();
254        let w_sum = w_square.sum().unwrap();
255        let op_result = self.lhs.borrow().value().mult(w_sum).unwrap();
256        self.output = Value::new(&op_result).into(); 
257    } 
258
259    /// Perform backward pass of regularization operation
260    fn backward(&mut self, upstream_gradient: NDArray<f64>) {
261        let lr = self.learning_rate / upstream_gradient.size() as f64;
262        let alpha = self.lhs().value().scalar_mult(2.0 * lr).unwrap();
263        let weight_update = self.rhs().value().scale_mult(alpha).unwrap();
264        self.gradient = Value::new(&weight_update).into();
265    }
266
267    /// Get output value of regularization operation
268    fn value(&self) -> NDArray<f64> {
269        self.output.borrow().val().clone()
270    }
271
272    /// Get gradient of regularization operation
273    fn grad(&self) -> NDArray<f64> {
274        self.gradient.borrow().val().clone()
275    }
276
277    /// Set gradient of regularization operation
278    fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
279        self.gradient = Value::new(&upstream_gradient).into();
280    } 
281}
282