oxirs_vec/
clustering.rs

1//! Advanced clustering algorithms for vector similarity and resource grouping
2//!
3//! This module provides various clustering algorithms for the SPARQL vec:cluster function:
4//! - K-means clustering
5//! - DBSCAN (Density-Based Spatial Clustering)
6//! - Hierarchical clustering (Agglomerative)
7//! - Spectral clustering
8//! - Community detection for graph clustering
9
10use crate::{similarity::SimilarityMetric, Vector};
11use anyhow::{anyhow, Result};
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16/// Clustering algorithm types
17#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
18pub enum ClusteringAlgorithm {
19    /// K-means clustering
20    KMeans,
21    /// DBSCAN density-based clustering
22    DBSCAN,
23    /// Hierarchical agglomerative clustering
24    Hierarchical,
25    /// Spectral clustering
26    Spectral,
27    /// Community detection (for graph-based clustering)
28    Community,
29    /// Threshold-based similarity clustering
30    Similarity,
31}
32
33/// Clustering configuration
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ClusteringConfig {
36    /// Clustering algorithm to use
37    pub algorithm: ClusteringAlgorithm,
38    /// Number of clusters (for k-means, spectral)
39    pub num_clusters: Option<usize>,
40    /// Similarity threshold (for DBSCAN, similarity clustering)
41    pub similarity_threshold: f32,
42    /// Minimum cluster size (for DBSCAN)
43    pub min_cluster_size: usize,
44    /// Distance metric to use
45    pub distance_metric: SimilarityMetric,
46    /// Maximum iterations (for iterative algorithms)
47    pub max_iterations: usize,
48    /// Random seed for reproducibility
49    pub random_seed: Option<u64>,
50    /// Convergence tolerance
51    pub tolerance: f32,
52    /// Linkage criterion for hierarchical clustering
53    pub linkage: LinkageCriterion,
54}
55
56impl Default for ClusteringConfig {
57    fn default() -> Self {
58        Self {
59            algorithm: ClusteringAlgorithm::KMeans,
60            num_clusters: Some(3),
61            similarity_threshold: 0.7,
62            min_cluster_size: 3,
63            distance_metric: SimilarityMetric::Cosine,
64            max_iterations: 100,
65            random_seed: None,
66            tolerance: 1e-4,
67            linkage: LinkageCriterion::Average,
68        }
69    }
70}
71
72/// Linkage criteria for hierarchical clustering
73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub enum LinkageCriterion {
75    /// Single linkage (minimum distance)
76    Single,
77    /// Complete linkage (maximum distance)
78    Complete,
79    /// Average linkage (average distance)
80    Average,
81    /// Ward linkage (minimize within-cluster variance)
82    Ward,
83}
84
85/// Cluster result
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Cluster {
88    /// Cluster ID
89    pub id: usize,
90    /// Resource IDs in this cluster
91    pub members: Vec<String>,
92    /// Cluster centroid (if applicable)
93    pub centroid: Option<Vector>,
94    /// Cluster statistics
95    pub stats: ClusterStats,
96}
97
98/// Cluster statistics
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ClusterStats {
101    /// Number of members
102    pub size: usize,
103    /// Average intra-cluster similarity
104    pub avg_intra_similarity: f32,
105    /// Cluster density (for DBSCAN)
106    pub density: f32,
107    /// Silhouette score for this cluster
108    pub silhouette_score: f32,
109}
110
111/// Clustering result
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ClusteringResult {
114    /// All clusters found
115    pub clusters: Vec<Cluster>,
116    /// Noise points (for DBSCAN)
117    pub noise: Vec<String>,
118    /// Overall clustering quality metrics
119    pub quality_metrics: ClusteringQualityMetrics,
120    /// Algorithm used
121    pub algorithm: ClusteringAlgorithm,
122    /// Configuration used
123    pub config: ClusteringConfig,
124}
125
126/// Quality metrics for clustering results
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ClusteringQualityMetrics {
129    /// Silhouette score (-1 to 1, higher is better)
130    pub silhouette_score: f32,
131    /// Davies-Bouldin index (lower is better)
132    pub davies_bouldin_index: f32,
133    /// Calinski-Harabasz index (higher is better)
134    pub calinski_harabasz_index: f32,
135    /// Within-cluster sum of squares
136    pub within_cluster_ss: f32,
137    /// Between-cluster sum of squares
138    pub between_cluster_ss: f32,
139}
140
141/// Main clustering engine
142pub struct ClusteringEngine {
143    config: ClusteringConfig,
144}
145
146impl ClusteringEngine {
147    pub fn new(config: ClusteringConfig) -> Self {
148        Self { config }
149    }
150
151    /// Cluster a set of resources with their embeddings
152    pub fn cluster(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
153        if resources.is_empty() {
154            return Ok(ClusteringResult {
155                clusters: Vec::new(),
156                noise: Vec::new(),
157                quality_metrics: ClusteringQualityMetrics::default(),
158                algorithm: self.config.algorithm,
159                config: self.config.clone(),
160            });
161        }
162
163        match self.config.algorithm {
164            ClusteringAlgorithm::KMeans => self.kmeans_clustering(resources),
165            ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(resources),
166            ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(resources),
167            ClusteringAlgorithm::Spectral => self.spectral_clustering(resources),
168            ClusteringAlgorithm::Community => self.community_detection(resources),
169            ClusteringAlgorithm::Similarity => self.similarity_clustering(resources),
170        }
171    }
172
173    /// K-means clustering implementation
174    fn kmeans_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
175        let k = self.config.num_clusters.unwrap_or(3);
176        if k >= resources.len() {
177            return Err(anyhow!(
178                "Number of clusters must be less than number of resources"
179            ));
180        }
181
182        let mut rng = if let Some(seed) = self.config.random_seed {
183            Random::seed(seed)
184        } else {
185            Random::seed(42)
186        };
187
188        // Initialize centroids using k-means++
189        let mut centroids = self.initialize_centroids_kmeans_plus_plus(resources, k, &mut rng)?;
190        let mut assignments = vec![0; resources.len()];
191        let mut prev_assignments = vec![usize::MAX; resources.len()];
192
193        for iteration in 0..self.config.max_iterations {
194            // Assign points to closest centroids
195            for (i, (_, vector)) in resources.iter().enumerate() {
196                let mut best_cluster = 0;
197                let mut best_distance = f32::INFINITY;
198
199                for (cluster_id, centroid) in centroids.iter().enumerate() {
200                    let distance = self.calculate_distance(vector, centroid)?;
201                    if distance < best_distance {
202                        best_distance = distance;
203                        best_cluster = cluster_id;
204                    }
205                }
206                assignments[i] = best_cluster;
207            }
208
209            // Check for convergence
210            if assignments == prev_assignments {
211                break;
212            }
213
214            // Update centroids
215            for (cluster_id, centroid) in centroids.iter_mut().enumerate().take(k) {
216                let cluster_vectors: Vec<&Vector> = resources
217                    .iter()
218                    .enumerate()
219                    .filter(|(i, _)| assignments[*i] == cluster_id)
220                    .map(|(_, (_, vector))| vector)
221                    .collect();
222
223                if !cluster_vectors.is_empty() {
224                    *centroid = self.compute_centroid(&cluster_vectors)?;
225                }
226            }
227
228            prev_assignments = assignments.clone();
229
230            if iteration > 0 && iteration % 10 == 0 {
231                println!(
232                    "K-means iteration {}/{}",
233                    iteration, self.config.max_iterations
234                );
235            }
236        }
237
238        // Build clusters from assignments
239        let mut clusters = Vec::new();
240        for (cluster_id, centroid) in centroids.iter().enumerate().take(k) {
241            let members: Vec<String> = resources
242                .iter()
243                .enumerate()
244                .filter(|(i, _)| assignments[*i] == cluster_id)
245                .map(|(_, (resource_id, _))| resource_id.clone())
246                .collect();
247
248            if !members.is_empty() {
249                let cluster_vectors: Vec<&Vector> = resources
250                    .iter()
251                    .enumerate()
252                    .filter(|(i, _)| assignments[*i] == cluster_id)
253                    .map(|(_, (_, vector))| vector)
254                    .collect();
255
256                let stats = self.compute_cluster_stats(&cluster_vectors)?;
257
258                clusters.push(Cluster {
259                    id: cluster_id,
260                    members,
261                    centroid: Some(centroid.clone()),
262                    stats,
263                });
264            }
265        }
266
267        let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
268
269        Ok(ClusteringResult {
270            clusters,
271            noise: Vec::new(),
272            quality_metrics,
273            algorithm: ClusteringAlgorithm::KMeans,
274            config: self.config.clone(),
275        })
276    }
277
278    /// DBSCAN clustering implementation
279    fn dbscan_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
280        let eps = 1.0 - self.config.similarity_threshold; // Convert similarity to distance
281        let min_pts = self.config.min_cluster_size;
282
283        let mut visited = vec![false; resources.len()];
284        let mut cluster_assignments = vec![None; resources.len()];
285        let mut cluster_id = 0;
286        let mut noise_points = Vec::new();
287
288        for i in 0..resources.len() {
289            if visited[i] {
290                continue;
291            }
292            visited[i] = true;
293
294            let neighbors = self.find_neighbors(resources, i, eps)?;
295
296            if neighbors.len() < min_pts {
297                noise_points.push(resources[i].0.clone());
298            } else {
299                let mut cluster_queue = VecDeque::new();
300                cluster_queue.push_back(i);
301                cluster_assignments[i] = Some(cluster_id);
302
303                while let Some(point_idx) = cluster_queue.pop_front() {
304                    let point_neighbors = self.find_neighbors(resources, point_idx, eps)?;
305
306                    if point_neighbors.len() >= min_pts {
307                        for &neighbor_idx in &point_neighbors {
308                            if !visited[neighbor_idx] {
309                                visited[neighbor_idx] = true;
310                                cluster_queue.push_back(neighbor_idx);
311                            }
312                            if cluster_assignments[neighbor_idx].is_none() {
313                                cluster_assignments[neighbor_idx] = Some(cluster_id);
314                            }
315                        }
316                    }
317                }
318                cluster_id += 1;
319            }
320        }
321
322        // Build clusters from assignments
323        let mut clusters = Vec::new();
324        for cid in 0..cluster_id {
325            let members: Vec<String> = resources
326                .iter()
327                .enumerate()
328                .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
329                .map(|(_, (resource_id, _))| resource_id.clone())
330                .collect();
331
332            if !members.is_empty() {
333                let cluster_vectors: Vec<&Vector> = resources
334                    .iter()
335                    .enumerate()
336                    .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
337                    .map(|(_, (_, vector))| vector)
338                    .collect();
339
340                let stats = self.compute_cluster_stats(&cluster_vectors)?;
341                let centroid = if !cluster_vectors.is_empty() {
342                    Some(self.compute_centroid(&cluster_vectors)?)
343                } else {
344                    None
345                };
346
347                clusters.push(Cluster {
348                    id: cid,
349                    members,
350                    centroid,
351                    stats,
352                });
353            }
354        }
355
356        let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
357
358        Ok(ClusteringResult {
359            clusters,
360            noise: noise_points,
361            quality_metrics,
362            algorithm: ClusteringAlgorithm::DBSCAN,
363            config: self.config.clone(),
364        })
365    }
366
367    /// Hierarchical clustering implementation (agglomerative)
368    fn hierarchical_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
369        let target_clusters = self.config.num_clusters.unwrap_or(3);
370
371        // Initialize each point as its own cluster
372        let mut clusters: Vec<Vec<usize>> = (0..resources.len()).map(|i| vec![i]).collect();
373
374        // Compute initial distance matrix
375        let mut distance_matrix = self.compute_distance_matrix(resources)?;
376
377        // Merge clusters until we reach the target number
378        while clusters.len() > target_clusters {
379            let (min_i, min_j) = self.find_closest_clusters(&clusters, &distance_matrix)?;
380
381            // Merge clusters
382            let cluster_j = clusters.remove(min_j.max(min_i));
383            clusters[min_i.min(min_j)].extend(cluster_j);
384
385            // Update distance matrix
386            self.update_distance_matrix(
387                &mut distance_matrix,
388                &clusters,
389                min_i.min(min_j),
390                resources,
391            )?;
392        }
393
394        // Build final cluster results
395        let mut result_clusters = Vec::new();
396        for (cluster_id, cluster_indices) in clusters.iter().enumerate() {
397            let members: Vec<String> = cluster_indices
398                .iter()
399                .map(|&idx| resources[idx].0.clone())
400                .collect();
401
402            let cluster_vectors: Vec<&Vector> = cluster_indices
403                .iter()
404                .map(|&idx| &resources[idx].1)
405                .collect();
406
407            let stats = self.compute_cluster_stats(&cluster_vectors)?;
408            let centroid = if !cluster_vectors.is_empty() {
409                Some(self.compute_centroid(&cluster_vectors)?)
410            } else {
411                None
412            };
413
414            result_clusters.push(Cluster {
415                id: cluster_id,
416                members,
417                centroid,
418                stats,
419            });
420        }
421
422        let quality_metrics = self.compute_quality_metrics(resources, &result_clusters)?;
423
424        Ok(ClusteringResult {
425            clusters: result_clusters,
426            noise: Vec::new(),
427            quality_metrics,
428            algorithm: ClusteringAlgorithm::Hierarchical,
429            config: self.config.clone(),
430        })
431    }
432
433    /// Placeholder for spectral clustering
434    fn spectral_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
435        // For now, fall back to k-means
436        // TODO: Implement proper spectral clustering with eigenvalue decomposition
437        println!("Spectral clustering not yet fully implemented, falling back to k-means");
438        self.kmeans_clustering(resources)
439    }
440
441    /// Placeholder for community detection
442    fn community_detection(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
443        // For now, fall back to similarity clustering
444        // TODO: Implement graph-based community detection algorithms like Louvain
445        println!(
446            "Community detection not yet fully implemented, falling back to similarity clustering"
447        );
448        self.similarity_clustering(resources)
449    }
450
451    /// Simple similarity-based clustering
452    fn similarity_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
453        let threshold = self.config.similarity_threshold;
454        let mut clusters = Vec::new();
455        let mut assigned = vec![false; resources.len()];
456        let mut cluster_id = 0;
457
458        for i in 0..resources.len() {
459            if assigned[i] {
460                continue;
461            }
462
463            let mut cluster_members = vec![i];
464            assigned[i] = true;
465
466            // Find all similar vectors
467            for j in (i + 1)..resources.len() {
468                if assigned[j] {
469                    continue;
470                }
471
472                let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
473                if similarity >= threshold {
474                    cluster_members.push(j);
475                    assigned[j] = true;
476                }
477            }
478
479            let members: Vec<String> = cluster_members
480                .iter()
481                .map(|&idx| resources[idx].0.clone())
482                .collect();
483
484            let cluster_vectors: Vec<&Vector> = cluster_members
485                .iter()
486                .map(|&idx| &resources[idx].1)
487                .collect();
488
489            let stats = self.compute_cluster_stats(&cluster_vectors)?;
490            let centroid = if !cluster_vectors.is_empty() {
491                Some(self.compute_centroid(&cluster_vectors)?)
492            } else {
493                None
494            };
495
496            clusters.push(Cluster {
497                id: cluster_id,
498                members,
499                centroid,
500                stats,
501            });
502
503            cluster_id += 1;
504        }
505
506        let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
507
508        Ok(ClusteringResult {
509            clusters,
510            noise: Vec::new(),
511            quality_metrics,
512            algorithm: ClusteringAlgorithm::Similarity,
513            config: self.config.clone(),
514        })
515    }
516
517    // Helper methods
518
519    /// Initialize centroids using k-means++
520    #[allow(deprecated)]
521    fn initialize_centroids_kmeans_plus_plus(
522        &self,
523        resources: &[(String, Vector)],
524        k: usize,
525        rng: &mut impl Rng,
526    ) -> Result<Vec<Vector>> {
527        let mut centroids = Vec::new();
528
529        // Choose first centroid randomly
530        let first_idx = rng.gen_range(0..resources.len());
531        centroids.push(resources[first_idx].1.clone());
532
533        // Choose remaining centroids with probability proportional to squared distance
534        for _ in 1..k {
535            let mut distances = Vec::new();
536            let mut total_distance = 0.0;
537
538            for (_, vector) in resources {
539                let min_dist_sq = centroids
540                    .iter()
541                    .map(|centroid| {
542                        self.calculate_distance(vector, centroid)
543                            .unwrap_or(f32::INFINITY)
544                    })
545                    .fold(f32::INFINITY, f32::min)
546                    .powi(2);
547                distances.push(min_dist_sq);
548                total_distance += min_dist_sq;
549            }
550
551            let target = rng.random::<f32>() * total_distance;
552            let mut cumulative = 0.0;
553
554            for (i, &dist) in distances.iter().enumerate() {
555                cumulative += dist;
556                if cumulative >= target {
557                    centroids.push(resources[i].1.clone());
558                    break;
559                }
560            }
561        }
562
563        Ok(centroids)
564    }
565
566    /// Calculate distance between two vectors
567    fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
568        match self.config.distance_metric {
569            SimilarityMetric::Cosine => Ok(1.0 - v1.cosine_similarity(v2)?),
570            SimilarityMetric::Euclidean => v1.euclidean_distance(v2),
571            SimilarityMetric::Manhattan => v1.manhattan_distance(v2),
572            _ => Ok(1.0 - v1.cosine_similarity(v2)?), // Default to cosine
573        }
574    }
575
576    /// Calculate similarity between two vectors
577    fn calculate_similarity(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
578        match self.config.distance_metric {
579            SimilarityMetric::Cosine => v1.cosine_similarity(v2),
580            SimilarityMetric::Euclidean => {
581                let dist = v1.euclidean_distance(v2)?;
582                Ok(1.0 / (1.0 + dist))
583            }
584            SimilarityMetric::Manhattan => {
585                let dist = v1.manhattan_distance(v2)?;
586                Ok(1.0 / (1.0 + dist))
587            }
588            _ => v1.cosine_similarity(v2), // Default to cosine
589        }
590    }
591
592    /// Find neighbors within distance eps
593    fn find_neighbors(
594        &self,
595        resources: &[(String, Vector)],
596        point_idx: usize,
597        eps: f32,
598    ) -> Result<Vec<usize>> {
599        let mut neighbors = Vec::new();
600        let point = &resources[point_idx].1;
601
602        for (i, (_, vector)) in resources.iter().enumerate() {
603            if i != point_idx {
604                let distance = self.calculate_distance(point, vector)?;
605                if distance <= eps {
606                    neighbors.push(i);
607                }
608            }
609        }
610
611        Ok(neighbors)
612    }
613
614    /// Compute centroid of vectors
615    fn compute_centroid(&self, vectors: &[&Vector]) -> Result<Vector> {
616        if vectors.is_empty() {
617            return Err(anyhow!("Cannot compute centroid of empty vector set"));
618        }
619
620        let dim = vectors[0].dimensions;
621        let mut centroid_data = vec![0.0; dim];
622
623        for vector in vectors {
624            let data = vector.as_f32();
625            for (i, &value) in data.iter().enumerate() {
626                centroid_data[i] += value;
627            }
628        }
629
630        let count = vectors.len() as f32;
631        for value in &mut centroid_data {
632            *value /= count;
633        }
634
635        Ok(Vector::new(centroid_data))
636    }
637
638    /// Compute cluster statistics
639    fn compute_cluster_stats(&self, vectors: &[&Vector]) -> Result<ClusterStats> {
640        if vectors.is_empty() {
641            return Ok(ClusterStats {
642                size: 0,
643                avg_intra_similarity: 0.0,
644                density: 0.0,
645                silhouette_score: 0.0,
646            });
647        }
648
649        let size = vectors.len();
650        let mut total_similarity = 0.0;
651        let mut pair_count = 0;
652
653        // Calculate average intra-cluster similarity
654        for i in 0..vectors.len() {
655            for j in (i + 1)..vectors.len() {
656                let similarity = self.calculate_similarity(vectors[i], vectors[j])?;
657                total_similarity += similarity;
658                pair_count += 1;
659            }
660        }
661
662        let avg_intra_similarity = if pair_count > 0 {
663            total_similarity / pair_count as f32
664        } else {
665            1.0 // Single point cluster
666        };
667
668        Ok(ClusterStats {
669            size,
670            avg_intra_similarity,
671            density: avg_intra_similarity, // Simplified density measure
672            silhouette_score: 0.0, // Simplified for individual cluster stats - use quality metrics for full silhouette score
673        })
674    }
675
676    /// Compute distance matrix for hierarchical clustering
677    fn compute_distance_matrix(&self, resources: &[(String, Vector)]) -> Result<Vec<Vec<f32>>> {
678        let n = resources.len();
679        let mut matrix = vec![vec![0.0; n]; n];
680
681        for i in 0..n {
682            for j in (i + 1)..n {
683                let distance = self.calculate_distance(&resources[i].1, &resources[j].1)?;
684                matrix[i][j] = distance;
685                matrix[j][i] = distance;
686            }
687        }
688
689        Ok(matrix)
690    }
691
692    /// Find closest clusters for hierarchical clustering
693    fn find_closest_clusters(
694        &self,
695        clusters: &[Vec<usize>],
696        distance_matrix: &[Vec<f32>],
697    ) -> Result<(usize, usize)> {
698        let mut min_distance = f32::INFINITY;
699        let mut closest_pair = (0, 1);
700
701        for i in 0..clusters.len() {
702            for j in (i + 1)..clusters.len() {
703                let distance = self.cluster_distance(&clusters[i], &clusters[j], distance_matrix);
704                if distance < min_distance {
705                    min_distance = distance;
706                    closest_pair = (i, j);
707                }
708            }
709        }
710
711        Ok(closest_pair)
712    }
713
714    /// Calculate distance between clusters based on linkage criterion
715    fn cluster_distance(
716        &self,
717        cluster1: &[usize],
718        cluster2: &[usize],
719        distance_matrix: &[Vec<f32>],
720    ) -> f32 {
721        match self.config.linkage {
722            LinkageCriterion::Single => {
723                // Minimum distance
724                cluster1
725                    .iter()
726                    .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
727                    .fold(f32::INFINITY, f32::min)
728            }
729            LinkageCriterion::Complete => {
730                // Maximum distance
731                cluster1
732                    .iter()
733                    .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
734                    .fold(0.0, f32::max)
735            }
736            LinkageCriterion::Average => {
737                // Average distance
738                let mut total = 0.0;
739                let mut count = 0;
740                for &i in cluster1 {
741                    for &j in cluster2 {
742                        total += distance_matrix[i][j];
743                        count += 1;
744                    }
745                }
746                if count > 0 {
747                    total / count as f32
748                } else {
749                    0.0
750                }
751            }
752            LinkageCriterion::Ward => {
753                // Simplified Ward linkage (should consider cluster variance)
754                self.cluster_distance(cluster1, cluster2, distance_matrix)
755            }
756        }
757    }
758
759    /// Update distance matrix after merging clusters
760    fn update_distance_matrix(
761        &self,
762        distance_matrix: &mut Vec<Vec<f32>>,
763        _clusters: &[Vec<usize>],
764        _merged_cluster: usize,
765        resources: &[(String, Vector)],
766    ) -> Result<()> {
767        // Simplified update - could be more efficient
768        let new_matrix = self.compute_distance_matrix(resources)?;
769        *distance_matrix = new_matrix;
770        Ok(())
771    }
772
773    /// Compute clustering quality metrics
774    fn compute_quality_metrics(
775        &self,
776        resources: &[(String, Vector)],
777        clusters: &[Cluster],
778    ) -> Result<ClusteringQualityMetrics> {
779        // Simplified quality metrics - in practice these would be more sophisticated
780        let mut within_cluster_ss = 0.0;
781        let mut silhouette_scores = Vec::new();
782
783        for cluster in clusters {
784            if cluster.members.len() > 1 {
785                let cluster_vectors: Vec<&Vector> = cluster
786                    .members
787                    .iter()
788                    .filter_map(|member| {
789                        resources
790                            .iter()
791                            .find(|(id, _)| id == member)
792                            .map(|(_, v)| v)
793                    })
794                    .collect();
795
796                if let Some(ref centroid) = cluster.centroid {
797                    for vector in &cluster_vectors {
798                        let dist = self.calculate_distance(vector, centroid)?;
799                        within_cluster_ss += dist * dist;
800                    }
801                }
802            }
803        }
804
805        // Calculate silhouette scores for all points
806        for (cluster_idx, cluster) in clusters.iter().enumerate() {
807            let cluster_vectors: Vec<(usize, &Vector)> = cluster
808                .members
809                .iter()
810                .filter_map(|member| {
811                    resources
812                        .iter()
813                        .enumerate()
814                        .find(|(_, (id, _))| id == member)
815                        .map(|(idx, (_, v))| (idx, v))
816                })
817                .collect();
818
819            // For each point in this cluster
820            for (point_idx, point_vector) in &cluster_vectors {
821                if cluster_vectors.len() <= 1 {
822                    // Single point cluster gets silhouette score of 0
823                    silhouette_scores.push(0.0);
824                    continue;
825                }
826
827                // Calculate average distance to other points in same cluster (a)
828                let mut intra_cluster_dist = 0.0;
829                let mut intra_count = 0;
830                for (other_idx, other_vector) in &cluster_vectors {
831                    if point_idx != other_idx {
832                        let dist = self.calculate_distance(point_vector, other_vector)?;
833                        intra_cluster_dist += dist;
834                        intra_count += 1;
835                    }
836                }
837                let a = if intra_count > 0 {
838                    intra_cluster_dist / intra_count as f32
839                } else {
840                    0.0
841                };
842
843                // Calculate minimum average distance to points in other clusters (b)
844                let mut min_inter_cluster_dist = f32::INFINITY;
845                for (other_cluster_idx, other_cluster) in clusters.iter().enumerate() {
846                    if cluster_idx != other_cluster_idx {
847                        let other_cluster_vectors: Vec<&Vector> = other_cluster
848                            .members
849                            .iter()
850                            .filter_map(|member| {
851                                resources
852                                    .iter()
853                                    .find(|(id, _)| id == member)
854                                    .map(|(_, v)| v)
855                            })
856                            .collect();
857
858                        if !other_cluster_vectors.is_empty() {
859                            let mut inter_cluster_dist = 0.0;
860                            for other_vector in &other_cluster_vectors {
861                                let dist = self.calculate_distance(point_vector, other_vector)?;
862                                inter_cluster_dist += dist;
863                            }
864                            let avg_dist = inter_cluster_dist / other_cluster_vectors.len() as f32;
865                            min_inter_cluster_dist = min_inter_cluster_dist.min(avg_dist);
866                        }
867                    }
868                }
869                let b = min_inter_cluster_dist;
870
871                // Calculate silhouette score for this point
872                let silhouette = if a.max(b) > 0.0 {
873                    (b - a) / a.max(b)
874                } else {
875                    0.0
876                };
877                silhouette_scores.push(silhouette);
878            }
879        }
880
881        let silhouette_score = if !silhouette_scores.is_empty() {
882            silhouette_scores.iter().sum::<f32>() / silhouette_scores.len() as f32
883        } else {
884            0.0
885        };
886
887        // Calculate Davies-Bouldin Index
888        let davies_bouldin_index = self.calculate_davies_bouldin_index(resources, clusters)?;
889
890        // Calculate Calinski-Harabasz Index
891        let calinski_harabasz_index =
892            self.calculate_calinski_harabasz_index(resources, clusters, within_cluster_ss)?;
893
894        // Calculate between-cluster sum of squares
895        let between_cluster_ss = self.calculate_between_cluster_ss(resources, clusters)?;
896
897        Ok(ClusteringQualityMetrics {
898            silhouette_score,
899            davies_bouldin_index,
900            calinski_harabasz_index,
901            within_cluster_ss,
902            between_cluster_ss,
903        })
904    }
905
906    /// Calculate Davies-Bouldin Index (lower is better)
907    fn calculate_davies_bouldin_index(
908        &self,
909        resources: &[(String, Vector)],
910        clusters: &[Cluster],
911    ) -> Result<f32> {
912        if clusters.len() <= 1 {
913            return Ok(0.0);
914        }
915
916        let mut db_sum = 0.0;
917        for i in 0..clusters.len() {
918            let mut max_ratio: f32 = 0.0;
919
920            // Get vectors for cluster i
921            let cluster_i_vectors: Vec<&Vector> = clusters[i]
922                .members
923                .iter()
924                .filter_map(|member| {
925                    resources
926                        .iter()
927                        .find(|(id, _)| id == member)
928                        .map(|(_, v)| v)
929                })
930                .collect();
931
932            if cluster_i_vectors.is_empty() {
933                continue;
934            }
935
936            // Calculate centroid for cluster i
937            let centroid_i = self.compute_centroid(&cluster_i_vectors)?;
938
939            // Calculate average distance to centroid for cluster i
940            let mut avg_dist_i = 0.0;
941            for vector in &cluster_i_vectors {
942                avg_dist_i += self.calculate_distance(vector, &centroid_i)?;
943            }
944            avg_dist_i /= cluster_i_vectors.len() as f32;
945
946            for (j, cluster_j) in clusters.iter().enumerate() {
947                if i == j {
948                    continue;
949                }
950
951                // Get vectors for cluster j
952                let cluster_j_vectors: Vec<&Vector> = cluster_j
953                    .members
954                    .iter()
955                    .filter_map(|member| {
956                        resources
957                            .iter()
958                            .find(|(id, _)| id == member)
959                            .map(|(_, v)| v)
960                    })
961                    .collect();
962
963                if cluster_j_vectors.is_empty() {
964                    continue;
965                }
966
967                // Calculate centroid for cluster j
968                let centroid_j = self.compute_centroid(&cluster_j_vectors)?;
969
970                // Calculate average distance to centroid for cluster j
971                let mut avg_dist_j = 0.0;
972                for vector in &cluster_j_vectors {
973                    avg_dist_j += self.calculate_distance(vector, &centroid_j)?;
974                }
975                avg_dist_j /= cluster_j_vectors.len() as f32;
976
977                // Calculate distance between centroids
978                let centroid_distance = self.calculate_distance(&centroid_i, &centroid_j)?;
979
980                // Calculate Davies-Bouldin ratio
981                if centroid_distance > 0.0 {
982                    let ratio: f32 = (avg_dist_i + avg_dist_j) / centroid_distance;
983                    max_ratio = max_ratio.max(ratio);
984                }
985            }
986            db_sum += max_ratio;
987        }
988
989        Ok(db_sum / clusters.len() as f32)
990    }
991
992    /// Calculate Calinski-Harabasz Index (higher is better)
993    fn calculate_calinski_harabasz_index(
994        &self,
995        resources: &[(String, Vector)],
996        clusters: &[Cluster],
997        within_cluster_ss: f32,
998    ) -> Result<f32> {
999        if clusters.len() <= 1 || resources.is_empty() {
1000            return Ok(0.0);
1001        }
1002
1003        // Calculate overall centroid
1004        let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1005        let overall_centroid = self.compute_centroid(&all_vectors)?;
1006
1007        // Calculate between-cluster sum of squares
1008        let mut between_cluster_ss = 0.0;
1009        for cluster in clusters {
1010            let cluster_vectors: Vec<&Vector> = cluster
1011                .members
1012                .iter()
1013                .filter_map(|member| {
1014                    resources
1015                        .iter()
1016                        .find(|(id, _)| id == member)
1017                        .map(|(_, v)| v)
1018                })
1019                .collect();
1020
1021            if !cluster_vectors.is_empty() {
1022                let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1023                let distance_sq = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1024                between_cluster_ss += cluster_vectors.len() as f32 * distance_sq * distance_sq;
1025            }
1026        }
1027
1028        // Calinski-Harabasz = (between_cluster_ss / (k-1)) / (within_cluster_ss / (n-k))
1029        let k = clusters.len() as f32;
1030        let n = resources.len() as f32;
1031
1032        if k >= n || within_cluster_ss <= 0.0 {
1033            return Ok(0.0);
1034        }
1035
1036        let ch_index = (between_cluster_ss / (k - 1.0)) / (within_cluster_ss / (n - k));
1037        Ok(ch_index)
1038    }
1039
1040    /// Calculate between-cluster sum of squares
1041    fn calculate_between_cluster_ss(
1042        &self,
1043        resources: &[(String, Vector)],
1044        clusters: &[Cluster],
1045    ) -> Result<f32> {
1046        if clusters.is_empty() || resources.is_empty() {
1047            return Ok(0.0);
1048        }
1049
1050        // Calculate overall centroid
1051        let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1052        let overall_centroid = self.compute_centroid(&all_vectors)?;
1053
1054        let mut between_cluster_ss = 0.0;
1055        for cluster in clusters {
1056            let cluster_vectors: Vec<&Vector> = cluster
1057                .members
1058                .iter()
1059                .filter_map(|member| {
1060                    resources
1061                        .iter()
1062                        .find(|(id, _)| id == member)
1063                        .map(|(_, v)| v)
1064                })
1065                .collect();
1066
1067            if !cluster_vectors.is_empty() {
1068                let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1069                let distance = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1070                between_cluster_ss += cluster_vectors.len() as f32 * distance * distance;
1071            }
1072        }
1073
1074        Ok(between_cluster_ss)
1075    }
1076}
1077
1078impl Default for ClusteringQualityMetrics {
1079    fn default() -> Self {
1080        Self {
1081            silhouette_score: 0.0,
1082            davies_bouldin_index: 0.0,
1083            calinski_harabasz_index: 0.0,
1084            within_cluster_ss: 0.0,
1085            between_cluster_ss: 0.0,
1086        }
1087    }
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092    use super::*;
1093
1094    #[test]
1095    fn test_kmeans_clustering() {
1096        let config = ClusteringConfig {
1097            algorithm: ClusteringAlgorithm::KMeans,
1098            num_clusters: Some(2),
1099            random_seed: Some(42),
1100            distance_metric: SimilarityMetric::Euclidean, // Use Euclidean for proper distance calculation
1101            ..Default::default()
1102        };
1103
1104        let engine = ClusteringEngine::new(config);
1105
1106        let resources = vec![
1107            ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1108            ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1109            ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1110            ("res4".to_string(), Vector::new(vec![10.1, 10.1, 10.1])),
1111        ];
1112
1113        let result = engine.cluster(&resources).unwrap();
1114
1115        assert_eq!(result.clusters.len(), 2);
1116        assert!(result.noise.is_empty());
1117    }
1118
1119    #[test]
1120    fn test_dbscan_clustering() {
1121        let config = ClusteringConfig {
1122            algorithm: ClusteringAlgorithm::DBSCAN,
1123            similarity_threshold: 0.9,
1124            min_cluster_size: 2,
1125            ..Default::default()
1126        };
1127
1128        let engine = ClusteringEngine::new(config);
1129
1130        let resources = vec![
1131            ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1132            ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1133            ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1134        ];
1135
1136        let result = engine.cluster(&resources).unwrap();
1137        assert!(result.clusters.len() <= 2);
1138    }
1139
1140    #[test]
1141    fn test_similarity_clustering() {
1142        let config = ClusteringConfig {
1143            algorithm: ClusteringAlgorithm::Similarity,
1144            similarity_threshold: 0.95,
1145            ..Default::default()
1146        };
1147
1148        let engine = ClusteringEngine::new(config);
1149
1150        let resources = vec![
1151            ("res1".to_string(), Vector::new(vec![1.0, 0.0, 0.0])),
1152            ("res2".to_string(), Vector::new(vec![0.0, 1.0, 0.0])),
1153            ("res3".to_string(), Vector::new(vec![0.0, 0.0, 1.0])),
1154        ];
1155
1156        let result = engine.cluster(&resources).unwrap();
1157        // Should have 3 clusters since vectors are orthogonal
1158        assert_eq!(result.clusters.len(), 3);
1159    }
1160}