use crate::error::InterpolateError;
use scirs2_core::ndarray::Array2;
use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
#[derive(Debug, Clone)]
pub enum LandmarkStrategy {
UniformRandom,
KMeansCenters {
n_iter: usize,
},
LeverageScore {
n_random_features: usize,
},
}
pub fn select_landmarks(
x: &Array2<f64>,
m: usize,
strategy: &LandmarkStrategy,
seed: u64,
bandwidth: f64,
) -> Result<Array2<f64>, InterpolateError> {
let n = x.nrows();
let d = x.ncols();
let m = m.min(n);
if m == 0 {
return Err(InterpolateError::InvalidInput {
message: "Number of landmarks m must be > 0".to_string(),
});
}
if n == 0 {
return Err(InterpolateError::InsufficientData(
"Training set is empty".to_string(),
));
}
if m == n {
return Ok(x.to_owned());
}
let indices = match strategy {
LandmarkStrategy::UniformRandom => uniform_random_indices(n, m, seed),
LandmarkStrategy::KMeansCenters { n_iter } => kmeans_indices(x, m, *n_iter, seed)?,
LandmarkStrategy::LeverageScore { n_random_features } => {
leverage_score_indices(x, m, *n_random_features, seed, bandwidth)?
}
};
let mut landmarks = Array2::zeros((indices.len(), d));
for (new_row, &src_row) in indices.iter().enumerate() {
for col in 0..d {
landmarks[[new_row, col]] = x[[src_row, col]];
}
}
Ok(landmarks)
}
fn uniform_random_indices(n: usize, m: usize, seed: u64) -> Vec<usize> {
let mut rng = StdRng::seed_from_u64(seed);
let mut indices: Vec<usize> = (0..n).collect();
for i in 0..m {
let j = rng.random_range(i..n);
indices.swap(i, j);
}
indices[..m].to_vec()
}
fn kmeans_indices(
x: &Array2<f64>,
m: usize,
n_iter: usize,
seed: u64,
) -> Result<Vec<usize>, InterpolateError> {
let n = x.nrows();
let d = x.ncols();
let n_iter = n_iter.max(1);
let mut rng = StdRng::seed_from_u64(seed);
let mut centroids: Array2<f64> = Array2::zeros((m, d));
let first = rng.random_range(0..n);
for col in 0..d {
centroids[[0, col]] = x[[first, col]];
}
let mut min_dists: Vec<f64> = (0..n).map(|i| sq_dist_row(x, i, ¢roids, 0)).collect();
for c in 1..m {
let total: f64 = min_dists.iter().sum();
if total < 1e-300 {
let fallback = rng.random_range(0..n);
for col in 0..d {
centroids[[c, col]] = x[[fallback, col]];
}
} else {
let mut thresh = rng.random::<f64>() * total;
let mut chosen = n - 1;
for (i, &d_val) in min_dists.iter().enumerate() {
thresh -= d_val;
if thresh <= 0.0 {
chosen = i;
break;
}
}
for col in 0..d {
centroids[[c, col]] = x[[chosen, col]];
}
}
for i in 0..n {
let dist_to_new = sq_dist_row(x, i, ¢roids, c);
if dist_to_new < min_dists[i] {
min_dists[i] = dist_to_new;
}
}
}
let mut labels = vec![0usize; n];
for _ in 0..n_iter {
let mut changed = false;
for i in 0..n {
let best = (0..m)
.min_by(|&a, &b| {
sq_dist_row(x, i, ¢roids, a)
.partial_cmp(&sq_dist_row(x, i, ¢roids, b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(0);
if labels[i] != best {
labels[i] = best;
changed = true;
}
}
if !changed {
break;
}
let mut sums: Array2<f64> = Array2::zeros((m, d));
let mut counts = vec![0usize; m];
for i in 0..n {
let c = labels[i];
counts[c] += 1;
for col in 0..d {
sums[[c, col]] += x[[i, col]];
}
}
for c in 0..m {
if counts[c] > 0 {
for col in 0..d {
centroids[[c, col]] = sums[[c, col]] / counts[c] as f64;
}
}
}
}
let indices: Vec<usize> = (0..m)
.map(|c| {
(0..n)
.min_by(|&i, &j| {
sq_dist_row(x, i, ¢roids, c)
.partial_cmp(&sq_dist_row(x, j, ¢roids, c))
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(0)
})
.collect();
let mut seen = std::collections::HashSet::new();
let unique: Vec<usize> = indices
.into_iter()
.filter(|idx| seen.insert(*idx))
.collect();
Ok(unique)
}
fn leverage_score_indices(
x: &Array2<f64>,
m: usize,
n_random_features: usize,
seed: u64,
bandwidth: f64,
) -> Result<Vec<usize>, InterpolateError> {
let n = x.nrows();
let d = x.ncols();
let s = n_random_features.min(n).max(1);
let mut rng = StdRng::seed_from_u64(seed.wrapping_add(7919));
let anchor_indices: Vec<usize> = {
let mut idx: Vec<usize> = (0..n).collect();
for i in 0..s {
let j = rng.random_range(i..n);
idx.swap(i, j);
}
idx[..s].to_vec()
};
let mut sketch = vec![0.0f64; n * s];
for i in 0..n {
for (j, &aj) in anchor_indices.iter().enumerate() {
let sq: f64 = (0..d).map(|k| (x[[i, k]] - x[[aj, k]]).powi(2)).sum();
sketch[i * s + j] = (-sq / (2.0 * bandwidth * bandwidth)).exp();
}
}
let mut scores: Vec<f64> = (0..n)
.map(|i| (0..s).map(|j| sketch[i * s + j].powi(2)).sum::<f64>())
.collect();
let total: f64 = scores.iter().sum();
if total < 1e-300 {
return Ok(uniform_random_indices(n, m, seed));
}
for s_val in scores.iter_mut() {
*s_val /= total;
}
let mut selected: std::collections::HashSet<usize> = std::collections::HashSet::new();
let mut attempts = 0usize;
while selected.len() < m && attempts < n * 10 {
attempts += 1;
let mut thresh = rng.random::<f64>();
let mut cum = 0.0;
for (i, &score) in scores.iter().enumerate() {
cum += score;
if thresh <= cum {
selected.insert(i);
break;
}
}
if attempts % (n + 1) == 0 && selected.len() < m {
let fallback = rng.random_range(0..n);
selected.insert(fallback);
}
}
if selected.len() < m {
let extra = uniform_random_indices(n, m, seed.wrapping_add(1));
for idx in extra {
if selected.len() >= m {
break;
}
selected.insert(idx);
}
}
let mut indices: Vec<usize> = selected.into_iter().collect();
indices.sort_unstable();
indices.truncate(m);
Ok(indices)
}
fn sq_dist_row(x: &Array2<f64>, row: usize, centroids: &Array2<f64>, c: usize) -> f64 {
let d = x.ncols();
(0..d)
.map(|k| (x[[row, k]] - centroids[[c, k]]).powi(2))
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_grid(n: usize) -> Array2<f64> {
let mut x = Array2::zeros((n, 2));
for i in 0..n {
x[[i, 0]] = i as f64 / n as f64;
x[[i, 1]] = (i as f64 / n as f64).powi(2);
}
x
}
#[test]
fn uniform_produces_m_landmarks() {
let x = make_grid(50);
let lm = select_landmarks(&x, 10, &LandmarkStrategy::UniformRandom, 0, 1.0)
.expect("uniform landmark selection");
assert_eq!(lm.nrows(), 10);
assert_eq!(lm.ncols(), 2);
}
#[test]
fn kmeans_produces_up_to_m_landmarks() {
let x = make_grid(40);
let lm = select_landmarks(
&x,
8,
&LandmarkStrategy::KMeansCenters { n_iter: 20 },
1,
1.0,
)
.expect("kmeans landmark selection");
assert!(lm.nrows() > 0 && lm.nrows() <= 8);
}
#[test]
fn leverage_score_produces_landmarks() {
let x = make_grid(60);
let lm = select_landmarks(
&x,
12,
&LandmarkStrategy::LeverageScore {
n_random_features: 20,
},
2,
1.0,
)
.expect("leverage-score landmark selection");
assert!(lm.nrows() > 0 && lm.nrows() <= 12);
}
#[test]
fn uniform_when_m_equals_n() {
let x = make_grid(5);
let lm =
select_landmarks(&x, 5, &LandmarkStrategy::UniformRandom, 0, 1.0).expect("m == n case");
assert_eq!(lm.nrows(), 5);
}
}