optirs_core/regularizers/
spectral_norm.rs1use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::Rng;
10use scirs2_core::Random;
11use std::fmt::Debug;
12
13use crate::error::{OptimError, Result};
14use crate::regularizers::Regularizer;
15
16#[derive(Debug, Clone)]
34pub struct SpectralNorm<A: Float> {
35 n_power_iterations: usize,
37 eps: A,
39 u: Option<Array<A, scirs2_core::ndarray::Ix1>>,
41 v: Option<Array<A, scirs2_core::ndarray::Ix1>>,
43 rng: Random<scirs2_core::random::rngs::StdRng>,
45}
46
47impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> SpectralNorm<A> {
48 pub fn new(n_poweriterations: usize) -> Self {
54 Self {
55 n_power_iterations: n_poweriterations,
56 eps: A::from_f64(1e-12).unwrap(),
57 u: None,
58 v: None,
59 rng: Random::seed(42),
60 }
61 }
62
63 fn compute_spectral_norm(&mut self, weights: &Array2<A>) -> Result<A> {
65 let (m, n) = (weights.nrows(), weights.ncols());
66
67 if self.u.is_none() || self.u.as_ref().unwrap().len() != m {
69 self.u = Some(Array::from_shape_fn((m,), |_| {
70 let val: f64 = self.rng.gen_range(0.0..1.0);
71 A::from_f64(val).unwrap()
72 }));
73 }
74
75 if self.v.is_none() || self.v.as_ref().unwrap().len() != n {
76 self.v = Some(Array::from_shape_fn((n,), |_| {
77 let val: f64 = self.rng.gen_range(0.0..1.0);
78 A::from_f64(val).unwrap()
79 }));
80 }
81
82 let mut u = self.u.as_ref().unwrap().clone();
83 let mut v = self.v.as_ref().unwrap().clone();
84
85 for _ in 0..self.n_power_iterations {
87 let wt_u = weights.t().dot(&u);
89 let v_norm = (wt_u.dot(&wt_u) + self.eps).sqrt();
90 v = wt_u / v_norm;
91
92 let w_v = weights.dot(&v);
94 let u_norm = (w_v.dot(&w_v) + self.eps).sqrt();
95 u = w_v / u_norm;
96 }
97
98 self.u = Some(u.clone());
100 self.v = Some(v.clone());
101
102 let w_v = weights.dot(&v);
104 let spectral_norm = u.dot(&w_v);
105
106 Ok(spectral_norm)
107 }
108
109 pub fn normalize(&mut self, weights: &Array2<A>) -> Result<Array2<A>> {
111 let spectral_norm = self.compute_spectral_norm(weights)?;
112
113 if spectral_norm > self.eps {
114 Ok(weights / spectral_norm)
115 } else {
116 Ok(weights.clone())
117 }
118 }
119
120 pub fn normalize_conv4d(&mut self, weights: &Array4<A>) -> Result<Array4<A>> {
122 let shape = weights.shape();
124 let out_channels = shape[0];
125 let in_channels = shape[1];
126 let kernel_h = shape[2];
127 let kernel_w = shape[3];
128
129 let weights_2d = weights
130 .to_shape((out_channels, in_channels * kernel_h * kernel_w))
131 .map_err(|e| OptimError::InvalidConfig(format!("Cannot reshape weights: {}", e)))?;
132 let weights_2d_owned = weights_2d.to_owned();
133 let normalized_2d = self.normalize(&weights_2d_owned)?;
134
135 let normalized_4d = normalized_2d
137 .to_shape((out_channels, in_channels, kernel_h, kernel_w))
138 .map_err(|e| {
139 OptimError::InvalidConfig(format!("Cannot reshape normalized weights: {}", e))
140 })?;
141 Ok(normalized_4d.to_owned())
142 }
143}
144
145impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
147 for SpectralNorm<A>
148{
149 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
150 Ok(A::zero())
153 }
154
155 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
156 Ok(A::zero())
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use approx::assert_relative_eq;
165 use scirs2_core::ndarray::array;
166
167 #[test]
168 fn test_spectral_norm_creation() {
169 let sn = SpectralNorm::<f64>::new(5);
170 assert_eq!(sn.n_power_iterations, 5);
171 }
172
173 #[test]
174 fn test_spectral_norm_2d() {
175 let mut sn = SpectralNorm::new(10);
176
177 let weights = array![[1.0, 0.0], [0.0, 2.0]];
180
181 let spectral_norm = sn.compute_spectral_norm(&weights).unwrap();
182
183 assert_relative_eq!(spectral_norm, 2.0, epsilon = 0.1);
185 }
186
187 #[test]
188 fn test_normalize_2d() {
189 let mut sn = SpectralNorm::new(10);
190
191 let weights = array![[1.0, 2.0], [3.0, 4.0]];
192 let normalized = sn.normalize(&weights).unwrap();
193
194 let spec_norm = sn.compute_spectral_norm(&normalized).unwrap();
196 assert_relative_eq!(spec_norm, 1.0, epsilon = 0.1);
197 }
198
199 #[test]
200 fn test_conv4d_normalization() {
201 let mut sn = SpectralNorm::new(5);
202
203 let weights = Array::from_shape_fn((2, 3, 3, 3), |(o, i, h, w)| {
205 (o * 27 + i * 9 + h * 3 + w) as f64
206 });
207
208 let normalized = sn.normalize_conv4d(&weights).unwrap();
209
210 assert_eq!(normalized.shape(), weights.shape());
212 }
213
214 #[test]
215 fn test_invalid_conv4d() {
216 let mut sn = SpectralNorm::<f64>::new(5);
217
218 let weights = Array::zeros((2, 3, 4, 4));
220
221 assert!(sn.normalize_conv4d(&weights).is_ok());
223 }
224
225 #[test]
226 fn test_regularizer_trait() {
227 let sn = SpectralNorm::new(5);
228 let params = array![[1.0, 2.0], [3.0, 4.0]];
229 let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
230
231 let penalty = sn.penalty(¶ms).unwrap();
233 assert_eq!(penalty, 0.0);
234
235 let apply_result = sn.apply(¶ms, &mut gradient).unwrap();
236 assert_eq!(apply_result, 0.0);
237
238 assert_eq!(gradient, array![[0.1, 0.2], [0.3, 0.4]]);
240 }
241}