optirs_core/regularizers/
manifold.rs1use scirs2_core::ndarray::{Array, Array2, ArrayBase, Data, Dimension, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::regularizers::Regularizer;
12
13#[derive(Debug, Clone)]
37pub struct ManifoldRegularization<A: Float> {
38 lambda: A,
40 similarity_matrix: Option<Array2<A>>,
42 degree_matrix: Option<Array2<A>>,
44 laplacian: Option<Array2<A>>,
46}
47
48impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> ManifoldRegularization<A> {
49 pub fn new(lambda: A) -> Self {
55 Self {
56 lambda,
57 similarity_matrix: None,
58 degree_matrix: None,
59 laplacian: None,
60 }
61 }
62
63 pub fn set_similarity_matrix(&mut self, similarity: Array2<A>) -> Result<()> {
69 let (rows, cols) = similarity.dim();
70 if rows != cols {
71 return Err(OptimError::InvalidConfig(
72 "Similarity matrix must be square".to_string(),
73 ));
74 }
75
76 let mut degree = Array2::zeros((rows, rows));
78 for i in 0..rows {
79 let row_sum = similarity.row(i).sum();
80 degree[[i, i]] = row_sum;
81 }
82
83 let laplacian = °ree - &similarity;
85
86 self.similarity_matrix = Some(similarity);
87 self.degree_matrix = Some(degree);
88 self.laplacian = Some(laplacian);
89
90 Ok(())
91 }
92
93 pub fn compute_penalty<S>(&self, params: &ArrayBase<S, scirs2_core::ndarray::Ix2>) -> Result<A>
95 where
96 S: Data<Elem = A>,
97 {
98 let laplacian = self
99 .laplacian
100 .as_ref()
101 .ok_or_else(|| OptimError::InvalidConfig("Similarity matrix not set".to_string()))?;
102
103 let lf = laplacian.dot(params);
106 let penalty = params
107 .iter()
108 .zip(lf.iter())
109 .map(|(p, lf)| *p * *lf)
110 .fold(A::zero(), |acc, val| acc + val);
111
112 Ok(self.lambda * penalty)
113 }
114
115 fn compute_gradient<S>(
117 &self,
118 params: &ArrayBase<S, scirs2_core::ndarray::Ix2>,
119 ) -> Result<Array2<A>>
120 where
121 S: Data<Elem = A>,
122 {
123 let laplacian = self
124 .laplacian
125 .as_ref()
126 .ok_or_else(|| OptimError::InvalidConfig("Similarity matrix not set".to_string()))?;
127
128 let gradient = laplacian.dot(params) * (A::from_f64(2.0).unwrap() * self.lambda);
130 Ok(gradient)
131 }
132}
133
134impl<
136 A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
137 D: Dimension + Send + Sync,
138 > Regularizer<A, D> for ManifoldRegularization<A>
139{
140 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
141 if params.ndim() != 2 {
142 return Ok(A::zero());
144 }
145
146 let params_2d = params
148 .view()
149 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
150 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
151
152 let gradient_update = self.compute_gradient(¶ms_2d)?;
153
154 let mut gradients_2d = gradients
156 .view_mut()
157 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
158 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
159
160 gradients_2d.zip_mut_with(&gradient_update, |g, &u| *g = *g + u);
161
162 self.compute_penalty(¶ms_2d)
164 }
165
166 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
167 if params.ndim() != 2 {
168 return Ok(A::zero());
170 }
171
172 let params_2d = params
174 .view()
175 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
176 .map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
177
178 self.compute_penalty(¶ms_2d)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use approx::assert_relative_eq;
186 use scirs2_core::ndarray::array;
187
188 #[test]
189 fn test_manifold_creation() {
190 let manifold = ManifoldRegularization::<f64>::new(0.01);
191 assert_eq!(manifold.lambda, 0.01);
192 assert!(manifold.similarity_matrix.is_none());
193 }
194
195 #[test]
196 fn test_set_similarity_matrix() {
197 let mut manifold = ManifoldRegularization::new(0.01);
198
199 let similarity = array![[1.0, 0.5], [0.5, 1.0]];
201
202 assert!(manifold.set_similarity_matrix(similarity).is_ok());
203 assert!(manifold.laplacian.is_some());
204
205 let laplacian = manifold.laplacian.as_ref().unwrap();
207 assert_relative_eq!(laplacian[[0, 0]], 0.5, epsilon = 1e-10);
211 assert_relative_eq!(laplacian[[0, 1]], -0.5, epsilon = 1e-10);
212 }
213
214 #[test]
215 fn test_invalid_similarity_matrix() {
216 let mut manifold = ManifoldRegularization::<f64>::new(0.01);
217
218 let similarity = array![[1.0, 0.5, 0.3], [0.5, 1.0, 0.4]];
220 assert!(manifold.set_similarity_matrix(similarity).is_err());
221 }
222
223 #[test]
224 fn test_penalty_without_similarity() {
225 let manifold = ManifoldRegularization::<f64>::new(0.01);
226 let params = array![[1.0, 2.0], [3.0, 4.0]];
227
228 assert!(manifold.compute_penalty(¶ms).is_err());
230 }
231
232 #[test]
233 fn test_penalty_computation() {
234 let mut manifold = ManifoldRegularization::new(0.1);
235
236 let similarity = array![[1.0, 0.8], [0.8, 1.0]];
238 manifold.set_similarity_matrix(similarity).unwrap();
239
240 let params = array![[1.0, 0.0], [0.0, 1.0]];
242 let penalty = manifold.compute_penalty(¶ms).unwrap();
243
244 assert!(penalty > 0.0);
246 }
247
248 #[test]
249 fn test_gradient_computation() {
250 let mut manifold = ManifoldRegularization::new(0.1);
251
252 let similarity = array![[1.0, 0.8], [0.8, 1.0]];
254 manifold.set_similarity_matrix(similarity).unwrap();
255
256 let params = array![[1.0, 2.0], [3.0, 4.0]];
257 let gradient = manifold.compute_gradient(¶ms).unwrap();
258
259 assert!(gradient.abs().sum() > 0.0);
261 }
262
263 #[test]
264 fn test_regularizer_trait() {
265 let mut manifold = ManifoldRegularization::new(0.01);
266
267 let similarity = array![[1.0, 0.6], [0.6, 1.0]];
269 manifold.set_similarity_matrix(similarity).unwrap();
270
271 let params = array![[1.0, 2.0], [3.0, 4.0]];
272 let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
273 let original_gradient = gradient.clone();
274
275 let penalty = manifold.apply(¶ms, &mut gradient).unwrap();
276
277 assert!(penalty > 0.0);
279
280 assert_ne!(gradient, original_gradient);
282 }
283
284 #[test]
285 fn test_identity_similarity() {
286 let mut manifold = ManifoldRegularization::new(0.1);
287
288 let similarity = array![[1.0, 0.0], [0.0, 1.0]];
290 manifold.set_similarity_matrix(similarity).unwrap();
291
292 let params = array![[1.0, 2.0], [3.0, 4.0]];
293 let penalty = manifold.compute_penalty(¶ms).unwrap();
294
295 assert!(penalty >= 0.0);
297 }
298}