use crate::mds::eigh_faer;
use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::Fit;
use ndarray::Array2;
#[derive(Debug, Clone)]
pub struct LLE {
n_components: usize,
n_neighbors: usize,
reg: f64,
}
impl LLE {
#[must_use]
pub fn new(n_components: usize) -> Self {
Self {
n_components,
n_neighbors: 5,
reg: 1e-3,
}
}
#[must_use]
pub fn with_n_neighbors(mut self, k: usize) -> Self {
self.n_neighbors = k;
self
}
#[must_use]
pub fn with_reg(mut self, reg: f64) -> Self {
self.reg = reg;
self
}
#[must_use]
pub fn n_components(&self) -> usize {
self.n_components
}
#[must_use]
pub fn n_neighbors(&self) -> usize {
self.n_neighbors
}
#[must_use]
pub fn reg(&self) -> f64 {
self.reg
}
}
#[derive(Debug, Clone)]
pub struct FittedLLE {
embedding_: Array2<f64>,
}
impl FittedLLE {
#[must_use]
pub fn embedding(&self) -> &Array2<f64> {
&self.embedding_
}
}
fn find_neighbors(x: &Array2<f64>, k: usize) -> Vec<Vec<usize>> {
let n = x.nrows();
let d = x.ncols();
let mut result = Vec::with_capacity(n);
for i in 0..n {
let mut dists: Vec<(f64, usize)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
let mut sq = 0.0;
for f in 0..d {
let diff = x[[i, f]] - x[[j, f]];
sq += diff * diff;
}
(sq, j)
})
.collect();
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
result.push(dists.iter().take(k).map(|&(_, j)| j).collect());
}
result
}
fn compute_weights(
x: &Array2<f64>,
neighbors: &[Vec<usize>],
reg: f64,
) -> Result<Array2<f64>, FerroError> {
let n = x.nrows();
let d = x.ncols();
let mut w = Array2::<f64>::zeros((n, n));
for i in 0..n {
let k = neighbors[i].len();
let mut z = Array2::<f64>::zeros((k, d));
for (j_idx, &j) in neighbors[i].iter().enumerate() {
for f in 0..d {
z[[j_idx, f]] = x[[i, f]] - x[[j, f]];
}
}
let mut c = z.dot(&z.t());
let trace: f64 = (0..k).map(|j| c[[j, j]]).sum();
let reg_val = reg * trace / k as f64;
let reg_val = if reg_val.abs() < 1e-15 { reg } else { reg_val };
for j in 0..k {
c[[j, j]] += reg_val;
}
let mut augmented = Array2::<f64>::zeros((k, k + 1));
for r in 0..k {
for col in 0..k {
augmented[[r, col]] = c[[r, col]];
}
augmented[[r, k]] = 1.0;
}
for col in 0..k {
let mut max_val = augmented[[col, col]].abs();
let mut max_row = col;
for r in (col + 1)..k {
let val = augmented[[r, col]].abs();
if val > max_val {
max_val = val;
max_row = r;
}
}
if max_val < 1e-15 {
return Err(FerroError::NumericalInstability {
message: format!(
"Singular local covariance matrix at point {i}. \
Try increasing reg or n_neighbors."
),
});
}
if max_row != col {
for c_idx in 0..=k {
let tmp = augmented[[col, c_idx]];
augmented[[col, c_idx]] = augmented[[max_row, c_idx]];
augmented[[max_row, c_idx]] = tmp;
}
}
let pivot = augmented[[col, col]];
for c_idx in col..=k {
augmented[[col, c_idx]] /= pivot;
}
for r in 0..k {
if r != col {
let factor = augmented[[r, col]];
for c_idx in col..=k {
augmented[[r, c_idx]] -= factor * augmented[[col, c_idx]];
}
}
}
}
let mut w_local = vec![0.0; k];
for j in 0..k {
w_local[j] = augmented[[j, k]];
}
let sum: f64 = w_local.iter().sum();
if sum.abs() > 1e-15 {
for val in &mut w_local {
*val /= sum;
}
}
for (j_idx, &j) in neighbors[i].iter().enumerate() {
w[[i, j]] = w_local[j_idx];
}
}
Ok(w)
}
impl Fit<Array2<f64>, ()> for LLE {
type Fitted = FittedLLE;
type Error = FerroError;
fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedLLE, FerroError> {
let n = x.nrows();
if self.n_components == 0 {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: "must be at least 1".into(),
});
}
if self.n_neighbors == 0 {
return Err(FerroError::InvalidParameter {
name: "n_neighbors".into(),
reason: "must be at least 1".into(),
});
}
if n < 2 {
return Err(FerroError::InsufficientSamples {
required: 2,
actual: n,
context: "LLE::fit requires at least 2 samples".into(),
});
}
if self.n_neighbors >= n {
return Err(FerroError::InvalidParameter {
name: "n_neighbors".into(),
reason: format!(
"n_neighbors ({}) must be less than n_samples ({})",
self.n_neighbors, n
),
});
}
if self.n_components >= n {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: format!(
"n_components ({}) must be less than n_samples ({})",
self.n_components, n
),
});
}
if self.reg < 0.0 {
return Err(FerroError::InvalidParameter {
name: "reg".into(),
reason: "must be non-negative".into(),
});
}
let neighbors = find_neighbors(x, self.n_neighbors);
let w = compute_weights(x, &neighbors, self.reg)?;
let mut iw = Array2::<f64>::zeros((n, n));
for i in 0..n {
iw[[i, i]] = 1.0;
for j in 0..n {
iw[[i, j]] -= w[[i, j]];
}
}
let m = iw.t().dot(&iw);
let (eigenvalues, eigenvectors) = eigh_faer(&m)?;
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| {
eigenvalues[a]
.partial_cmp(&eigenvalues[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let n_comp = self.n_components;
let mut embedding = Array2::<f64>::zeros((n, n_comp));
for (k, &idx) in indices.iter().skip(1).take(n_comp).enumerate() {
for i in 0..n {
embedding[[i, k]] = eigenvectors[[i, idx]];
}
}
Ok(FittedLLE {
embedding_: embedding,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn grid_data() -> Array2<f64> {
array![
[0.0, 0.0],
[1.0, 0.0],
[2.0, 0.0],
[0.0, 1.0],
[1.0, 1.0],
[2.0, 1.0],
[0.0, 2.0],
[1.0, 2.0],
[2.0, 2.0],
]
}
fn line_data() -> Array2<f64> {
array![
[0.0, 0.0],
[1.0, 0.0],
[2.0, 0.0],
[3.0, 0.0],
[4.0, 0.0],
[5.0, 0.0],
]
}
#[test]
fn test_lle_basic_shape() {
let lle = LLE::new(2).with_n_neighbors(3);
let x = grid_data();
let fitted = lle.fit(&x, &()).unwrap();
assert_eq!(fitted.embedding().dim(), (9, 2));
}
#[test]
fn test_lle_1d() {
let lle = LLE::new(1).with_n_neighbors(2);
let x = line_data();
let fitted = lle.fit(&x, &()).unwrap();
assert_eq!(fitted.embedding().ncols(), 1);
}
#[test]
fn test_lle_preserves_local_structure() {
let lle = LLE::new(1).with_n_neighbors(2);
let x = line_data();
let fitted = lle.fit(&x, &()).unwrap();
let emb = fitted.embedding();
let vals: Vec<f64> = (0..6).map(|i| emb[[i, 0]]).collect();
let ascending = vals.windows(2).all(|w| w[0] <= w[1] + 1e-6);
let descending = vals.windows(2).all(|w| w[0] >= w[1] - 1e-6);
assert!(
ascending || descending,
"embedding should be monotonic: {vals:?}"
);
}
#[test]
fn test_lle_invalid_n_components_zero() {
let lle = LLE::new(0);
let x = grid_data();
assert!(lle.fit(&x, &()).is_err());
}
#[test]
fn test_lle_invalid_n_neighbors_zero() {
let lle = LLE::new(2).with_n_neighbors(0);
let x = grid_data();
assert!(lle.fit(&x, &()).is_err());
}
#[test]
fn test_lle_n_neighbors_too_large() {
let lle = LLE::new(2).with_n_neighbors(100);
let x = grid_data(); assert!(lle.fit(&x, &()).is_err());
}
#[test]
fn test_lle_insufficient_samples() {
let lle = LLE::new(1).with_n_neighbors(1);
let x = array![[1.0, 2.0]]; assert!(lle.fit(&x, &()).is_err());
}
#[test]
fn test_lle_getters() {
let lle = LLE::new(3).with_n_neighbors(7).with_reg(0.01);
assert_eq!(lle.n_components(), 3);
assert_eq!(lle.n_neighbors(), 7);
assert!((lle.reg() - 0.01).abs() < 1e-15);
}
#[test]
fn test_lle_default_params() {
let lle = LLE::new(2);
assert_eq!(lle.n_neighbors(), 5);
assert!((lle.reg() - 1e-3).abs() < 1e-15);
}
#[test]
fn test_lle_n_components_too_large() {
let lle = LLE::new(50);
let x = grid_data(); assert!(lle.fit(&x, &()).is_err());
}
#[test]
fn test_lle_negative_reg() {
let lle = LLE::new(2).with_reg(-1.0);
let x = grid_data();
assert!(lle.fit(&x, &()).is_err());
}
#[test]
fn test_lle_larger_dataset() {
let n = 20;
let d = 3;
let mut data = Array2::<f64>::zeros((n, d));
for i in 0..n {
for j in 0..d {
data[[i, j]] = (i * d + j) as f64 / (n * d) as f64;
}
}
let lle = LLE::new(2).with_n_neighbors(5);
let fitted = lle.fit(&data, &()).unwrap();
assert_eq!(fitted.embedding().dim(), (20, 2));
}
#[test]
fn test_lle_different_n_neighbors() {
let x = grid_data();
let lle3 = LLE::new(2).with_n_neighbors(3);
let lle6 = LLE::new(2).with_n_neighbors(6);
let fitted3 = lle3.fit(&x, &()).unwrap();
let fitted6 = lle6.fit(&x, &()).unwrap();
let emb3 = fitted3.embedding();
let emb6 = fitted6.embedding();
let mut diff_sum = 0.0;
for (a, b) in emb3.iter().zip(emb6.iter()) {
diff_sum += (a - b).abs();
}
assert!(
diff_sum > 1e-10,
"different n_neighbors should produce different embeddings (got diff_sum={diff_sum})"
);
}
}