1use crate::{similarity::SimilarityMetric, Vector};
11use anyhow::{anyhow, Result};
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
18pub enum ClusteringAlgorithm {
19 KMeans,
21 DBSCAN,
23 Hierarchical,
25 Spectral,
27 Community,
29 Similarity,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ClusteringConfig {
36 pub algorithm: ClusteringAlgorithm,
38 pub num_clusters: Option<usize>,
40 pub similarity_threshold: f32,
42 pub min_cluster_size: usize,
44 pub distance_metric: SimilarityMetric,
46 pub max_iterations: usize,
48 pub random_seed: Option<u64>,
50 pub tolerance: f32,
52 pub linkage: LinkageCriterion,
54}
55
56impl Default for ClusteringConfig {
57 fn default() -> Self {
58 Self {
59 algorithm: ClusteringAlgorithm::KMeans,
60 num_clusters: Some(3),
61 similarity_threshold: 0.7,
62 min_cluster_size: 3,
63 distance_metric: SimilarityMetric::Cosine,
64 max_iterations: 100,
65 random_seed: None,
66 tolerance: 1e-4,
67 linkage: LinkageCriterion::Average,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub enum LinkageCriterion {
75 Single,
77 Complete,
79 Average,
81 Ward,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Cluster {
88 pub id: usize,
90 pub members: Vec<String>,
92 pub centroid: Option<Vector>,
94 pub stats: ClusterStats,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ClusterStats {
101 pub size: usize,
103 pub avg_intra_similarity: f32,
105 pub density: f32,
107 pub silhouette_score: f32,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ClusteringResult {
114 pub clusters: Vec<Cluster>,
116 pub noise: Vec<String>,
118 pub quality_metrics: ClusteringQualityMetrics,
120 pub algorithm: ClusteringAlgorithm,
122 pub config: ClusteringConfig,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ClusteringQualityMetrics {
129 pub silhouette_score: f32,
131 pub davies_bouldin_index: f32,
133 pub calinski_harabasz_index: f32,
135 pub within_cluster_ss: f32,
137 pub between_cluster_ss: f32,
139}
140
141pub struct ClusteringEngine {
143 config: ClusteringConfig,
144}
145
146impl ClusteringEngine {
147 pub fn new(config: ClusteringConfig) -> Self {
148 Self { config }
149 }
150
151 pub fn cluster(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
153 if resources.is_empty() {
154 return Ok(ClusteringResult {
155 clusters: Vec::new(),
156 noise: Vec::new(),
157 quality_metrics: ClusteringQualityMetrics::default(),
158 algorithm: self.config.algorithm,
159 config: self.config.clone(),
160 });
161 }
162
163 match self.config.algorithm {
164 ClusteringAlgorithm::KMeans => self.kmeans_clustering(resources),
165 ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(resources),
166 ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(resources),
167 ClusteringAlgorithm::Spectral => self.spectral_clustering(resources),
168 ClusteringAlgorithm::Community => self.community_detection(resources),
169 ClusteringAlgorithm::Similarity => self.similarity_clustering(resources),
170 }
171 }
172
173 fn kmeans_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
175 let k = self.config.num_clusters.unwrap_or(3);
176 if k >= resources.len() {
177 return Err(anyhow!(
178 "Number of clusters must be less than number of resources"
179 ));
180 }
181
182 let mut rng = if let Some(seed) = self.config.random_seed {
183 Random::seed(seed)
184 } else {
185 Random::seed(42)
186 };
187
188 let mut centroids = self.initialize_centroids_kmeans_plus_plus(resources, k, &mut rng)?;
190 let mut assignments = vec![0; resources.len()];
191 let mut prev_assignments = vec![usize::MAX; resources.len()];
192
193 for iteration in 0..self.config.max_iterations {
194 for (i, (_, vector)) in resources.iter().enumerate() {
196 let mut best_cluster = 0;
197 let mut best_distance = f32::INFINITY;
198
199 for (cluster_id, centroid) in centroids.iter().enumerate() {
200 let distance = self.calculate_distance(vector, centroid)?;
201 if distance < best_distance {
202 best_distance = distance;
203 best_cluster = cluster_id;
204 }
205 }
206 assignments[i] = best_cluster;
207 }
208
209 if assignments == prev_assignments {
211 break;
212 }
213
214 for (cluster_id, centroid) in centroids.iter_mut().enumerate().take(k) {
216 let cluster_vectors: Vec<&Vector> = resources
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| assignments[*i] == cluster_id)
220 .map(|(_, (_, vector))| vector)
221 .collect();
222
223 if !cluster_vectors.is_empty() {
224 *centroid = self.compute_centroid(&cluster_vectors)?;
225 }
226 }
227
228 prev_assignments = assignments.clone();
229
230 if iteration > 0 && iteration % 10 == 0 {
231 println!(
232 "K-means iteration {}/{}",
233 iteration, self.config.max_iterations
234 );
235 }
236 }
237
238 let mut clusters = Vec::new();
240 for (cluster_id, centroid) in centroids.iter().enumerate().take(k) {
241 let members: Vec<String> = resources
242 .iter()
243 .enumerate()
244 .filter(|(i, _)| assignments[*i] == cluster_id)
245 .map(|(_, (resource_id, _))| resource_id.clone())
246 .collect();
247
248 if !members.is_empty() {
249 let cluster_vectors: Vec<&Vector> = resources
250 .iter()
251 .enumerate()
252 .filter(|(i, _)| assignments[*i] == cluster_id)
253 .map(|(_, (_, vector))| vector)
254 .collect();
255
256 let stats = self.compute_cluster_stats(&cluster_vectors)?;
257
258 clusters.push(Cluster {
259 id: cluster_id,
260 members,
261 centroid: Some(centroid.clone()),
262 stats,
263 });
264 }
265 }
266
267 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
268
269 Ok(ClusteringResult {
270 clusters,
271 noise: Vec::new(),
272 quality_metrics,
273 algorithm: ClusteringAlgorithm::KMeans,
274 config: self.config.clone(),
275 })
276 }
277
278 fn dbscan_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
280 let eps = 1.0 - self.config.similarity_threshold; let min_pts = self.config.min_cluster_size;
282
283 let mut visited = vec![false; resources.len()];
284 let mut cluster_assignments = vec![None; resources.len()];
285 let mut cluster_id = 0;
286 let mut noise_points = Vec::new();
287
288 for i in 0..resources.len() {
289 if visited[i] {
290 continue;
291 }
292 visited[i] = true;
293
294 let neighbors = self.find_neighbors(resources, i, eps)?;
295
296 if neighbors.len() < min_pts {
297 noise_points.push(resources[i].0.clone());
298 } else {
299 let mut cluster_queue = VecDeque::new();
300 cluster_queue.push_back(i);
301 cluster_assignments[i] = Some(cluster_id);
302
303 while let Some(point_idx) = cluster_queue.pop_front() {
304 let point_neighbors = self.find_neighbors(resources, point_idx, eps)?;
305
306 if point_neighbors.len() >= min_pts {
307 for &neighbor_idx in &point_neighbors {
308 if !visited[neighbor_idx] {
309 visited[neighbor_idx] = true;
310 cluster_queue.push_back(neighbor_idx);
311 }
312 if cluster_assignments[neighbor_idx].is_none() {
313 cluster_assignments[neighbor_idx] = Some(cluster_id);
314 }
315 }
316 }
317 }
318 cluster_id += 1;
319 }
320 }
321
322 let mut clusters = Vec::new();
324 for cid in 0..cluster_id {
325 let members: Vec<String> = resources
326 .iter()
327 .enumerate()
328 .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
329 .map(|(_, (resource_id, _))| resource_id.clone())
330 .collect();
331
332 if !members.is_empty() {
333 let cluster_vectors: Vec<&Vector> = resources
334 .iter()
335 .enumerate()
336 .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
337 .map(|(_, (_, vector))| vector)
338 .collect();
339
340 let stats = self.compute_cluster_stats(&cluster_vectors)?;
341 let centroid = if !cluster_vectors.is_empty() {
342 Some(self.compute_centroid(&cluster_vectors)?)
343 } else {
344 None
345 };
346
347 clusters.push(Cluster {
348 id: cid,
349 members,
350 centroid,
351 stats,
352 });
353 }
354 }
355
356 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
357
358 Ok(ClusteringResult {
359 clusters,
360 noise: noise_points,
361 quality_metrics,
362 algorithm: ClusteringAlgorithm::DBSCAN,
363 config: self.config.clone(),
364 })
365 }
366
367 fn hierarchical_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
369 let target_clusters = self.config.num_clusters.unwrap_or(3);
370
371 let mut clusters: Vec<Vec<usize>> = (0..resources.len()).map(|i| vec![i]).collect();
373
374 let mut distance_matrix = self.compute_distance_matrix(resources)?;
376
377 while clusters.len() > target_clusters {
379 let (min_i, min_j) = self.find_closest_clusters(&clusters, &distance_matrix)?;
380
381 let cluster_j = clusters.remove(min_j.max(min_i));
383 clusters[min_i.min(min_j)].extend(cluster_j);
384
385 self.update_distance_matrix(
387 &mut distance_matrix,
388 &clusters,
389 min_i.min(min_j),
390 resources,
391 )?;
392 }
393
394 let mut result_clusters = Vec::new();
396 for (cluster_id, cluster_indices) in clusters.iter().enumerate() {
397 let members: Vec<String> = cluster_indices
398 .iter()
399 .map(|&idx| resources[idx].0.clone())
400 .collect();
401
402 let cluster_vectors: Vec<&Vector> = cluster_indices
403 .iter()
404 .map(|&idx| &resources[idx].1)
405 .collect();
406
407 let stats = self.compute_cluster_stats(&cluster_vectors)?;
408 let centroid = if !cluster_vectors.is_empty() {
409 Some(self.compute_centroid(&cluster_vectors)?)
410 } else {
411 None
412 };
413
414 result_clusters.push(Cluster {
415 id: cluster_id,
416 members,
417 centroid,
418 stats,
419 });
420 }
421
422 let quality_metrics = self.compute_quality_metrics(resources, &result_clusters)?;
423
424 Ok(ClusteringResult {
425 clusters: result_clusters,
426 noise: Vec::new(),
427 quality_metrics,
428 algorithm: ClusteringAlgorithm::Hierarchical,
429 config: self.config.clone(),
430 })
431 }
432
433 fn spectral_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
435 use scirs2_core::ndarray_ext::Array2;
436
437 let n = resources.len();
438 let k = self.config.num_clusters.unwrap_or(3);
439
440 if k >= n {
441 return Err(anyhow!(
442 "Number of clusters must be less than number of resources"
443 ));
444 }
445
446 let mut similarity_matrix_data = vec![0.0; n * n];
448 for i in 0..n {
449 for j in 0..n {
450 if i == j {
451 similarity_matrix_data[i * n + j] = 1.0;
452 } else {
453 let sim = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
454 similarity_matrix_data[i * n + j] = sim as f64;
455 }
456 }
457 }
458
459 let similarity_matrix = Array2::from_shape_vec((n, n), similarity_matrix_data)
460 .map_err(|e| anyhow!("Failed to create similarity matrix: {}", e))?;
461
462 let degrees: Vec<f64> = (0..n)
464 .map(|i| (0..n).map(|j| similarity_matrix[[i, j]]).sum::<f64>())
465 .collect();
466
467 let mut laplacian_data = vec![0.0; n * n];
469 for i in 0..n {
470 let d_i_sqrt = degrees[i].sqrt();
471 for j in 0..n {
472 let d_j_sqrt = degrees[j].sqrt();
473
474 if i == j {
475 laplacian_data[i * n + j] = 1.0;
476 } else if d_i_sqrt > 1e-10 && d_j_sqrt > 1e-10 {
477 laplacian_data[i * n + j] = -similarity_matrix[[i, j]] / (d_i_sqrt * d_j_sqrt);
478 }
479 }
480 }
481
482 let laplacian = Array2::from_shape_vec((n, n), laplacian_data)
483 .map_err(|e| anyhow!("Failed to create Laplacian matrix: {}", e))?;
484
485 let (eigenvalues, eigenvectors) = scirs2_linalg::eigen::eigh(&laplacian.view(), None)
488 .map_err(|e| anyhow!("Eigenvalue decomposition failed: {}", e))?;
489
490 let mut eigen_pairs: Vec<(f64, usize)> = eigenvalues
492 .iter()
493 .enumerate()
494 .map(|(idx, &val)| (val, idx))
495 .collect();
496
497 eigen_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
499
500 let selected_indices: Vec<usize> =
502 eigen_pairs.iter().take(k).map(|(_, idx)| *idx).collect();
503
504 let mut embedding_data = Vec::with_capacity(n * k);
506 for i in 0..n {
507 for &col_idx in &selected_indices {
508 embedding_data.push(eigenvectors[[i, col_idx]]);
509 }
510 }
511
512 for row_idx in 0..n {
514 let row_start = row_idx * k;
515 let row_end = row_start + k;
516 let row_slice = &embedding_data[row_start..row_end];
517
518 let norm: f64 = row_slice.iter().map(|x| x * x).sum::<f64>().sqrt();
519
520 if norm > 1e-10 {
521 for val in &mut embedding_data[row_start..row_end] {
522 *val /= norm;
523 }
524 }
525 }
526
527 let embedded_resources: Vec<(String, Vector)> = resources
529 .iter()
530 .enumerate()
531 .map(|(i, (id, _))| {
532 let row_start = i * k;
533 let row_end = row_start + k;
534 let embedding: Vec<f32> = embedding_data[row_start..row_end]
535 .iter()
536 .map(|&x| x as f32)
537 .collect();
538 (id.clone(), Vector::new(embedding))
539 })
540 .collect();
541
542 let kmeans_config = ClusteringConfig {
544 algorithm: ClusteringAlgorithm::KMeans,
545 num_clusters: Some(k),
546 ..self.config.clone()
547 };
548
549 let kmeans_engine = ClusteringEngine::new(kmeans_config);
550 let mut result = kmeans_engine.kmeans_clustering(&embedded_resources)?;
551
552 result.algorithm = ClusteringAlgorithm::Spectral;
554
555 Ok(result)
556 }
557
558 fn community_detection(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
560 use std::collections::HashMap;
561
562 let n = resources.len();
563 let threshold = self.config.similarity_threshold;
564
565 let mut graph: Vec<Vec<(usize, f32)>> = vec![Vec::new(); n];
567 let mut total_weight = 0.0;
568
569 for i in 0..n {
570 for j in (i + 1)..n {
571 let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
572 if similarity >= threshold {
573 graph[i].push((j, similarity));
574 graph[j].push((i, similarity));
575 total_weight += similarity * 2.0; }
577 }
578 }
579
580 let mut node_to_community: Vec<usize> = (0..n).collect();
582 let mut community_weights: HashMap<usize, f32> = HashMap::new();
583
584 for (i, neighbors) in graph.iter().enumerate().take(n) {
586 let weight: f32 = neighbors.iter().map(|(_, w)| w).sum();
587 community_weights.insert(i, weight);
588 }
589
590 let mut improved = true;
592 let mut iteration = 0;
593 let max_iterations = self.config.max_iterations;
594
595 while improved && iteration < max_iterations {
596 improved = false;
597 iteration += 1;
598
599 for node in 0..n {
600 let current_community = node_to_community[node];
601
602 let mut best_community = current_community;
604 let mut best_gain = 0.0;
605
606 let mut neighbor_communities: HashMap<usize, f32> = HashMap::new();
608 for &(neighbor, weight) in &graph[node] {
609 let neighbor_comm = node_to_community[neighbor];
610 *neighbor_communities.entry(neighbor_comm).or_insert(0.0) += weight;
611 }
612
613 for (&neighbor_comm, &edge_weight) in &neighbor_communities {
615 if neighbor_comm == current_community {
616 continue;
617 }
618
619 let k_i = graph[node].iter().map(|(_, w)| w).sum::<f32>();
621 let sigma_tot = community_weights
622 .get(&neighbor_comm)
623 .copied()
624 .unwrap_or(0.0);
625
626 let gain = edge_weight - (k_i * sigma_tot) / (2.0 * total_weight);
627
628 if gain > best_gain {
629 best_gain = gain;
630 best_community = neighbor_comm;
631 }
632 }
633
634 if best_community != current_community && best_gain > self.config.tolerance {
636 let node_weight = graph[node].iter().map(|(_, w)| w).sum::<f32>();
638 *community_weights.entry(current_community).or_insert(0.0) -= node_weight;
639 *community_weights.entry(best_community).or_insert(0.0) += node_weight;
640
641 node_to_community[node] = best_community;
642 improved = true;
643 }
644 }
645 }
646
647 let mut communities: HashMap<usize, Vec<usize>> = HashMap::new();
649 for (node, &community) in node_to_community.iter().enumerate() {
650 communities.entry(community).or_default().push(node);
651 }
652
653 let mut clusters = Vec::new();
654 for (cluster_id, (_, members_idx)) in communities.iter().enumerate() {
655 let members: Vec<String> = members_idx
656 .iter()
657 .map(|&idx| resources[idx].0.clone())
658 .collect();
659
660 let cluster_vectors: Vec<&Vector> =
661 members_idx.iter().map(|&idx| &resources[idx].1).collect();
662
663 let stats = self.compute_cluster_stats(&cluster_vectors)?;
664
665 let centroid = if !cluster_vectors.is_empty() {
667 Some(self.compute_centroid(&cluster_vectors)?)
668 } else {
669 None
670 };
671
672 clusters.push(Cluster {
673 id: cluster_id,
674 members,
675 centroid,
676 stats,
677 });
678 }
679
680 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
681
682 Ok(ClusteringResult {
683 clusters,
684 noise: Vec::new(),
685 quality_metrics,
686 algorithm: ClusteringAlgorithm::Community,
687 config: self.config.clone(),
688 })
689 }
690
691 fn similarity_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
693 let threshold = self.config.similarity_threshold;
694 let mut clusters = Vec::new();
695 let mut assigned = vec![false; resources.len()];
696 let mut cluster_id = 0;
697
698 for i in 0..resources.len() {
699 if assigned[i] {
700 continue;
701 }
702
703 let mut cluster_members = vec![i];
704 assigned[i] = true;
705
706 for j in (i + 1)..resources.len() {
708 if assigned[j] {
709 continue;
710 }
711
712 let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
713 if similarity >= threshold {
714 cluster_members.push(j);
715 assigned[j] = true;
716 }
717 }
718
719 let members: Vec<String> = cluster_members
720 .iter()
721 .map(|&idx| resources[idx].0.clone())
722 .collect();
723
724 let cluster_vectors: Vec<&Vector> = cluster_members
725 .iter()
726 .map(|&idx| &resources[idx].1)
727 .collect();
728
729 let stats = self.compute_cluster_stats(&cluster_vectors)?;
730 let centroid = if !cluster_vectors.is_empty() {
731 Some(self.compute_centroid(&cluster_vectors)?)
732 } else {
733 None
734 };
735
736 clusters.push(Cluster {
737 id: cluster_id,
738 members,
739 centroid,
740 stats,
741 });
742
743 cluster_id += 1;
744 }
745
746 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
747
748 Ok(ClusteringResult {
749 clusters,
750 noise: Vec::new(),
751 quality_metrics,
752 algorithm: ClusteringAlgorithm::Similarity,
753 config: self.config.clone(),
754 })
755 }
756
757 #[allow(deprecated)]
761 fn initialize_centroids_kmeans_plus_plus(
762 &self,
763 resources: &[(String, Vector)],
764 k: usize,
765 rng: &mut impl Rng,
766 ) -> Result<Vec<Vector>> {
767 let mut centroids = Vec::new();
768
769 let first_idx = rng.gen_range(0..resources.len());
771 centroids.push(resources[first_idx].1.clone());
772
773 for _ in 1..k {
775 let mut distances = Vec::new();
776 let mut total_distance = 0.0;
777
778 for (_, vector) in resources {
779 let min_dist_sq = centroids
780 .iter()
781 .map(|centroid| {
782 self.calculate_distance(vector, centroid)
783 .unwrap_or(f32::INFINITY)
784 })
785 .fold(f32::INFINITY, f32::min)
786 .powi(2);
787 distances.push(min_dist_sq);
788 total_distance += min_dist_sq;
789 }
790
791 let target = rng.random::<f32>() * total_distance;
792 let mut cumulative = 0.0;
793
794 for (i, &dist) in distances.iter().enumerate() {
795 cumulative += dist;
796 if cumulative >= target {
797 centroids.push(resources[i].1.clone());
798 break;
799 }
800 }
801 }
802
803 Ok(centroids)
804 }
805
806 fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
808 match self.config.distance_metric {
809 SimilarityMetric::Cosine => Ok(1.0 - v1.cosine_similarity(v2)?),
810 SimilarityMetric::Euclidean => v1.euclidean_distance(v2),
811 SimilarityMetric::Manhattan => v1.manhattan_distance(v2),
812 _ => Ok(1.0 - v1.cosine_similarity(v2)?), }
814 }
815
816 fn calculate_similarity(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
818 match self.config.distance_metric {
819 SimilarityMetric::Cosine => v1.cosine_similarity(v2),
820 SimilarityMetric::Euclidean => {
821 let dist = v1.euclidean_distance(v2)?;
822 Ok(1.0 / (1.0 + dist))
823 }
824 SimilarityMetric::Manhattan => {
825 let dist = v1.manhattan_distance(v2)?;
826 Ok(1.0 / (1.0 + dist))
827 }
828 _ => v1.cosine_similarity(v2), }
830 }
831
832 fn find_neighbors(
834 &self,
835 resources: &[(String, Vector)],
836 point_idx: usize,
837 eps: f32,
838 ) -> Result<Vec<usize>> {
839 let mut neighbors = Vec::new();
840 let point = &resources[point_idx].1;
841
842 for (i, (_, vector)) in resources.iter().enumerate() {
843 if i != point_idx {
844 let distance = self.calculate_distance(point, vector)?;
845 if distance <= eps {
846 neighbors.push(i);
847 }
848 }
849 }
850
851 Ok(neighbors)
852 }
853
854 fn compute_centroid(&self, vectors: &[&Vector]) -> Result<Vector> {
856 if vectors.is_empty() {
857 return Err(anyhow!("Cannot compute centroid of empty vector set"));
858 }
859
860 let dim = vectors[0].dimensions;
861 let mut centroid_data = vec![0.0; dim];
862
863 for vector in vectors {
864 let data = vector.as_f32();
865 for (i, &value) in data.iter().enumerate() {
866 centroid_data[i] += value;
867 }
868 }
869
870 let count = vectors.len() as f32;
871 for value in &mut centroid_data {
872 *value /= count;
873 }
874
875 Ok(Vector::new(centroid_data))
876 }
877
878 fn compute_cluster_stats(&self, vectors: &[&Vector]) -> Result<ClusterStats> {
880 if vectors.is_empty() {
881 return Ok(ClusterStats {
882 size: 0,
883 avg_intra_similarity: 0.0,
884 density: 0.0,
885 silhouette_score: 0.0,
886 });
887 }
888
889 let size = vectors.len();
890 let mut total_similarity = 0.0;
891 let mut pair_count = 0;
892
893 for i in 0..vectors.len() {
895 for j in (i + 1)..vectors.len() {
896 let similarity = self.calculate_similarity(vectors[i], vectors[j])?;
897 total_similarity += similarity;
898 pair_count += 1;
899 }
900 }
901
902 let avg_intra_similarity = if pair_count > 0 {
903 total_similarity / pair_count as f32
904 } else {
905 1.0 };
907
908 Ok(ClusterStats {
909 size,
910 avg_intra_similarity,
911 density: avg_intra_similarity, silhouette_score: 0.0, })
914 }
915
916 fn compute_distance_matrix(&self, resources: &[(String, Vector)]) -> Result<Vec<Vec<f32>>> {
918 let n = resources.len();
919 let mut matrix = vec![vec![0.0; n]; n];
920
921 for i in 0..n {
922 for j in (i + 1)..n {
923 let distance = self.calculate_distance(&resources[i].1, &resources[j].1)?;
924 matrix[i][j] = distance;
925 matrix[j][i] = distance;
926 }
927 }
928
929 Ok(matrix)
930 }
931
932 fn find_closest_clusters(
934 &self,
935 clusters: &[Vec<usize>],
936 distance_matrix: &[Vec<f32>],
937 ) -> Result<(usize, usize)> {
938 let mut min_distance = f32::INFINITY;
939 let mut closest_pair = (0, 1);
940
941 for i in 0..clusters.len() {
942 for j in (i + 1)..clusters.len() {
943 let distance = self.cluster_distance(&clusters[i], &clusters[j], distance_matrix);
944 if distance < min_distance {
945 min_distance = distance;
946 closest_pair = (i, j);
947 }
948 }
949 }
950
951 Ok(closest_pair)
952 }
953
954 fn cluster_distance(
956 &self,
957 cluster1: &[usize],
958 cluster2: &[usize],
959 distance_matrix: &[Vec<f32>],
960 ) -> f32 {
961 match self.config.linkage {
962 LinkageCriterion::Single => {
963 cluster1
965 .iter()
966 .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
967 .fold(f32::INFINITY, f32::min)
968 }
969 LinkageCriterion::Complete => {
970 cluster1
972 .iter()
973 .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
974 .fold(0.0, f32::max)
975 }
976 LinkageCriterion::Average => {
977 let mut total = 0.0;
979 let mut count = 0;
980 for &i in cluster1 {
981 for &j in cluster2 {
982 total += distance_matrix[i][j];
983 count += 1;
984 }
985 }
986 if count > 0 {
987 total / count as f32
988 } else {
989 0.0
990 }
991 }
992 LinkageCriterion::Ward => {
993 self.cluster_distance(cluster1, cluster2, distance_matrix)
995 }
996 }
997 }
998
999 fn update_distance_matrix(
1001 &self,
1002 distance_matrix: &mut Vec<Vec<f32>>,
1003 _clusters: &[Vec<usize>],
1004 _merged_cluster: usize,
1005 resources: &[(String, Vector)],
1006 ) -> Result<()> {
1007 let new_matrix = self.compute_distance_matrix(resources)?;
1009 *distance_matrix = new_matrix;
1010 Ok(())
1011 }
1012
1013 fn compute_quality_metrics(
1015 &self,
1016 resources: &[(String, Vector)],
1017 clusters: &[Cluster],
1018 ) -> Result<ClusteringQualityMetrics> {
1019 let mut within_cluster_ss = 0.0;
1021 let mut silhouette_scores = Vec::new();
1022
1023 for cluster in clusters {
1024 if cluster.members.len() > 1 {
1025 let cluster_vectors: Vec<&Vector> = cluster
1026 .members
1027 .iter()
1028 .filter_map(|member| {
1029 resources
1030 .iter()
1031 .find(|(id, _)| id == member)
1032 .map(|(_, v)| v)
1033 })
1034 .collect();
1035
1036 if let Some(ref centroid) = cluster.centroid {
1037 for vector in &cluster_vectors {
1038 let dist = self.calculate_distance(vector, centroid)?;
1039 within_cluster_ss += dist * dist;
1040 }
1041 }
1042 }
1043 }
1044
1045 for (cluster_idx, cluster) in clusters.iter().enumerate() {
1047 let cluster_vectors: Vec<(usize, &Vector)> = cluster
1048 .members
1049 .iter()
1050 .filter_map(|member| {
1051 resources
1052 .iter()
1053 .enumerate()
1054 .find(|(_, (id, _))| id == member)
1055 .map(|(idx, (_, v))| (idx, v))
1056 })
1057 .collect();
1058
1059 for (point_idx, point_vector) in &cluster_vectors {
1061 if cluster_vectors.len() <= 1 {
1062 silhouette_scores.push(0.0);
1064 continue;
1065 }
1066
1067 let mut intra_cluster_dist = 0.0;
1069 let mut intra_count = 0;
1070 for (other_idx, other_vector) in &cluster_vectors {
1071 if point_idx != other_idx {
1072 let dist = self.calculate_distance(point_vector, other_vector)?;
1073 intra_cluster_dist += dist;
1074 intra_count += 1;
1075 }
1076 }
1077 let a = if intra_count > 0 {
1078 intra_cluster_dist / intra_count as f32
1079 } else {
1080 0.0
1081 };
1082
1083 let mut min_inter_cluster_dist = f32::INFINITY;
1085 for (other_cluster_idx, other_cluster) in clusters.iter().enumerate() {
1086 if cluster_idx != other_cluster_idx {
1087 let other_cluster_vectors: Vec<&Vector> = other_cluster
1088 .members
1089 .iter()
1090 .filter_map(|member| {
1091 resources
1092 .iter()
1093 .find(|(id, _)| id == member)
1094 .map(|(_, v)| v)
1095 })
1096 .collect();
1097
1098 if !other_cluster_vectors.is_empty() {
1099 let mut inter_cluster_dist = 0.0;
1100 for other_vector in &other_cluster_vectors {
1101 let dist = self.calculate_distance(point_vector, other_vector)?;
1102 inter_cluster_dist += dist;
1103 }
1104 let avg_dist = inter_cluster_dist / other_cluster_vectors.len() as f32;
1105 min_inter_cluster_dist = min_inter_cluster_dist.min(avg_dist);
1106 }
1107 }
1108 }
1109 let b = min_inter_cluster_dist;
1110
1111 let silhouette = if a.max(b) > 0.0 {
1113 (b - a) / a.max(b)
1114 } else {
1115 0.0
1116 };
1117 silhouette_scores.push(silhouette);
1118 }
1119 }
1120
1121 let silhouette_score = if !silhouette_scores.is_empty() {
1122 silhouette_scores.iter().sum::<f32>() / silhouette_scores.len() as f32
1123 } else {
1124 0.0
1125 };
1126
1127 let davies_bouldin_index = self.calculate_davies_bouldin_index(resources, clusters)?;
1129
1130 let calinski_harabasz_index =
1132 self.calculate_calinski_harabasz_index(resources, clusters, within_cluster_ss)?;
1133
1134 let between_cluster_ss = self.calculate_between_cluster_ss(resources, clusters)?;
1136
1137 Ok(ClusteringQualityMetrics {
1138 silhouette_score,
1139 davies_bouldin_index,
1140 calinski_harabasz_index,
1141 within_cluster_ss,
1142 between_cluster_ss,
1143 })
1144 }
1145
1146 fn calculate_davies_bouldin_index(
1148 &self,
1149 resources: &[(String, Vector)],
1150 clusters: &[Cluster],
1151 ) -> Result<f32> {
1152 if clusters.len() <= 1 {
1153 return Ok(0.0);
1154 }
1155
1156 let mut db_sum = 0.0;
1157 for i in 0..clusters.len() {
1158 let mut max_ratio: f32 = 0.0;
1159
1160 let cluster_i_vectors: Vec<&Vector> = clusters[i]
1162 .members
1163 .iter()
1164 .filter_map(|member| {
1165 resources
1166 .iter()
1167 .find(|(id, _)| id == member)
1168 .map(|(_, v)| v)
1169 })
1170 .collect();
1171
1172 if cluster_i_vectors.is_empty() {
1173 continue;
1174 }
1175
1176 let centroid_i = self.compute_centroid(&cluster_i_vectors)?;
1178
1179 let mut avg_dist_i = 0.0;
1181 for vector in &cluster_i_vectors {
1182 avg_dist_i += self.calculate_distance(vector, ¢roid_i)?;
1183 }
1184 avg_dist_i /= cluster_i_vectors.len() as f32;
1185
1186 for (j, cluster_j) in clusters.iter().enumerate() {
1187 if i == j {
1188 continue;
1189 }
1190
1191 let cluster_j_vectors: Vec<&Vector> = cluster_j
1193 .members
1194 .iter()
1195 .filter_map(|member| {
1196 resources
1197 .iter()
1198 .find(|(id, _)| id == member)
1199 .map(|(_, v)| v)
1200 })
1201 .collect();
1202
1203 if cluster_j_vectors.is_empty() {
1204 continue;
1205 }
1206
1207 let centroid_j = self.compute_centroid(&cluster_j_vectors)?;
1209
1210 let mut avg_dist_j = 0.0;
1212 for vector in &cluster_j_vectors {
1213 avg_dist_j += self.calculate_distance(vector, ¢roid_j)?;
1214 }
1215 avg_dist_j /= cluster_j_vectors.len() as f32;
1216
1217 let centroid_distance = self.calculate_distance(¢roid_i, ¢roid_j)?;
1219
1220 if centroid_distance > 0.0 {
1222 let ratio: f32 = (avg_dist_i + avg_dist_j) / centroid_distance;
1223 max_ratio = max_ratio.max(ratio);
1224 }
1225 }
1226 db_sum += max_ratio;
1227 }
1228
1229 Ok(db_sum / clusters.len() as f32)
1230 }
1231
1232 fn calculate_calinski_harabasz_index(
1234 &self,
1235 resources: &[(String, Vector)],
1236 clusters: &[Cluster],
1237 within_cluster_ss: f32,
1238 ) -> Result<f32> {
1239 if clusters.len() <= 1 || resources.is_empty() {
1240 return Ok(0.0);
1241 }
1242
1243 let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1245 let overall_centroid = self.compute_centroid(&all_vectors)?;
1246
1247 let mut between_cluster_ss = 0.0;
1249 for cluster in clusters {
1250 let cluster_vectors: Vec<&Vector> = cluster
1251 .members
1252 .iter()
1253 .filter_map(|member| {
1254 resources
1255 .iter()
1256 .find(|(id, _)| id == member)
1257 .map(|(_, v)| v)
1258 })
1259 .collect();
1260
1261 if !cluster_vectors.is_empty() {
1262 let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1263 let distance_sq = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1264 between_cluster_ss += cluster_vectors.len() as f32 * distance_sq * distance_sq;
1265 }
1266 }
1267
1268 let k = clusters.len() as f32;
1270 let n = resources.len() as f32;
1271
1272 if k >= n || within_cluster_ss <= 0.0 {
1273 return Ok(0.0);
1274 }
1275
1276 let ch_index = (between_cluster_ss / (k - 1.0)) / (within_cluster_ss / (n - k));
1277 Ok(ch_index)
1278 }
1279
1280 fn calculate_between_cluster_ss(
1282 &self,
1283 resources: &[(String, Vector)],
1284 clusters: &[Cluster],
1285 ) -> Result<f32> {
1286 if clusters.is_empty() || resources.is_empty() {
1287 return Ok(0.0);
1288 }
1289
1290 let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1292 let overall_centroid = self.compute_centroid(&all_vectors)?;
1293
1294 let mut between_cluster_ss = 0.0;
1295 for cluster in clusters {
1296 let cluster_vectors: Vec<&Vector> = cluster
1297 .members
1298 .iter()
1299 .filter_map(|member| {
1300 resources
1301 .iter()
1302 .find(|(id, _)| id == member)
1303 .map(|(_, v)| v)
1304 })
1305 .collect();
1306
1307 if !cluster_vectors.is_empty() {
1308 let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1309 let distance = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1310 between_cluster_ss += cluster_vectors.len() as f32 * distance * distance;
1311 }
1312 }
1313
1314 Ok(between_cluster_ss)
1315 }
1316}
1317
1318impl Default for ClusteringQualityMetrics {
1319 fn default() -> Self {
1320 Self {
1321 silhouette_score: 0.0,
1322 davies_bouldin_index: 0.0,
1323 calinski_harabasz_index: 0.0,
1324 within_cluster_ss: 0.0,
1325 between_cluster_ss: 0.0,
1326 }
1327 }
1328}
1329
1330#[cfg(test)]
1331mod tests {
1332 use super::*;
1333
1334 #[test]
1335 fn test_kmeans_clustering() {
1336 let config = ClusteringConfig {
1337 algorithm: ClusteringAlgorithm::KMeans,
1338 num_clusters: Some(2),
1339 random_seed: Some(42),
1340 distance_metric: SimilarityMetric::Euclidean, ..Default::default()
1342 };
1343
1344 let engine = ClusteringEngine::new(config);
1345
1346 let resources = vec![
1347 ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1348 ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1349 ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1350 ("res4".to_string(), Vector::new(vec![10.1, 10.1, 10.1])),
1351 ];
1352
1353 let result = engine.cluster(&resources).unwrap();
1354
1355 assert_eq!(result.clusters.len(), 2);
1356 assert!(result.noise.is_empty());
1357 }
1358
1359 #[test]
1360 fn test_dbscan_clustering() {
1361 let config = ClusteringConfig {
1362 algorithm: ClusteringAlgorithm::DBSCAN,
1363 similarity_threshold: 0.9,
1364 min_cluster_size: 2,
1365 ..Default::default()
1366 };
1367
1368 let engine = ClusteringEngine::new(config);
1369
1370 let resources = vec![
1371 ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1372 ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1373 ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1374 ];
1375
1376 let result = engine.cluster(&resources).unwrap();
1377 assert!(result.clusters.len() <= 2);
1378 }
1379
1380 #[test]
1381 fn test_similarity_clustering() {
1382 let config = ClusteringConfig {
1383 algorithm: ClusteringAlgorithm::Similarity,
1384 similarity_threshold: 0.95,
1385 ..Default::default()
1386 };
1387
1388 let engine = ClusteringEngine::new(config);
1389
1390 let resources = vec![
1391 ("res1".to_string(), Vector::new(vec![1.0, 0.0, 0.0])),
1392 ("res2".to_string(), Vector::new(vec![0.0, 1.0, 0.0])),
1393 ("res3".to_string(), Vector::new(vec![0.0, 0.0, 1.0])),
1394 ];
1395
1396 let result = engine.cluster(&resources).unwrap();
1397 assert_eq!(result.clusters.len(), 3);
1399 }
1400}