use lapl::{gaussian_similarity, knn_graph, spectral_embedding, SpectralEmbeddingConfig};
use ndarray::Array2;
use crate::{Error, Result};
#[derive(Debug, Clone)]
pub struct SpectralClustering {
k: usize,
affinity: AffinityType,
sigma: f64,
n_neighbors: usize,
kmeans_iter: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum AffinityType {
Rbf,
Knn,
Precomputed,
}
impl SpectralClustering {
pub fn new(k: usize) -> Self {
Self {
k,
affinity: AffinityType::Rbf,
sigma: 1.0,
n_neighbors: 10,
kmeans_iter: 100,
}
}
pub fn affinity(mut self, affinity: AffinityType) -> Self {
self.affinity = affinity;
self
}
pub fn sigma(mut self, sigma: f64) -> Self {
self.sigma = sigma;
self
}
pub fn n_neighbors(mut self, n: usize) -> Self {
self.n_neighbors = n;
self
}
pub fn kmeans_iter(mut self, iter: usize) -> Self {
self.kmeans_iter = iter;
self
}
pub fn fit(&self, points: &Array2<f64>) -> Result<Vec<usize>> {
let n = points.nrows();
if n == 0 {
return Err(Error::EmptyInput);
}
if n < self.k {
return Err(Error::InvalidClusterCount {
requested: self.k,
n_items: n,
});
}
let affinity = self.build_affinity(points);
let cfg = SpectralEmbeddingConfig {
skip_first: false,
..SpectralEmbeddingConfig::default()
};
let embedding = spectral_embedding(&affinity, self.k, &cfg)
.map_err(|e| Error::Other(format!("lapl spectral_embedding failed: {e}")))?;
self.kmeans_on_embedding(&embedding)
}
pub fn fit_affinity(&self, affinity: &Array2<f64>) -> Result<Vec<usize>> {
let n = affinity.nrows();
if n == 0 {
return Err(Error::EmptyInput);
}
if affinity.ncols() != n {
return Err(Error::DimensionMismatch {
expected: n,
found: affinity.ncols(),
});
}
let cfg = SpectralEmbeddingConfig {
skip_first: false,
..SpectralEmbeddingConfig::default()
};
let embedding = spectral_embedding(affinity, self.k, &cfg)
.map_err(|e| Error::Other(format!("lapl spectral_embedding failed: {e}")))?;
self.kmeans_on_embedding(&embedding)
}
fn build_affinity(&self, points: &Array2<f64>) -> Array2<f64> {
match self.affinity {
AffinityType::Rbf => gaussian_similarity(points, self.sigma),
AffinityType::Knn => {
let n = points.nrows();
let mut distances = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut dist_sq = 0.0;
for d in 0..points.ncols() {
let diff = points[[i, d]] - points[[j, d]];
dist_sq += diff * diff;
}
distances[[i, j]] = dist_sq.sqrt();
}
}
knn_graph(&distances, self.n_neighbors)
}
AffinityType::Precomputed => {
panic!("use fit_affinity for precomputed affinity")
}
}
}
fn kmeans_on_embedding(&self, embedding: &Array2<f64>) -> Result<Vec<usize>> {
let n = embedding.nrows();
let d = embedding.ncols();
let k = self.k;
let mut rows: Vec<Vec<f32>> = Vec::with_capacity(n);
for i in 0..n {
let mut row = Vec::with_capacity(d);
for j in 0..d {
row.push(embedding[[i, j]] as f32);
}
rows.push(row);
}
let base_seed = 42u64;
let mut best: Option<(f32, Vec<usize>)> = None;
for t in 0..4u64 {
let fit = clump::cluster::Kmeans::new(k)
.with_max_iter(self.kmeans_iter)
.with_tol(1e-4)
.with_seed(base_seed.wrapping_add(t))
.fit(&rows)
.map_err(|e| Error::Other(format!("clump kmeans failed: {e}")))?;
let mut wcss = 0.0f32;
for (i, &a) in fit.labels.iter().enumerate() {
let c = &fit.centroids[a];
let x = &rows[i];
let mut d2 = 0.0f32;
for j in 0..d {
let diff = x[j] - c[j];
d2 += diff * diff;
}
wcss += d2;
}
let labels = fit.labels;
match &mut best {
None => best = Some((wcss, labels)),
Some((best_wcss, best_assignments)) => {
if wcss < *best_wcss {
*best_wcss = wcss;
*best_assignments = labels;
}
}
}
}
best.map(|(_, labels)| labels).ok_or_else(|| {
Error::Other("n>0 implies at least one kmeans run but best was None".into())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_spectral_two_clusters() {
let points = array![
[0.0, 0.0],
[0.1, 0.0],
[0.0, 0.1],
[5.0, 5.0],
[5.1, 5.0],
[5.0, 5.1],
];
let labels = SpectralClustering::new(2)
.sigma(1.0)
.kmeans_iter(50)
.fit(&points)
.expect("fit should succeed");
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[1], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_spectral_with_knn() {
let points = array![
[0.0, 0.0],
[0.1, 0.0],
[0.0, 0.1],
[5.0, 5.0],
[5.1, 5.0],
[5.0, 5.1],
];
let labels = SpectralClustering::new(2)
.affinity(AffinityType::Knn)
.n_neighbors(2)
.fit(&points)
.expect("fit should succeed");
assert_eq!(labels.len(), 6);
}
#[test]
fn test_spectral_empty_error() {
let points: Array2<f64> = Array2::zeros((0, 2));
let result = SpectralClustering::new(2).fit(&points);
assert!(result.is_err());
}
}