use crate::kmeans::KMeans;
use ferrolearn_core::NdarrayFaerBackend;
use ferrolearn_core::backend::Backend;
use ferrolearn_core::error::FerroError;
use ferrolearn_core::traits::{Fit, Predict};
use ndarray::{Array1, Array2};
use num_traits::Float;
#[derive(Debug, Clone)]
pub struct SpectralClustering<F> {
pub n_clusters: usize,
pub gamma: F,
pub n_init: usize,
pub random_state: Option<u64>,
}
impl<F: Float> SpectralClustering<F> {
#[must_use]
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
gamma: F::one(),
n_init: 10,
random_state: None,
}
}
#[must_use]
pub fn with_gamma(mut self, gamma: F) -> Self {
self.gamma = gamma;
self
}
#[must_use]
pub fn with_n_init(mut self, n_init: usize) -> Self {
self.n_init = n_init;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
}
#[derive(Debug, Clone)]
pub struct FittedSpectralClustering<F> {
labels_: Array1<usize>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float> FittedSpectralClustering<F> {
#[must_use]
pub fn labels(&self) -> &Array1<usize> {
&self.labels_
}
}
fn affinity_matrix<F: Float>(x: &Array2<F>, gamma: f64) -> Array2<f64> {
let n = x.nrows();
Array2::from_shape_fn((n, n), |(i, j)| {
if i == j {
1.0_f64
} else {
let sq: F = x
.row(i)
.iter()
.zip(x.row(j).iter())
.fold(F::zero(), |acc, (&a, &b)| acc + (a - b) * (a - b));
let sq64 = sq.to_f64().unwrap_or(0.0);
(-gamma * sq64).exp()
}
})
}
fn normalized_laplacian(a: &Array2<f64>) -> Array2<f64> {
let n = a.nrows();
let d: Vec<f64> = (0..n).map(|i| a.row(i).iter().sum()).collect();
let d_inv_sqrt: Vec<f64> = d
.iter()
.map(|&di| if di > 0.0 { 1.0 / di.sqrt() } else { 0.0 })
.collect();
Array2::from_shape_fn((n, n), |(i, j)| d_inv_sqrt[i] * a[[i, j]] * d_inv_sqrt[j])
}
fn top_k_eigenvectors(sym: &Array2<f64>, k: usize) -> Result<Array2<f64>, FerroError> {
let (eigenvalues, eigenvectors) = NdarrayFaerBackend::eigh(sym)?;
let n = eigenvalues.len();
let start = n.saturating_sub(k);
let n_rows = eigenvectors.nrows();
let mut result = Array2::<f64>::zeros((n_rows, k));
for (new_col, old_col) in (start..n).enumerate() {
for row in 0..n_rows {
result[[row, new_col]] = eigenvectors[[row, old_col]];
}
}
Ok(result)
}
fn row_normalize(m: &Array2<f64>) -> Array2<f64> {
let (n, d) = m.dim();
Array2::from_shape_fn((n, d), |(i, j)| {
let norm: f64 = m.row(i).iter().map(|&v| v * v).sum::<f64>().sqrt();
if norm > 0.0 {
m[[i, j]] / norm
} else {
m[[i, j]]
}
})
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SpectralClustering<F> {
type Fitted = FittedSpectralClustering<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSpectralClustering<F>, FerroError> {
let n_samples = x.nrows();
if self.n_clusters == 0 {
return Err(FerroError::InvalidParameter {
name: "n_clusters".into(),
reason: "must be at least 1".into(),
});
}
if self.gamma <= F::zero() {
return Err(FerroError::InvalidParameter {
name: "gamma".into(),
reason: "must be positive".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: 0,
context: "SpectralClustering requires at least n_clusters samples".into(),
});
}
if n_samples < self.n_clusters {
return Err(FerroError::InsufficientSamples {
required: self.n_clusters,
actual: n_samples,
context: "SpectralClustering requires at least n_clusters samples".into(),
});
}
let gamma64 = self.gamma.to_f64().unwrap_or(1.0);
let aff = affinity_matrix(x, gamma64);
let lap = normalized_laplacian(&aff);
let k = self.n_clusters;
let embed = top_k_eigenvectors(&lap, k)?;
let embed_norm = row_normalize(&embed);
let embed_f: Array2<F> = Array2::from_shape_fn(embed_norm.dim(), |(i, j)| {
F::from(embed_norm[[i, j]]).unwrap_or(F::zero())
});
let mut km = KMeans::<F>::new(k).with_n_init(self.n_init);
if let Some(seed) = self.random_state {
km = km.with_random_state(seed);
}
let fitted_km = km.fit(&embed_f, &())?;
let labels = fitted_km.predict(&embed_f)?;
Ok(FittedSpectralClustering {
labels_: labels,
_marker: std::marker::PhantomData,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn two_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(10, 2),
vec![
0.0, 0.0, 0.2, 0.1, -0.1, 0.2, 0.1, -0.1, 0.0, 0.1, 10.0, 10.0, 10.2, 10.1, 9.9,
10.2, 10.1, 9.9, 10.0, 10.1,
],
)
.unwrap()
}
#[test]
fn test_two_blobs_two_clusters() {
let x = two_blobs();
let model = SpectralClustering::<f64>::new(2)
.with_gamma(0.1)
.with_random_state(42);
let fitted = model.fit(&x, &()).unwrap();
let labels = fitted.labels();
assert_eq!(labels.len(), 10);
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[0], labels[3]);
assert_eq!(labels[0], labels[4]);
assert_eq!(labels[5], labels[6]);
assert_eq!(labels[5], labels[7]);
assert_eq!(labels[5], labels[8]);
assert_eq!(labels[5], labels[9]);
assert_ne!(labels[0], labels[5]);
}
#[test]
fn test_labels_length_matches_n_samples() {
let x = two_blobs();
let fitted = SpectralClustering::<f64>::new(2)
.with_random_state(0)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.labels().len(), x.nrows());
}
#[test]
fn test_labels_in_valid_range() {
let x = two_blobs();
let k = 2usize;
let fitted = SpectralClustering::<f64>::new(k)
.with_random_state(1)
.fit(&x, &())
.unwrap();
for &l in fitted.labels().iter() {
assert!(l < k, "label {l} >= n_clusters {k}");
}
}
#[test]
fn test_single_cluster() {
let x = two_blobs();
let fitted = SpectralClustering::<f64>::new(1)
.with_random_state(0)
.fit(&x, &())
.unwrap();
for &l in fitted.labels().iter() {
assert_eq!(l, 0);
}
}
#[test]
fn test_invalid_n_clusters_zero() {
let x = two_blobs();
let result = SpectralClustering::<f64>::new(0).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_gamma_zero() {
let x = two_blobs();
let result = SpectralClustering::<f64>::new(2)
.with_gamma(0.0)
.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_invalid_gamma_negative() {
let x = two_blobs();
let result = SpectralClustering::<f64>::new(2)
.with_gamma(-1.0)
.fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_empty_data_error() {
let x = Array2::<f64>::zeros((0, 2));
let result = SpectralClustering::<f64>::new(2).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_insufficient_samples_error() {
let x = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).unwrap();
let result = SpectralClustering::<f64>::new(3).fit(&x, &());
assert!(result.is_err());
}
#[test]
fn test_n_clusters_equals_n_samples() {
let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0]).unwrap();
let fitted = SpectralClustering::<f64>::new(3)
.with_random_state(0)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.labels().len(), 3);
}
#[test]
fn test_f32_support() {
let x = Array2::from_shape_vec(
(6, 2),
vec![
0.0f32, 0.0, 0.1, 0.1, -0.1, 0.1, 10.0, 10.0, 10.1, 10.1, 9.9, 10.1,
],
)
.unwrap();
let fitted = SpectralClustering::<f32>::new(2)
.with_gamma(0.1)
.with_random_state(42)
.fit(&x, &())
.unwrap();
assert_eq!(fitted.labels().len(), 6);
}
#[test]
fn test_reproducibility_with_seed() {
let x = two_blobs();
let model = SpectralClustering::<f64>::new(2)
.with_gamma(0.1)
.with_random_state(7);
let fitted1 = model.fit(&x, &()).unwrap();
let fitted2 = model.fit(&x, &()).unwrap();
assert_eq!(fitted1.labels(), fitted2.labels());
}
}