Skip to main content

optirs_core/regularizers/
spectral_norm.rs

1// Spectral normalization regularization
2//
3// Spectral normalization is a weight normalization technique that controls the
4// Lipschitz constant of the neural network by normalizing the spectral norm
5// (largest singular value) of weight matrices.
6
7use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::Rng;
10use scirs2_core::Random;
11use std::cell::RefCell;
12use std::fmt::Debug;
13
14use crate::error::{OptimError, Result};
15use crate::regularizers::Regularizer;
16
17/// Spectral normalization regularizer
18///
19/// Normalizes weight matrices by their spectral norm to ensure the Lipschitz
20/// constant is bounded. This helps with training stability and generalization.
21///
22/// # Example
23///
24/// ```no_run
25/// use scirs2_core::ndarray::array;
26/// use optirs_core::regularizers::SpectralNorm;
27///
28/// let spec_norm = SpectralNorm::new(1);
29/// let weights = array![[1.0, 2.0], [3.0, 4.0]];
30///
31/// // Normalize weights by spectral norm
32/// let normalized_weights = spec_norm.normalize(&weights).expect("normalize failed");
33/// ```
34#[derive(Debug)]
35pub struct SpectralNorm<A: Float> {
36    /// Number of power iterations for SVD approximation
37    n_power_iterations: usize,
38    /// Epsilon for numerical stability
39    eps: A,
40    /// Cached left singular vector
41    u: RefCell<Option<Array<A, scirs2_core::ndarray::Ix1>>>,
42    /// Cached right singular vector
43    v: RefCell<Option<Array<A, scirs2_core::ndarray::Ix1>>>,
44    /// Random number generator
45    rng: RefCell<Random<scirs2_core::random::rngs::StdRng>>,
46}
47
48impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> SpectralNorm<A> {
49    /// Create a new spectral normalization regularizer
50    ///
51    /// # Arguments
52    ///
53    /// * `n_power_iterations` - Number of power iterations for SVD approximation
54    pub fn new(n_poweriterations: usize) -> Self {
55        Self {
56            n_power_iterations: n_poweriterations,
57            eps: A::from_f64(1e-12).unwrap_or_else(|| A::epsilon()),
58            u: RefCell::new(None),
59            v: RefCell::new(None),
60            rng: RefCell::new(Random::seed(42)),
61        }
62    }
63
64    /// Compute the spectral norm (largest singular value) using power iteration
65    fn compute_spectral_norm(&self, weights: &Array2<A>) -> Result<A> {
66        let (m, n) = (weights.nrows(), weights.ncols());
67
68        // Initialize u and v if not already done
69        {
70            let u_ref = self.u.borrow();
71            let needs_init = u_ref.is_none() || u_ref.as_ref().is_none_or(|arr| arr.len() != m);
72            drop(u_ref);
73            if needs_init {
74                let mut rng = self.rng.borrow_mut();
75                let new_u = Array::from_shape_fn((m,), |_| {
76                    let val: f64 = rng.gen_range(0.0..1.0);
77                    A::from_f64(val).unwrap_or_else(|| A::one())
78                });
79                *self.u.borrow_mut() = Some(new_u);
80            }
81        }
82
83        {
84            let v_ref = self.v.borrow();
85            let needs_init = v_ref.is_none() || v_ref.as_ref().is_none_or(|arr| arr.len() != n);
86            drop(v_ref);
87            if needs_init {
88                let mut rng = self.rng.borrow_mut();
89                let new_v = Array::from_shape_fn((n,), |_| {
90                    let val: f64 = rng.gen_range(0.0..1.0);
91                    A::from_f64(val).unwrap_or_else(|| A::one())
92                });
93                *self.v.borrow_mut() = Some(new_v);
94            }
95        }
96
97        let mut u = self
98            .u
99            .borrow()
100            .as_ref()
101            .ok_or_else(|| {
102                OptimError::InvalidParameter("Left singular vector not initialized".to_string())
103            })?
104            .clone();
105        let mut v = self
106            .v
107            .borrow()
108            .as_ref()
109            .ok_or_else(|| {
110                OptimError::InvalidParameter("Right singular vector not initialized".to_string())
111            })?
112            .clone();
113
114        // Power iteration
115        for _ in 0..self.n_power_iterations {
116            // v = W^T u / ||W^T u||
117            let wt_u = weights.t().dot(&u);
118            let v_norm = (wt_u.dot(&wt_u) + self.eps).sqrt();
119            v = wt_u / v_norm;
120
121            // u = W v / ||W v||
122            let w_v = weights.dot(&v);
123            let u_norm = (w_v.dot(&w_v) + self.eps).sqrt();
124            u = w_v / u_norm;
125        }
126
127        // Update cached vectors
128        *self.u.borrow_mut() = Some(u.clone());
129        *self.v.borrow_mut() = Some(v.clone());
130
131        // Compute spectral norm as u^T W v
132        let w_v = weights.dot(&v);
133        let spectral_norm = u.dot(&w_v);
134
135        Ok(spectral_norm)
136    }
137
138    /// Normalize weights by spectral norm
139    pub fn normalize(&self, weights: &Array2<A>) -> Result<Array2<A>> {
140        let spectral_norm = self.compute_spectral_norm(weights)?;
141
142        if spectral_norm > self.eps {
143            Ok(weights / spectral_norm)
144        } else {
145            Ok(weights.clone())
146        }
147    }
148
149    /// Apply spectral normalization to 4D convolutional weights
150    pub fn normalize_conv4d(&self, weights: &Array4<A>) -> Result<Array4<A>> {
151        // Reshape to 2D for spectral norm computation
152        let shape = weights.shape();
153        let out_channels = shape[0];
154        let in_channels = shape[1];
155        let kernel_h = shape[2];
156        let kernel_w = shape[3];
157
158        let weights_2d = weights
159            .to_shape((out_channels, in_channels * kernel_h * kernel_w))
160            .map_err(|e| OptimError::InvalidConfig(format!("Cannot reshape weights: {}", e)))?;
161        let weights_2d_owned = weights_2d.to_owned();
162        let normalized_2d = self.normalize(&weights_2d_owned)?;
163
164        // Reshape back to 4D
165        let normalized_4d = normalized_2d
166            .to_shape((out_channels, in_channels, kernel_h, kernel_w))
167            .map_err(|e| {
168                OptimError::InvalidConfig(format!("Cannot reshape normalized weights: {}", e))
169            })?;
170        Ok(normalized_4d.to_owned())
171    }
172}
173
174// Implement Regularizer trait
175impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
176    for SpectralNorm<A>
177{
178    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
179        // For spectral normalization, we don't modify _gradients directly
180        // Instead, the normalization is typically applied during the forward pass
181        Ok(A::zero())
182    }
183
184    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
185        // Spectral normalization doesn't add a penalty term
186        Ok(A::zero())
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use approx::assert_relative_eq;
194    use scirs2_core::ndarray::array;
195
196    #[test]
197    fn test_spectral_norm_creation() {
198        let sn = SpectralNorm::<f64>::new(5);
199        assert_eq!(sn.n_power_iterations, 5);
200    }
201
202    #[test]
203    fn test_spectral_norm_2d() {
204        let sn = SpectralNorm::new(10);
205
206        // Create a simple matrix with known singular values
207        // For a 2x2 matrix [[1, 0], [0, 2]], the singular values are 1 and 2
208        let weights = array![[1.0, 0.0], [0.0, 2.0]];
209
210        let spectral_norm = sn
211            .compute_spectral_norm(&weights)
212            .expect("test: compute_spectral_norm failed");
213
214        // The spectral norm should be close to 2.0 (largest singular value)
215        assert_relative_eq!(spectral_norm, 2.0, epsilon = 0.1);
216    }
217
218    #[test]
219    fn test_normalize_2d() {
220        let sn = SpectralNorm::new(10);
221
222        let weights = array![[1.0, 2.0], [3.0, 4.0]];
223        let normalized = sn.normalize(&weights).expect("test: normalize failed");
224
225        // After normalization, the spectral norm should be close to 1
226        let spec_norm = sn
227            .compute_spectral_norm(&normalized)
228            .expect("test: compute_spectral_norm failed");
229        assert_relative_eq!(spec_norm, 1.0, epsilon = 0.1);
230    }
231
232    #[test]
233    fn test_conv4d_normalization() {
234        let sn = SpectralNorm::new(5);
235
236        // Create a 4D tensor (out_channels, in_channels, height, width)
237        let weights = Array::from_shape_fn((2, 3, 3, 3), |(o, i, h, w)| {
238            (o * 27 + i * 9 + h * 3 + w) as f64
239        });
240
241        let normalized = sn
242            .normalize_conv4d(&weights)
243            .expect("test: normalize_conv4d failed");
244
245        // Check that the shape is preserved
246        assert_eq!(normalized.shape(), weights.shape());
247    }
248
249    #[test]
250    fn test_invalid_conv4d() {
251        let sn = SpectralNorm::<f64>::new(5);
252
253        // Create a 4D tensor (which is valid)
254        let weights = Array::zeros((2, 3, 4, 4));
255
256        // Should succeed for 4D tensors
257        assert!(sn.normalize_conv4d(&weights).is_ok());
258    }
259
260    #[test]
261    fn test_regularizer_trait() {
262        let sn = SpectralNorm::new(5);
263        let params = array![[1.0, 2.0], [3.0, 4.0]];
264        let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
265
266        // Spectral norm doesn't modify gradients or add penalties
267        let penalty = sn.penalty(&params).expect("test: penalty failed");
268        assert_eq!(penalty, 0.0);
269
270        let apply_result = sn
271            .apply(&params, &mut gradient)
272            .expect("test: apply failed");
273        assert_eq!(apply_result, 0.0);
274
275        // Gradients should be unchanged
276        assert_eq!(gradient, array![[0.1, 0.2], [0.3, 0.4]]);
277    }
278}