dendritic_autodiff/
ops.rs1use dendritic_ndarray::ndarray::NDArray;
2use dendritic_ndarray::ops::*;
3use std::cell::{RefCell, RefMut};
4use crate::node::{Node, Value};
5
6pub struct Dot<RHS, LHS> {
7 pub rhs: RefCell<RHS>,
8 pub lhs: RefCell<LHS>,
9 pub output: RefCell<Value<NDArray<f64>>>,
10 pub gradient: RefCell<Value<NDArray<f64>>>
11}
12
13
14impl<RHS, LHS> Dot<RHS, LHS>
15where
16 RHS: Node,
17 LHS: Node,
18{
19
20 pub fn new(rhs: RHS, lhs: LHS) -> Dot<RHS, LHS> {
22
23 let op_result = rhs.value().dot(lhs.value().clone()).unwrap();
24 let op_value = Value::new(&op_result);
25
26 Dot {
27 rhs: RefCell::new(rhs),
28 lhs: RefCell::new(lhs),
29 output: RefCell::new(op_value.clone()),
30 gradient: RefCell::new(op_value)
31 }
32 }
33
34 pub fn rhs(&self) -> RefMut<dyn Node> {
36 self.rhs.borrow_mut()
37 }
38
39 pub fn lhs(&self) -> RefMut<dyn Node> {
41 self.lhs.borrow_mut()
42 }
43
44}
45
46
47impl <RHS, LHS>Node for Dot<RHS, LHS>
48where
49 RHS: Node,
50 LHS: Node,
51{
52
53 fn forward(&mut self) {
55
56 let rhs = self.rhs().value();
57 let lhs = self.lhs().value();
58
59 self.rhs().forward();
60 self.lhs().forward();
61
62 let result = rhs.dot(lhs).unwrap();
63 self.output = Value::new(&result).into();
64 }
65
66 fn backward(&mut self, upstream_gradient: NDArray<f64>) {
68
69 self.gradient = Value::new(&upstream_gradient).into();
70
71 let rhs_t = self.rhs().value().transpose().unwrap();
72 let lhs_t = self.lhs().value().transpose().unwrap();
73
74 let rhs_grad = rhs_t.dot(upstream_gradient.clone()).unwrap();
75 let lhs_grad = upstream_gradient.dot(lhs_t).unwrap();
76
77 self.rhs().backward(rhs_grad);
78 self.lhs().backward(lhs_grad);
79
80 }
81
82
83 fn value(&self) -> NDArray<f64> {
85 self.output.borrow().val().clone()
86 }
87
88 fn grad(&self) -> NDArray<f64> {
90 self.gradient.borrow().val().clone()
91 }
92
93 fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
95 self.gradient = Value::new(&upstream_gradient).into();
96 }
97}
98
99
100pub struct ScaleAdd<RHS, LHS>
101where
102 RHS: Node,
103 LHS: Node,
104{
105 pub rhs: RefCell<RHS>,
106 pub lhs: RefCell<LHS>,
107 pub output: RefCell<Value<NDArray<f64>>>,
108 pub gradient: RefCell<Value<NDArray<f64>>>
109}
110
111
112
113impl<RHS, LHS> ScaleAdd<RHS, LHS>
114where
115 RHS: Node,
116 LHS: Node,
117{
118
119 pub fn new(rhs: RHS, lhs: LHS) -> Self {
121
122 let scalar_vec = lhs.value();
123 let op_result = rhs.value().scale_add(scalar_vec).unwrap();
124 let op_value = Value::new(&op_result);
125
126 ScaleAdd {
127 rhs: RefCell::new(rhs),
128 lhs: RefCell::new(lhs),
129 output: RefCell::new(op_value.clone()),
130 gradient: RefCell::new(op_value)
131 }
132 }
133
134 pub fn rhs(&self) -> RefMut<dyn Node> {
136 self.rhs.borrow_mut()
137 }
138
139 pub fn lhs(&self) -> RefMut<dyn Node> {
141 self.lhs.borrow_mut()
142 }
143
144}
145
146
147
148impl<LHS, RHS> Node for ScaleAdd<RHS, LHS>
149where
150 RHS: Node,
151 LHS: Node,
152{
153
154 fn forward(&mut self) {
156
157 self.rhs().forward();
158 self.lhs().forward();
159
160 let scalar_vec = self.lhs().value();
161 let op_result = self.rhs().value().scale_add(scalar_vec).unwrap();
162 self.output = Value::new(&op_result).into();
163 }
164
165 fn backward(&mut self, upstream_gradient: NDArray<f64>) {
167 self.gradient = Value::new(&upstream_gradient).into();
168 self.lhs().backward(upstream_gradient.clone());
169 self.rhs().backward(upstream_gradient);
170 }
171
172 fn value(&self) -> NDArray<f64> {
174 self.output.borrow().val().clone()
175 }
176
177 fn grad(&self) -> NDArray<f64> {
179 self.gradient.borrow().val().clone()
180 }
181
182 fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
184 self.gradient = Value::new(&upstream_gradient).into();
185 }
186}
187
188
189pub struct Regularization<RHS, LHS>
190where
191 RHS: Node,
192 LHS: Node,
193{
194 pub rhs: RefCell<RHS>,
195 pub lhs: RefCell<LHS>,
196 pub output: RefCell<Value<NDArray<f64>>>,
197 pub gradient: RefCell<Value<NDArray<f64>>>,
198 pub learning_rate: f64
199}
200
201
202
203impl<RHS, LHS> Regularization<RHS, LHS>
204where
205 RHS: Node,
206 LHS: Node,
207{
208
209 pub fn new(rhs: RHS, lhs: LHS, learning_rate: f64) -> Self {
211
212 let weights = rhs.value();
213 let w_square = weights.square().unwrap();
214 let w_sum = w_square.sum().unwrap();
215 let op_result = lhs.value().mult(w_sum).unwrap();
216 let op_value = Value::new(&op_result);
217
218 Regularization {
219 rhs: RefCell::new(rhs),
220 lhs: RefCell::new(lhs),
221 output: RefCell::new(op_value.clone()),
222 gradient: RefCell::new(op_value),
223 learning_rate: learning_rate
224 }
225 }
226
227 pub fn rhs(&self) -> RefMut<dyn Node> {
229 self.rhs.borrow_mut()
230 }
231
232 pub fn lhs(&self) -> RefMut<dyn Node> {
234 self.lhs.borrow_mut()
235 }
236
237}
238
239
240impl<LHS, RHS> Node for Regularization<RHS, LHS>
241where
242 RHS: Node,
243 LHS: Node,
244{
245
246 fn forward(&mut self) {
248
249 self.rhs().forward();
250 self.lhs().forward();
251
252 let weights = self.rhs().value();
253 let w_square = weights.square().unwrap();
254 let w_sum = w_square.sum().unwrap();
255 let op_result = self.lhs.borrow().value().mult(w_sum).unwrap();
256 self.output = Value::new(&op_result).into();
257 }
258
259 fn backward(&mut self, upstream_gradient: NDArray<f64>) {
261 let lr = self.learning_rate / upstream_gradient.size() as f64;
262 let alpha = self.lhs().value().scalar_mult(2.0 * lr).unwrap();
263 let weight_update = self.rhs().value().scale_mult(alpha).unwrap();
264 self.gradient = Value::new(&weight_update).into();
265 }
266
267 fn value(&self) -> NDArray<f64> {
269 self.output.borrow().val().clone()
270 }
271
272 fn grad(&self) -> NDArray<f64> {
274 self.gradient.borrow().val().clone()
275 }
276
277 fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
279 self.gradient = Value::new(&upstream_gradient).into();
280 }
281}
282