dendritic_autodiff/
regularizers.rs

1use dendritic_ndarray::ndarray::NDArray;
2use dendritic_ndarray::ops::*;
3use std::cell::{RefCell, RefMut}; 
4use crate::node::{Node, Value}; 
5
6pub struct L2Regularization<RHS, LHS> 
7where
8    RHS: Node,
9    LHS: Node,
10{
11    pub rhs: RefCell<RHS>,
12    pub lhs: RefCell<LHS>,
13    pub output: RefCell<Value<NDArray<f64>>>,
14    pub gradient: RefCell<Value<NDArray<f64>>>,
15    pub learning_rate: f64
16}
17
18
19impl<RHS, LHS> L2Regularization<RHS, LHS> 
20where
21    RHS: Node,
22    LHS: Node,
23{
24
25    /// Create new instance of L2 regularization operation
26    pub fn new(rhs: RHS, lhs: LHS, learning_rate: f64) -> Self {
27
28        let weights = rhs.value();
29        let w_square = weights.square().unwrap();
30        let w_sum = w_square.sum().unwrap();
31        let op_result = lhs.value().mult(w_sum).unwrap();
32        let op_value = Value::new(&op_result);
33
34        L2Regularization {
35            rhs: RefCell::new(rhs),
36            lhs: RefCell::new(lhs),
37            output: RefCell::new(op_value.clone()),
38            gradient: RefCell::new(op_value),
39            learning_rate: learning_rate
40        }
41    }
42
43    /// Get right hand side value of L2 regularization operation
44    pub fn rhs(&self) -> RefMut<dyn Node> {
45        self.rhs.borrow_mut()
46    }
47
48    /// Get left hand side value of L2 regularization operation
49    pub fn lhs(&self) -> RefMut<dyn Node> {
50        self.lhs.borrow_mut()
51    }
52
53}
54
55
56impl<LHS, RHS> Node for L2Regularization<RHS, LHS> 
57where
58    RHS: Node,
59    LHS: Node,
60{
61
62    /// Perform forward pass on L2 regularization
63    fn forward(&mut self) {
64
65        self.rhs().forward();
66        self.lhs().forward();
67
68        let weights = self.rhs().value();
69        let w_square = weights.square().unwrap();
70        let w_sum = w_square.sum().unwrap();
71        let op_result = self.lhs.borrow().value().mult(w_sum).unwrap();
72        self.output = Value::new(&op_result).into(); 
73    } 
74
75    /// Perform backward pass on L2 regularization
76    fn backward(&mut self, upstream_gradient: NDArray<f64>) {
77        let lr = self.learning_rate / upstream_gradient.size() as f64;
78        let alpha = self.lhs().value().scalar_mult(2.0 * lr).unwrap();
79        let weight_update = self.rhs().value().scale_mult(alpha).unwrap();
80        self.gradient = Value::new(&weight_update).into();
81    }
82
83    /// Get output value of L2 regularization
84    fn value(&self) -> NDArray<f64> {
85        self.output.borrow().val().clone()
86    }
87
88    /// Get gradient of L2 regularization
89    fn grad(&self) -> NDArray<f64> {
90        self.gradient.borrow().val().clone()
91    }
92 
93    /// Set gradient of L2 regularization
94    fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
95        self.gradient = Value::new(&upstream_gradient).into();
96    } 
97}
98
99
100pub struct L1Regularization<RHS, LHS> 
101where
102    RHS: Node,
103    LHS: Node,
104{
105    pub rhs: RefCell<RHS>,
106    pub lhs: RefCell<LHS>,
107    pub output: RefCell<Value<NDArray<f64>>>,
108    pub gradient: RefCell<Value<NDArray<f64>>>,
109    pub learning_rate: f64
110}
111
112
113impl<RHS, LHS> L1Regularization<RHS, LHS> 
114where
115    RHS: Node,
116    LHS: Node,
117{
118
119    /// Create new instance of L1 regularization operation
120    pub fn new(rhs: RHS, lhs: LHS, learning_rate: f64) -> Self {
121
122        let weights = rhs.value();
123        let w_abs = weights.abs().unwrap();
124        let w_sum = w_abs.sum().unwrap();
125        let op_result = lhs.value().mult(w_sum).unwrap();
126        let op_value = Value::new(&op_result);
127
128        L1Regularization {
129            rhs: RefCell::new(rhs),
130            lhs: RefCell::new(lhs),
131            output: RefCell::new(op_value.clone()),
132            gradient: RefCell::new(op_value),
133            learning_rate: learning_rate
134        }
135    }
136
137    /// Get right hand side value of L1 regularization operation
138    pub fn rhs(&self) -> RefMut<dyn Node> {
139        self.rhs.borrow_mut()
140    }
141
142    /// Get left hand side value of L1 regularization operation
143    pub fn lhs(&self) -> RefMut<dyn Node> {
144        self.lhs.borrow_mut()
145    }
146
147}
148
149
150impl<LHS, RHS> Node for L1Regularization<RHS, LHS> 
151where
152    RHS: Node,
153    LHS: Node,
154{
155
156    /// Perform forward pass on L1 regularization
157    fn forward(&mut self) {
158
159        self.rhs().forward();
160        self.lhs().forward();
161
162        let weights = self.rhs().value();
163        let w_abs = weights.abs().unwrap();
164        let w_sum = w_abs.sum().unwrap();
165        let op_result = self.lhs.borrow().value().mult(w_sum).unwrap();
166        self.output = Value::new(&op_result).into(); 
167    } 
168
169    /// Perform backward pass on L1 regularization
170    fn backward(&mut self, upstream_gradient: NDArray<f64>) {
171        let lr = self.learning_rate / upstream_gradient.size() as f64;
172        let alpha = self.lhs().value().scalar_mult(lr).unwrap();
173        let sig = self.rhs().value().signum().unwrap();
174        let weight_update = sig.scale_mult(alpha).unwrap();
175        self.gradient = Value::new(&weight_update).into();
176        
177    }
178
179    /// Get output value of L1 regularization
180    fn value(&self) -> NDArray<f64> {
181        self.output.borrow().val().clone()
182    }
183
184    /// Get gradient of L1 regularization
185    fn grad(&self) -> NDArray<f64> {
186        self.gradient.borrow().val().clone()
187    }
188
189    /// Set gradient of L1 regularization
190    fn set_grad(&mut self, upstream_gradient: NDArray<f64>) {
191        self.gradient = Value::new(&upstream_gradient).into();
192    } 
193}