1use crate::{similarity::SimilarityMetric, Vector};
11use anyhow::{anyhow, Result};
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
18pub enum ClusteringAlgorithm {
19 KMeans,
21 DBSCAN,
23 Hierarchical,
25 Spectral,
27 Community,
29 Similarity,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ClusteringConfig {
36 pub algorithm: ClusteringAlgorithm,
38 pub num_clusters: Option<usize>,
40 pub similarity_threshold: f32,
42 pub min_cluster_size: usize,
44 pub distance_metric: SimilarityMetric,
46 pub max_iterations: usize,
48 pub random_seed: Option<u64>,
50 pub tolerance: f32,
52 pub linkage: LinkageCriterion,
54}
55
56impl Default for ClusteringConfig {
57 fn default() -> Self {
58 Self {
59 algorithm: ClusteringAlgorithm::KMeans,
60 num_clusters: Some(3),
61 similarity_threshold: 0.7,
62 min_cluster_size: 3,
63 distance_metric: SimilarityMetric::Cosine,
64 max_iterations: 100,
65 random_seed: None,
66 tolerance: 1e-4,
67 linkage: LinkageCriterion::Average,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub enum LinkageCriterion {
75 Single,
77 Complete,
79 Average,
81 Ward,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Cluster {
88 pub id: usize,
90 pub members: Vec<String>,
92 pub centroid: Option<Vector>,
94 pub stats: ClusterStats,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ClusterStats {
101 pub size: usize,
103 pub avg_intra_similarity: f32,
105 pub density: f32,
107 pub silhouette_score: f32,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ClusteringResult {
114 pub clusters: Vec<Cluster>,
116 pub noise: Vec<String>,
118 pub quality_metrics: ClusteringQualityMetrics,
120 pub algorithm: ClusteringAlgorithm,
122 pub config: ClusteringConfig,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ClusteringQualityMetrics {
129 pub silhouette_score: f32,
131 pub davies_bouldin_index: f32,
133 pub calinski_harabasz_index: f32,
135 pub within_cluster_ss: f32,
137 pub between_cluster_ss: f32,
139}
140
141pub struct ClusteringEngine {
143 config: ClusteringConfig,
144}
145
146impl ClusteringEngine {
147 pub fn new(config: ClusteringConfig) -> Self {
148 Self { config }
149 }
150
151 pub fn cluster(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
153 if resources.is_empty() {
154 return Ok(ClusteringResult {
155 clusters: Vec::new(),
156 noise: Vec::new(),
157 quality_metrics: ClusteringQualityMetrics::default(),
158 algorithm: self.config.algorithm,
159 config: self.config.clone(),
160 });
161 }
162
163 match self.config.algorithm {
164 ClusteringAlgorithm::KMeans => self.kmeans_clustering(resources),
165 ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(resources),
166 ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(resources),
167 ClusteringAlgorithm::Spectral => self.spectral_clustering(resources),
168 ClusteringAlgorithm::Community => self.community_detection(resources),
169 ClusteringAlgorithm::Similarity => self.similarity_clustering(resources),
170 }
171 }
172
173 fn kmeans_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
175 let k = self.config.num_clusters.unwrap_or(3);
176 if k >= resources.len() {
177 return Err(anyhow!(
178 "Number of clusters must be less than number of resources"
179 ));
180 }
181
182 let mut rng = if let Some(seed) = self.config.random_seed {
183 Random::seed(seed)
184 } else {
185 Random::seed(42)
186 };
187
188 let mut centroids = self.initialize_centroids_kmeans_plus_plus(resources, k, &mut rng)?;
190 let mut assignments = vec![0; resources.len()];
191 let mut prev_assignments = vec![usize::MAX; resources.len()];
192
193 for iteration in 0..self.config.max_iterations {
194 for (i, (_, vector)) in resources.iter().enumerate() {
196 let mut best_cluster = 0;
197 let mut best_distance = f32::INFINITY;
198
199 for (cluster_id, centroid) in centroids.iter().enumerate() {
200 let distance = self.calculate_distance(vector, centroid)?;
201 if distance < best_distance {
202 best_distance = distance;
203 best_cluster = cluster_id;
204 }
205 }
206 assignments[i] = best_cluster;
207 }
208
209 if assignments == prev_assignments {
211 break;
212 }
213
214 for (cluster_id, centroid) in centroids.iter_mut().enumerate().take(k) {
216 let cluster_vectors: Vec<&Vector> = resources
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| assignments[*i] == cluster_id)
220 .map(|(_, (_, vector))| vector)
221 .collect();
222
223 if !cluster_vectors.is_empty() {
224 *centroid = self.compute_centroid(&cluster_vectors)?;
225 }
226 }
227
228 prev_assignments = assignments.clone();
229
230 if iteration > 0 && iteration % 10 == 0 {
231 println!(
232 "K-means iteration {}/{}",
233 iteration, self.config.max_iterations
234 );
235 }
236 }
237
238 let mut clusters = Vec::new();
240 for (cluster_id, centroid) in centroids.iter().enumerate().take(k) {
241 let members: Vec<String> = resources
242 .iter()
243 .enumerate()
244 .filter(|(i, _)| assignments[*i] == cluster_id)
245 .map(|(_, (resource_id, _))| resource_id.clone())
246 .collect();
247
248 if !members.is_empty() {
249 let cluster_vectors: Vec<&Vector> = resources
250 .iter()
251 .enumerate()
252 .filter(|(i, _)| assignments[*i] == cluster_id)
253 .map(|(_, (_, vector))| vector)
254 .collect();
255
256 let stats = self.compute_cluster_stats(&cluster_vectors)?;
257
258 clusters.push(Cluster {
259 id: cluster_id,
260 members,
261 centroid: Some(centroid.clone()),
262 stats,
263 });
264 }
265 }
266
267 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
268
269 Ok(ClusteringResult {
270 clusters,
271 noise: Vec::new(),
272 quality_metrics,
273 algorithm: ClusteringAlgorithm::KMeans,
274 config: self.config.clone(),
275 })
276 }
277
278 fn dbscan_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
280 let eps = 1.0 - self.config.similarity_threshold; let min_pts = self.config.min_cluster_size;
282
283 let mut visited = vec![false; resources.len()];
284 let mut cluster_assignments = vec![None; resources.len()];
285 let mut cluster_id = 0;
286 let mut noise_points = Vec::new();
287
288 for i in 0..resources.len() {
289 if visited[i] {
290 continue;
291 }
292 visited[i] = true;
293
294 let neighbors = self.find_neighbors(resources, i, eps)?;
295
296 if neighbors.len() < min_pts {
297 noise_points.push(resources[i].0.clone());
298 } else {
299 let mut cluster_queue = VecDeque::new();
300 cluster_queue.push_back(i);
301 cluster_assignments[i] = Some(cluster_id);
302
303 while let Some(point_idx) = cluster_queue.pop_front() {
304 let point_neighbors = self.find_neighbors(resources, point_idx, eps)?;
305
306 if point_neighbors.len() >= min_pts {
307 for &neighbor_idx in &point_neighbors {
308 if !visited[neighbor_idx] {
309 visited[neighbor_idx] = true;
310 cluster_queue.push_back(neighbor_idx);
311 }
312 if cluster_assignments[neighbor_idx].is_none() {
313 cluster_assignments[neighbor_idx] = Some(cluster_id);
314 }
315 }
316 }
317 }
318 cluster_id += 1;
319 }
320 }
321
322 let mut clusters = Vec::new();
324 for cid in 0..cluster_id {
325 let members: Vec<String> = resources
326 .iter()
327 .enumerate()
328 .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
329 .map(|(_, (resource_id, _))| resource_id.clone())
330 .collect();
331
332 if !members.is_empty() {
333 let cluster_vectors: Vec<&Vector> = resources
334 .iter()
335 .enumerate()
336 .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
337 .map(|(_, (_, vector))| vector)
338 .collect();
339
340 let stats = self.compute_cluster_stats(&cluster_vectors)?;
341 let centroid = if !cluster_vectors.is_empty() {
342 Some(self.compute_centroid(&cluster_vectors)?)
343 } else {
344 None
345 };
346
347 clusters.push(Cluster {
348 id: cid,
349 members,
350 centroid,
351 stats,
352 });
353 }
354 }
355
356 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
357
358 Ok(ClusteringResult {
359 clusters,
360 noise: noise_points,
361 quality_metrics,
362 algorithm: ClusteringAlgorithm::DBSCAN,
363 config: self.config.clone(),
364 })
365 }
366
367 fn hierarchical_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
369 let target_clusters = self.config.num_clusters.unwrap_or(3);
370
371 let mut clusters: Vec<Vec<usize>> = (0..resources.len()).map(|i| vec![i]).collect();
373
374 let mut distance_matrix = self.compute_distance_matrix(resources)?;
376
377 while clusters.len() > target_clusters {
379 let (min_i, min_j) = self.find_closest_clusters(&clusters, &distance_matrix)?;
380
381 let cluster_j = clusters.remove(min_j.max(min_i));
383 clusters[min_i.min(min_j)].extend(cluster_j);
384
385 self.update_distance_matrix(
387 &mut distance_matrix,
388 &clusters,
389 min_i.min(min_j),
390 resources,
391 )?;
392 }
393
394 let mut result_clusters = Vec::new();
396 for (cluster_id, cluster_indices) in clusters.iter().enumerate() {
397 let members: Vec<String> = cluster_indices
398 .iter()
399 .map(|&idx| resources[idx].0.clone())
400 .collect();
401
402 let cluster_vectors: Vec<&Vector> = cluster_indices
403 .iter()
404 .map(|&idx| &resources[idx].1)
405 .collect();
406
407 let stats = self.compute_cluster_stats(&cluster_vectors)?;
408 let centroid = if !cluster_vectors.is_empty() {
409 Some(self.compute_centroid(&cluster_vectors)?)
410 } else {
411 None
412 };
413
414 result_clusters.push(Cluster {
415 id: cluster_id,
416 members,
417 centroid,
418 stats,
419 });
420 }
421
422 let quality_metrics = self.compute_quality_metrics(resources, &result_clusters)?;
423
424 Ok(ClusteringResult {
425 clusters: result_clusters,
426 noise: Vec::new(),
427 quality_metrics,
428 algorithm: ClusteringAlgorithm::Hierarchical,
429 config: self.config.clone(),
430 })
431 }
432
433 fn spectral_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
435 println!("Spectral clustering not yet fully implemented, falling back to k-means");
438 self.kmeans_clustering(resources)
439 }
440
441 fn community_detection(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
443 println!(
446 "Community detection not yet fully implemented, falling back to similarity clustering"
447 );
448 self.similarity_clustering(resources)
449 }
450
451 fn similarity_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
453 let threshold = self.config.similarity_threshold;
454 let mut clusters = Vec::new();
455 let mut assigned = vec![false; resources.len()];
456 let mut cluster_id = 0;
457
458 for i in 0..resources.len() {
459 if assigned[i] {
460 continue;
461 }
462
463 let mut cluster_members = vec![i];
464 assigned[i] = true;
465
466 for j in (i + 1)..resources.len() {
468 if assigned[j] {
469 continue;
470 }
471
472 let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
473 if similarity >= threshold {
474 cluster_members.push(j);
475 assigned[j] = true;
476 }
477 }
478
479 let members: Vec<String> = cluster_members
480 .iter()
481 .map(|&idx| resources[idx].0.clone())
482 .collect();
483
484 let cluster_vectors: Vec<&Vector> = cluster_members
485 .iter()
486 .map(|&idx| &resources[idx].1)
487 .collect();
488
489 let stats = self.compute_cluster_stats(&cluster_vectors)?;
490 let centroid = if !cluster_vectors.is_empty() {
491 Some(self.compute_centroid(&cluster_vectors)?)
492 } else {
493 None
494 };
495
496 clusters.push(Cluster {
497 id: cluster_id,
498 members,
499 centroid,
500 stats,
501 });
502
503 cluster_id += 1;
504 }
505
506 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
507
508 Ok(ClusteringResult {
509 clusters,
510 noise: Vec::new(),
511 quality_metrics,
512 algorithm: ClusteringAlgorithm::Similarity,
513 config: self.config.clone(),
514 })
515 }
516
517 #[allow(deprecated)]
521 fn initialize_centroids_kmeans_plus_plus(
522 &self,
523 resources: &[(String, Vector)],
524 k: usize,
525 rng: &mut impl Rng,
526 ) -> Result<Vec<Vector>> {
527 let mut centroids = Vec::new();
528
529 let first_idx = rng.gen_range(0..resources.len());
531 centroids.push(resources[first_idx].1.clone());
532
533 for _ in 1..k {
535 let mut distances = Vec::new();
536 let mut total_distance = 0.0;
537
538 for (_, vector) in resources {
539 let min_dist_sq = centroids
540 .iter()
541 .map(|centroid| {
542 self.calculate_distance(vector, centroid)
543 .unwrap_or(f32::INFINITY)
544 })
545 .fold(f32::INFINITY, f32::min)
546 .powi(2);
547 distances.push(min_dist_sq);
548 total_distance += min_dist_sq;
549 }
550
551 let target = rng.random::<f32>() * total_distance;
552 let mut cumulative = 0.0;
553
554 for (i, &dist) in distances.iter().enumerate() {
555 cumulative += dist;
556 if cumulative >= target {
557 centroids.push(resources[i].1.clone());
558 break;
559 }
560 }
561 }
562
563 Ok(centroids)
564 }
565
566 fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
568 match self.config.distance_metric {
569 SimilarityMetric::Cosine => Ok(1.0 - v1.cosine_similarity(v2)?),
570 SimilarityMetric::Euclidean => v1.euclidean_distance(v2),
571 SimilarityMetric::Manhattan => v1.manhattan_distance(v2),
572 _ => Ok(1.0 - v1.cosine_similarity(v2)?), }
574 }
575
576 fn calculate_similarity(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
578 match self.config.distance_metric {
579 SimilarityMetric::Cosine => v1.cosine_similarity(v2),
580 SimilarityMetric::Euclidean => {
581 let dist = v1.euclidean_distance(v2)?;
582 Ok(1.0 / (1.0 + dist))
583 }
584 SimilarityMetric::Manhattan => {
585 let dist = v1.manhattan_distance(v2)?;
586 Ok(1.0 / (1.0 + dist))
587 }
588 _ => v1.cosine_similarity(v2), }
590 }
591
592 fn find_neighbors(
594 &self,
595 resources: &[(String, Vector)],
596 point_idx: usize,
597 eps: f32,
598 ) -> Result<Vec<usize>> {
599 let mut neighbors = Vec::new();
600 let point = &resources[point_idx].1;
601
602 for (i, (_, vector)) in resources.iter().enumerate() {
603 if i != point_idx {
604 let distance = self.calculate_distance(point, vector)?;
605 if distance <= eps {
606 neighbors.push(i);
607 }
608 }
609 }
610
611 Ok(neighbors)
612 }
613
614 fn compute_centroid(&self, vectors: &[&Vector]) -> Result<Vector> {
616 if vectors.is_empty() {
617 return Err(anyhow!("Cannot compute centroid of empty vector set"));
618 }
619
620 let dim = vectors[0].dimensions;
621 let mut centroid_data = vec![0.0; dim];
622
623 for vector in vectors {
624 let data = vector.as_f32();
625 for (i, &value) in data.iter().enumerate() {
626 centroid_data[i] += value;
627 }
628 }
629
630 let count = vectors.len() as f32;
631 for value in &mut centroid_data {
632 *value /= count;
633 }
634
635 Ok(Vector::new(centroid_data))
636 }
637
638 fn compute_cluster_stats(&self, vectors: &[&Vector]) -> Result<ClusterStats> {
640 if vectors.is_empty() {
641 return Ok(ClusterStats {
642 size: 0,
643 avg_intra_similarity: 0.0,
644 density: 0.0,
645 silhouette_score: 0.0,
646 });
647 }
648
649 let size = vectors.len();
650 let mut total_similarity = 0.0;
651 let mut pair_count = 0;
652
653 for i in 0..vectors.len() {
655 for j in (i + 1)..vectors.len() {
656 let similarity = self.calculate_similarity(vectors[i], vectors[j])?;
657 total_similarity += similarity;
658 pair_count += 1;
659 }
660 }
661
662 let avg_intra_similarity = if pair_count > 0 {
663 total_similarity / pair_count as f32
664 } else {
665 1.0 };
667
668 Ok(ClusterStats {
669 size,
670 avg_intra_similarity,
671 density: avg_intra_similarity, silhouette_score: 0.0, })
674 }
675
676 fn compute_distance_matrix(&self, resources: &[(String, Vector)]) -> Result<Vec<Vec<f32>>> {
678 let n = resources.len();
679 let mut matrix = vec![vec![0.0; n]; n];
680
681 for i in 0..n {
682 for j in (i + 1)..n {
683 let distance = self.calculate_distance(&resources[i].1, &resources[j].1)?;
684 matrix[i][j] = distance;
685 matrix[j][i] = distance;
686 }
687 }
688
689 Ok(matrix)
690 }
691
692 fn find_closest_clusters(
694 &self,
695 clusters: &[Vec<usize>],
696 distance_matrix: &[Vec<f32>],
697 ) -> Result<(usize, usize)> {
698 let mut min_distance = f32::INFINITY;
699 let mut closest_pair = (0, 1);
700
701 for i in 0..clusters.len() {
702 for j in (i + 1)..clusters.len() {
703 let distance = self.cluster_distance(&clusters[i], &clusters[j], distance_matrix);
704 if distance < min_distance {
705 min_distance = distance;
706 closest_pair = (i, j);
707 }
708 }
709 }
710
711 Ok(closest_pair)
712 }
713
714 fn cluster_distance(
716 &self,
717 cluster1: &[usize],
718 cluster2: &[usize],
719 distance_matrix: &[Vec<f32>],
720 ) -> f32 {
721 match self.config.linkage {
722 LinkageCriterion::Single => {
723 cluster1
725 .iter()
726 .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
727 .fold(f32::INFINITY, f32::min)
728 }
729 LinkageCriterion::Complete => {
730 cluster1
732 .iter()
733 .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
734 .fold(0.0, f32::max)
735 }
736 LinkageCriterion::Average => {
737 let mut total = 0.0;
739 let mut count = 0;
740 for &i in cluster1 {
741 for &j in cluster2 {
742 total += distance_matrix[i][j];
743 count += 1;
744 }
745 }
746 if count > 0 {
747 total / count as f32
748 } else {
749 0.0
750 }
751 }
752 LinkageCriterion::Ward => {
753 self.cluster_distance(cluster1, cluster2, distance_matrix)
755 }
756 }
757 }
758
759 fn update_distance_matrix(
761 &self,
762 distance_matrix: &mut Vec<Vec<f32>>,
763 _clusters: &[Vec<usize>],
764 _merged_cluster: usize,
765 resources: &[(String, Vector)],
766 ) -> Result<()> {
767 let new_matrix = self.compute_distance_matrix(resources)?;
769 *distance_matrix = new_matrix;
770 Ok(())
771 }
772
773 fn compute_quality_metrics(
775 &self,
776 resources: &[(String, Vector)],
777 clusters: &[Cluster],
778 ) -> Result<ClusteringQualityMetrics> {
779 let mut within_cluster_ss = 0.0;
781 let mut silhouette_scores = Vec::new();
782
783 for cluster in clusters {
784 if cluster.members.len() > 1 {
785 let cluster_vectors: Vec<&Vector> = cluster
786 .members
787 .iter()
788 .filter_map(|member| {
789 resources
790 .iter()
791 .find(|(id, _)| id == member)
792 .map(|(_, v)| v)
793 })
794 .collect();
795
796 if let Some(ref centroid) = cluster.centroid {
797 for vector in &cluster_vectors {
798 let dist = self.calculate_distance(vector, centroid)?;
799 within_cluster_ss += dist * dist;
800 }
801 }
802 }
803 }
804
805 for (cluster_idx, cluster) in clusters.iter().enumerate() {
807 let cluster_vectors: Vec<(usize, &Vector)> = cluster
808 .members
809 .iter()
810 .filter_map(|member| {
811 resources
812 .iter()
813 .enumerate()
814 .find(|(_, (id, _))| id == member)
815 .map(|(idx, (_, v))| (idx, v))
816 })
817 .collect();
818
819 for (point_idx, point_vector) in &cluster_vectors {
821 if cluster_vectors.len() <= 1 {
822 silhouette_scores.push(0.0);
824 continue;
825 }
826
827 let mut intra_cluster_dist = 0.0;
829 let mut intra_count = 0;
830 for (other_idx, other_vector) in &cluster_vectors {
831 if point_idx != other_idx {
832 let dist = self.calculate_distance(point_vector, other_vector)?;
833 intra_cluster_dist += dist;
834 intra_count += 1;
835 }
836 }
837 let a = if intra_count > 0 {
838 intra_cluster_dist / intra_count as f32
839 } else {
840 0.0
841 };
842
843 let mut min_inter_cluster_dist = f32::INFINITY;
845 for (other_cluster_idx, other_cluster) in clusters.iter().enumerate() {
846 if cluster_idx != other_cluster_idx {
847 let other_cluster_vectors: Vec<&Vector> = other_cluster
848 .members
849 .iter()
850 .filter_map(|member| {
851 resources
852 .iter()
853 .find(|(id, _)| id == member)
854 .map(|(_, v)| v)
855 })
856 .collect();
857
858 if !other_cluster_vectors.is_empty() {
859 let mut inter_cluster_dist = 0.0;
860 for other_vector in &other_cluster_vectors {
861 let dist = self.calculate_distance(point_vector, other_vector)?;
862 inter_cluster_dist += dist;
863 }
864 let avg_dist = inter_cluster_dist / other_cluster_vectors.len() as f32;
865 min_inter_cluster_dist = min_inter_cluster_dist.min(avg_dist);
866 }
867 }
868 }
869 let b = min_inter_cluster_dist;
870
871 let silhouette = if a.max(b) > 0.0 {
873 (b - a) / a.max(b)
874 } else {
875 0.0
876 };
877 silhouette_scores.push(silhouette);
878 }
879 }
880
881 let silhouette_score = if !silhouette_scores.is_empty() {
882 silhouette_scores.iter().sum::<f32>() / silhouette_scores.len() as f32
883 } else {
884 0.0
885 };
886
887 let davies_bouldin_index = self.calculate_davies_bouldin_index(resources, clusters)?;
889
890 let calinski_harabasz_index =
892 self.calculate_calinski_harabasz_index(resources, clusters, within_cluster_ss)?;
893
894 let between_cluster_ss = self.calculate_between_cluster_ss(resources, clusters)?;
896
897 Ok(ClusteringQualityMetrics {
898 silhouette_score,
899 davies_bouldin_index,
900 calinski_harabasz_index,
901 within_cluster_ss,
902 between_cluster_ss,
903 })
904 }
905
906 fn calculate_davies_bouldin_index(
908 &self,
909 resources: &[(String, Vector)],
910 clusters: &[Cluster],
911 ) -> Result<f32> {
912 if clusters.len() <= 1 {
913 return Ok(0.0);
914 }
915
916 let mut db_sum = 0.0;
917 for i in 0..clusters.len() {
918 let mut max_ratio: f32 = 0.0;
919
920 let cluster_i_vectors: Vec<&Vector> = clusters[i]
922 .members
923 .iter()
924 .filter_map(|member| {
925 resources
926 .iter()
927 .find(|(id, _)| id == member)
928 .map(|(_, v)| v)
929 })
930 .collect();
931
932 if cluster_i_vectors.is_empty() {
933 continue;
934 }
935
936 let centroid_i = self.compute_centroid(&cluster_i_vectors)?;
938
939 let mut avg_dist_i = 0.0;
941 for vector in &cluster_i_vectors {
942 avg_dist_i += self.calculate_distance(vector, ¢roid_i)?;
943 }
944 avg_dist_i /= cluster_i_vectors.len() as f32;
945
946 for (j, cluster_j) in clusters.iter().enumerate() {
947 if i == j {
948 continue;
949 }
950
951 let cluster_j_vectors: Vec<&Vector> = cluster_j
953 .members
954 .iter()
955 .filter_map(|member| {
956 resources
957 .iter()
958 .find(|(id, _)| id == member)
959 .map(|(_, v)| v)
960 })
961 .collect();
962
963 if cluster_j_vectors.is_empty() {
964 continue;
965 }
966
967 let centroid_j = self.compute_centroid(&cluster_j_vectors)?;
969
970 let mut avg_dist_j = 0.0;
972 for vector in &cluster_j_vectors {
973 avg_dist_j += self.calculate_distance(vector, ¢roid_j)?;
974 }
975 avg_dist_j /= cluster_j_vectors.len() as f32;
976
977 let centroid_distance = self.calculate_distance(¢roid_i, ¢roid_j)?;
979
980 if centroid_distance > 0.0 {
982 let ratio: f32 = (avg_dist_i + avg_dist_j) / centroid_distance;
983 max_ratio = max_ratio.max(ratio);
984 }
985 }
986 db_sum += max_ratio;
987 }
988
989 Ok(db_sum / clusters.len() as f32)
990 }
991
992 fn calculate_calinski_harabasz_index(
994 &self,
995 resources: &[(String, Vector)],
996 clusters: &[Cluster],
997 within_cluster_ss: f32,
998 ) -> Result<f32> {
999 if clusters.len() <= 1 || resources.is_empty() {
1000 return Ok(0.0);
1001 }
1002
1003 let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1005 let overall_centroid = self.compute_centroid(&all_vectors)?;
1006
1007 let mut between_cluster_ss = 0.0;
1009 for cluster in clusters {
1010 let cluster_vectors: Vec<&Vector> = cluster
1011 .members
1012 .iter()
1013 .filter_map(|member| {
1014 resources
1015 .iter()
1016 .find(|(id, _)| id == member)
1017 .map(|(_, v)| v)
1018 })
1019 .collect();
1020
1021 if !cluster_vectors.is_empty() {
1022 let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1023 let distance_sq = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1024 between_cluster_ss += cluster_vectors.len() as f32 * distance_sq * distance_sq;
1025 }
1026 }
1027
1028 let k = clusters.len() as f32;
1030 let n = resources.len() as f32;
1031
1032 if k >= n || within_cluster_ss <= 0.0 {
1033 return Ok(0.0);
1034 }
1035
1036 let ch_index = (between_cluster_ss / (k - 1.0)) / (within_cluster_ss / (n - k));
1037 Ok(ch_index)
1038 }
1039
1040 fn calculate_between_cluster_ss(
1042 &self,
1043 resources: &[(String, Vector)],
1044 clusters: &[Cluster],
1045 ) -> Result<f32> {
1046 if clusters.is_empty() || resources.is_empty() {
1047 return Ok(0.0);
1048 }
1049
1050 let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1052 let overall_centroid = self.compute_centroid(&all_vectors)?;
1053
1054 let mut between_cluster_ss = 0.0;
1055 for cluster in clusters {
1056 let cluster_vectors: Vec<&Vector> = cluster
1057 .members
1058 .iter()
1059 .filter_map(|member| {
1060 resources
1061 .iter()
1062 .find(|(id, _)| id == member)
1063 .map(|(_, v)| v)
1064 })
1065 .collect();
1066
1067 if !cluster_vectors.is_empty() {
1068 let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1069 let distance = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1070 between_cluster_ss += cluster_vectors.len() as f32 * distance * distance;
1071 }
1072 }
1073
1074 Ok(between_cluster_ss)
1075 }
1076}
1077
1078impl Default for ClusteringQualityMetrics {
1079 fn default() -> Self {
1080 Self {
1081 silhouette_score: 0.0,
1082 davies_bouldin_index: 0.0,
1083 calinski_harabasz_index: 0.0,
1084 within_cluster_ss: 0.0,
1085 between_cluster_ss: 0.0,
1086 }
1087 }
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092 use super::*;
1093
1094 #[test]
1095 fn test_kmeans_clustering() {
1096 let config = ClusteringConfig {
1097 algorithm: ClusteringAlgorithm::KMeans,
1098 num_clusters: Some(2),
1099 random_seed: Some(42),
1100 distance_metric: SimilarityMetric::Euclidean, ..Default::default()
1102 };
1103
1104 let engine = ClusteringEngine::new(config);
1105
1106 let resources = vec![
1107 ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1108 ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1109 ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1110 ("res4".to_string(), Vector::new(vec![10.1, 10.1, 10.1])),
1111 ];
1112
1113 let result = engine.cluster(&resources).unwrap();
1114
1115 assert_eq!(result.clusters.len(), 2);
1116 assert!(result.noise.is_empty());
1117 }
1118
1119 #[test]
1120 fn test_dbscan_clustering() {
1121 let config = ClusteringConfig {
1122 algorithm: ClusteringAlgorithm::DBSCAN,
1123 similarity_threshold: 0.9,
1124 min_cluster_size: 2,
1125 ..Default::default()
1126 };
1127
1128 let engine = ClusteringEngine::new(config);
1129
1130 let resources = vec![
1131 ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1132 ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1133 ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1134 ];
1135
1136 let result = engine.cluster(&resources).unwrap();
1137 assert!(result.clusters.len() <= 2);
1138 }
1139
1140 #[test]
1141 fn test_similarity_clustering() {
1142 let config = ClusteringConfig {
1143 algorithm: ClusteringAlgorithm::Similarity,
1144 similarity_threshold: 0.95,
1145 ..Default::default()
1146 };
1147
1148 let engine = ClusteringEngine::new(config);
1149
1150 let resources = vec![
1151 ("res1".to_string(), Vector::new(vec![1.0, 0.0, 0.0])),
1152 ("res2".to_string(), Vector::new(vec![0.0, 1.0, 0.0])),
1153 ("res3".to_string(), Vector::new(vec![0.0, 0.0, 1.0])),
1154 ];
1155
1156 let result = engine.cluster(&resources).unwrap();
1157 assert_eq!(result.clusters.len(), 3);
1159 }
1160}