Skip to main content

ferrolearn_decomp/
lle.rs

1//! Locally Linear Embedding (LLE).
2//!
3//! Non-linear dimensionality reduction that preserves local geometry by
4//! reconstructing each point from its nearest neighbors and then finding a
5//! low-dimensional embedding that preserves those reconstruction weights.
6//!
7//! # Algorithm
8//!
9//! 1. Find k-nearest neighbors for each data point.
10//! 2. Compute reconstruction weights `W` by solving local least-squares
11//!    problems: for each point, minimise `||x_i - sum_j w_ij x_j||^2`
12//!    subject to `sum_j w_ij = 1`.
13//! 3. Construct `M = (I - W)^T (I - W)` and find the bottom `n_components`
14//!    eigenvectors of `M`, excluding the trivial constant eigenvector.
15//!
16//! # Examples
17//!
18//! ```
19//! use ferrolearn_decomp::LLE;
20//! use ferrolearn_core::traits::Fit;
21//! use ndarray::array;
22//!
23//! let lle = LLE::new(2);
24//! let x = array![
25//!     [0.0, 0.0],
26//!     [1.0, 0.0],
27//!     [2.0, 0.0],
28//!     [0.0, 1.0],
29//!     [1.0, 1.0],
30//!     [2.0, 1.0],
31//! ];
32//! let fitted = lle.fit(&x, &()).unwrap();
33//! assert_eq!(fitted.embedding().ncols(), 2);
34//! ```
35
36use crate::mds::eigh_faer;
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::Fit;
39use ndarray::Array2;
40
41// ---------------------------------------------------------------------------
42// LLE (unfitted)
43// ---------------------------------------------------------------------------
44
45/// Locally Linear Embedding configuration.
46///
47/// Holds hyperparameters for the LLE algorithm. Call [`Fit::fit`] to compute
48/// the embedding and obtain a [`FittedLLE`].
49#[derive(Debug, Clone)]
50pub struct LLE {
51    /// Number of embedding dimensions.
52    n_components: usize,
53    /// Number of nearest neighbors.
54    n_neighbors: usize,
55    /// Regularization parameter added to the local covariance matrix.
56    reg: f64,
57}
58
59impl LLE {
60    /// Create a new `LLE` with `n_components` embedding dimensions.
61    ///
62    /// Defaults: `n_neighbors = 5`, `reg = 1e-3`.
63    #[must_use]
64    pub fn new(n_components: usize) -> Self {
65        Self {
66            n_components,
67            n_neighbors: 5,
68            reg: 1e-3,
69        }
70    }
71
72    /// Set the number of nearest neighbors.
73    #[must_use]
74    pub fn with_n_neighbors(mut self, k: usize) -> Self {
75        self.n_neighbors = k;
76        self
77    }
78
79    /// Set the regularization parameter.
80    #[must_use]
81    pub fn with_reg(mut self, reg: f64) -> Self {
82        self.reg = reg;
83        self
84    }
85
86    /// Return the configured number of components.
87    #[must_use]
88    pub fn n_components(&self) -> usize {
89        self.n_components
90    }
91
92    /// Return the configured number of neighbors.
93    #[must_use]
94    pub fn n_neighbors(&self) -> usize {
95        self.n_neighbors
96    }
97
98    /// Return the configured regularization parameter.
99    #[must_use]
100    pub fn reg(&self) -> f64 {
101        self.reg
102    }
103}
104
105// ---------------------------------------------------------------------------
106// FittedLLE
107// ---------------------------------------------------------------------------
108
109/// A fitted LLE model holding the learned embedding.
110///
111/// Created by calling [`Fit::fit`] on a [`LLE`].
112#[derive(Debug, Clone)]
113pub struct FittedLLE {
114    /// The embedding, shape `(n_samples, n_components)`.
115    embedding_: Array2<f64>,
116}
117
118impl FittedLLE {
119    /// The embedding coordinates, shape `(n_samples, n_components)`.
120    #[must_use]
121    pub fn embedding(&self) -> &Array2<f64> {
122        &self.embedding_
123    }
124}
125
126// ---------------------------------------------------------------------------
127// Helpers
128// ---------------------------------------------------------------------------
129
130/// Find the k-nearest neighbors for each point.
131/// Returns `neighbors[i]` = sorted Vec of neighbor indices.
132fn find_neighbors(x: &Array2<f64>, k: usize) -> Vec<Vec<usize>> {
133    let n = x.nrows();
134    let d = x.ncols();
135    let mut result = Vec::with_capacity(n);
136
137    for i in 0..n {
138        let mut dists: Vec<(f64, usize)> = (0..n)
139            .filter(|&j| j != i)
140            .map(|j| {
141                let mut sq = 0.0;
142                for f in 0..d {
143                    let diff = x[[i, f]] - x[[j, f]];
144                    sq += diff * diff;
145                }
146                (sq, j)
147            })
148            .collect();
149        dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
150        result.push(dists.iter().take(k).map(|&(_, j)| j).collect());
151    }
152    result
153}
154
155/// Solve for reconstruction weights using local covariance + regularization.
156///
157/// For each point i, solve for w such that:
158///   - w minimises ||x_i - sum_j w_j * x_{neighbors_j}||^2
159///   - sum_j w_j = 1
160///
161/// Returns a sparse weight matrix stored as dense `(n, n)`.
162fn compute_weights(
163    x: &Array2<f64>,
164    neighbors: &[Vec<usize>],
165    reg: f64,
166) -> Result<Array2<f64>, FerroError> {
167    let n = x.nrows();
168    let d = x.ncols();
169    let mut w = Array2::<f64>::zeros((n, n));
170
171    for i in 0..n {
172        let k = neighbors[i].len();
173
174        // Z[j][f] = x[i][f] - x[neighbors[j]][f]
175        let mut z = Array2::<f64>::zeros((k, d));
176        for (j_idx, &j) in neighbors[i].iter().enumerate() {
177            for f in 0..d {
178                z[[j_idx, f]] = x[[i, f]] - x[[j, f]];
179            }
180        }
181
182        // Local covariance: C = Z * Z^T
183        let mut c = z.dot(&z.t());
184
185        // Regularization: C += reg * trace(C) * I / k
186        let trace: f64 = (0..k).map(|j| c[[j, j]]).sum();
187        let reg_val = reg * trace / k as f64;
188        // If trace is zero (degenerate), use a small fixed regularization.
189        let reg_val = if reg_val.abs() < 1e-15 { reg } else { reg_val };
190        for j in 0..k {
191            c[[j, j]] += reg_val;
192        }
193
194        // Solve C * w_local = ones(k)
195        // Use a simple Gaussian elimination (the matrix is small, k x k).
196        let mut augmented = Array2::<f64>::zeros((k, k + 1));
197        for r in 0..k {
198            for col in 0..k {
199                augmented[[r, col]] = c[[r, col]];
200            }
201            augmented[[r, k]] = 1.0;
202        }
203
204        // Forward elimination with partial pivoting.
205        for col in 0..k {
206            // Find pivot.
207            let mut max_val = augmented[[col, col]].abs();
208            let mut max_row = col;
209            for r in (col + 1)..k {
210                let val = augmented[[r, col]].abs();
211                if val > max_val {
212                    max_val = val;
213                    max_row = r;
214                }
215            }
216            if max_val < 1e-15 {
217                return Err(FerroError::NumericalInstability {
218                    message: format!(
219                        "Singular local covariance matrix at point {i}. \
220                         Try increasing reg or n_neighbors."
221                    ),
222                });
223            }
224            if max_row != col {
225                for c_idx in 0..=k {
226                    let tmp = augmented[[col, c_idx]];
227                    augmented[[col, c_idx]] = augmented[[max_row, c_idx]];
228                    augmented[[max_row, c_idx]] = tmp;
229                }
230            }
231            let pivot = augmented[[col, col]];
232            for c_idx in col..=k {
233                augmented[[col, c_idx]] /= pivot;
234            }
235            for r in 0..k {
236                if r != col {
237                    let factor = augmented[[r, col]];
238                    for c_idx in col..=k {
239                        augmented[[r, c_idx]] -= factor * augmented[[col, c_idx]];
240                    }
241                }
242            }
243        }
244
245        // Extract solution.
246        let mut w_local = vec![0.0; k];
247        for j in 0..k {
248            w_local[j] = augmented[[j, k]];
249        }
250
251        // Normalise so that sum = 1.
252        let sum: f64 = w_local.iter().sum();
253        if sum.abs() > 1e-15 {
254            for val in &mut w_local {
255                *val /= sum;
256            }
257        }
258
259        // Store in the weight matrix.
260        for (j_idx, &j) in neighbors[i].iter().enumerate() {
261            w[[i, j]] = w_local[j_idx];
262        }
263    }
264
265    Ok(w)
266}
267
268// ---------------------------------------------------------------------------
269// Trait implementations
270// ---------------------------------------------------------------------------
271
272impl Fit<Array2<f64>, ()> for LLE {
273    type Fitted = FittedLLE;
274    type Error = FerroError;
275
276    /// Fit LLE by computing reconstruction weights and finding the
277    /// bottom eigenvectors of `(I - W)^T (I - W)`.
278    ///
279    /// # Errors
280    ///
281    /// - [`FerroError::InvalidParameter`] if `n_components` is zero,
282    ///   `n_neighbors` is zero, or parameters are out of range.
283    /// - [`FerroError::InsufficientSamples`] if there are fewer than
284    ///   `n_neighbors + 1` samples.
285    /// - [`FerroError::NumericalInstability`] if a local covariance matrix
286    ///   is singular.
287    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedLLE, FerroError> {
288        let n = x.nrows();
289
290        if self.n_components == 0 {
291            return Err(FerroError::InvalidParameter {
292                name: "n_components".into(),
293                reason: "must be at least 1".into(),
294            });
295        }
296        if self.n_neighbors == 0 {
297            return Err(FerroError::InvalidParameter {
298                name: "n_neighbors".into(),
299                reason: "must be at least 1".into(),
300            });
301        }
302        if n < 2 {
303            return Err(FerroError::InsufficientSamples {
304                required: 2,
305                actual: n,
306                context: "LLE::fit requires at least 2 samples".into(),
307            });
308        }
309        if self.n_neighbors >= n {
310            return Err(FerroError::InvalidParameter {
311                name: "n_neighbors".into(),
312                reason: format!(
313                    "n_neighbors ({}) must be less than n_samples ({})",
314                    self.n_neighbors, n
315                ),
316            });
317        }
318        // Need n_components + 1 eigenvectors (skipping the trivial one).
319        if self.n_components >= n {
320            return Err(FerroError::InvalidParameter {
321                name: "n_components".into(),
322                reason: format!(
323                    "n_components ({}) must be less than n_samples ({})",
324                    self.n_components, n
325                ),
326            });
327        }
328        if self.reg < 0.0 {
329            return Err(FerroError::InvalidParameter {
330                name: "reg".into(),
331                reason: "must be non-negative".into(),
332            });
333        }
334
335        // Step 1: Find neighbors.
336        let neighbors = find_neighbors(x, self.n_neighbors);
337
338        // Step 2: Compute reconstruction weights.
339        let w = compute_weights(x, &neighbors, self.reg)?;
340
341        // Step 3: Construct M = (I - W)^T (I - W).
342        // I - W
343        let mut iw = Array2::<f64>::zeros((n, n));
344        for i in 0..n {
345            iw[[i, i]] = 1.0;
346            for j in 0..n {
347                iw[[i, j]] -= w[[i, j]];
348            }
349        }
350        // M = (I-W)^T (I-W)
351        let m = iw.t().dot(&iw);
352
353        // Step 4: Eigendecompose M.
354        let (eigenvalues, eigenvectors) = eigh_faer(&m)?;
355
356        // Sort eigenvalues ascending.
357        let mut indices: Vec<usize> = (0..n).collect();
358        indices.sort_by(|&a, &b| {
359            eigenvalues[a]
360                .partial_cmp(&eigenvalues[b])
361                .unwrap_or(std::cmp::Ordering::Equal)
362        });
363
364        // Skip the first (smallest, ~0) eigenvector, take next n_components.
365        let n_comp = self.n_components;
366        let mut embedding = Array2::<f64>::zeros((n, n_comp));
367        for (k, &idx) in indices.iter().skip(1).take(n_comp).enumerate() {
368            for i in 0..n {
369                embedding[[i, k]] = eigenvectors[[i, idx]];
370            }
371        }
372
373        Ok(FittedLLE {
374            embedding_: embedding,
375        })
376    }
377}
378
379// ---------------------------------------------------------------------------
380// Tests
381// ---------------------------------------------------------------------------
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use ndarray::array;
387
388    /// Helper: simple 2D grid.
389    fn grid_data() -> Array2<f64> {
390        array![
391            [0.0, 0.0],
392            [1.0, 0.0],
393            [2.0, 0.0],
394            [0.0, 1.0],
395            [1.0, 1.0],
396            [2.0, 1.0],
397            [0.0, 2.0],
398            [1.0, 2.0],
399            [2.0, 2.0],
400        ]
401    }
402
403    /// Helper: line data.
404    fn line_data() -> Array2<f64> {
405        array![
406            [0.0, 0.0],
407            [1.0, 0.0],
408            [2.0, 0.0],
409            [3.0, 0.0],
410            [4.0, 0.0],
411            [5.0, 0.0],
412        ]
413    }
414
415    #[test]
416    fn test_lle_basic_shape() {
417        let lle = LLE::new(2).with_n_neighbors(3);
418        let x = grid_data();
419        let fitted = lle.fit(&x, &()).unwrap();
420        assert_eq!(fitted.embedding().dim(), (9, 2));
421    }
422
423    #[test]
424    fn test_lle_1d() {
425        let lle = LLE::new(1).with_n_neighbors(2);
426        let x = line_data();
427        let fitted = lle.fit(&x, &()).unwrap();
428        assert_eq!(fitted.embedding().ncols(), 1);
429    }
430
431    #[test]
432    fn test_lle_preserves_local_structure() {
433        // Points on a line embedded in 1D: the embedding should roughly
434        // preserve neighbor relationships.
435        let lle = LLE::new(1).with_n_neighbors(2);
436        let x = line_data();
437        let fitted = lle.fit(&x, &()).unwrap();
438        let emb = fitted.embedding();
439        // Check that the embedding is monotonic (preserves ordering).
440        let vals: Vec<f64> = (0..6).map(|i| emb[[i, 0]]).collect();
441        let ascending = vals.windows(2).all(|w| w[0] <= w[1] + 1e-6);
442        let descending = vals.windows(2).all(|w| w[0] >= w[1] - 1e-6);
443        assert!(
444            ascending || descending,
445            "embedding should be monotonic: {vals:?}"
446        );
447    }
448
449    #[test]
450    fn test_lle_invalid_n_components_zero() {
451        let lle = LLE::new(0);
452        let x = grid_data();
453        assert!(lle.fit(&x, &()).is_err());
454    }
455
456    #[test]
457    fn test_lle_invalid_n_neighbors_zero() {
458        let lle = LLE::new(2).with_n_neighbors(0);
459        let x = grid_data();
460        assert!(lle.fit(&x, &()).is_err());
461    }
462
463    #[test]
464    fn test_lle_n_neighbors_too_large() {
465        let lle = LLE::new(2).with_n_neighbors(100);
466        let x = grid_data(); // 9 samples
467        assert!(lle.fit(&x, &()).is_err());
468    }
469
470    #[test]
471    fn test_lle_insufficient_samples() {
472        let lle = LLE::new(1).with_n_neighbors(1);
473        let x = array![[1.0, 2.0]]; // 1 sample
474        assert!(lle.fit(&x, &()).is_err());
475    }
476
477    #[test]
478    fn test_lle_getters() {
479        let lle = LLE::new(3).with_n_neighbors(7).with_reg(0.01);
480        assert_eq!(lle.n_components(), 3);
481        assert_eq!(lle.n_neighbors(), 7);
482        assert!((lle.reg() - 0.01).abs() < 1e-15);
483    }
484
485    #[test]
486    fn test_lle_default_params() {
487        let lle = LLE::new(2);
488        assert_eq!(lle.n_neighbors(), 5);
489        assert!((lle.reg() - 1e-3).abs() < 1e-15);
490    }
491
492    #[test]
493    fn test_lle_n_components_too_large() {
494        let lle = LLE::new(50);
495        let x = grid_data(); // 9 samples
496        assert!(lle.fit(&x, &()).is_err());
497    }
498
499    #[test]
500    fn test_lle_negative_reg() {
501        let lle = LLE::new(2).with_reg(-1.0);
502        let x = grid_data();
503        assert!(lle.fit(&x, &()).is_err());
504    }
505
506    #[test]
507    fn test_lle_larger_dataset() {
508        let n = 20;
509        let d = 3;
510        let mut data = Array2::<f64>::zeros((n, d));
511        for i in 0..n {
512            for j in 0..d {
513                data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
514            }
515        }
516        let lle = LLE::new(2).with_n_neighbors(5);
517        let fitted = lle.fit(&data, &()).unwrap();
518        assert_eq!(fitted.embedding().dim(), (20, 2));
519    }
520
521    #[test]
522    fn test_lle_different_n_neighbors() {
523        // Different n_neighbors should produce different embeddings.
524        let x = grid_data();
525        let lle3 = LLE::new(2).with_n_neighbors(3);
526        let lle6 = LLE::new(2).with_n_neighbors(6);
527        let fitted3 = lle3.fit(&x, &()).unwrap();
528        let fitted6 = lle6.fit(&x, &()).unwrap();
529        let emb3 = fitted3.embedding();
530        let emb6 = fitted6.embedding();
531        let mut diff_sum = 0.0;
532        for (a, b) in emb3.iter().zip(emb6.iter()) {
533            diff_sum += (a - b).abs();
534        }
535        assert!(
536            diff_sum > 1e-10,
537            "different n_neighbors should produce different embeddings (got diff_sum={diff_sum})"
538        );
539    }
540}