optirs_core/regularizers/
orthogonal.rs

1// Orthogonal regularization
2//
3// Orthogonal regularization encourages weight matrices to be orthogonal,
4// which helps with gradient flow and prevents vanishing/exploding gradients.
5
6use scirs2_core::ndarray::{Array, Array3, ArrayBase, Data, Dimension, Ix2, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::regularizers::Regularizer;
12
13/// Orthogonal regularization
14///
15/// Encourages weight matrices to be orthogonal by penalizing the difference
16/// between W^T * W and the identity matrix.
17///
18/// # Example
19///
20/// ```
21/// use scirs2_core::ndarray::array;
22/// use optirs_core::regularizers::{OrthogonalRegularization, Regularizer};
23///
24/// let ortho_reg = OrthogonalRegularization::new(0.01);
25/// let weights = array![[1.0, 0.0], [0.0, 1.0]];
26/// let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
27///
28/// // Apply orthogonal regularization  
29/// let penalty = ortho_reg.apply(&weights, &mut gradient).unwrap();
30/// ```
31#[derive(Debug, Clone)]
32pub struct OrthogonalRegularization<A: Float> {
33    /// Regularization strength
34    lambda: A,
35}
36
37impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> OrthogonalRegularization<A> {
38    /// Create a new orthogonal regularization
39    ///
40    /// # Arguments
41    ///
42    /// * `lambda` - Regularization strength
43    pub fn new(lambda: A) -> Self {
44        Self { lambda }
45    }
46
47    /// Compute orthogonal penalty for a 2D weight matrix
48    pub fn compute_penalty_2d<S: Data<Elem = A>>(&self, weights: &ArrayBase<S, Ix2>) -> A {
49        let n = weights.nrows().min(weights.ncols());
50        let eye = Array::<A, Ix2>::eye(n);
51
52        // Compute W^T * W
53        let wtw = weights.t().dot(weights);
54
55        // Compute Frobenius norm of (W^T * W - I)
56        let mut penalty = A::zero();
57        for i in 0..n {
58            for j in 0..n {
59                let diff = wtw[[i, j]] - eye[[i, j]];
60                penalty = penalty + diff * diff;
61            }
62        }
63
64        // For non-square matrices, add penalty for off-diagonal elements
65        if weights.nrows() != weights.ncols() {
66            let (rows, cols) = wtw.dim();
67            for i in 0..rows {
68                for j in 0..cols {
69                    if i >= n || j >= n {
70                        penalty = penalty + wtw[[i, j]] * wtw[[i, j]];
71                    }
72                }
73            }
74        }
75
76        self.lambda * penalty
77    }
78
79    /// Compute gradient of orthogonal penalty
80    fn compute_gradient_2d<S: Data<Elem = A>>(&self, weights: &ArrayBase<S, Ix2>) -> Array<A, Ix2> {
81        let n = weights.nrows().min(weights.ncols());
82
83        // Compute W^T * W
84        let wtw = weights.t().dot(weights);
85
86        // Compute gradient: 2 * lambda * W * (W^T * W - I)
87        let mut diff = wtw.clone();
88        for i in 0..n {
89            diff[[i, i]] = diff[[i, i]] - A::one();
90        }
91
92        weights.dot(&diff) * (A::from_f64(2.0).unwrap() * self.lambda)
93    }
94}
95
96// Implement Regularizer trait
97impl<
98        A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
99        D: Dimension + Send + Sync,
100    > Regularizer<A, D> for OrthogonalRegularization<A>
101{
102    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
103        if params.ndim() != 2 {
104            // Only apply to 2D weight matrices
105            return Ok(A::zero());
106        }
107
108        // Downcast to 2D
109        let params_2d = params
110            .view()
111            .into_dimensionality::<Ix2>()
112            .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
113
114        let gradient_update = self.compute_gradient_2d(&params_2d);
115
116        // Add orthogonal regularization gradient
117        let mut gradients_2d = gradients
118            .view_mut()
119            .into_dimensionality::<Ix2>()
120            .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
121
122        // Manual element-wise addition
123        gradients_2d.zip_mut_with(&gradient_update, |g, &u| *g = *g + u);
124
125        // Return penalty
126        Ok(self.compute_penalty_2d(&params_2d))
127    }
128
129    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
130        if params.ndim() != 2 {
131            // Only apply to 2D weight matrices
132            return Ok(A::zero());
133        }
134
135        // Downcast to 2D
136        let params_2d = params
137            .view()
138            .into_dimensionality::<Ix2>()
139            .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
140
141        Ok(self.compute_penalty_2d(&params_2d))
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use approx::assert_relative_eq;
149    use scirs2_core::ndarray::array;
150
151    #[test]
152    fn test_orthogonal_creation() {
153        let ortho = OrthogonalRegularization::<f64>::new(0.01);
154        assert_eq!(ortho.lambda, 0.01);
155    }
156
157    #[test]
158    fn test_identity_matrix_penalty() {
159        let ortho = OrthogonalRegularization::new(0.01);
160
161        // Identity matrix is already orthogonal, penalty should be 0
162        let weights = array![[1.0, 0.0], [0.0, 1.0]];
163        let penalty = ortho.compute_penalty_2d(&weights);
164
165        assert_relative_eq!(penalty, 0.0, epsilon = 1e-10);
166    }
167
168    #[test]
169    fn test_non_orthogonal_penalty() {
170        let ortho = OrthogonalRegularization::new(0.01);
171
172        // Non-orthogonal matrix should have non-zero penalty
173        let weights = array![[1.0, 0.5], [0.5, 1.0]];
174        let penalty = ortho.compute_penalty_2d(&weights);
175
176        assert!(penalty > 0.0);
177    }
178
179    #[test]
180    fn test_rectangular_matrix() {
181        let ortho = OrthogonalRegularization::new(0.01);
182
183        // Rectangular matrix
184        let weights = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
185        let penalty = ortho.compute_penalty_2d(&weights);
186
187        // First 2x2 block is identity, rest should contribute to penalty
188        assert!(penalty >= 0.0);
189    }
190
191    #[test]
192    fn test_gradient_computation() {
193        let ortho = OrthogonalRegularization::new(0.1);
194
195        let weights = array![[1.0, 0.5], [0.5, 1.0]];
196        let gradient = ortho.compute_gradient_2d(&weights);
197
198        // Gradient should not be zero for non-orthogonal matrix
199        assert!(gradient.abs().sum() > 0.0);
200    }
201
202    #[test]
203    fn test_regularizer_trait() {
204        let ortho = OrthogonalRegularization::new(0.01);
205
206        let params = array![[1.0, 0.5], [0.5, 1.0]];
207        let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
208        let original_gradient = gradient.clone();
209
210        let penalty = ortho.apply(&params, &mut gradient).unwrap();
211
212        // Penalty should be positive
213        assert!(penalty > 0.0);
214
215        // Gradient should be modified
216        assert_ne!(gradient, original_gradient);
217
218        // Penalty from apply should match penalty method
219        let penalty2 = ortho.penalty(&params).unwrap();
220        assert_relative_eq!(penalty, penalty2, epsilon = 1e-10);
221    }
222
223    #[test]
224    fn test_non_2d_array() {
225        let ortho = OrthogonalRegularization::new(0.01);
226
227        // 3D array - should return zero penalty
228        let params = Array3::<f64>::zeros((2, 2, 2));
229        let mut gradient = Array3::<f64>::zeros((2, 2, 2));
230
231        let penalty = ortho.apply(&params, &mut gradient).unwrap();
232        assert_eq!(penalty, 0.0);
233
234        // Gradient should be unchanged
235        assert_eq!(gradient, Array3::<f64>::zeros((2, 2, 2)));
236    }
237}