optirs_core/regularizers/
weight_standardization.rs1use scirs2_core::ndarray::{Array, Array2, Array4, ArrayBase, Data, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14#[derive(Debug, Clone)]
37pub struct WeightStandardization<A: Float> {
38 eps: A,
40}
41
42impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> WeightStandardization<A> {
43 pub fn new(eps: f64) -> Self {
49 Self {
50 eps: A::from_f64(eps).unwrap(),
51 }
52 }
53
54 pub fn standardize(&self, weights: &Array2<A>) -> Result<Array2<A>> {
66 let n_cols = weights.ncols();
68 let n_cols_f = A::from_usize(n_cols).unwrap();
69
70 let means = weights.sum_axis(scirs2_core::ndarray::Axis(1)) / n_cols_f;
72
73 let mut centered = weights.clone();
75 for i in 0..weights.nrows() {
76 for j in 0..weights.ncols() {
77 centered[[i, j]] = centered[[i, j]] - means[i];
78 }
79 }
80
81 let mut var = Array::zeros(weights.nrows());
83 for i in 0..weights.nrows() {
84 let mut sum_sq = A::zero();
85 for j in 0..weights.ncols() {
86 sum_sq = sum_sq + centered[[i, j]] * centered[[i, j]];
87 }
88 var[i] = sum_sq / n_cols_f;
89 }
90
91 let mut standardized = centered.clone();
93 for i in 0..weights.nrows() {
94 let denom = (var[i] + self.eps).sqrt();
95 for j in 0..weights.ncols() {
96 standardized[[i, j]] = centered[[i, j]] / denom;
97 }
98 }
99
100 Ok(standardized)
101 }
102
103 pub fn standardize_conv4d(&self, weights: &Array4<A>) -> Result<Array4<A>> {
113 let shape = weights.shape();
114 if shape.len() != 4 {
115 return Err(OptimError::InvalidConfig(
116 "Expected 4D weights for conv4d standardization".to_string(),
117 ));
118 }
119
120 let out_channels = shape[0];
121 let in_channels = shape[1];
122 let kernel_h = shape[2];
123 let kernel_w = shape[3];
124 let n_elements = in_channels * kernel_h * kernel_w;
125 let n_elements_f = A::from_usize(n_elements).unwrap();
126
127 let mut means = Array::zeros(out_channels);
129
130 for c_out in 0..out_channels {
131 let mut sum = A::zero();
132 for c_in in 0..in_channels {
133 for h in 0..kernel_h {
134 for w in 0..kernel_w {
135 sum = sum + weights[[c_out, c_in, h, w]];
136 }
137 }
138 }
139 means[c_out] = sum / n_elements_f;
140 }
141
142 let mut centered = weights.clone();
144
145 for c_out in 0..out_channels {
146 for c_in in 0..in_channels {
147 for h in 0..kernel_h {
148 for w in 0..kernel_w {
149 centered[[c_out, c_in, h, w]] = weights[[c_out, c_in, h, w]] - means[c_out];
150 }
151 }
152 }
153 }
154
155 let mut vars = Array::zeros(out_channels);
157
158 for c_out in 0..out_channels {
159 let mut sum_sq = A::zero();
160 for c_in in 0..in_channels {
161 for h in 0..kernel_h {
162 for w in 0..kernel_w {
163 sum_sq =
164 sum_sq + centered[[c_out, c_in, h, w]] * centered[[c_out, c_in, h, w]];
165 }
166 }
167 }
168 vars[c_out] = sum_sq / n_elements_f;
169 }
170
171 let mut standardized = centered.clone();
173
174 for c_out in 0..out_channels {
175 let std_dev = (vars[c_out] + self.eps).sqrt();
176 for c_in in 0..in_channels {
177 for h in 0..kernel_h {
178 for w in 0..kernel_w {
179 standardized[[c_out, c_in, h, w]] = centered[[c_out, c_in, h, w]] / std_dev;
180 }
181 }
182 }
183 }
184
185 Ok(standardized)
186 }
187
188 fn compute_gradients<S1, S2>(
199 &self,
200 weights: &ArrayBase<S1, scirs2_core::ndarray::Ix2>,
201 grad_output: &ArrayBase<S2, scirs2_core::ndarray::Ix2>,
202 ) -> Result<Array2<A>>
203 where
204 S1: Data<Elem = A>,
205 S2: Data<Elem = A>,
206 {
207 let weights = weights.to_owned();
212 let grad_output = grad_output.to_owned();
213
214 let n_rows = weights.nrows();
215 let n_cols = weights.ncols();
216 let epsilon = A::from_f64(1e-6).unwrap();
217
218 let mut gradients = Array2::zeros((n_rows, n_cols));
219 let standardized = self.standardize(&weights)?;
220
221 for i in 0..n_rows {
223 for j in 0..n_cols {
224 let mut weights_plus = weights.clone();
225 weights_plus[[i, j]] = weights_plus[[i, j]] + epsilon;
226
227 let standardized_plus = self.standardize(&weights_plus)?;
228
229 let diff = &standardized_plus - &standardized;
231
232 let mut grad_sum = A::zero();
234 for r in 0..n_rows {
235 for c in 0..n_cols {
236 grad_sum = grad_sum + diff[[r, c]] * grad_output[[r, c]];
237 }
238 }
239
240 gradients[[i, j]] = grad_sum / epsilon;
241 }
242 }
243
244 Ok(gradients)
245 }
246}
247
248impl<
249 A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
250 D: Dimension + Send + Sync,
251 > Regularizer<A, D> for WeightStandardization<A>
252{
253 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
254 if params.ndim() != 2 {
256 return Ok(A::zero());
259 }
260
261 let params_2d = params
263 .view()
264 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
265 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
266 let gradients_2d = gradients
267 .view()
268 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
269 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
270
271 let corrections = self.compute_gradients(¶ms_2d, &gradients_2d)?;
273
274 let mut grad_mut = gradients
276 .view_mut()
277 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
278 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
279
280 grad_mut.zip_mut_with(&corrections, |g, &c| *g = *g + c);
282
283 Ok(A::zero())
285 }
286
287 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
288 Ok(A::zero())
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use approx::assert_relative_eq;
297 use scirs2_core::ndarray::array;
298
299 #[test]
300 fn test_weight_standardization_creation() {
301 let ws = WeightStandardization::<f64>::new(1e-5);
302 assert_eq!(ws.eps, 1e-5);
303 }
304
305 #[test]
306 fn test_standardize_2d() {
307 let ws = WeightStandardization::new(1e-5);
308
309 let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
311
312 let standardized = ws.standardize(&weights).unwrap();
313
314 assert_eq!(standardized.shape(), weights.shape());
316
317 let mean1 = standardized.row(0).sum() / 3.0;
319 let mean2 = standardized.row(1).sum() / 3.0;
320
321 assert_relative_eq!(mean1, 0.0, epsilon = 1e-10);
322 assert_relative_eq!(mean2, 0.0, epsilon = 1e-10);
323
324 let var1 = standardized.row(0).mapv(|x| x * x).sum() / 3.0;
326 let var2 = standardized.row(1).mapv(|x| x * x).sum() / 3.0;
327
328 println!("var1 = {}, var2 = {}", var1, var2);
329
330 assert!((var1 - 1.0).abs() < 2e-4);
332 assert!((var2 - 1.0).abs() < 2e-4);
333 }
334
335 #[test]
336 fn test_standardize_conv4d() {
337 let ws = WeightStandardization::new(1e-5);
338
339 let weights = Array4::from_shape_fn((2, 2, 2, 2), |idx| {
341 let (a, b, c, d) = (idx.0, idx.1, idx.2, idx.3);
342 (a * 8 + b * 4 + c * 2 + d) as f64
343 });
344
345 let standardized = ws.standardize_conv4d(&weights).unwrap();
346
347 assert_eq!(standardized.shape(), weights.shape());
349
350 let mut sum1 = 0.0;
352 let mut sum2 = 0.0;
353
354 for c_in in 0..2 {
355 for h in 0..2 {
356 for w in 0..2 {
357 sum1 += standardized[[0, c_in, h, w]];
358 sum2 += standardized[[1, c_in, h, w]];
359 }
360 }
361 }
362
363 let mean1 = sum1 / 8.0;
364 let mean2 = sum2 / 8.0;
365
366 assert_relative_eq!(mean1, 0.0, epsilon = 1e-10);
367 assert_relative_eq!(mean2, 0.0, epsilon = 1e-10);
368
369 let mut sum_sq1 = 0.0;
371 let mut sum_sq2 = 0.0;
372
373 for c_in in 0..2 {
374 for h in 0..2 {
375 for w in 0..2 {
376 sum_sq1 += standardized[[0, c_in, h, w]] * standardized[[0, c_in, h, w]];
377 sum_sq2 += standardized[[1, c_in, h, w]] * standardized[[1, c_in, h, w]];
378 }
379 }
380 }
381
382 let var1 = sum_sq1 / 8.0;
383 let var2 = sum_sq2 / 8.0;
384
385 assert!((var1 - 1.0).abs() < 1e-5);
386 assert!((var2 - 1.0).abs() < 1e-5);
387 }
388
389 #[test]
390 fn test_regularizer_trait() {
391 let ws = WeightStandardization::new(1e-5);
392 let params = array![[1.0, 2.0], [3.0, 4.0]];
393 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
394 let orig_gradients = gradients.clone();
395
396 let penalty = ws.apply(¶ms, &mut gradients).unwrap();
397
398 assert_eq!(penalty, 0.0);
400
401 assert_ne!(gradients, orig_gradients);
403 }
404}