use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use nalgebra::DMatrix;
use rand::prelude::*;
use rand::{SeedableRng, rngs::StdRng};
use std::collections::HashMap;
fn create_rng(seed: Option<u64>) -> StdRng {
match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::seed_from_u64(rand::random::<u64>()),
}
}
pub fn spectral_embeddings<A, W, Ty>(graph: &BaseGraph<A, W, Ty>, k: usize) -> Result<Vec<Vec<f64>>>
where
W: Copy + PartialOrd + Into<f64> + From<u8>,
Ty: GraphConstructor<A, W>,
{
let node_list: Vec<NodeId> = graph.nodes().map(|(node, _)| node).collect();
let n = node_list.len();
if n == 0 {
return Err(GraphinaError::invalid_graph(
"SpectralEmbeddings: empty graph",
));
}
if k == 0 {
return Err(GraphinaError::invalid_graph("SpectralEmbeddings: k=0"));
}
if k > n {
return Err(GraphinaError::invalid_graph(
"SpectralEmbeddings: k > node count",
));
}
let mut node_to_idx: HashMap<NodeId, usize> = HashMap::new();
for (idx, &node) in node_list.iter().enumerate() {
node_to_idx.insert(node, idx);
}
let mut adj = DMatrix::<f64>::zeros(n, n);
for (u, v, &w) in graph.edges() {
let ui = node_to_idx[&u];
let vi = node_to_idx[&v];
let weight: f64 = w.into();
adj[(ui, vi)] += weight;
adj[(vi, ui)] += weight;
}
let mut deg = DMatrix::<f64>::zeros(n, n);
for i in 0..n {
let d: f64 = (0..n).map(|j| adj[(i, j)]).sum();
deg[(i, i)] = d;
}
let lap = ° - &adj;
let eig = lap.symmetric_eigen();
let mut embedding = vec![vec![0.0; k]; n];
for (i, row) in embedding.iter_mut().enumerate() {
for (j, val) in row.iter_mut().enumerate().take(k) {
*val = eig.eigenvectors[(i, j)];
}
}
Ok(embedding)
}
pub fn spectral_clustering<A, W, Ty>(
graph: &BaseGraph<A, W, Ty>,
k: usize,
seed: Option<u64>,
) -> Result<Vec<Vec<NodeId>>>
where
W: Copy + PartialOrd + Into<f64> + From<u8>,
Ty: GraphConstructor<A, W>,
{
let node_list: Vec<NodeId> = graph.nodes().map(|(node, _)| node).collect();
let embedding = spectral_embeddings(graph, k)?;
Ok(k_means(&embedding, k, seed, &node_list))
}
fn k_means(
data: &[Vec<f64>],
k: usize,
seed: Option<u64>,
node_list: &[NodeId],
) -> Vec<Vec<NodeId>> {
let n = data.len();
let d = if n > 0 { data[0].len() } else { 0 };
let mut rng = create_rng(seed);
let mut centroids: Vec<Vec<f64>> = data.choose_multiple(&mut rng, k).cloned().collect();
let mut assignments = vec![0; n];
let mut changed = true;
let max_iter = 100;
let mut iter = 0;
while changed && iter < max_iter {
changed = false;
for (i, point) in data.iter().enumerate() {
let (best_j, _) = centroids
.iter()
.enumerate()
.map(|(j, centroid)| (j, euclidean_distance(point, centroid)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap();
if assignments[i] != best_j {
assignments[i] = best_j;
changed = true;
}
}
let mut new_centroids = vec![vec![0.0; d]; k];
let mut counts = vec![0; k];
for (i, &cluster) in assignments.iter().enumerate() {
counts[cluster] += 1;
for (j, &val) in data[i].iter().enumerate() {
new_centroids[cluster][j] += val;
}
}
for j in 0..k {
if counts[j] > 0 {
for l in 0..d {
new_centroids[j][l] /= counts[j] as f64;
}
} else {
new_centroids[j] = data[rng.random_range(0..n)].clone();
}
}
centroids = new_centroids;
iter += 1;
}
let mut clusters: Vec<Vec<NodeId>> = vec![Vec::new(); k];
for (i, &cluster) in assignments.iter().enumerate() {
clusters[cluster].push(node_list[i]);
}
clusters
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}