optirs_core/regularizers/
weight_standardization.rs

1// Weight Standardization
2//
3// Weight Standardization is a technique that normalizes the weights of convolutional
4// layers by standardizing the weights along the channel dimension. This improves
5// training stability and allows for use of larger batch sizes.
6
7use scirs2_core::ndarray::{Array, Array2, Array4, ArrayBase, Data, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14/// Weight Standardization
15///
16/// Weight Standardization normalizes the weights along the channel dimension by
17/// adjusting them to have zero mean and unit variance. This helps with training
18/// stability, especially when used with batch normalization.
19///
20/// # Example
21///
22/// ```
23/// use scirs2_core::ndarray::array;
24/// use optirs_core::regularizers::{WeightStandardization, Regularizer};
25///
26/// let weight_std = WeightStandardization::new(1e-5);
27/// let weights = array![[1.0, 2.0], [3.0, 4.0]];
28/// let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
29///
30/// // Get standardized weights
31/// let standardized = weight_std.standardize(&weights).unwrap();
32///
33/// // Apply during training (modifies gradients)
34/// let _ = weight_std.apply(&weights, &mut gradients);
35/// ```
36#[derive(Debug, Clone)]
37pub struct WeightStandardization<A: Float> {
38    /// Small constant for numerical stability
39    eps: A,
40}
41
42impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> WeightStandardization<A> {
43    /// Create a new Weight Standardization regularizer
44    ///
45    /// # Arguments
46    ///
47    /// * `eps` - Small constant for numerical stability (typically 1e-5)
48    pub fn new(eps: f64) -> Self {
49        Self {
50            eps: A::from_f64(eps).unwrap(),
51        }
52    }
53
54    /// Apply weight standardization to a 2D weight matrix
55    ///
56    /// Standardizes the weights to have zero mean and unit variance.
57    ///
58    /// # Arguments
59    ///
60    /// * `weights` - 2D weight matrix
61    ///
62    /// # Returns
63    ///
64    /// Standardized weights with zero mean and unit variance
65    pub fn standardize(&self, weights: &Array2<A>) -> Result<Array2<A>> {
66        // Calculate mean for each row (output channel)
67        let n_cols = weights.ncols();
68        let n_cols_f = A::from_usize(n_cols).unwrap();
69
70        // Calculate mean, subtract from weights, then calculate variance and normalize
71        let means = weights.sum_axis(scirs2_core::ndarray::Axis(1)) / n_cols_f;
72
73        // Subtract mean from weights
74        let mut centered = weights.clone();
75        for i in 0..weights.nrows() {
76            for j in 0..weights.ncols() {
77                centered[[i, j]] = centered[[i, j]] - means[i];
78            }
79        }
80
81        // Calculate variance
82        let mut var = Array::zeros(weights.nrows());
83        for i in 0..weights.nrows() {
84            let mut sum_sq = A::zero();
85            for j in 0..weights.ncols() {
86                sum_sq = sum_sq + centered[[i, j]] * centered[[i, j]];
87            }
88            var[i] = sum_sq / n_cols_f;
89        }
90
91        // Normalize
92        let mut standardized = centered.clone();
93        for i in 0..weights.nrows() {
94            let denom = (var[i] + self.eps).sqrt();
95            for j in 0..weights.ncols() {
96                standardized[[i, j]] = centered[[i, j]] / denom;
97            }
98        }
99
100        Ok(standardized)
101    }
102
103    /// Apply weight standardization to 4D convolutional weights
104    ///
105    /// # Arguments
106    ///
107    /// * `weights` - Convolutional weights with shape [out_channels, in_channels, height, width]
108    ///
109    /// # Returns
110    ///
111    /// Standardized convolutional weights
112    pub fn standardize_conv4d(&self, weights: &Array4<A>) -> Result<Array4<A>> {
113        let shape = weights.shape();
114        if shape.len() != 4 {
115            return Err(OptimError::InvalidConfig(
116                "Expected 4D weights for conv4d standardization".to_string(),
117            ));
118        }
119
120        let out_channels = shape[0];
121        let in_channels = shape[1];
122        let kernel_h = shape[2];
123        let kernel_w = shape[3];
124        let n_elements = in_channels * kernel_h * kernel_w;
125        let n_elements_f = A::from_usize(n_elements).unwrap();
126
127        // Calculate mean for each output channel
128        let mut means = Array::zeros(out_channels);
129
130        for c_out in 0..out_channels {
131            let mut sum = A::zero();
132            for c_in in 0..in_channels {
133                for h in 0..kernel_h {
134                    for w in 0..kernel_w {
135                        sum = sum + weights[[c_out, c_in, h, w]];
136                    }
137                }
138            }
139            means[c_out] = sum / n_elements_f;
140        }
141
142        // Center the weights
143        let mut centered = weights.clone();
144
145        for c_out in 0..out_channels {
146            for c_in in 0..in_channels {
147                for h in 0..kernel_h {
148                    for w in 0..kernel_w {
149                        centered[[c_out, c_in, h, w]] = weights[[c_out, c_in, h, w]] - means[c_out];
150                    }
151                }
152            }
153        }
154
155        // Calculate variance for each output channel
156        let mut vars = Array::zeros(out_channels);
157
158        for c_out in 0..out_channels {
159            let mut sum_sq = A::zero();
160            for c_in in 0..in_channels {
161                for h in 0..kernel_h {
162                    for w in 0..kernel_w {
163                        sum_sq =
164                            sum_sq + centered[[c_out, c_in, h, w]] * centered[[c_out, c_in, h, w]];
165                    }
166                }
167            }
168            vars[c_out] = sum_sq / n_elements_f;
169        }
170
171        // Standardize
172        let mut standardized = centered.clone();
173
174        for c_out in 0..out_channels {
175            let std_dev = (vars[c_out] + self.eps).sqrt();
176            for c_in in 0..in_channels {
177                for h in 0..kernel_h {
178                    for w in 0..kernel_w {
179                        standardized[[c_out, c_in, h, w]] = centered[[c_out, c_in, h, w]] / std_dev;
180                    }
181                }
182            }
183        }
184
185        Ok(standardized)
186    }
187
188    /// Calculate the gradients of weight standardization
189    ///
190    /// # Arguments
191    ///
192    /// * `weights` - Original weights
193    /// * `grad_output` - Gradient from the next layer
194    ///
195    /// # Returns
196    ///
197    /// The gradient for the weights
198    fn compute_gradients<S1, S2>(
199        &self,
200        weights: &ArrayBase<S1, scirs2_core::ndarray::Ix2>,
201        grad_output: &ArrayBase<S2, scirs2_core::ndarray::Ix2>,
202    ) -> Result<Array2<A>>
203    where
204        S1: Data<Elem = A>,
205        S2: Data<Elem = A>,
206    {
207        // For simplicity, we're implementing a numerical approximation of the gradient
208        // In a real-world scenario, you would implement the analytical gradient
209
210        // Convert views to owned arrays to ensure we can modify them
211        let weights = weights.to_owned();
212        let grad_output = grad_output.to_owned();
213
214        let n_rows = weights.nrows();
215        let n_cols = weights.ncols();
216        let epsilon = A::from_f64(1e-6).unwrap();
217
218        let mut gradients = Array2::zeros((n_rows, n_cols));
219        let standardized = self.standardize(&weights)?;
220
221        // Numerical gradient approximation
222        for i in 0..n_rows {
223            for j in 0..n_cols {
224                let mut weights_plus = weights.clone();
225                weights_plus[[i, j]] = weights_plus[[i, j]] + epsilon;
226
227                let standardized_plus = self.standardize(&weights_plus)?;
228
229                // Calculate the gradient using centered difference
230                let diff = &standardized_plus - &standardized;
231
232                // Element-wise multiplication with grad_output and sum
233                let mut grad_sum = A::zero();
234                for r in 0..n_rows {
235                    for c in 0..n_cols {
236                        grad_sum = grad_sum + diff[[r, c]] * grad_output[[r, c]];
237                    }
238                }
239
240                gradients[[i, j]] = grad_sum / epsilon;
241            }
242        }
243
244        Ok(gradients)
245    }
246}
247
248impl<
249        A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
250        D: Dimension + Send + Sync,
251    > Regularizer<A, D> for WeightStandardization<A>
252{
253    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
254        // Check if we have 2D parameters
255        if params.ndim() != 2 {
256            // For simplicity, only handle 2D weights for gradient computation
257            // In practice, you would also handle 4D conv weights
258            return Ok(A::zero());
259        }
260
261        // Downcast to 2D
262        let params_2d = params
263            .view()
264            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
265            .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
266        let gradients_2d = gradients
267            .view()
268            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
269            .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
270
271        // Compute the gradient corrections
272        let corrections = self.compute_gradients(&params_2d, &gradients_2d)?;
273
274        // Apply the corrections to the gradients
275        let mut grad_mut = gradients
276            .view_mut()
277            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
278            .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
279
280        // Add the corrections to the gradients
281        grad_mut.zip_mut_with(&corrections, |g, &c| *g = *g + c);
282
283        // Weight standardization doesn't add a penalty term
284        Ok(A::zero())
285    }
286
287    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
288        // Weight standardization doesn't add a penalty term
289        Ok(A::zero())
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use approx::assert_relative_eq;
297    use scirs2_core::ndarray::array;
298
299    #[test]
300    fn test_weight_standardization_creation() {
301        let ws = WeightStandardization::<f64>::new(1e-5);
302        assert_eq!(ws.eps, 1e-5);
303    }
304
305    #[test]
306    fn test_standardize_2d() {
307        let ws = WeightStandardization::new(1e-5);
308
309        // Create a simple 2D weight matrix
310        let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
311
312        let standardized = ws.standardize(&weights).unwrap();
313
314        // Check shape is preserved
315        assert_eq!(standardized.shape(), weights.shape());
316
317        // Check means are close to zero
318        let mean1 = standardized.row(0).sum() / 3.0;
319        let mean2 = standardized.row(1).sum() / 3.0;
320
321        assert_relative_eq!(mean1, 0.0, epsilon = 1e-10);
322        assert_relative_eq!(mean2, 0.0, epsilon = 1e-10);
323
324        // Check variances are close to 1 (allowing for numerical precision)
325        let var1 = standardized.row(0).mapv(|x| x * x).sum() / 3.0;
326        let var2 = standardized.row(1).mapv(|x| x * x).sum() / 3.0;
327
328        println!("var1 = {}, var2 = {}", var1, var2);
329
330        // Relaxed tolerance needed due to numerical precision
331        assert!((var1 - 1.0).abs() < 2e-4);
332        assert!((var2 - 1.0).abs() < 2e-4);
333    }
334
335    #[test]
336    fn test_standardize_conv4d() {
337        let ws = WeightStandardization::new(1e-5);
338
339        // Create a simple 4D convolutional weight tensor
340        let weights = Array4::from_shape_fn((2, 2, 2, 2), |idx| {
341            let (a, b, c, d) = (idx.0, idx.1, idx.2, idx.3);
342            (a * 8 + b * 4 + c * 2 + d) as f64
343        });
344
345        let standardized = ws.standardize_conv4d(&weights).unwrap();
346
347        // Check shape is preserved
348        assert_eq!(standardized.shape(), weights.shape());
349
350        // Check means are close to zero for each output channel
351        let mut sum1 = 0.0;
352        let mut sum2 = 0.0;
353
354        for c_in in 0..2 {
355            for h in 0..2 {
356                for w in 0..2 {
357                    sum1 += standardized[[0, c_in, h, w]];
358                    sum2 += standardized[[1, c_in, h, w]];
359                }
360            }
361        }
362
363        let mean1 = sum1 / 8.0;
364        let mean2 = sum2 / 8.0;
365
366        assert_relative_eq!(mean1, 0.0, epsilon = 1e-10);
367        assert_relative_eq!(mean2, 0.0, epsilon = 1e-10);
368
369        // Check variances are close to 1 for each output channel (allowing for numerical precision)
370        let mut sum_sq1 = 0.0;
371        let mut sum_sq2 = 0.0;
372
373        for c_in in 0..2 {
374            for h in 0..2 {
375                for w in 0..2 {
376                    sum_sq1 += standardized[[0, c_in, h, w]] * standardized[[0, c_in, h, w]];
377                    sum_sq2 += standardized[[1, c_in, h, w]] * standardized[[1, c_in, h, w]];
378                }
379            }
380        }
381
382        let var1 = sum_sq1 / 8.0;
383        let var2 = sum_sq2 / 8.0;
384
385        assert!((var1 - 1.0).abs() < 1e-5);
386        assert!((var2 - 1.0).abs() < 1e-5);
387    }
388
389    #[test]
390    fn test_regularizer_trait() {
391        let ws = WeightStandardization::new(1e-5);
392        let params = array![[1.0, 2.0], [3.0, 4.0]];
393        let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
394        let orig_gradients = gradients.clone();
395
396        let penalty = ws.apply(&params, &mut gradients).unwrap();
397
398        // Penalty should be zero
399        assert_eq!(penalty, 0.0);
400
401        // Gradients should be modified
402        assert_ne!(gradients, orig_gradients);
403    }
404}