optirs_core/regularizers/
spectral_norm.rs1use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::Rng;
10use scirs2_core::Random;
11use std::cell::RefCell;
12use std::fmt::Debug;
13
14use crate::error::{OptimError, Result};
15use crate::regularizers::Regularizer;
16
17#[derive(Debug)]
35pub struct SpectralNorm<A: Float> {
36 n_power_iterations: usize,
38 eps: A,
40 u: RefCell<Option<Array<A, scirs2_core::ndarray::Ix1>>>,
42 v: RefCell<Option<Array<A, scirs2_core::ndarray::Ix1>>>,
44 rng: RefCell<Random<scirs2_core::random::rngs::StdRng>>,
46}
47
48impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> SpectralNorm<A> {
49 pub fn new(n_poweriterations: usize) -> Self {
55 Self {
56 n_power_iterations: n_poweriterations,
57 eps: A::from_f64(1e-12).unwrap_or_else(|| A::epsilon()),
58 u: RefCell::new(None),
59 v: RefCell::new(None),
60 rng: RefCell::new(Random::seed(42)),
61 }
62 }
63
64 fn compute_spectral_norm(&self, weights: &Array2<A>) -> Result<A> {
66 let (m, n) = (weights.nrows(), weights.ncols());
67
68 {
70 let u_ref = self.u.borrow();
71 let needs_init = u_ref.is_none() || u_ref.as_ref().is_none_or(|arr| arr.len() != m);
72 drop(u_ref);
73 if needs_init {
74 let mut rng = self.rng.borrow_mut();
75 let new_u = Array::from_shape_fn((m,), |_| {
76 let val: f64 = rng.gen_range(0.0..1.0);
77 A::from_f64(val).unwrap_or_else(|| A::one())
78 });
79 *self.u.borrow_mut() = Some(new_u);
80 }
81 }
82
83 {
84 let v_ref = self.v.borrow();
85 let needs_init = v_ref.is_none() || v_ref.as_ref().is_none_or(|arr| arr.len() != n);
86 drop(v_ref);
87 if needs_init {
88 let mut rng = self.rng.borrow_mut();
89 let new_v = Array::from_shape_fn((n,), |_| {
90 let val: f64 = rng.gen_range(0.0..1.0);
91 A::from_f64(val).unwrap_or_else(|| A::one())
92 });
93 *self.v.borrow_mut() = Some(new_v);
94 }
95 }
96
97 let mut u = self
98 .u
99 .borrow()
100 .as_ref()
101 .ok_or_else(|| {
102 OptimError::InvalidParameter("Left singular vector not initialized".to_string())
103 })?
104 .clone();
105 let mut v = self
106 .v
107 .borrow()
108 .as_ref()
109 .ok_or_else(|| {
110 OptimError::InvalidParameter("Right singular vector not initialized".to_string())
111 })?
112 .clone();
113
114 for _ in 0..self.n_power_iterations {
116 let wt_u = weights.t().dot(&u);
118 let v_norm = (wt_u.dot(&wt_u) + self.eps).sqrt();
119 v = wt_u / v_norm;
120
121 let w_v = weights.dot(&v);
123 let u_norm = (w_v.dot(&w_v) + self.eps).sqrt();
124 u = w_v / u_norm;
125 }
126
127 *self.u.borrow_mut() = Some(u.clone());
129 *self.v.borrow_mut() = Some(v.clone());
130
131 let w_v = weights.dot(&v);
133 let spectral_norm = u.dot(&w_v);
134
135 Ok(spectral_norm)
136 }
137
138 pub fn normalize(&self, weights: &Array2<A>) -> Result<Array2<A>> {
140 let spectral_norm = self.compute_spectral_norm(weights)?;
141
142 if spectral_norm > self.eps {
143 Ok(weights / spectral_norm)
144 } else {
145 Ok(weights.clone())
146 }
147 }
148
149 pub fn normalize_conv4d(&self, weights: &Array4<A>) -> Result<Array4<A>> {
151 let shape = weights.shape();
153 let out_channels = shape[0];
154 let in_channels = shape[1];
155 let kernel_h = shape[2];
156 let kernel_w = shape[3];
157
158 let weights_2d = weights
159 .to_shape((out_channels, in_channels * kernel_h * kernel_w))
160 .map_err(|e| OptimError::InvalidConfig(format!("Cannot reshape weights: {}", e)))?;
161 let weights_2d_owned = weights_2d.to_owned();
162 let normalized_2d = self.normalize(&weights_2d_owned)?;
163
164 let normalized_4d = normalized_2d
166 .to_shape((out_channels, in_channels, kernel_h, kernel_w))
167 .map_err(|e| {
168 OptimError::InvalidConfig(format!("Cannot reshape normalized weights: {}", e))
169 })?;
170 Ok(normalized_4d.to_owned())
171 }
172}
173
174impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
176 for SpectralNorm<A>
177{
178 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
179 Ok(A::zero())
182 }
183
184 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
185 Ok(A::zero())
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use approx::assert_relative_eq;
194 use scirs2_core::ndarray::array;
195
196 #[test]
197 fn test_spectral_norm_creation() {
198 let sn = SpectralNorm::<f64>::new(5);
199 assert_eq!(sn.n_power_iterations, 5);
200 }
201
202 #[test]
203 fn test_spectral_norm_2d() {
204 let sn = SpectralNorm::new(10);
205
206 let weights = array![[1.0, 0.0], [0.0, 2.0]];
209
210 let spectral_norm = sn
211 .compute_spectral_norm(&weights)
212 .expect("test: compute_spectral_norm failed");
213
214 assert_relative_eq!(spectral_norm, 2.0, epsilon = 0.1);
216 }
217
218 #[test]
219 fn test_normalize_2d() {
220 let sn = SpectralNorm::new(10);
221
222 let weights = array![[1.0, 2.0], [3.0, 4.0]];
223 let normalized = sn.normalize(&weights).expect("test: normalize failed");
224
225 let spec_norm = sn
227 .compute_spectral_norm(&normalized)
228 .expect("test: compute_spectral_norm failed");
229 assert_relative_eq!(spec_norm, 1.0, epsilon = 0.1);
230 }
231
232 #[test]
233 fn test_conv4d_normalization() {
234 let sn = SpectralNorm::new(5);
235
236 let weights = Array::from_shape_fn((2, 3, 3, 3), |(o, i, h, w)| {
238 (o * 27 + i * 9 + h * 3 + w) as f64
239 });
240
241 let normalized = sn
242 .normalize_conv4d(&weights)
243 .expect("test: normalize_conv4d failed");
244
245 assert_eq!(normalized.shape(), weights.shape());
247 }
248
249 #[test]
250 fn test_invalid_conv4d() {
251 let sn = SpectralNorm::<f64>::new(5);
252
253 let weights = Array::zeros((2, 3, 4, 4));
255
256 assert!(sn.normalize_conv4d(&weights).is_ok());
258 }
259
260 #[test]
261 fn test_regularizer_trait() {
262 let sn = SpectralNorm::new(5);
263 let params = array![[1.0, 2.0], [3.0, 4.0]];
264 let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
265
266 let penalty = sn.penalty(¶ms).expect("test: penalty failed");
268 assert_eq!(penalty, 0.0);
269
270 let apply_result = sn
271 .apply(¶ms, &mut gradient)
272 .expect("test: apply failed");
273 assert_eq!(apply_result, 0.0);
274
275 assert_eq!(gradient, array![[0.1, 0.2], [0.3, 0.4]]);
277 }
278}