dendritic_autodiff/
regularizers.rs1use dendritic_ndarray::ndarray::NDArray;
2use dendritic_ndarray::ops::*;
3use std::cell::{RefCell, RefMut};
4use crate::node::{Node, Value};
5
6pub struct L2Regularization<RHS, LHS>
7where
8 RHS: Node,
9 LHS: Node,
10{
11 pub rhs: RefCell<RHS>,
12 pub lhs: RefCell<LHS>,
13 pub output: RefCell<Value<NDArray<f64>>>,
14 pub gradient: RefCell<Value<NDArray<f64>>>,
15 pub learning_rate: f64
16}
17
18
19impl<RHS, LHS> L2Regularization<RHS, LHS>
20where
21 RHS: Node,
22 LHS: Node,
23{
24
25 pub fn new(rhs: RHS, lhs: LHS, learning_rate: f64) -> Self {
27
28 let weights = rhs.value();
29 let w_square = weights.square().unwrap();
30 let w_sum = w_square.sum().unwrap();
31 let op_result = lhs.value().mult(w_sum).unwrap();
32 let op_value = Value::new(&op_result);
33
34 L2Regularization {
35 rhs: RefCell::new(rhs),
36 lhs: RefCell::new(lhs),
37 output: RefCell::new(op_value.clone()),
38 gradient: RefCell::new(op_value),
39 learning_rate: learning_rate
40 }
41 }
42
43 pub fn rhs(&self) -> RefMut<dyn Node> {
45 self.rhs.borrow_mut()
46 }
47
48 pub fn lhs(&self) -> RefMut<dyn Node> {
50 self.lhs.borrow_mut()
51 }
52
53}
54
55
56impl<LHS, RHS> Node for L2Regularization<RHS, LHS>
57where
58 RHS: Node,
59 LHS: Node,
60{
61
62 fn forward(&mut self) {
64
65 self.rhs().forward();
66 self.lhs().forward();
67
68 let weights = self.rhs().value();
69 let w_square = weights.square().unwrap();
70 let w_sum = w_square.sum().unwrap();
71 let op_result = self.lhs.borrow().value().mult(w_sum).unwrap();
72 self.output = Value::new(&op_result).into();
73 }
74
75 fn backward(&mut self, upstream_gradient: NDArray<f64>) {
77 let lr = self.learning_rate / upstream_gradient.size() as f64;
78 let alpha = self.lhs().value().scalar_mult(2.0 * lr).unwrap();
79 let weight_update = self.rhs().value().scale_mult(alpha).unwrap();
80 self.gradient = Value::new(&weight_update).into();
81 }
82
83 fn value(&self) -> NDArray<f64> {
85 self.output.borrow().val().clone()
86 }
87
88 fn grad(&self) -> NDArray<f64> {
90 self.gradient.borrow().val().clone()
91 }
92
93 fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
95 self.gradient = Value::new(&upstream_gradient).into();
96 }
97}
98
99
100pub struct L1Regularization<RHS, LHS>
101where
102 RHS: Node,
103 LHS: Node,
104{
105 pub rhs: RefCell<RHS>,
106 pub lhs: RefCell<LHS>,
107 pub output: RefCell<Value<NDArray<f64>>>,
108 pub gradient: RefCell<Value<NDArray<f64>>>,
109 pub learning_rate: f64
110}
111
112
113impl<RHS, LHS> L1Regularization<RHS, LHS>
114where
115 RHS: Node,
116 LHS: Node,
117{
118
119 pub fn new(rhs: RHS, lhs: LHS, learning_rate: f64) -> Self {
121
122 let weights = rhs.value();
123 let w_abs = weights.abs().unwrap();
124 let w_sum = w_abs.sum().unwrap();
125 let op_result = lhs.value().mult(w_sum).unwrap();
126 let op_value = Value::new(&op_result);
127
128 L1Regularization {
129 rhs: RefCell::new(rhs),
130 lhs: RefCell::new(lhs),
131 output: RefCell::new(op_value.clone()),
132 gradient: RefCell::new(op_value),
133 learning_rate: learning_rate
134 }
135 }
136
137 pub fn rhs(&self) -> RefMut<dyn Node> {
139 self.rhs.borrow_mut()
140 }
141
142 pub fn lhs(&self) -> RefMut<dyn Node> {
144 self.lhs.borrow_mut()
145 }
146
147}
148
149
150impl<LHS, RHS> Node for L1Regularization<RHS, LHS>
151where
152 RHS: Node,
153 LHS: Node,
154{
155
156 fn forward(&mut self) {
158
159 self.rhs().forward();
160 self.lhs().forward();
161
162 let weights = self.rhs().value();
163 let w_abs = weights.abs().unwrap();
164 let w_sum = w_abs.sum().unwrap();
165 let op_result = self.lhs.borrow().value().mult(w_sum).unwrap();
166 self.output = Value::new(&op_result).into();
167 }
168
169 fn backward(&mut self, upstream_gradient: NDArray<f64>) {
171 let lr = self.learning_rate / upstream_gradient.size() as f64;
172 let alpha = self.lhs().value().scalar_mult(lr).unwrap();
173 let sig = self.rhs().value().signum().unwrap();
174 let weight_update = sig.scale_mult(alpha).unwrap();
175 self.gradient = Value::new(&weight_update).into();
176
177 }
178
179 fn value(&self) -> NDArray<f64> {
181 self.output.borrow().val().clone()
182 }
183
184 fn grad(&self) -> NDArray<f64> {
186 self.gradient.borrow().val().clone()
187 }
188
189 fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
191 self.gradient = Value::new(&upstream_gradient).into();
192 }
193}