use crate::utils::cosine_similarity;
use faer::Side;
use faer::prelude::*;
pub fn spectral_cluster(embeddings: &[Vec<f32>], max_k: usize) -> Vec<usize> {
let n = embeddings.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![0];
}
let k_nn = (n / 10).clamp(2, 10);
let mut aff = vec![0.0f64; n * n];
for i in 0..n {
aff[i * n + i] = 1.0;
let mut neighbors: Vec<(f64, usize)> = Vec::with_capacity(n);
for j in 0..n {
if i != j {
let sim = cosine_similarity(&embeddings[i], &embeddings[j]) as f64;
neighbors.push((sim, j));
}
}
neighbors.sort_by(|a, b| b.0.total_cmp(&a.0));
for &(sim, j) in neighbors.iter().take(k_nn) {
if sim > 0.0 {
aff[i * n + j] = sim;
aff[j * n + i] = sim;
}
}
}
let deg: Vec<f64> = (0..n).map(|i| aff[i * n..i * n + n].iter().sum()).collect();
let mut lap = Mat::zeros(n, n);
for i in 0..n {
for j in 0..n {
let val = if i == j {
1.0 - aff[i * n + j] / deg[i].max(1e-10)
} else {
-aff[i * n + j] / (deg[i].sqrt() * deg[j].sqrt()).max(1e-10)
};
lap[(i, j)] = val;
}
}
let eig = match lap.self_adjoint_eigen(Side::Lower) {
Ok(e) => e,
Err(_) => return vec![0; n],
};
let s = eig.S();
let u = eig.U();
let mut eig_pairs: Vec<(f64, usize)> = (0..n).map(|i| (s[i], i)).collect();
eig_pairs.sort_by(|a, b| a.0.total_cmp(&b.0));
let max_k = max_k.min(n).min(20);
let mut eigengap_k = 1usize;
let mut best_gap = 0.0f64;
for k in 1..max_k.saturating_sub(1) {
let prev = eig_pairs[k].0;
let gap = if prev > 1e-10 {
(eig_pairs[k + 1].0 - prev) / prev
} else {
0.0
};
if gap > best_gap {
best_gap = gap;
eigengap_k = k;
}
}
let mut best_k = eigengap_k.max(2).min(max_k);
let mut best_bic = f64::INFINITY;
for k in 2..=max_k.min(10) {
let dim = k;
let mut features = vec![vec![0.0f64; dim]; n];
for (i, feat) in features.iter_mut().enumerate() {
for (col, &(_, idx)) in eig_pairs.iter().take(dim).enumerate() {
feat[col] = u[(i, idx)];
}
}
for feat in features.iter_mut() {
let norm: f64 = feat.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm > 1e-10 {
for v in feat.iter_mut() {
*v /= norm;
}
}
}
let labels = kmeans_on_features(&features, k, 20);
let bic = compute_bic(&features, &labels, k);
if bic < best_bic {
best_bic = bic;
best_k = k;
}
}
let mut features = vec![vec![0.0f64; best_k]; n];
for (i, feat) in features.iter_mut().enumerate() {
for (col, &(_, idx)) in eig_pairs.iter().take(best_k).enumerate() {
feat[col] = u[(i, idx)];
}
}
for feat in features.iter_mut() {
let norm: f64 = feat.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm > 1e-10 {
for v in feat.iter_mut() {
*v /= norm;
}
}
}
kmeans_on_features(&features, best_k, 20)
}
fn kmeans_on_features(features: &[Vec<f64>], k: usize, max_iter: usize) -> Vec<usize> {
let n = features.len();
let k = k.min(n).max(1);
let dim = features[0].len();
if k == 1 {
return vec![0; n];
}
let mut centroids: Vec<Vec<f64>> = Vec::with_capacity(k);
let mut rng = fastrand::Rng::new();
let first_idx = rng.usize(0..n);
centroids.push(features[first_idx].clone());
let mut dists = vec![f64::INFINITY; n];
for _ in 1..k {
for (i, feat) in features.iter().enumerate() {
let d = euclidean_distance(feat, ¢roids[centroids.len() - 1]);
if d < dists[i] {
dists[i] = d;
}
}
let total: f64 = dists.iter().sum();
if total <= 0.0 {
break;
}
let target = rng.f64() * total;
let mut cumsum = 0.0;
let mut chosen = 0;
for (i, &d) in dists.iter().enumerate() {
cumsum += d;
if cumsum >= target {
chosen = i;
break;
}
}
centroids.push(features[chosen].clone());
}
let k = centroids.len();
let mut labels = vec![0usize; n];
for _ in 0..max_iter {
let mut changed = false;
for (i, feat) in features.iter().enumerate() {
let mut best = 0usize;
let mut best_dist = f64::INFINITY;
for (c_idx, c) in centroids.iter().enumerate() {
let dist = euclidean_distance(feat, c);
if dist < best_dist {
best_dist = dist;
best = c_idx;
}
}
if labels[i] != best {
labels[i] = best;
changed = true;
}
}
if !changed {
break;
}
let mut new_centroids = vec![vec![0.0; dim]; k];
let mut counts = vec![0usize; k];
for (i, feat) in features.iter().enumerate() {
let c = labels[i];
for (new_centroid, &v) in new_centroids[c].iter_mut().zip(feat.iter()) {
*new_centroid += v;
}
counts[c] += 1;
}
for (c, new_centroid) in new_centroids.iter_mut().enumerate().take(k) {
if counts[c] > 0 {
for v in new_centroid.iter_mut().take(dim) {
*v /= counts[c] as f64;
}
}
}
centroids = new_centroids;
}
labels
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
fn compute_bic(features: &[Vec<f64>], labels: &[usize], k: usize) -> f64 {
let n = features.len();
if n == 0 {
return f64::INFINITY;
}
let dim = features[0].len();
let mut centroids = vec![vec![0.0f64; dim]; k];
let mut counts = vec![0usize; k];
for (i, feat) in features.iter().enumerate() {
let c = labels[i];
for (d, &v) in feat.iter().enumerate() {
centroids[c][d] += v;
}
counts[c] += 1;
}
for (c, centroid) in centroids.iter_mut().enumerate().take(k) {
if counts[c] > 0 {
for v in centroid.iter_mut().take(dim) {
*v /= counts[c] as f64;
}
}
}
let mut inertia = 0.0f64;
for (i, feat) in features.iter().enumerate() {
let c = labels[i];
inertia += euclidean_distance(feat, ¢roids[c]).powi(2);
}
let p = k * (dim + 1);
if inertia < 1e-10 {
return p as f64 * (n as f64).ln();
}
let log_likelihood = -(n as f64) * (inertia / n as f64).ln() / 2.0;
-2.0 * log_likelihood + p as f64 * (n as f64).ln()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spectral_cluster_basic() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
];
let labels = spectral_cluster(&embeddings, 10);
assert_eq!(labels.len(), 4);
let num_clusters = labels.iter().copied().max().unwrap_or(0) + 1;
assert_eq!(num_clusters, 2);
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_spectral_cluster_empty() {
let labels = spectral_cluster(&[], 10);
assert!(labels.is_empty());
}
#[test]
fn test_spectral_cluster_single() {
let labels = spectral_cluster(&[vec![1.0, 0.0]], 10);
assert_eq!(labels, vec![0]);
}
}