1use crate::mds::eigh_faer;
38use ferrolearn_core::error::FerroError;
39use ferrolearn_core::traits::Fit;
40use ndarray::Array2;
41
42#[derive(Debug, Clone, Copy, PartialEq)]
48pub enum Affinity {
49 RBF {
51 gamma: f64,
53 },
54 NearestNeighbors {
57 n_neighbors: usize,
59 },
60}
61
62#[derive(Debug, Clone)]
72pub struct SpectralEmbedding {
73 n_components: usize,
75 affinity: Affinity,
77}
78
79impl SpectralEmbedding {
80 #[must_use]
84 pub fn new(n_components: usize) -> Self {
85 Self {
86 n_components,
87 affinity: Affinity::RBF { gamma: 1.0 },
88 }
89 }
90
91 #[must_use]
93 pub fn with_affinity(mut self, affinity: Affinity) -> Self {
94 self.affinity = affinity;
95 self
96 }
97
98 #[must_use]
100 pub fn n_components(&self) -> usize {
101 self.n_components
102 }
103
104 #[must_use]
106 pub fn affinity(&self) -> Affinity {
107 self.affinity
108 }
109}
110
111#[derive(Debug, Clone)]
119pub struct FittedSpectralEmbedding {
120 embedding_: Array2<f64>,
122}
123
124impl FittedSpectralEmbedding {
125 #[must_use]
127 pub fn embedding(&self) -> &Array2<f64> {
128 &self.embedding_
129 }
130}
131
132fn build_affinity_matrix(x: &Array2<f64>, affinity: &Affinity) -> Array2<f64> {
138 let n = x.nrows();
139 match affinity {
140 Affinity::RBF { gamma } => {
141 let mut w = Array2::<f64>::zeros((n, n));
142 for i in 0..n {
143 for j in (i + 1)..n {
144 let mut sq = 0.0;
145 for k in 0..x.ncols() {
146 let diff = x[[i, k]] - x[[j, k]];
147 sq += diff * diff;
148 }
149 let val = (-gamma * sq).exp();
150 w[[i, j]] = val;
151 w[[j, i]] = val;
152 }
153 }
155 w
156 }
157 Affinity::NearestNeighbors { n_neighbors } => {
158 let k = *n_neighbors;
159 let mut w = Array2::<f64>::zeros((n, n));
160 let mut sq_dist = Array2::<f64>::zeros((n, n));
162 for i in 0..n {
163 for j in (i + 1)..n {
164 let mut sq = 0.0;
165 for f in 0..x.ncols() {
166 let diff = x[[i, f]] - x[[j, f]];
167 sq += diff * diff;
168 }
169 sq_dist[[i, j]] = sq;
170 sq_dist[[j, i]] = sq;
171 }
172 }
173 for i in 0..n {
175 let mut neighbors: Vec<(f64, usize)> = (0..n)
176 .filter(|&j| j != i)
177 .map(|j| (sq_dist[[i, j]], j))
178 .collect();
179 neighbors
180 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
181 for &(_, j) in neighbors.iter().take(k) {
182 w[[i, j]] = 1.0;
183 w[[j, i]] = 1.0; }
185 }
186 w
187 }
188 }
189}
190
191fn normalised_laplacian(w: &Array2<f64>) -> Array2<f64> {
193 let n = w.nrows();
194 let mut d_inv_sqrt = vec![0.0; n];
196 for i in 0..n {
197 let deg: f64 = (0..n).map(|j| w[[i, j]]).sum();
198 d_inv_sqrt[i] = if deg > 1e-15 { 1.0 / deg.sqrt() } else { 0.0 };
199 }
200
201 let mut l = Array2::<f64>::zeros((n, n));
202 for i in 0..n {
203 for j in 0..n {
204 if i == j {
205 l[[i, j]] = 1.0 - d_inv_sqrt[i] * w[[i, j]] * d_inv_sqrt[j];
206 } else {
207 l[[i, j]] = -d_inv_sqrt[i] * w[[i, j]] * d_inv_sqrt[j];
208 }
209 }
210 }
211 l
212}
213
214impl Fit<Array2<f64>, ()> for SpectralEmbedding {
219 type Fitted = FittedSpectralEmbedding;
220 type Error = FerroError;
221
222 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedSpectralEmbedding, FerroError> {
230 let n = x.nrows();
231
232 if self.n_components == 0 {
233 return Err(FerroError::InvalidParameter {
234 name: "n_components".into(),
235 reason: "must be at least 1".into(),
236 });
237 }
238 if n < 2 {
239 return Err(FerroError::InsufficientSamples {
240 required: 2,
241 actual: n,
242 context: "SpectralEmbedding::fit requires at least 2 samples".into(),
243 });
244 }
245 if self.n_components >= n {
247 return Err(FerroError::InvalidParameter {
248 name: "n_components".into(),
249 reason: format!(
250 "n_components ({}) must be less than n_samples ({})",
251 self.n_components, n
252 ),
253 });
254 }
255
256 if let Affinity::NearestNeighbors { n_neighbors } = self.affinity {
257 if n_neighbors == 0 {
258 return Err(FerroError::InvalidParameter {
259 name: "n_neighbors".into(),
260 reason: "must be at least 1".into(),
261 });
262 }
263 if n_neighbors >= n {
264 return Err(FerroError::InvalidParameter {
265 name: "n_neighbors".into(),
266 reason: format!(
267 "n_neighbors ({n_neighbors}) must be less than n_samples ({n})"
268 ),
269 });
270 }
271 }
272
273 if let Affinity::RBF { gamma } = self.affinity {
274 if gamma <= 0.0 {
275 return Err(FerroError::InvalidParameter {
276 name: "gamma".into(),
277 reason: "must be positive".into(),
278 });
279 }
280 }
281
282 let w = build_affinity_matrix(x, &self.affinity);
284
285 let l = normalised_laplacian(&w);
287
288 let (eigenvalues, eigenvectors) = eigh_faer(&l)?;
290
291 let mut indices: Vec<usize> = (0..n).collect();
293 indices.sort_by(|&a, &b| {
294 eigenvalues[a]
295 .partial_cmp(&eigenvalues[b])
296 .unwrap_or(std::cmp::Ordering::Equal)
297 });
298
299 let n_comp = self.n_components;
301 let mut embedding = Array2::<f64>::zeros((n, n_comp));
302 for (k, &idx) in indices.iter().skip(1).take(n_comp).enumerate() {
303 for i in 0..n {
304 embedding[[i, k]] = eigenvectors[[i, idx]];
305 }
306 }
307
308 Ok(FittedSpectralEmbedding {
309 embedding_: embedding,
310 })
311 }
312}
313
314#[cfg(test)]
319mod tests {
320 use super::*;
321 use ndarray::array;
322
323 fn two_clusters() -> Array2<f64> {
325 array![
326 [0.0, 0.0],
327 [0.1, 0.0],
328 [0.0, 0.1],
329 [0.1, 0.1],
330 [5.0, 5.0],
331 [5.1, 5.0],
332 [5.0, 5.1],
333 [5.1, 5.1],
334 ]
335 }
336
337 fn simple_data() -> Array2<f64> {
339 array![[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 1.0], [1.0, 1.0],]
340 }
341
342 #[test]
343 fn test_spectral_embedding_basic_shape() {
344 let se = SpectralEmbedding::new(2);
345 let x = two_clusters();
346 let fitted = se.fit(&x, &()).unwrap();
347 assert_eq!(fitted.embedding().dim(), (8, 2));
348 }
349
350 #[test]
351 fn test_spectral_embedding_1d() {
352 let se = SpectralEmbedding::new(1);
353 let x = two_clusters();
354 let fitted = se.fit(&x, &()).unwrap();
355 assert_eq!(fitted.embedding().ncols(), 1);
356 }
357
358 #[test]
359 fn test_spectral_embedding_rbf_separates_clusters() {
360 let se = SpectralEmbedding::new(1).with_affinity(Affinity::RBF { gamma: 1.0 });
361 let x = two_clusters();
362 let fitted = se.fit(&x, &()).unwrap();
363 let emb = fitted.embedding();
364
365 let c1_mean: f64 = (0..4).map(|i| emb[[i, 0]]).sum::<f64>() / 4.0;
369 let c2_mean: f64 = (4..8).map(|i| emb[[i, 0]]).sum::<f64>() / 4.0;
370 assert!(
371 (c1_mean - c2_mean).abs() > 0.01,
372 "clusters should be separated: c1={c1_mean}, c2={c2_mean}"
373 );
374 }
375
376 #[test]
377 fn test_spectral_embedding_knn_affinity() {
378 let se =
379 SpectralEmbedding::new(2).with_affinity(Affinity::NearestNeighbors { n_neighbors: 3 });
380 let x = two_clusters();
381 let fitted = se.fit(&x, &()).unwrap();
382 assert_eq!(fitted.embedding().dim(), (8, 2));
383 }
384
385 #[test]
386 fn test_spectral_embedding_invalid_n_components_zero() {
387 let se = SpectralEmbedding::new(0);
388 let x = simple_data();
389 assert!(se.fit(&x, &()).is_err());
390 }
391
392 #[test]
393 fn test_spectral_embedding_n_components_too_large() {
394 let se = SpectralEmbedding::new(5); let x = simple_data();
396 assert!(se.fit(&x, &()).is_err());
397 }
398
399 #[test]
400 fn test_spectral_embedding_insufficient_samples() {
401 let se = SpectralEmbedding::new(1);
402 let x = array![[1.0, 2.0]]; assert!(se.fit(&x, &()).is_err());
404 }
405
406 #[test]
407 fn test_spectral_embedding_knn_n_neighbors_zero() {
408 let se =
409 SpectralEmbedding::new(1).with_affinity(Affinity::NearestNeighbors { n_neighbors: 0 });
410 let x = simple_data();
411 assert!(se.fit(&x, &()).is_err());
412 }
413
414 #[test]
415 fn test_spectral_embedding_getters() {
416 let se = SpectralEmbedding::new(3).with_affinity(Affinity::RBF { gamma: 0.5 });
417 assert_eq!(se.n_components(), 3);
418 assert_eq!(se.affinity(), Affinity::RBF { gamma: 0.5 });
419 }
420
421 #[test]
422 fn test_spectral_embedding_knn_too_many_neighbors() {
423 let se = SpectralEmbedding::new(1)
424 .with_affinity(Affinity::NearestNeighbors { n_neighbors: 100 });
425 let x = simple_data(); assert!(se.fit(&x, &()).is_err());
427 }
428
429 #[test]
430 fn test_spectral_embedding_negative_gamma() {
431 let se = SpectralEmbedding::new(1).with_affinity(Affinity::RBF { gamma: -1.0 });
432 let x = simple_data();
433 assert!(se.fit(&x, &()).is_err());
434 }
435
436 #[test]
437 fn test_spectral_embedding_larger_dataset() {
438 let n = 20;
439 let d = 3;
440 let mut data = Array2::<f64>::zeros((n, d));
441 for i in 0..n {
442 for j in 0..d {
443 data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
444 }
445 }
446 let se = SpectralEmbedding::new(2);
447 let fitted = se.fit(&data, &()).unwrap();
448 assert_eq!(fitted.embedding().dim(), (20, 2));
449 }
450}