1use crate::{similarity::SimilarityMetric, Vector};
11use anyhow::{anyhow, Result};
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
18pub enum ClusteringAlgorithm {
19 KMeans,
21 DBSCAN,
23 Hierarchical,
25 Spectral,
27 Community,
29 Similarity,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ClusteringConfig {
36 pub algorithm: ClusteringAlgorithm,
38 pub num_clusters: Option<usize>,
40 pub similarity_threshold: f32,
42 pub min_cluster_size: usize,
44 pub distance_metric: SimilarityMetric,
46 pub max_iterations: usize,
48 pub random_seed: Option<u64>,
50 pub tolerance: f32,
52 pub linkage: LinkageCriterion,
54}
55
56impl Default for ClusteringConfig {
57 fn default() -> Self {
58 Self {
59 algorithm: ClusteringAlgorithm::KMeans,
60 num_clusters: Some(3),
61 similarity_threshold: 0.7,
62 min_cluster_size: 3,
63 distance_metric: SimilarityMetric::Cosine,
64 max_iterations: 100,
65 random_seed: None,
66 tolerance: 1e-4,
67 linkage: LinkageCriterion::Average,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub enum LinkageCriterion {
75 Single,
77 Complete,
79 Average,
81 Ward,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Cluster {
88 pub id: usize,
90 pub members: Vec<String>,
92 pub centroid: Option<Vector>,
94 pub stats: ClusterStats,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ClusterStats {
101 pub size: usize,
103 pub avg_intra_similarity: f32,
105 pub density: f32,
107 pub silhouette_score: f32,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ClusteringResult {
114 pub clusters: Vec<Cluster>,
116 pub noise: Vec<String>,
118 pub quality_metrics: ClusteringQualityMetrics,
120 pub algorithm: ClusteringAlgorithm,
122 pub config: ClusteringConfig,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ClusteringQualityMetrics {
129 pub silhouette_score: f32,
131 pub davies_bouldin_index: f32,
133 pub calinski_harabasz_index: f32,
135 pub within_cluster_ss: f32,
137 pub between_cluster_ss: f32,
139}
140
141pub struct ClusteringEngine {
143 config: ClusteringConfig,
144}
145
146impl ClusteringEngine {
147 pub fn new(config: ClusteringConfig) -> Self {
148 Self { config }
149 }
150
151 pub fn cluster(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
153 if resources.is_empty() {
154 return Ok(ClusteringResult {
155 clusters: Vec::new(),
156 noise: Vec::new(),
157 quality_metrics: ClusteringQualityMetrics::default(),
158 algorithm: self.config.algorithm,
159 config: self.config.clone(),
160 });
161 }
162
163 match self.config.algorithm {
164 ClusteringAlgorithm::KMeans => self.kmeans_clustering(resources),
165 ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(resources),
166 ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(resources),
167 ClusteringAlgorithm::Spectral => self.spectral_clustering(resources),
168 ClusteringAlgorithm::Community => self.community_detection(resources),
169 ClusteringAlgorithm::Similarity => self.similarity_clustering(resources),
170 }
171 }
172
173 fn kmeans_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
175 let k = self.config.num_clusters.unwrap_or(3);
176 if k >= resources.len() {
177 return Err(anyhow!(
178 "Number of clusters must be less than number of resources"
179 ));
180 }
181
182 let mut rng = if let Some(seed) = self.config.random_seed {
183 Random::seed(seed)
184 } else {
185 Random::seed(42)
186 };
187
188 let mut centroids = self.initialize_centroids_kmeans_plus_plus(resources, k, &mut rng)?;
190 let mut assignments = vec![0; resources.len()];
191 let mut prev_assignments = vec![usize::MAX; resources.len()];
192
193 for iteration in 0..self.config.max_iterations {
194 for (i, (_, vector)) in resources.iter().enumerate() {
196 let mut best_cluster = 0;
197 let mut best_distance = f32::INFINITY;
198
199 for (cluster_id, centroid) in centroids.iter().enumerate() {
200 let distance = self.calculate_distance(vector, centroid)?;
201 if distance < best_distance {
202 best_distance = distance;
203 best_cluster = cluster_id;
204 }
205 }
206 assignments[i] = best_cluster;
207 }
208
209 if assignments == prev_assignments {
211 break;
212 }
213
214 for (cluster_id, centroid) in centroids.iter_mut().enumerate().take(k) {
216 let cluster_vectors: Vec<&Vector> = resources
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| assignments[*i] == cluster_id)
220 .map(|(_, (_, vector))| vector)
221 .collect();
222
223 if !cluster_vectors.is_empty() {
224 *centroid = self.compute_centroid(&cluster_vectors)?;
225 }
226 }
227
228 prev_assignments = assignments.clone();
229
230 if iteration > 0 && iteration % 10 == 0 {
231 println!(
232 "K-means iteration {}/{}",
233 iteration, self.config.max_iterations
234 );
235 }
236 }
237
238 let mut clusters = Vec::new();
240 for (cluster_id, centroid) in centroids.iter().enumerate().take(k) {
241 let members: Vec<String> = resources
242 .iter()
243 .enumerate()
244 .filter(|(i, _)| assignments[*i] == cluster_id)
245 .map(|(_, (resource_id, _))| resource_id.clone())
246 .collect();
247
248 if !members.is_empty() {
249 let cluster_vectors: Vec<&Vector> = resources
250 .iter()
251 .enumerate()
252 .filter(|(i, _)| assignments[*i] == cluster_id)
253 .map(|(_, (_, vector))| vector)
254 .collect();
255
256 let stats = self.compute_cluster_stats(&cluster_vectors)?;
257
258 clusters.push(Cluster {
259 id: cluster_id,
260 members,
261 centroid: Some(centroid.clone()),
262 stats,
263 });
264 }
265 }
266
267 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
268
269 Ok(ClusteringResult {
270 clusters,
271 noise: Vec::new(),
272 quality_metrics,
273 algorithm: ClusteringAlgorithm::KMeans,
274 config: self.config.clone(),
275 })
276 }
277
278 fn dbscan_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
280 let eps = 1.0 - self.config.similarity_threshold; let min_pts = self.config.min_cluster_size;
282
283 let mut visited = vec![false; resources.len()];
284 let mut cluster_assignments = vec![None; resources.len()];
285 let mut cluster_id = 0;
286 let mut noise_points = Vec::new();
287
288 for i in 0..resources.len() {
289 if visited[i] {
290 continue;
291 }
292 visited[i] = true;
293
294 let neighbors = self.find_neighbors(resources, i, eps)?;
295
296 if neighbors.len() < min_pts {
297 noise_points.push(resources[i].0.clone());
298 } else {
299 let mut cluster_queue = VecDeque::new();
300 cluster_queue.push_back(i);
301 cluster_assignments[i] = Some(cluster_id);
302
303 while let Some(point_idx) = cluster_queue.pop_front() {
304 let point_neighbors = self.find_neighbors(resources, point_idx, eps)?;
305
306 if point_neighbors.len() >= min_pts {
307 for &neighbor_idx in &point_neighbors {
308 if !visited[neighbor_idx] {
309 visited[neighbor_idx] = true;
310 cluster_queue.push_back(neighbor_idx);
311 }
312 if cluster_assignments[neighbor_idx].is_none() {
313 cluster_assignments[neighbor_idx] = Some(cluster_id);
314 }
315 }
316 }
317 }
318 cluster_id += 1;
319 }
320 }
321
322 let mut clusters = Vec::new();
324 for cid in 0..cluster_id {
325 let members: Vec<String> = resources
326 .iter()
327 .enumerate()
328 .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
329 .map(|(_, (resource_id, _))| resource_id.clone())
330 .collect();
331
332 if !members.is_empty() {
333 let cluster_vectors: Vec<&Vector> = resources
334 .iter()
335 .enumerate()
336 .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
337 .map(|(_, (_, vector))| vector)
338 .collect();
339
340 let stats = self.compute_cluster_stats(&cluster_vectors)?;
341 let centroid = if !cluster_vectors.is_empty() {
342 Some(self.compute_centroid(&cluster_vectors)?)
343 } else {
344 None
345 };
346
347 clusters.push(Cluster {
348 id: cid,
349 members,
350 centroid,
351 stats,
352 });
353 }
354 }
355
356 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
357
358 Ok(ClusteringResult {
359 clusters,
360 noise: noise_points,
361 quality_metrics,
362 algorithm: ClusteringAlgorithm::DBSCAN,
363 config: self.config.clone(),
364 })
365 }
366
367 fn hierarchical_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
369 let target_clusters = self.config.num_clusters.unwrap_or(3);
370
371 let mut clusters: Vec<Vec<usize>> = (0..resources.len()).map(|i| vec![i]).collect();
373
374 let mut distance_matrix = self.compute_distance_matrix(resources)?;
376
377 while clusters.len() > target_clusters {
379 let (min_i, min_j) = self.find_closest_clusters(&clusters, &distance_matrix)?;
380
381 let cluster_j = clusters.remove(min_j.max(min_i));
383 clusters[min_i.min(min_j)].extend(cluster_j);
384
385 self.update_distance_matrix(
387 &mut distance_matrix,
388 &clusters,
389 min_i.min(min_j),
390 resources,
391 )?;
392 }
393
394 let mut result_clusters = Vec::new();
396 for (cluster_id, cluster_indices) in clusters.iter().enumerate() {
397 let members: Vec<String> = cluster_indices
398 .iter()
399 .map(|&idx| resources[idx].0.clone())
400 .collect();
401
402 let cluster_vectors: Vec<&Vector> = cluster_indices
403 .iter()
404 .map(|&idx| &resources[idx].1)
405 .collect();
406
407 let stats = self.compute_cluster_stats(&cluster_vectors)?;
408 let centroid = if !cluster_vectors.is_empty() {
409 Some(self.compute_centroid(&cluster_vectors)?)
410 } else {
411 None
412 };
413
414 result_clusters.push(Cluster {
415 id: cluster_id,
416 members,
417 centroid,
418 stats,
419 });
420 }
421
422 let quality_metrics = self.compute_quality_metrics(resources, &result_clusters)?;
423
424 Ok(ClusteringResult {
425 clusters: result_clusters,
426 noise: Vec::new(),
427 quality_metrics,
428 algorithm: ClusteringAlgorithm::Hierarchical,
429 config: self.config.clone(),
430 })
431 }
432
433 fn spectral_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
435 println!("Spectral clustering not yet fully implemented, falling back to k-means");
438 self.kmeans_clustering(resources)
439 }
440
441 fn community_detection(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
443 println!(
446 "Community detection not yet fully implemented, falling back to similarity clustering"
447 );
448 self.similarity_clustering(resources)
449 }
450
451 fn similarity_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
453 let threshold = self.config.similarity_threshold;
454 let mut clusters = Vec::new();
455 let mut assigned = vec![false; resources.len()];
456 let mut cluster_id = 0;
457
458 for i in 0..resources.len() {
459 if assigned[i] {
460 continue;
461 }
462
463 let mut cluster_members = vec![i];
464 assigned[i] = true;
465
466 for j in (i + 1)..resources.len() {
468 if assigned[j] {
469 continue;
470 }
471
472 let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
473 if similarity >= threshold {
474 cluster_members.push(j);
475 assigned[j] = true;
476 }
477 }
478
479 let members: Vec<String> = cluster_members
480 .iter()
481 .map(|&idx| resources[idx].0.clone())
482 .collect();
483
484 let cluster_vectors: Vec<&Vector> = cluster_members
485 .iter()
486 .map(|&idx| &resources[idx].1)
487 .collect();
488
489 let stats = self.compute_cluster_stats(&cluster_vectors)?;
490 let centroid = if !cluster_vectors.is_empty() {
491 Some(self.compute_centroid(&cluster_vectors)?)
492 } else {
493 None
494 };
495
496 clusters.push(Cluster {
497 id: cluster_id,
498 members,
499 centroid,
500 stats,
501 });
502
503 cluster_id += 1;
504 }
505
506 let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
507
508 Ok(ClusteringResult {
509 clusters,
510 noise: Vec::new(),
511 quality_metrics,
512 algorithm: ClusteringAlgorithm::Similarity,
513 config: self.config.clone(),
514 })
515 }
516
517 fn initialize_centroids_kmeans_plus_plus(
521 &self,
522 resources: &[(String, Vector)],
523 k: usize,
524 rng: &mut impl Rng,
525 ) -> Result<Vec<Vector>> {
526 let mut centroids = Vec::new();
527
528 let first_idx = rng.gen_range(0..resources.len());
530 centroids.push(resources[first_idx].1.clone());
531
532 for _ in 1..k {
534 let mut distances = Vec::new();
535 let mut total_distance = 0.0;
536
537 for (_, vector) in resources {
538 let min_dist_sq = centroids
539 .iter()
540 .map(|centroid| {
541 self.calculate_distance(vector, centroid)
542 .unwrap_or(f32::INFINITY)
543 })
544 .fold(f32::INFINITY, f32::min)
545 .powi(2);
546 distances.push(min_dist_sq);
547 total_distance += min_dist_sq;
548 }
549
550 let target = rng.gen::<f32>() * total_distance;
551 let mut cumulative = 0.0;
552
553 for (i, &dist) in distances.iter().enumerate() {
554 cumulative += dist;
555 if cumulative >= target {
556 centroids.push(resources[i].1.clone());
557 break;
558 }
559 }
560 }
561
562 Ok(centroids)
563 }
564
565 fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
567 match self.config.distance_metric {
568 SimilarityMetric::Cosine => Ok(1.0 - v1.cosine_similarity(v2)?),
569 SimilarityMetric::Euclidean => v1.euclidean_distance(v2),
570 SimilarityMetric::Manhattan => v1.manhattan_distance(v2),
571 _ => Ok(1.0 - v1.cosine_similarity(v2)?), }
573 }
574
575 fn calculate_similarity(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
577 match self.config.distance_metric {
578 SimilarityMetric::Cosine => v1.cosine_similarity(v2),
579 SimilarityMetric::Euclidean => {
580 let dist = v1.euclidean_distance(v2)?;
581 Ok(1.0 / (1.0 + dist))
582 }
583 SimilarityMetric::Manhattan => {
584 let dist = v1.manhattan_distance(v2)?;
585 Ok(1.0 / (1.0 + dist))
586 }
587 _ => v1.cosine_similarity(v2), }
589 }
590
591 fn find_neighbors(
593 &self,
594 resources: &[(String, Vector)],
595 point_idx: usize,
596 eps: f32,
597 ) -> Result<Vec<usize>> {
598 let mut neighbors = Vec::new();
599 let point = &resources[point_idx].1;
600
601 for (i, (_, vector)) in resources.iter().enumerate() {
602 if i != point_idx {
603 let distance = self.calculate_distance(point, vector)?;
604 if distance <= eps {
605 neighbors.push(i);
606 }
607 }
608 }
609
610 Ok(neighbors)
611 }
612
613 fn compute_centroid(&self, vectors: &[&Vector]) -> Result<Vector> {
615 if vectors.is_empty() {
616 return Err(anyhow!("Cannot compute centroid of empty vector set"));
617 }
618
619 let dim = vectors[0].dimensions;
620 let mut centroid_data = vec![0.0; dim];
621
622 for vector in vectors {
623 let data = vector.as_f32();
624 for (i, &value) in data.iter().enumerate() {
625 centroid_data[i] += value;
626 }
627 }
628
629 let count = vectors.len() as f32;
630 for value in &mut centroid_data {
631 *value /= count;
632 }
633
634 Ok(Vector::new(centroid_data))
635 }
636
637 fn compute_cluster_stats(&self, vectors: &[&Vector]) -> Result<ClusterStats> {
639 if vectors.is_empty() {
640 return Ok(ClusterStats {
641 size: 0,
642 avg_intra_similarity: 0.0,
643 density: 0.0,
644 silhouette_score: 0.0,
645 });
646 }
647
648 let size = vectors.len();
649 let mut total_similarity = 0.0;
650 let mut pair_count = 0;
651
652 for i in 0..vectors.len() {
654 for j in (i + 1)..vectors.len() {
655 let similarity = self.calculate_similarity(vectors[i], vectors[j])?;
656 total_similarity += similarity;
657 pair_count += 1;
658 }
659 }
660
661 let avg_intra_similarity = if pair_count > 0 {
662 total_similarity / pair_count as f32
663 } else {
664 1.0 };
666
667 Ok(ClusterStats {
668 size,
669 avg_intra_similarity,
670 density: avg_intra_similarity, silhouette_score: 0.0, })
673 }
674
675 fn compute_distance_matrix(&self, resources: &[(String, Vector)]) -> Result<Vec<Vec<f32>>> {
677 let n = resources.len();
678 let mut matrix = vec![vec![0.0; n]; n];
679
680 for i in 0..n {
681 for j in (i + 1)..n {
682 let distance = self.calculate_distance(&resources[i].1, &resources[j].1)?;
683 matrix[i][j] = distance;
684 matrix[j][i] = distance;
685 }
686 }
687
688 Ok(matrix)
689 }
690
691 fn find_closest_clusters(
693 &self,
694 clusters: &[Vec<usize>],
695 distance_matrix: &[Vec<f32>],
696 ) -> Result<(usize, usize)> {
697 let mut min_distance = f32::INFINITY;
698 let mut closest_pair = (0, 1);
699
700 for i in 0..clusters.len() {
701 for j in (i + 1)..clusters.len() {
702 let distance = self.cluster_distance(&clusters[i], &clusters[j], distance_matrix);
703 if distance < min_distance {
704 min_distance = distance;
705 closest_pair = (i, j);
706 }
707 }
708 }
709
710 Ok(closest_pair)
711 }
712
713 fn cluster_distance(
715 &self,
716 cluster1: &[usize],
717 cluster2: &[usize],
718 distance_matrix: &[Vec<f32>],
719 ) -> f32 {
720 match self.config.linkage {
721 LinkageCriterion::Single => {
722 cluster1
724 .iter()
725 .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
726 .fold(f32::INFINITY, f32::min)
727 }
728 LinkageCriterion::Complete => {
729 cluster1
731 .iter()
732 .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
733 .fold(0.0, f32::max)
734 }
735 LinkageCriterion::Average => {
736 let mut total = 0.0;
738 let mut count = 0;
739 for &i in cluster1 {
740 for &j in cluster2 {
741 total += distance_matrix[i][j];
742 count += 1;
743 }
744 }
745 if count > 0 {
746 total / count as f32
747 } else {
748 0.0
749 }
750 }
751 LinkageCriterion::Ward => {
752 self.cluster_distance(cluster1, cluster2, distance_matrix)
754 }
755 }
756 }
757
758 fn update_distance_matrix(
760 &self,
761 distance_matrix: &mut Vec<Vec<f32>>,
762 _clusters: &[Vec<usize>],
763 _merged_cluster: usize,
764 resources: &[(String, Vector)],
765 ) -> Result<()> {
766 let new_matrix = self.compute_distance_matrix(resources)?;
768 *distance_matrix = new_matrix;
769 Ok(())
770 }
771
772 fn compute_quality_metrics(
774 &self,
775 resources: &[(String, Vector)],
776 clusters: &[Cluster],
777 ) -> Result<ClusteringQualityMetrics> {
778 let mut within_cluster_ss = 0.0;
780 let mut silhouette_scores = Vec::new();
781
782 for cluster in clusters {
783 if cluster.members.len() > 1 {
784 let cluster_vectors: Vec<&Vector> = cluster
785 .members
786 .iter()
787 .filter_map(|member| {
788 resources
789 .iter()
790 .find(|(id, _)| id == member)
791 .map(|(_, v)| v)
792 })
793 .collect();
794
795 if let Some(ref centroid) = cluster.centroid {
796 for vector in &cluster_vectors {
797 let dist = self.calculate_distance(vector, centroid)?;
798 within_cluster_ss += dist * dist;
799 }
800 }
801 }
802 }
803
804 for (cluster_idx, cluster) in clusters.iter().enumerate() {
806 let cluster_vectors: Vec<(usize, &Vector)> = cluster
807 .members
808 .iter()
809 .filter_map(|member| {
810 resources
811 .iter()
812 .enumerate()
813 .find(|(_, (id, _))| id == member)
814 .map(|(idx, (_, v))| (idx, v))
815 })
816 .collect();
817
818 for (point_idx, point_vector) in &cluster_vectors {
820 if cluster_vectors.len() <= 1 {
821 silhouette_scores.push(0.0);
823 continue;
824 }
825
826 let mut intra_cluster_dist = 0.0;
828 let mut intra_count = 0;
829 for (other_idx, other_vector) in &cluster_vectors {
830 if point_idx != other_idx {
831 let dist = self.calculate_distance(point_vector, other_vector)?;
832 intra_cluster_dist += dist;
833 intra_count += 1;
834 }
835 }
836 let a = if intra_count > 0 {
837 intra_cluster_dist / intra_count as f32
838 } else {
839 0.0
840 };
841
842 let mut min_inter_cluster_dist = f32::INFINITY;
844 for (other_cluster_idx, other_cluster) in clusters.iter().enumerate() {
845 if cluster_idx != other_cluster_idx {
846 let other_cluster_vectors: Vec<&Vector> = other_cluster
847 .members
848 .iter()
849 .filter_map(|member| {
850 resources
851 .iter()
852 .find(|(id, _)| id == member)
853 .map(|(_, v)| v)
854 })
855 .collect();
856
857 if !other_cluster_vectors.is_empty() {
858 let mut inter_cluster_dist = 0.0;
859 for other_vector in &other_cluster_vectors {
860 let dist = self.calculate_distance(point_vector, other_vector)?;
861 inter_cluster_dist += dist;
862 }
863 let avg_dist = inter_cluster_dist / other_cluster_vectors.len() as f32;
864 min_inter_cluster_dist = min_inter_cluster_dist.min(avg_dist);
865 }
866 }
867 }
868 let b = min_inter_cluster_dist;
869
870 let silhouette = if a.max(b) > 0.0 {
872 (b - a) / a.max(b)
873 } else {
874 0.0
875 };
876 silhouette_scores.push(silhouette);
877 }
878 }
879
880 let silhouette_score = if !silhouette_scores.is_empty() {
881 silhouette_scores.iter().sum::<f32>() / silhouette_scores.len() as f32
882 } else {
883 0.0
884 };
885
886 let davies_bouldin_index = self.calculate_davies_bouldin_index(resources, clusters)?;
888
889 let calinski_harabasz_index =
891 self.calculate_calinski_harabasz_index(resources, clusters, within_cluster_ss)?;
892
893 let between_cluster_ss = self.calculate_between_cluster_ss(resources, clusters)?;
895
896 Ok(ClusteringQualityMetrics {
897 silhouette_score,
898 davies_bouldin_index,
899 calinski_harabasz_index,
900 within_cluster_ss,
901 between_cluster_ss,
902 })
903 }
904
905 fn calculate_davies_bouldin_index(
907 &self,
908 resources: &[(String, Vector)],
909 clusters: &[Cluster],
910 ) -> Result<f32> {
911 if clusters.len() <= 1 {
912 return Ok(0.0);
913 }
914
915 let mut db_sum = 0.0;
916 for i in 0..clusters.len() {
917 let mut max_ratio: f32 = 0.0;
918
919 let cluster_i_vectors: Vec<&Vector> = clusters[i]
921 .members
922 .iter()
923 .filter_map(|member| {
924 resources
925 .iter()
926 .find(|(id, _)| id == member)
927 .map(|(_, v)| v)
928 })
929 .collect();
930
931 if cluster_i_vectors.is_empty() {
932 continue;
933 }
934
935 let centroid_i = self.compute_centroid(&cluster_i_vectors)?;
937
938 let mut avg_dist_i = 0.0;
940 for vector in &cluster_i_vectors {
941 avg_dist_i += self.calculate_distance(vector, ¢roid_i)?;
942 }
943 avg_dist_i /= cluster_i_vectors.len() as f32;
944
945 for (j, cluster_j) in clusters.iter().enumerate() {
946 if i == j {
947 continue;
948 }
949
950 let cluster_j_vectors: Vec<&Vector> = cluster_j
952 .members
953 .iter()
954 .filter_map(|member| {
955 resources
956 .iter()
957 .find(|(id, _)| id == member)
958 .map(|(_, v)| v)
959 })
960 .collect();
961
962 if cluster_j_vectors.is_empty() {
963 continue;
964 }
965
966 let centroid_j = self.compute_centroid(&cluster_j_vectors)?;
968
969 let mut avg_dist_j = 0.0;
971 for vector in &cluster_j_vectors {
972 avg_dist_j += self.calculate_distance(vector, ¢roid_j)?;
973 }
974 avg_dist_j /= cluster_j_vectors.len() as f32;
975
976 let centroid_distance = self.calculate_distance(¢roid_i, ¢roid_j)?;
978
979 if centroid_distance > 0.0 {
981 let ratio: f32 = (avg_dist_i + avg_dist_j) / centroid_distance;
982 max_ratio = max_ratio.max(ratio);
983 }
984 }
985 db_sum += max_ratio;
986 }
987
988 Ok(db_sum / clusters.len() as f32)
989 }
990
991 fn calculate_calinski_harabasz_index(
993 &self,
994 resources: &[(String, Vector)],
995 clusters: &[Cluster],
996 within_cluster_ss: f32,
997 ) -> Result<f32> {
998 if clusters.len() <= 1 || resources.is_empty() {
999 return Ok(0.0);
1000 }
1001
1002 let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1004 let overall_centroid = self.compute_centroid(&all_vectors)?;
1005
1006 let mut between_cluster_ss = 0.0;
1008 for cluster in clusters {
1009 let cluster_vectors: Vec<&Vector> = cluster
1010 .members
1011 .iter()
1012 .filter_map(|member| {
1013 resources
1014 .iter()
1015 .find(|(id, _)| id == member)
1016 .map(|(_, v)| v)
1017 })
1018 .collect();
1019
1020 if !cluster_vectors.is_empty() {
1021 let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1022 let distance_sq = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1023 between_cluster_ss += cluster_vectors.len() as f32 * distance_sq * distance_sq;
1024 }
1025 }
1026
1027 let k = clusters.len() as f32;
1029 let n = resources.len() as f32;
1030
1031 if k >= n || within_cluster_ss <= 0.0 {
1032 return Ok(0.0);
1033 }
1034
1035 let ch_index = (between_cluster_ss / (k - 1.0)) / (within_cluster_ss / (n - k));
1036 Ok(ch_index)
1037 }
1038
1039 fn calculate_between_cluster_ss(
1041 &self,
1042 resources: &[(String, Vector)],
1043 clusters: &[Cluster],
1044 ) -> Result<f32> {
1045 if clusters.is_empty() || resources.is_empty() {
1046 return Ok(0.0);
1047 }
1048
1049 let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1051 let overall_centroid = self.compute_centroid(&all_vectors)?;
1052
1053 let mut between_cluster_ss = 0.0;
1054 for cluster in clusters {
1055 let cluster_vectors: Vec<&Vector> = cluster
1056 .members
1057 .iter()
1058 .filter_map(|member| {
1059 resources
1060 .iter()
1061 .find(|(id, _)| id == member)
1062 .map(|(_, v)| v)
1063 })
1064 .collect();
1065
1066 if !cluster_vectors.is_empty() {
1067 let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1068 let distance = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1069 between_cluster_ss += cluster_vectors.len() as f32 * distance * distance;
1070 }
1071 }
1072
1073 Ok(between_cluster_ss)
1074 }
1075}
1076
1077impl Default for ClusteringQualityMetrics {
1078 fn default() -> Self {
1079 Self {
1080 silhouette_score: 0.0,
1081 davies_bouldin_index: 0.0,
1082 calinski_harabasz_index: 0.0,
1083 within_cluster_ss: 0.0,
1084 between_cluster_ss: 0.0,
1085 }
1086 }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use super::*;
1092
1093 #[test]
1094 fn test_kmeans_clustering() {
1095 let config = ClusteringConfig {
1096 algorithm: ClusteringAlgorithm::KMeans,
1097 num_clusters: Some(2),
1098 random_seed: Some(42),
1099 distance_metric: SimilarityMetric::Euclidean, ..Default::default()
1101 };
1102
1103 let engine = ClusteringEngine::new(config);
1104
1105 let resources = vec![
1106 ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1107 ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1108 ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1109 ("res4".to_string(), Vector::new(vec![10.1, 10.1, 10.1])),
1110 ];
1111
1112 let result = engine.cluster(&resources).unwrap();
1113
1114 assert_eq!(result.clusters.len(), 2);
1115 assert!(result.noise.is_empty());
1116 }
1117
1118 #[test]
1119 fn test_dbscan_clustering() {
1120 let config = ClusteringConfig {
1121 algorithm: ClusteringAlgorithm::DBSCAN,
1122 similarity_threshold: 0.9,
1123 min_cluster_size: 2,
1124 ..Default::default()
1125 };
1126
1127 let engine = ClusteringEngine::new(config);
1128
1129 let resources = vec![
1130 ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1131 ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1132 ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1133 ];
1134
1135 let result = engine.cluster(&resources).unwrap();
1136 assert!(result.clusters.len() <= 2);
1137 }
1138
1139 #[test]
1140 fn test_similarity_clustering() {
1141 let config = ClusteringConfig {
1142 algorithm: ClusteringAlgorithm::Similarity,
1143 similarity_threshold: 0.95,
1144 ..Default::default()
1145 };
1146
1147 let engine = ClusteringEngine::new(config);
1148
1149 let resources = vec![
1150 ("res1".to_string(), Vector::new(vec![1.0, 0.0, 0.0])),
1151 ("res2".to_string(), Vector::new(vec![0.0, 1.0, 0.0])),
1152 ("res3".to_string(), Vector::new(vec![0.0, 0.0, 1.0])),
1153 ];
1154
1155 let result = engine.cluster(&resources).unwrap();
1156 assert_eq!(result.clusters.len(), 3);
1158 }
1159}