optirs_core/regularizers/
manifold.rs1use scirs2_core::ndarray::{Array, Array2, ArrayBase, Data, Dimension, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::regularizers::Regularizer;
12
13#[derive(Debug, Clone)]
37pub struct ManifoldRegularization<A: Float> {
38 lambda: A,
40 similarity_matrix: Option<Array2<A>>,
42 degree_matrix: Option<Array2<A>>,
44 laplacian: Option<Array2<A>>,
46}
47
48impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> ManifoldRegularization<A> {
49 pub fn new(lambda: A) -> Self {
55 Self {
56 lambda,
57 similarity_matrix: None,
58 degree_matrix: None,
59 laplacian: None,
60 }
61 }
62
63 pub fn set_similarity_matrix(&mut self, similarity: Array2<A>) -> Result<()> {
69 let (rows, cols) = similarity.dim();
70 if rows != cols {
71 return Err(OptimError::InvalidConfig(
72 "Similarity matrix must be square".to_string(),
73 ));
74 }
75
76 let mut degree = Array2::zeros((rows, rows));
78 for i in 0..rows {
79 let row_sum = similarity.row(i).sum();
80 degree[[i, i]] = row_sum;
81 }
82
83 let laplacian = °ree - &similarity;
85
86 self.similarity_matrix = Some(similarity);
87 self.degree_matrix = Some(degree);
88 self.laplacian = Some(laplacian);
89
90 Ok(())
91 }
92
93 pub fn compute_penalty<S>(&self, params: &ArrayBase<S, scirs2_core::ndarray::Ix2>) -> Result<A>
95 where
96 S: Data<Elem = A>,
97 {
98 let laplacian = self
99 .laplacian
100 .as_ref()
101 .ok_or_else(|| OptimError::InvalidConfig("Similarity matrix not set".to_string()))?;
102
103 let lf = laplacian.dot(params);
106 let penalty = params
107 .iter()
108 .zip(lf.iter())
109 .map(|(p, lf)| *p * *lf)
110 .fold(A::zero(), |acc, val| acc + val);
111
112 Ok(self.lambda * penalty)
113 }
114
115 fn compute_gradient<S>(
117 &self,
118 params: &ArrayBase<S, scirs2_core::ndarray::Ix2>,
119 ) -> Result<Array2<A>>
120 where
121 S: Data<Elem = A>,
122 {
123 let laplacian = self
124 .laplacian
125 .as_ref()
126 .ok_or_else(|| OptimError::InvalidConfig("Similarity matrix not set".to_string()))?;
127
128 let gradient =
130 laplacian.dot(params) * (A::from_f64(2.0).expect("unwrap failed") * self.lambda);
131 Ok(gradient)
132 }
133}
134
135impl<
137 A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
138 D: Dimension + Send + Sync,
139 > Regularizer<A, D> for ManifoldRegularization<A>
140{
141 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
142 if params.ndim() != 2 {
143 return Ok(A::zero());
145 }
146
147 let params_2d = params
149 .view()
150 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
151 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
152
153 let gradient_update = self.compute_gradient(¶ms_2d)?;
154
155 let mut gradients_2d = gradients
157 .view_mut()
158 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
159 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
160
161 gradients_2d.zip_mut_with(&gradient_update, |g, &u| *g = *g + u);
162
163 self.compute_penalty(¶ms_2d)
165 }
166
167 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
168 if params.ndim() != 2 {
169 return Ok(A::zero());
171 }
172
173 let params_2d = params
175 .view()
176 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
177 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
178
179 self.compute_penalty(¶ms_2d)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use approx::assert_relative_eq;
187 use scirs2_core::ndarray::array;
188
189 #[test]
190 fn test_manifold_creation() {
191 let manifold = ManifoldRegularization::<f64>::new(0.01);
192 assert_eq!(manifold.lambda, 0.01);
193 assert!(manifold.similarity_matrix.is_none());
194 }
195
196 #[test]
197 fn test_set_similarity_matrix() {
198 let mut manifold = ManifoldRegularization::new(0.01);
199
200 let similarity = array![[1.0, 0.5], [0.5, 1.0]];
202
203 assert!(manifold.set_similarity_matrix(similarity).is_ok());
204 assert!(manifold.laplacian.is_some());
205
206 let laplacian = manifold.laplacian.as_ref().expect("unwrap failed");
208 assert_relative_eq!(laplacian[[0, 0]], 0.5, epsilon = 1e-10);
212 assert_relative_eq!(laplacian[[0, 1]], -0.5, epsilon = 1e-10);
213 }
214
215 #[test]
216 fn test_invalid_similarity_matrix() {
217 let mut manifold = ManifoldRegularization::<f64>::new(0.01);
218
219 let similarity = array![[1.0, 0.5, 0.3], [0.5, 1.0, 0.4]];
221 assert!(manifold.set_similarity_matrix(similarity).is_err());
222 }
223
224 #[test]
225 fn test_penalty_without_similarity() {
226 let manifold = ManifoldRegularization::<f64>::new(0.01);
227 let params = array![[1.0, 2.0], [3.0, 4.0]];
228
229 assert!(manifold.compute_penalty(¶ms).is_err());
231 }
232
233 #[test]
234 fn test_penalty_computation() {
235 let mut manifold = ManifoldRegularization::new(0.1);
236
237 let similarity = array![[1.0, 0.8], [0.8, 1.0]];
239 manifold
240 .set_similarity_matrix(similarity)
241 .expect("unwrap failed");
242
243 let params = array![[1.0, 0.0], [0.0, 1.0]];
245 let penalty = manifold.compute_penalty(¶ms).expect("unwrap failed");
246
247 assert!(penalty > 0.0);
249 }
250
251 #[test]
252 fn test_gradient_computation() {
253 let mut manifold = ManifoldRegularization::new(0.1);
254
255 let similarity = array![[1.0, 0.8], [0.8, 1.0]];
257 manifold
258 .set_similarity_matrix(similarity)
259 .expect("unwrap failed");
260
261 let params = array![[1.0, 2.0], [3.0, 4.0]];
262 let gradient = manifold.compute_gradient(¶ms).expect("unwrap failed");
263
264 assert!(gradient.abs().sum() > 0.0);
266 }
267
268 #[test]
269 fn test_regularizer_trait() {
270 let mut manifold = ManifoldRegularization::new(0.01);
271
272 let similarity = array![[1.0, 0.6], [0.6, 1.0]];
274 manifold
275 .set_similarity_matrix(similarity)
276 .expect("unwrap failed");
277
278 let params = array![[1.0, 2.0], [3.0, 4.0]];
279 let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
280 let original_gradient = gradient.clone();
281
282 let penalty = manifold
283 .apply(¶ms, &mut gradient)
284 .expect("unwrap failed");
285
286 assert!(penalty > 0.0);
288
289 assert_ne!(gradient, original_gradient);
291 }
292
293 #[test]
294 fn test_identity_similarity() {
295 let mut manifold = ManifoldRegularization::new(0.1);
296
297 let similarity = array![[1.0, 0.0], [0.0, 1.0]];
299 manifold
300 .set_similarity_matrix(similarity)
301 .expect("unwrap failed");
302
303 let params = array![[1.0, 2.0], [3.0, 4.0]];
304 let penalty = manifold.compute_penalty(¶ms).expect("unwrap failed");
305
306 assert!(penalty >= 0.0);
308 }
309}