optirs_core/regularizers/
spectral_norm.rs

1// Spectral normalization regularization
2//
3// Spectral normalization is a weight normalization technique that controls the
4// Lipschitz constant of the neural network by normalizing the spectral norm
5// (largest singular value) of weight matrices.
6
7use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::Rng;
10use scirs2_core::Random;
11use std::fmt::Debug;
12
13use crate::error::{OptimError, Result};
14use crate::regularizers::Regularizer;
15
16/// Spectral normalization regularizer
17///
18/// Normalizes weight matrices by their spectral norm to ensure the Lipschitz
19/// constant is bounded. This helps with training stability and generalization.
20///
21/// # Example
22///
23/// ```no_run
24/// use scirs2_core::ndarray::array;
25/// use optirs_core::regularizers::SpectralNorm;
26///
27/// let mut spec_norm = SpectralNorm::new(1);
28/// let weights = array![[1.0, 2.0], [3.0, 4.0]];
29///
30/// // Normalize weights by spectral norm
31/// let normalized_weights = spec_norm.normalize(&weights).unwrap();
32/// ```
33#[derive(Debug, Clone)]
34pub struct SpectralNorm<A: Float> {
35    /// Number of power iterations for SVD approximation
36    n_power_iterations: usize,
37    /// Epsilon for numerical stability
38    eps: A,
39    /// Cached left singular vector
40    u: Option<Array<A, scirs2_core::ndarray::Ix1>>,
41    /// Cached right singular vector  
42    v: Option<Array<A, scirs2_core::ndarray::Ix1>>,
43    /// Random number generator
44    rng: Random<scirs2_core::random::rngs::StdRng>,
45}
46
47impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> SpectralNorm<A> {
48    /// Create a new spectral normalization regularizer
49    ///
50    /// # Arguments
51    ///
52    /// * `n_power_iterations` - Number of power iterations for SVD approximation
53    pub fn new(n_poweriterations: usize) -> Self {
54        Self {
55            n_power_iterations: n_poweriterations,
56            eps: A::from_f64(1e-12).unwrap(),
57            u: None,
58            v: None,
59            rng: Random::seed(42),
60        }
61    }
62
63    /// Compute the spectral norm (largest singular value) using power iteration
64    fn compute_spectral_norm(&mut self, weights: &Array2<A>) -> Result<A> {
65        let (m, n) = (weights.nrows(), weights.ncols());
66
67        // Initialize u and v if not already done
68        if self.u.is_none() || self.u.as_ref().unwrap().len() != m {
69            self.u = Some(Array::from_shape_fn((m,), |_| {
70                let val: f64 = self.rng.gen_range(0.0..1.0);
71                A::from_f64(val).unwrap()
72            }));
73        }
74
75        if self.v.is_none() || self.v.as_ref().unwrap().len() != n {
76            self.v = Some(Array::from_shape_fn((n,), |_| {
77                let val: f64 = self.rng.gen_range(0.0..1.0);
78                A::from_f64(val).unwrap()
79            }));
80        }
81
82        let mut u = self.u.as_ref().unwrap().clone();
83        let mut v = self.v.as_ref().unwrap().clone();
84
85        // Power iteration
86        for _ in 0..self.n_power_iterations {
87            // v = W^T u / ||W^T u||
88            let wt_u = weights.t().dot(&u);
89            let v_norm = (wt_u.dot(&wt_u) + self.eps).sqrt();
90            v = wt_u / v_norm;
91
92            // u = W v / ||W v||
93            let w_v = weights.dot(&v);
94            let u_norm = (w_v.dot(&w_v) + self.eps).sqrt();
95            u = w_v / u_norm;
96        }
97
98        // Update cached vectors
99        self.u = Some(u.clone());
100        self.v = Some(v.clone());
101
102        // Compute spectral norm as u^T W v
103        let w_v = weights.dot(&v);
104        let spectral_norm = u.dot(&w_v);
105
106        Ok(spectral_norm)
107    }
108
109    /// Normalize weights by spectral norm
110    pub fn normalize(&mut self, weights: &Array2<A>) -> Result<Array2<A>> {
111        let spectral_norm = self.compute_spectral_norm(weights)?;
112
113        if spectral_norm > self.eps {
114            Ok(weights / spectral_norm)
115        } else {
116            Ok(weights.clone())
117        }
118    }
119
120    /// Apply spectral normalization to 4D convolutional weights
121    pub fn normalize_conv4d(&mut self, weights: &Array4<A>) -> Result<Array4<A>> {
122        // Reshape to 2D for spectral norm computation
123        let shape = weights.shape();
124        let out_channels = shape[0];
125        let in_channels = shape[1];
126        let kernel_h = shape[2];
127        let kernel_w = shape[3];
128
129        let weights_2d = weights
130            .to_shape((out_channels, in_channels * kernel_h * kernel_w))
131            .map_err(|e| OptimError::InvalidConfig(format!("Cannot reshape weights: {}", e)))?;
132        let weights_2d_owned = weights_2d.to_owned();
133        let normalized_2d = self.normalize(&weights_2d_owned)?;
134
135        // Reshape back to 4D
136        let normalized_4d = normalized_2d
137            .to_shape((out_channels, in_channels, kernel_h, kernel_w))
138            .map_err(|e| {
139                OptimError::InvalidConfig(format!("Cannot reshape normalized weights: {}", e))
140            })?;
141        Ok(normalized_4d.to_owned())
142    }
143}
144
145// Implement Regularizer trait
146impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
147    for SpectralNorm<A>
148{
149    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
150        // For spectral normalization, we don't modify _gradients directly
151        // Instead, the normalization is typically applied during the forward pass
152        Ok(A::zero())
153    }
154
155    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
156        // Spectral normalization doesn't add a penalty term
157        Ok(A::zero())
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use approx::assert_relative_eq;
165    use scirs2_core::ndarray::array;
166
167    #[test]
168    fn test_spectral_norm_creation() {
169        let sn = SpectralNorm::<f64>::new(5);
170        assert_eq!(sn.n_power_iterations, 5);
171    }
172
173    #[test]
174    fn test_spectral_norm_2d() {
175        let mut sn = SpectralNorm::new(10);
176
177        // Create a simple matrix with known singular values
178        // For a 2x2 matrix [[1, 0], [0, 2]], the singular values are 1 and 2
179        let weights = array![[1.0, 0.0], [0.0, 2.0]];
180
181        let spectral_norm = sn.compute_spectral_norm(&weights).unwrap();
182
183        // The spectral norm should be close to 2.0 (largest singular value)
184        assert_relative_eq!(spectral_norm, 2.0, epsilon = 0.1);
185    }
186
187    #[test]
188    fn test_normalize_2d() {
189        let mut sn = SpectralNorm::new(10);
190
191        let weights = array![[1.0, 2.0], [3.0, 4.0]];
192        let normalized = sn.normalize(&weights).unwrap();
193
194        // After normalization, the spectral norm should be close to 1
195        let spec_norm = sn.compute_spectral_norm(&normalized).unwrap();
196        assert_relative_eq!(spec_norm, 1.0, epsilon = 0.1);
197    }
198
199    #[test]
200    fn test_conv4d_normalization() {
201        let mut sn = SpectralNorm::new(5);
202
203        // Create a 4D tensor (out_channels, in_channels, height, width)
204        let weights = Array::from_shape_fn((2, 3, 3, 3), |(o, i, h, w)| {
205            (o * 27 + i * 9 + h * 3 + w) as f64
206        });
207
208        let normalized = sn.normalize_conv4d(&weights).unwrap();
209
210        // Check that the shape is preserved
211        assert_eq!(normalized.shape(), weights.shape());
212    }
213
214    #[test]
215    fn test_invalid_conv4d() {
216        let mut sn = SpectralNorm::<f64>::new(5);
217
218        // Create a 4D tensor (which is valid)
219        let weights = Array::zeros((2, 3, 4, 4));
220
221        // Should succeed for 4D tensors
222        assert!(sn.normalize_conv4d(&weights).is_ok());
223    }
224
225    #[test]
226    fn test_regularizer_trait() {
227        let sn = SpectralNorm::new(5);
228        let params = array![[1.0, 2.0], [3.0, 4.0]];
229        let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
230
231        // Spectral norm doesn't modify gradients or add penalties
232        let penalty = sn.penalty(&params).unwrap();
233        assert_eq!(penalty, 0.0);
234
235        let apply_result = sn.apply(&params, &mut gradient).unwrap();
236        assert_eq!(apply_result, 0.0);
237
238        // Gradients should be unchanged
239        assert_eq!(gradient, array![[0.1, 0.2], [0.3, 0.4]]);
240    }
241}