optirs_core/regularizers/
orthogonal.rs1use scirs2_core::ndarray::{Array, Array3, ArrayBase, Data, Dimension, Ix2, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::regularizers::Regularizer;
12
13#[derive(Debug, Clone)]
32pub struct OrthogonalRegularization<A: Float> {
33 lambda: A,
35}
36
37impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> OrthogonalRegularization<A> {
38 pub fn new(lambda: A) -> Self {
44 Self { lambda }
45 }
46
47 pub fn compute_penalty_2d<S: Data<Elem = A>>(&self, weights: &ArrayBase<S, Ix2>) -> A {
49 let n = weights.nrows().min(weights.ncols());
50 let eye = Array::<A, Ix2>::eye(n);
51
52 let wtw = weights.t().dot(weights);
54
55 let mut penalty = A::zero();
57 for i in 0..n {
58 for j in 0..n {
59 let diff = wtw[[i, j]] - eye[[i, j]];
60 penalty = penalty + diff * diff;
61 }
62 }
63
64 if weights.nrows() != weights.ncols() {
66 let (rows, cols) = wtw.dim();
67 for i in 0..rows {
68 for j in 0..cols {
69 if i >= n || j >= n {
70 penalty = penalty + wtw[[i, j]] * wtw[[i, j]];
71 }
72 }
73 }
74 }
75
76 self.lambda * penalty
77 }
78
79 fn compute_gradient_2d<S: Data<Elem = A>>(&self, weights: &ArrayBase<S, Ix2>) -> Array<A, Ix2> {
81 let n = weights.nrows().min(weights.ncols());
82
83 let wtw = weights.t().dot(weights);
85
86 let mut diff = wtw.clone();
88 for i in 0..n {
89 diff[[i, i]] = diff[[i, i]] - A::one();
90 }
91
92 weights.dot(&diff) * (A::from_f64(2.0).unwrap() * self.lambda)
93 }
94}
95
96impl<
98 A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
99 D: Dimension + Send + Sync,
100 > Regularizer<A, D> for OrthogonalRegularization<A>
101{
102 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
103 if params.ndim() != 2 {
104 return Ok(A::zero());
106 }
107
108 let params_2d = params
110 .view()
111 .into_dimensionality::<Ix2>()
112 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
113
114 let gradient_update = self.compute_gradient_2d(¶ms_2d);
115
116 let mut gradients_2d = gradients
118 .view_mut()
119 .into_dimensionality::<Ix2>()
120 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
121
122 gradients_2d.zip_mut_with(&gradient_update, |g, &u| *g = *g + u);
124
125 Ok(self.compute_penalty_2d(¶ms_2d))
127 }
128
129 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
130 if params.ndim() != 2 {
131 return Ok(A::zero());
133 }
134
135 let params_2d = params
137 .view()
138 .into_dimensionality::<Ix2>()
139 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
140
141 Ok(self.compute_penalty_2d(¶ms_2d))
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use approx::assert_relative_eq;
149 use scirs2_core::ndarray::array;
150
151 #[test]
152 fn test_orthogonal_creation() {
153 let ortho = OrthogonalRegularization::<f64>::new(0.01);
154 assert_eq!(ortho.lambda, 0.01);
155 }
156
157 #[test]
158 fn test_identity_matrix_penalty() {
159 let ortho = OrthogonalRegularization::new(0.01);
160
161 let weights = array![[1.0, 0.0], [0.0, 1.0]];
163 let penalty = ortho.compute_penalty_2d(&weights);
164
165 assert_relative_eq!(penalty, 0.0, epsilon = 1e-10);
166 }
167
168 #[test]
169 fn test_non_orthogonal_penalty() {
170 let ortho = OrthogonalRegularization::new(0.01);
171
172 let weights = array![[1.0, 0.5], [0.5, 1.0]];
174 let penalty = ortho.compute_penalty_2d(&weights);
175
176 assert!(penalty > 0.0);
177 }
178
179 #[test]
180 fn test_rectangular_matrix() {
181 let ortho = OrthogonalRegularization::new(0.01);
182
183 let weights = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
185 let penalty = ortho.compute_penalty_2d(&weights);
186
187 assert!(penalty >= 0.0);
189 }
190
191 #[test]
192 fn test_gradient_computation() {
193 let ortho = OrthogonalRegularization::new(0.1);
194
195 let weights = array![[1.0, 0.5], [0.5, 1.0]];
196 let gradient = ortho.compute_gradient_2d(&weights);
197
198 assert!(gradient.abs().sum() > 0.0);
200 }
201
202 #[test]
203 fn test_regularizer_trait() {
204 let ortho = OrthogonalRegularization::new(0.01);
205
206 let params = array![[1.0, 0.5], [0.5, 1.0]];
207 let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
208 let original_gradient = gradient.clone();
209
210 let penalty = ortho.apply(¶ms, &mut gradient).unwrap();
211
212 assert!(penalty > 0.0);
214
215 assert_ne!(gradient, original_gradient);
217
218 let penalty2 = ortho.penalty(¶ms).unwrap();
220 assert_relative_eq!(penalty, penalty2, epsilon = 1e-10);
221 }
222
223 #[test]
224 fn test_non_2d_array() {
225 let ortho = OrthogonalRegularization::new(0.01);
226
227 let params = Array3::<f64>::zeros((2, 2, 2));
229 let mut gradient = Array3::<f64>::zeros((2, 2, 2));
230
231 let penalty = ortho.apply(¶ms, &mut gradient).unwrap();
232 assert_eq!(penalty, 0.0);
233
234 assert_eq!(gradient, Array3::<f64>::zeros((2, 2, 2)));
236 }
237}