use crate::{similarity::SimilarityMetric, Vector};
use anyhow::{anyhow, Result};
use scirs2_core::random::{Random, Rng, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum ClusteringAlgorithm {
KMeans,
DBSCAN,
Hierarchical,
Spectral,
Community,
Similarity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringConfig {
pub algorithm: ClusteringAlgorithm,
pub num_clusters: Option<usize>,
pub similarity_threshold: f32,
pub min_cluster_size: usize,
pub distance_metric: SimilarityMetric,
pub max_iterations: usize,
pub random_seed: Option<u64>,
pub tolerance: f32,
pub linkage: LinkageCriterion,
}
impl Default for ClusteringConfig {
fn default() -> Self {
Self {
algorithm: ClusteringAlgorithm::KMeans,
num_clusters: Some(3),
similarity_threshold: 0.7,
min_cluster_size: 3,
distance_metric: SimilarityMetric::Cosine,
max_iterations: 100,
random_seed: None,
tolerance: 1e-4,
linkage: LinkageCriterion::Average,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum LinkageCriterion {
Single,
Complete,
Average,
Ward,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cluster {
pub id: usize,
pub members: Vec<String>,
pub centroid: Option<Vector>,
pub stats: ClusterStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterStats {
pub size: usize,
pub avg_intra_similarity: f32,
pub density: f32,
pub silhouette_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringResult {
pub clusters: Vec<Cluster>,
pub noise: Vec<String>,
pub quality_metrics: ClusteringQualityMetrics,
pub algorithm: ClusteringAlgorithm,
pub config: ClusteringConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringQualityMetrics {
pub silhouette_score: f32,
pub davies_bouldin_index: f32,
pub calinski_harabasz_index: f32,
pub within_cluster_ss: f32,
pub between_cluster_ss: f32,
}
pub struct ClusteringEngine {
config: ClusteringConfig,
}
impl ClusteringEngine {
pub fn new(config: ClusteringConfig) -> Self {
Self { config }
}
pub fn cluster(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
if resources.is_empty() {
return Ok(ClusteringResult {
clusters: Vec::new(),
noise: Vec::new(),
quality_metrics: ClusteringQualityMetrics::default(),
algorithm: self.config.algorithm,
config: self.config.clone(),
});
}
match self.config.algorithm {
ClusteringAlgorithm::KMeans => self.kmeans_clustering(resources),
ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(resources),
ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(resources),
ClusteringAlgorithm::Spectral => self.spectral_clustering(resources),
ClusteringAlgorithm::Community => self.community_detection(resources),
ClusteringAlgorithm::Similarity => self.similarity_clustering(resources),
}
}
fn kmeans_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
let k = self.config.num_clusters.unwrap_or(3);
if k >= resources.len() {
return Err(anyhow!(
"Number of clusters must be less than number of resources"
));
}
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
let mut centroids = self.initialize_centroids_kmeans_plus_plus(resources, k, &mut rng)?;
let mut assignments = vec![0; resources.len()];
let mut prev_assignments = vec![usize::MAX; resources.len()];
for iteration in 0..self.config.max_iterations {
for (i, (_, vector)) in resources.iter().enumerate() {
let mut best_cluster = 0;
let mut best_distance = f32::INFINITY;
for (cluster_id, centroid) in centroids.iter().enumerate() {
let distance = self.calculate_distance(vector, centroid)?;
if distance < best_distance {
best_distance = distance;
best_cluster = cluster_id;
}
}
assignments[i] = best_cluster;
}
if assignments == prev_assignments {
break;
}
for (cluster_id, centroid) in centroids.iter_mut().enumerate().take(k) {
let cluster_vectors: Vec<&Vector> = resources
.iter()
.enumerate()
.filter(|(i, _)| assignments[*i] == cluster_id)
.map(|(_, (_, vector))| vector)
.collect();
if !cluster_vectors.is_empty() {
*centroid = self.compute_centroid(&cluster_vectors)?;
}
}
prev_assignments = assignments.clone();
if iteration > 0 && iteration % 10 == 0 {
println!(
"K-means iteration {}/{}",
iteration, self.config.max_iterations
);
}
}
let mut clusters = Vec::new();
for (cluster_id, centroid) in centroids.iter().enumerate().take(k) {
let members: Vec<String> = resources
.iter()
.enumerate()
.filter(|(i, _)| assignments[*i] == cluster_id)
.map(|(_, (resource_id, _))| resource_id.clone())
.collect();
if !members.is_empty() {
let cluster_vectors: Vec<&Vector> = resources
.iter()
.enumerate()
.filter(|(i, _)| assignments[*i] == cluster_id)
.map(|(_, (_, vector))| vector)
.collect();
let stats = self.compute_cluster_stats(&cluster_vectors)?;
clusters.push(Cluster {
id: cluster_id,
members,
centroid: Some(centroid.clone()),
stats,
});
}
}
let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
Ok(ClusteringResult {
clusters,
noise: Vec::new(),
quality_metrics,
algorithm: ClusteringAlgorithm::KMeans,
config: self.config.clone(),
})
}
fn dbscan_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
let eps = 1.0 - self.config.similarity_threshold; let min_pts = self.config.min_cluster_size;
let mut visited = vec![false; resources.len()];
let mut cluster_assignments = vec![None; resources.len()];
let mut cluster_id = 0;
let mut noise_points = Vec::new();
for i in 0..resources.len() {
if visited[i] {
continue;
}
visited[i] = true;
let neighbors = self.find_neighbors(resources, i, eps)?;
if neighbors.len() < min_pts {
noise_points.push(resources[i].0.clone());
} else {
let mut cluster_queue = VecDeque::new();
cluster_queue.push_back(i);
cluster_assignments[i] = Some(cluster_id);
while let Some(point_idx) = cluster_queue.pop_front() {
let point_neighbors = self.find_neighbors(resources, point_idx, eps)?;
if point_neighbors.len() >= min_pts {
for &neighbor_idx in &point_neighbors {
if !visited[neighbor_idx] {
visited[neighbor_idx] = true;
cluster_queue.push_back(neighbor_idx);
}
if cluster_assignments[neighbor_idx].is_none() {
cluster_assignments[neighbor_idx] = Some(cluster_id);
}
}
}
}
cluster_id += 1;
}
}
let mut clusters = Vec::new();
for cid in 0..cluster_id {
let members: Vec<String> = resources
.iter()
.enumerate()
.filter(|(i, _)| cluster_assignments[*i] == Some(cid))
.map(|(_, (resource_id, _))| resource_id.clone())
.collect();
if !members.is_empty() {
let cluster_vectors: Vec<&Vector> = resources
.iter()
.enumerate()
.filter(|(i, _)| cluster_assignments[*i] == Some(cid))
.map(|(_, (_, vector))| vector)
.collect();
let stats = self.compute_cluster_stats(&cluster_vectors)?;
let centroid = if !cluster_vectors.is_empty() {
Some(self.compute_centroid(&cluster_vectors)?)
} else {
None
};
clusters.push(Cluster {
id: cid,
members,
centroid,
stats,
});
}
}
let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
Ok(ClusteringResult {
clusters,
noise: noise_points,
quality_metrics,
algorithm: ClusteringAlgorithm::DBSCAN,
config: self.config.clone(),
})
}
fn hierarchical_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
let target_clusters = self.config.num_clusters.unwrap_or(3);
let mut clusters: Vec<Vec<usize>> = (0..resources.len()).map(|i| vec![i]).collect();
let mut distance_matrix = self.compute_distance_matrix(resources)?;
while clusters.len() > target_clusters {
let (min_i, min_j) = self.find_closest_clusters(&clusters, &distance_matrix)?;
let cluster_j = clusters.remove(min_j.max(min_i));
clusters[min_i.min(min_j)].extend(cluster_j);
self.update_distance_matrix(
&mut distance_matrix,
&clusters,
min_i.min(min_j),
resources,
)?;
}
let mut result_clusters = Vec::new();
for (cluster_id, cluster_indices) in clusters.iter().enumerate() {
let members: Vec<String> = cluster_indices
.iter()
.map(|&idx| resources[idx].0.clone())
.collect();
let cluster_vectors: Vec<&Vector> = cluster_indices
.iter()
.map(|&idx| &resources[idx].1)
.collect();
let stats = self.compute_cluster_stats(&cluster_vectors)?;
let centroid = if !cluster_vectors.is_empty() {
Some(self.compute_centroid(&cluster_vectors)?)
} else {
None
};
result_clusters.push(Cluster {
id: cluster_id,
members,
centroid,
stats,
});
}
let quality_metrics = self.compute_quality_metrics(resources, &result_clusters)?;
Ok(ClusteringResult {
clusters: result_clusters,
noise: Vec::new(),
quality_metrics,
algorithm: ClusteringAlgorithm::Hierarchical,
config: self.config.clone(),
})
}
fn spectral_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
use scirs2_core::ndarray_ext::Array2;
let n = resources.len();
let k = self.config.num_clusters.unwrap_or(3);
if k >= n {
return Err(anyhow!(
"Number of clusters must be less than number of resources"
));
}
let mut similarity_matrix_data = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
if i == j {
similarity_matrix_data[i * n + j] = 1.0;
} else {
let sim = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
similarity_matrix_data[i * n + j] = sim as f64;
}
}
}
let similarity_matrix = Array2::from_shape_vec((n, n), similarity_matrix_data)
.map_err(|e| anyhow!("Failed to create similarity matrix: {}", e))?;
let degrees: Vec<f64> = (0..n)
.map(|i| (0..n).map(|j| similarity_matrix[[i, j]]).sum::<f64>())
.collect();
let mut laplacian_data = vec![0.0; n * n];
for i in 0..n {
let d_i_sqrt = degrees[i].sqrt();
for j in 0..n {
let d_j_sqrt = degrees[j].sqrt();
if i == j {
laplacian_data[i * n + j] = 1.0;
} else if d_i_sqrt > 1e-10 && d_j_sqrt > 1e-10 {
laplacian_data[i * n + j] = -similarity_matrix[[i, j]] / (d_i_sqrt * d_j_sqrt);
}
}
}
let laplacian = Array2::from_shape_vec((n, n), laplacian_data)
.map_err(|e| anyhow!("Failed to create Laplacian matrix: {}", e))?;
let (eigenvalues, eigenvectors) = scirs2_linalg::eigen::eigh(&laplacian.view(), None)
.map_err(|e| anyhow!("Eigenvalue decomposition failed: {}", e))?;
let mut eigen_pairs: Vec<(f64, usize)> = eigenvalues
.iter()
.enumerate()
.map(|(idx, &val)| (val, idx))
.collect();
eigen_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let selected_indices: Vec<usize> =
eigen_pairs.iter().take(k).map(|(_, idx)| *idx).collect();
let mut embedding_data = Vec::with_capacity(n * k);
for i in 0..n {
for &col_idx in &selected_indices {
embedding_data.push(eigenvectors[[i, col_idx]]);
}
}
for row_idx in 0..n {
let row_start = row_idx * k;
let row_end = row_start + k;
let row_slice = &embedding_data[row_start..row_end];
let norm: f64 = row_slice.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-10 {
for val in &mut embedding_data[row_start..row_end] {
*val /= norm;
}
}
}
let embedded_resources: Vec<(String, Vector)> = resources
.iter()
.enumerate()
.map(|(i, (id, _))| {
let row_start = i * k;
let row_end = row_start + k;
let embedding: Vec<f32> = embedding_data[row_start..row_end]
.iter()
.map(|&x| x as f32)
.collect();
(id.clone(), Vector::new(embedding))
})
.collect();
let kmeans_config = ClusteringConfig {
algorithm: ClusteringAlgorithm::KMeans,
num_clusters: Some(k),
..self.config.clone()
};
let kmeans_engine = ClusteringEngine::new(kmeans_config);
let mut result = kmeans_engine.kmeans_clustering(&embedded_resources)?;
result.algorithm = ClusteringAlgorithm::Spectral;
Ok(result)
}
fn community_detection(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
use std::collections::HashMap;
let n = resources.len();
let threshold = self.config.similarity_threshold;
let mut graph: Vec<Vec<(usize, f32)>> = vec![Vec::new(); n];
let mut total_weight = 0.0;
for i in 0..n {
for j in (i + 1)..n {
let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
if similarity >= threshold {
graph[i].push((j, similarity));
graph[j].push((i, similarity));
total_weight += similarity * 2.0; }
}
}
let mut node_to_community: Vec<usize> = (0..n).collect();
let mut community_weights: HashMap<usize, f32> = HashMap::new();
for (i, neighbors) in graph.iter().enumerate().take(n) {
let weight: f32 = neighbors.iter().map(|(_, w)| w).sum();
community_weights.insert(i, weight);
}
let mut improved = true;
let mut iteration = 0;
let max_iterations = self.config.max_iterations;
while improved && iteration < max_iterations {
improved = false;
iteration += 1;
for node in 0..n {
let current_community = node_to_community[node];
let mut best_community = current_community;
let mut best_gain = 0.0;
let mut neighbor_communities: HashMap<usize, f32> = HashMap::new();
for &(neighbor, weight) in &graph[node] {
let neighbor_comm = node_to_community[neighbor];
*neighbor_communities.entry(neighbor_comm).or_insert(0.0) += weight;
}
for (&neighbor_comm, &edge_weight) in &neighbor_communities {
if neighbor_comm == current_community {
continue;
}
let k_i = graph[node].iter().map(|(_, w)| w).sum::<f32>();
let sigma_tot = community_weights
.get(&neighbor_comm)
.copied()
.unwrap_or(0.0);
let gain = edge_weight - (k_i * sigma_tot) / (2.0 * total_weight);
if gain > best_gain {
best_gain = gain;
best_community = neighbor_comm;
}
}
if best_community != current_community && best_gain > self.config.tolerance {
let node_weight = graph[node].iter().map(|(_, w)| w).sum::<f32>();
*community_weights.entry(current_community).or_insert(0.0) -= node_weight;
*community_weights.entry(best_community).or_insert(0.0) += node_weight;
node_to_community[node] = best_community;
improved = true;
}
}
}
let mut communities: HashMap<usize, Vec<usize>> = HashMap::new();
for (node, &community) in node_to_community.iter().enumerate() {
communities.entry(community).or_default().push(node);
}
let mut clusters = Vec::new();
for (cluster_id, (_, members_idx)) in communities.iter().enumerate() {
let members: Vec<String> = members_idx
.iter()
.map(|&idx| resources[idx].0.clone())
.collect();
let cluster_vectors: Vec<&Vector> =
members_idx.iter().map(|&idx| &resources[idx].1).collect();
let stats = self.compute_cluster_stats(&cluster_vectors)?;
let centroid = if !cluster_vectors.is_empty() {
Some(self.compute_centroid(&cluster_vectors)?)
} else {
None
};
clusters.push(Cluster {
id: cluster_id,
members,
centroid,
stats,
});
}
let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
Ok(ClusteringResult {
clusters,
noise: Vec::new(),
quality_metrics,
algorithm: ClusteringAlgorithm::Community,
config: self.config.clone(),
})
}
fn similarity_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
let threshold = self.config.similarity_threshold;
let mut clusters = Vec::new();
let mut assigned = vec![false; resources.len()];
let mut cluster_id = 0;
for i in 0..resources.len() {
if assigned[i] {
continue;
}
let mut cluster_members = vec![i];
assigned[i] = true;
for j in (i + 1)..resources.len() {
if assigned[j] {
continue;
}
let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
if similarity >= threshold {
cluster_members.push(j);
assigned[j] = true;
}
}
let members: Vec<String> = cluster_members
.iter()
.map(|&idx| resources[idx].0.clone())
.collect();
let cluster_vectors: Vec<&Vector> = cluster_members
.iter()
.map(|&idx| &resources[idx].1)
.collect();
let stats = self.compute_cluster_stats(&cluster_vectors)?;
let centroid = if !cluster_vectors.is_empty() {
Some(self.compute_centroid(&cluster_vectors)?)
} else {
None
};
clusters.push(Cluster {
id: cluster_id,
members,
centroid,
stats,
});
cluster_id += 1;
}
let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
Ok(ClusteringResult {
clusters,
noise: Vec::new(),
quality_metrics,
algorithm: ClusteringAlgorithm::Similarity,
config: self.config.clone(),
})
}
#[allow(deprecated)]
fn initialize_centroids_kmeans_plus_plus(
&self,
resources: &[(String, Vector)],
k: usize,
rng: &mut impl Rng,
) -> Result<Vec<Vector>> {
let mut centroids = Vec::new();
let first_idx = rng.random_range(0..resources.len());
centroids.push(resources[first_idx].1.clone());
for _ in 1..k {
let mut distances = Vec::new();
let mut total_distance = 0.0;
for (_, vector) in resources {
let min_dist_sq = centroids
.iter()
.map(|centroid| {
self.calculate_distance(vector, centroid)
.unwrap_or(f32::INFINITY)
})
.fold(f32::INFINITY, f32::min)
.powi(2);
distances.push(min_dist_sq);
total_distance += min_dist_sq;
}
let target = rng.random::<f32>() * total_distance;
let mut cumulative = 0.0;
for (i, &dist) in distances.iter().enumerate() {
cumulative += dist;
if cumulative >= target {
centroids.push(resources[i].1.clone());
break;
}
}
}
Ok(centroids)
}
fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
match self.config.distance_metric {
SimilarityMetric::Cosine => Ok(1.0 - v1.cosine_similarity(v2)?),
SimilarityMetric::Euclidean => v1.euclidean_distance(v2),
SimilarityMetric::Manhattan => v1.manhattan_distance(v2),
_ => Ok(1.0 - v1.cosine_similarity(v2)?), }
}
fn calculate_similarity(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
match self.config.distance_metric {
SimilarityMetric::Cosine => v1.cosine_similarity(v2),
SimilarityMetric::Euclidean => {
let dist = v1.euclidean_distance(v2)?;
Ok(1.0 / (1.0 + dist))
}
SimilarityMetric::Manhattan => {
let dist = v1.manhattan_distance(v2)?;
Ok(1.0 / (1.0 + dist))
}
_ => v1.cosine_similarity(v2), }
}
fn find_neighbors(
&self,
resources: &[(String, Vector)],
point_idx: usize,
eps: f32,
) -> Result<Vec<usize>> {
let mut neighbors = Vec::new();
let point = &resources[point_idx].1;
for (i, (_, vector)) in resources.iter().enumerate() {
if i != point_idx {
let distance = self.calculate_distance(point, vector)?;
if distance <= eps {
neighbors.push(i);
}
}
}
Ok(neighbors)
}
fn compute_centroid(&self, vectors: &[&Vector]) -> Result<Vector> {
if vectors.is_empty() {
return Err(anyhow!("Cannot compute centroid of empty vector set"));
}
let dim = vectors[0].dimensions;
let mut centroid_data = vec![0.0; dim];
for vector in vectors {
let data = vector.as_f32();
for (i, &value) in data.iter().enumerate() {
centroid_data[i] += value;
}
}
let count = vectors.len() as f32;
for value in &mut centroid_data {
*value /= count;
}
Ok(Vector::new(centroid_data))
}
fn compute_cluster_stats(&self, vectors: &[&Vector]) -> Result<ClusterStats> {
if vectors.is_empty() {
return Ok(ClusterStats {
size: 0,
avg_intra_similarity: 0.0,
density: 0.0,
silhouette_score: 0.0,
});
}
let size = vectors.len();
let mut total_similarity = 0.0;
let mut pair_count = 0;
for i in 0..vectors.len() {
for j in (i + 1)..vectors.len() {
let similarity = self.calculate_similarity(vectors[i], vectors[j])?;
total_similarity += similarity;
pair_count += 1;
}
}
let avg_intra_similarity = if pair_count > 0 {
total_similarity / pair_count as f32
} else {
1.0 };
Ok(ClusterStats {
size,
avg_intra_similarity,
density: avg_intra_similarity, silhouette_score: 0.0, })
}
fn compute_distance_matrix(&self, resources: &[(String, Vector)]) -> Result<Vec<Vec<f32>>> {
let n = resources.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let distance = self.calculate_distance(&resources[i].1, &resources[j].1)?;
matrix[i][j] = distance;
matrix[j][i] = distance;
}
}
Ok(matrix)
}
fn find_closest_clusters(
&self,
clusters: &[Vec<usize>],
distance_matrix: &[Vec<f32>],
) -> Result<(usize, usize)> {
let mut min_distance = f32::INFINITY;
let mut closest_pair = (0, 1);
for i in 0..clusters.len() {
for j in (i + 1)..clusters.len() {
let distance = self.cluster_distance(&clusters[i], &clusters[j], distance_matrix);
if distance < min_distance {
min_distance = distance;
closest_pair = (i, j);
}
}
}
Ok(closest_pair)
}
fn cluster_distance(
&self,
cluster1: &[usize],
cluster2: &[usize],
distance_matrix: &[Vec<f32>],
) -> f32 {
match self.config.linkage {
LinkageCriterion::Single => {
cluster1
.iter()
.flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
.fold(f32::INFINITY, f32::min)
}
LinkageCriterion::Complete => {
cluster1
.iter()
.flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
.fold(0.0, f32::max)
}
LinkageCriterion::Average => {
let mut total = 0.0;
let mut count = 0;
for &i in cluster1 {
for &j in cluster2 {
total += distance_matrix[i][j];
count += 1;
}
}
if count > 0 {
total / count as f32
} else {
0.0
}
}
LinkageCriterion::Ward => {
self.cluster_distance(cluster1, cluster2, distance_matrix)
}
}
}
fn update_distance_matrix(
&self,
distance_matrix: &mut Vec<Vec<f32>>,
_clusters: &[Vec<usize>],
_merged_cluster: usize,
resources: &[(String, Vector)],
) -> Result<()> {
let new_matrix = self.compute_distance_matrix(resources)?;
*distance_matrix = new_matrix;
Ok(())
}
fn compute_quality_metrics(
&self,
resources: &[(String, Vector)],
clusters: &[Cluster],
) -> Result<ClusteringQualityMetrics> {
let mut within_cluster_ss = 0.0;
let mut silhouette_scores = Vec::new();
for cluster in clusters {
if cluster.members.len() > 1 {
let cluster_vectors: Vec<&Vector> = cluster
.members
.iter()
.filter_map(|member| {
resources
.iter()
.find(|(id, _)| id == member)
.map(|(_, v)| v)
})
.collect();
if let Some(ref centroid) = cluster.centroid {
for vector in &cluster_vectors {
let dist = self.calculate_distance(vector, centroid)?;
within_cluster_ss += dist * dist;
}
}
}
}
for (cluster_idx, cluster) in clusters.iter().enumerate() {
let cluster_vectors: Vec<(usize, &Vector)> = cluster
.members
.iter()
.filter_map(|member| {
resources
.iter()
.enumerate()
.find(|(_, (id, _))| id == member)
.map(|(idx, (_, v))| (idx, v))
})
.collect();
for (point_idx, point_vector) in &cluster_vectors {
if cluster_vectors.len() <= 1 {
silhouette_scores.push(0.0);
continue;
}
let mut intra_cluster_dist = 0.0;
let mut intra_count = 0;
for (other_idx, other_vector) in &cluster_vectors {
if point_idx != other_idx {
let dist = self.calculate_distance(point_vector, other_vector)?;
intra_cluster_dist += dist;
intra_count += 1;
}
}
let a = if intra_count > 0 {
intra_cluster_dist / intra_count as f32
} else {
0.0
};
let mut min_inter_cluster_dist = f32::INFINITY;
for (other_cluster_idx, other_cluster) in clusters.iter().enumerate() {
if cluster_idx != other_cluster_idx {
let other_cluster_vectors: Vec<&Vector> = other_cluster
.members
.iter()
.filter_map(|member| {
resources
.iter()
.find(|(id, _)| id == member)
.map(|(_, v)| v)
})
.collect();
if !other_cluster_vectors.is_empty() {
let mut inter_cluster_dist = 0.0;
for other_vector in &other_cluster_vectors {
let dist = self.calculate_distance(point_vector, other_vector)?;
inter_cluster_dist += dist;
}
let avg_dist = inter_cluster_dist / other_cluster_vectors.len() as f32;
min_inter_cluster_dist = min_inter_cluster_dist.min(avg_dist);
}
}
}
let b = min_inter_cluster_dist;
let silhouette = if a.max(b) > 0.0 {
(b - a) / a.max(b)
} else {
0.0
};
silhouette_scores.push(silhouette);
}
}
let silhouette_score = if !silhouette_scores.is_empty() {
silhouette_scores.iter().sum::<f32>() / silhouette_scores.len() as f32
} else {
0.0
};
let davies_bouldin_index = self.calculate_davies_bouldin_index(resources, clusters)?;
let calinski_harabasz_index =
self.calculate_calinski_harabasz_index(resources, clusters, within_cluster_ss)?;
let between_cluster_ss = self.calculate_between_cluster_ss(resources, clusters)?;
Ok(ClusteringQualityMetrics {
silhouette_score,
davies_bouldin_index,
calinski_harabasz_index,
within_cluster_ss,
between_cluster_ss,
})
}
fn calculate_davies_bouldin_index(
&self,
resources: &[(String, Vector)],
clusters: &[Cluster],
) -> Result<f32> {
if clusters.len() <= 1 {
return Ok(0.0);
}
let mut db_sum = 0.0;
for i in 0..clusters.len() {
let mut max_ratio: f32 = 0.0;
let cluster_i_vectors: Vec<&Vector> = clusters[i]
.members
.iter()
.filter_map(|member| {
resources
.iter()
.find(|(id, _)| id == member)
.map(|(_, v)| v)
})
.collect();
if cluster_i_vectors.is_empty() {
continue;
}
let centroid_i = self.compute_centroid(&cluster_i_vectors)?;
let mut avg_dist_i = 0.0;
for vector in &cluster_i_vectors {
avg_dist_i += self.calculate_distance(vector, ¢roid_i)?;
}
avg_dist_i /= cluster_i_vectors.len() as f32;
for (j, cluster_j) in clusters.iter().enumerate() {
if i == j {
continue;
}
let cluster_j_vectors: Vec<&Vector> = cluster_j
.members
.iter()
.filter_map(|member| {
resources
.iter()
.find(|(id, _)| id == member)
.map(|(_, v)| v)
})
.collect();
if cluster_j_vectors.is_empty() {
continue;
}
let centroid_j = self.compute_centroid(&cluster_j_vectors)?;
let mut avg_dist_j = 0.0;
for vector in &cluster_j_vectors {
avg_dist_j += self.calculate_distance(vector, ¢roid_j)?;
}
avg_dist_j /= cluster_j_vectors.len() as f32;
let centroid_distance = self.calculate_distance(¢roid_i, ¢roid_j)?;
if centroid_distance > 0.0 {
let ratio: f32 = (avg_dist_i + avg_dist_j) / centroid_distance;
max_ratio = max_ratio.max(ratio);
}
}
db_sum += max_ratio;
}
Ok(db_sum / clusters.len() as f32)
}
fn calculate_calinski_harabasz_index(
&self,
resources: &[(String, Vector)],
clusters: &[Cluster],
within_cluster_ss: f32,
) -> Result<f32> {
if clusters.len() <= 1 || resources.is_empty() {
return Ok(0.0);
}
let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
let overall_centroid = self.compute_centroid(&all_vectors)?;
let mut between_cluster_ss = 0.0;
for cluster in clusters {
let cluster_vectors: Vec<&Vector> = cluster
.members
.iter()
.filter_map(|member| {
resources
.iter()
.find(|(id, _)| id == member)
.map(|(_, v)| v)
})
.collect();
if !cluster_vectors.is_empty() {
let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
let distance_sq = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
between_cluster_ss += cluster_vectors.len() as f32 * distance_sq * distance_sq;
}
}
let k = clusters.len() as f32;
let n = resources.len() as f32;
if k >= n || within_cluster_ss <= 0.0 {
return Ok(0.0);
}
let ch_index = (between_cluster_ss / (k - 1.0)) / (within_cluster_ss / (n - k));
Ok(ch_index)
}
fn calculate_between_cluster_ss(
&self,
resources: &[(String, Vector)],
clusters: &[Cluster],
) -> Result<f32> {
if clusters.is_empty() || resources.is_empty() {
return Ok(0.0);
}
let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
let overall_centroid = self.compute_centroid(&all_vectors)?;
let mut between_cluster_ss = 0.0;
for cluster in clusters {
let cluster_vectors: Vec<&Vector> = cluster
.members
.iter()
.filter_map(|member| {
resources
.iter()
.find(|(id, _)| id == member)
.map(|(_, v)| v)
})
.collect();
if !cluster_vectors.is_empty() {
let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
let distance = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
between_cluster_ss += cluster_vectors.len() as f32 * distance * distance;
}
}
Ok(between_cluster_ss)
}
}
impl Default for ClusteringQualityMetrics {
fn default() -> Self {
Self {
silhouette_score: 0.0,
davies_bouldin_index: 0.0,
calinski_harabasz_index: 0.0,
within_cluster_ss: 0.0,
between_cluster_ss: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kmeans_clustering() -> Result<()> {
let config = ClusteringConfig {
algorithm: ClusteringAlgorithm::KMeans,
num_clusters: Some(2),
random_seed: Some(42),
distance_metric: SimilarityMetric::Euclidean, ..Default::default()
};
let engine = ClusteringEngine::new(config);
let resources = vec![
("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
("res4".to_string(), Vector::new(vec![10.1, 10.1, 10.1])),
];
let result = engine.cluster(&resources)?;
assert_eq!(result.clusters.len(), 2);
assert!(result.noise.is_empty());
Ok(())
}
#[test]
fn test_dbscan_clustering() -> Result<()> {
let config = ClusteringConfig {
algorithm: ClusteringAlgorithm::DBSCAN,
similarity_threshold: 0.9,
min_cluster_size: 2,
..Default::default()
};
let engine = ClusteringEngine::new(config);
let resources = vec![
("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
];
let result = engine.cluster(&resources)?;
assert!(result.clusters.len() <= 2);
Ok(())
}
#[test]
fn test_similarity_clustering() -> Result<()> {
let config = ClusteringConfig {
algorithm: ClusteringAlgorithm::Similarity,
similarity_threshold: 0.95,
..Default::default()
};
let engine = ClusteringEngine::new(config);
let resources = vec![
("res1".to_string(), Vector::new(vec![1.0, 0.0, 0.0])),
("res2".to_string(), Vector::new(vec![0.0, 1.0, 0.0])),
("res3".to_string(), Vector::new(vec![0.0, 0.0, 1.0])),
];
let result = engine.cluster(&resources)?;
assert_eq!(result.clusters.len(), 3);
Ok(())
}
}