1use crate::mds::eigh_faer;
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::Fit;
39use ndarray::Array2;
40
41#[derive(Debug, Clone)]
50pub struct LLE {
51 n_components: usize,
53 n_neighbors: usize,
55 reg: f64,
57}
58
59impl LLE {
60 #[must_use]
64 pub fn new(n_components: usize) -> Self {
65 Self {
66 n_components,
67 n_neighbors: 5,
68 reg: 1e-3,
69 }
70 }
71
72 #[must_use]
74 pub fn with_n_neighbors(mut self, k: usize) -> Self {
75 self.n_neighbors = k;
76 self
77 }
78
79 #[must_use]
81 pub fn with_reg(mut self, reg: f64) -> Self {
82 self.reg = reg;
83 self
84 }
85
86 #[must_use]
88 pub fn n_components(&self) -> usize {
89 self.n_components
90 }
91
92 #[must_use]
94 pub fn n_neighbors(&self) -> usize {
95 self.n_neighbors
96 }
97
98 #[must_use]
100 pub fn reg(&self) -> f64 {
101 self.reg
102 }
103}
104
105#[derive(Debug, Clone)]
113pub struct FittedLLE {
114 embedding_: Array2<f64>,
116}
117
118impl FittedLLE {
119 #[must_use]
121 pub fn embedding(&self) -> &Array2<f64> {
122 &self.embedding_
123 }
124}
125
126fn find_neighbors(x: &Array2<f64>, k: usize) -> Vec<Vec<usize>> {
133 let n = x.nrows();
134 let d = x.ncols();
135 let mut result = Vec::with_capacity(n);
136
137 for i in 0..n {
138 let mut dists: Vec<(f64, usize)> = (0..n)
139 .filter(|&j| j != i)
140 .map(|j| {
141 let mut sq = 0.0;
142 for f in 0..d {
143 let diff = x[[i, f]] - x[[j, f]];
144 sq += diff * diff;
145 }
146 (sq, j)
147 })
148 .collect();
149 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
150 result.push(dists.iter().take(k).map(|&(_, j)| j).collect());
151 }
152 result
153}
154
155fn compute_weights(
163 x: &Array2<f64>,
164 neighbors: &[Vec<usize>],
165 reg: f64,
166) -> Result<Array2<f64>, FerroError> {
167 let n = x.nrows();
168 let d = x.ncols();
169 let mut w = Array2::<f64>::zeros((n, n));
170
171 for i in 0..n {
172 let k = neighbors[i].len();
173
174 let mut z = Array2::<f64>::zeros((k, d));
176 for (j_idx, &j) in neighbors[i].iter().enumerate() {
177 for f in 0..d {
178 z[[j_idx, f]] = x[[i, f]] - x[[j, f]];
179 }
180 }
181
182 let mut c = z.dot(&z.t());
184
185 let trace: f64 = (0..k).map(|j| c[[j, j]]).sum();
187 let reg_val = reg * trace / k as f64;
188 let reg_val = if reg_val.abs() < 1e-15 { reg } else { reg_val };
190 for j in 0..k {
191 c[[j, j]] += reg_val;
192 }
193
194 let mut augmented = Array2::<f64>::zeros((k, k + 1));
197 for r in 0..k {
198 for col in 0..k {
199 augmented[[r, col]] = c[[r, col]];
200 }
201 augmented[[r, k]] = 1.0;
202 }
203
204 for col in 0..k {
206 let mut max_val = augmented[[col, col]].abs();
208 let mut max_row = col;
209 for r in (col + 1)..k {
210 let val = augmented[[r, col]].abs();
211 if val > max_val {
212 max_val = val;
213 max_row = r;
214 }
215 }
216 if max_val < 1e-15 {
217 return Err(FerroError::NumericalInstability {
218 message: format!(
219 "Singular local covariance matrix at point {i}. \
220 Try increasing reg or n_neighbors."
221 ),
222 });
223 }
224 if max_row != col {
225 for c_idx in 0..=k {
226 let tmp = augmented[[col, c_idx]];
227 augmented[[col, c_idx]] = augmented[[max_row, c_idx]];
228 augmented[[max_row, c_idx]] = tmp;
229 }
230 }
231 let pivot = augmented[[col, col]];
232 for c_idx in col..=k {
233 augmented[[col, c_idx]] /= pivot;
234 }
235 for r in 0..k {
236 if r != col {
237 let factor = augmented[[r, col]];
238 for c_idx in col..=k {
239 augmented[[r, c_idx]] -= factor * augmented[[col, c_idx]];
240 }
241 }
242 }
243 }
244
245 let mut w_local = vec![0.0; k];
247 for j in 0..k {
248 w_local[j] = augmented[[j, k]];
249 }
250
251 let sum: f64 = w_local.iter().sum();
253 if sum.abs() > 1e-15 {
254 for val in &mut w_local {
255 *val /= sum;
256 }
257 }
258
259 for (j_idx, &j) in neighbors[i].iter().enumerate() {
261 w[[i, j]] = w_local[j_idx];
262 }
263 }
264
265 Ok(w)
266}
267
268impl Fit<Array2<f64>, ()> for LLE {
273 type Fitted = FittedLLE;
274 type Error = FerroError;
275
276 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedLLE, FerroError> {
288 let n = x.nrows();
289
290 if self.n_components == 0 {
291 return Err(FerroError::InvalidParameter {
292 name: "n_components".into(),
293 reason: "must be at least 1".into(),
294 });
295 }
296 if self.n_neighbors == 0 {
297 return Err(FerroError::InvalidParameter {
298 name: "n_neighbors".into(),
299 reason: "must be at least 1".into(),
300 });
301 }
302 if n < 2 {
303 return Err(FerroError::InsufficientSamples {
304 required: 2,
305 actual: n,
306 context: "LLE::fit requires at least 2 samples".into(),
307 });
308 }
309 if self.n_neighbors >= n {
310 return Err(FerroError::InvalidParameter {
311 name: "n_neighbors".into(),
312 reason: format!(
313 "n_neighbors ({}) must be less than n_samples ({})",
314 self.n_neighbors, n
315 ),
316 });
317 }
318 if self.n_components >= n {
320 return Err(FerroError::InvalidParameter {
321 name: "n_components".into(),
322 reason: format!(
323 "n_components ({}) must be less than n_samples ({})",
324 self.n_components, n
325 ),
326 });
327 }
328 if self.reg < 0.0 {
329 return Err(FerroError::InvalidParameter {
330 name: "reg".into(),
331 reason: "must be non-negative".into(),
332 });
333 }
334
335 let neighbors = find_neighbors(x, self.n_neighbors);
337
338 let w = compute_weights(x, &neighbors, self.reg)?;
340
341 let mut iw = Array2::<f64>::zeros((n, n));
344 for i in 0..n {
345 iw[[i, i]] = 1.0;
346 for j in 0..n {
347 iw[[i, j]] -= w[[i, j]];
348 }
349 }
350 let m = iw.t().dot(&iw);
352
353 let (eigenvalues, eigenvectors) = eigh_faer(&m)?;
355
356 let mut indices: Vec<usize> = (0..n).collect();
358 indices.sort_by(|&a, &b| {
359 eigenvalues[a]
360 .partial_cmp(&eigenvalues[b])
361 .unwrap_or(std::cmp::Ordering::Equal)
362 });
363
364 let n_comp = self.n_components;
366 let mut embedding = Array2::<f64>::zeros((n, n_comp));
367 for (k, &idx) in indices.iter().skip(1).take(n_comp).enumerate() {
368 for i in 0..n {
369 embedding[[i, k]] = eigenvectors[[i, idx]];
370 }
371 }
372
373 Ok(FittedLLE {
374 embedding_: embedding,
375 })
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386 use ndarray::array;
387
388 fn grid_data() -> Array2<f64> {
390 array![
391 [0.0, 0.0],
392 [1.0, 0.0],
393 [2.0, 0.0],
394 [0.0, 1.0],
395 [1.0, 1.0],
396 [2.0, 1.0],
397 [0.0, 2.0],
398 [1.0, 2.0],
399 [2.0, 2.0],
400 ]
401 }
402
403 fn line_data() -> Array2<f64> {
405 array![
406 [0.0, 0.0],
407 [1.0, 0.0],
408 [2.0, 0.0],
409 [3.0, 0.0],
410 [4.0, 0.0],
411 [5.0, 0.0],
412 ]
413 }
414
415 #[test]
416 fn test_lle_basic_shape() {
417 let lle = LLE::new(2).with_n_neighbors(3);
418 let x = grid_data();
419 let fitted = lle.fit(&x, &()).unwrap();
420 assert_eq!(fitted.embedding().dim(), (9, 2));
421 }
422
423 #[test]
424 fn test_lle_1d() {
425 let lle = LLE::new(1).with_n_neighbors(2);
426 let x = line_data();
427 let fitted = lle.fit(&x, &()).unwrap();
428 assert_eq!(fitted.embedding().ncols(), 1);
429 }
430
431 #[test]
432 fn test_lle_preserves_local_structure() {
433 let lle = LLE::new(1).with_n_neighbors(2);
436 let x = line_data();
437 let fitted = lle.fit(&x, &()).unwrap();
438 let emb = fitted.embedding();
439 let vals: Vec<f64> = (0..6).map(|i| emb[[i, 0]]).collect();
441 let ascending = vals.windows(2).all(|w| w[0] <= w[1] + 1e-6);
442 let descending = vals.windows(2).all(|w| w[0] >= w[1] - 1e-6);
443 assert!(
444 ascending || descending,
445 "embedding should be monotonic: {vals:?}"
446 );
447 }
448
449 #[test]
450 fn test_lle_invalid_n_components_zero() {
451 let lle = LLE::new(0);
452 let x = grid_data();
453 assert!(lle.fit(&x, &()).is_err());
454 }
455
456 #[test]
457 fn test_lle_invalid_n_neighbors_zero() {
458 let lle = LLE::new(2).with_n_neighbors(0);
459 let x = grid_data();
460 assert!(lle.fit(&x, &()).is_err());
461 }
462
463 #[test]
464 fn test_lle_n_neighbors_too_large() {
465 let lle = LLE::new(2).with_n_neighbors(100);
466 let x = grid_data(); assert!(lle.fit(&x, &()).is_err());
468 }
469
470 #[test]
471 fn test_lle_insufficient_samples() {
472 let lle = LLE::new(1).with_n_neighbors(1);
473 let x = array![[1.0, 2.0]]; assert!(lle.fit(&x, &()).is_err());
475 }
476
477 #[test]
478 fn test_lle_getters() {
479 let lle = LLE::new(3).with_n_neighbors(7).with_reg(0.01);
480 assert_eq!(lle.n_components(), 3);
481 assert_eq!(lle.n_neighbors(), 7);
482 assert!((lle.reg() - 0.01).abs() < 1e-15);
483 }
484
485 #[test]
486 fn test_lle_default_params() {
487 let lle = LLE::new(2);
488 assert_eq!(lle.n_neighbors(), 5);
489 assert!((lle.reg() - 1e-3).abs() < 1e-15);
490 }
491
492 #[test]
493 fn test_lle_n_components_too_large() {
494 let lle = LLE::new(50);
495 let x = grid_data(); assert!(lle.fit(&x, &()).is_err());
497 }
498
499 #[test]
500 fn test_lle_negative_reg() {
501 let lle = LLE::new(2).with_reg(-1.0);
502 let x = grid_data();
503 assert!(lle.fit(&x, &()).is_err());
504 }
505
506 #[test]
507 fn test_lle_larger_dataset() {
508 let n = 20;
509 let d = 3;
510 let mut data = Array2::<f64>::zeros((n, d));
511 for i in 0..n {
512 for j in 0..d {
513 data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
514 }
515 }
516 let lle = LLE::new(2).with_n_neighbors(5);
517 let fitted = lle.fit(&data, &()).unwrap();
518 assert_eq!(fitted.embedding().dim(), (20, 2));
519 }
520
521 #[test]
522 fn test_lle_different_n_neighbors() {
523 let x = grid_data();
525 let lle3 = LLE::new(2).with_n_neighbors(3);
526 let lle6 = LLE::new(2).with_n_neighbors(6);
527 let fitted3 = lle3.fit(&x, &()).unwrap();
528 let fitted6 = lle6.fit(&x, &()).unwrap();
529 let emb3 = fitted3.embedding();
530 let emb6 = fitted6.embedding();
531 let mut diff_sum = 0.0;
532 for (a, b) in emb3.iter().zip(emb6.iter()) {
533 diff_sum += (a - b).abs();
534 }
535 assert!(
536 diff_sum > 1e-10,
537 "different n_neighbors should produce different embeddings (got diff_sum={diff_sum})"
538 );
539 }
540}