oxirs_vec/
clustering.rs

1//! Advanced clustering algorithms for vector similarity and resource grouping
2//!
3//! This module provides various clustering algorithms for the SPARQL vec:cluster function:
4//! - K-means clustering
5//! - DBSCAN (Density-Based Spatial Clustering)
6//! - Hierarchical clustering (Agglomerative)
7//! - Spectral clustering
8//! - Community detection for graph clustering
9
10use crate::{similarity::SimilarityMetric, Vector};
11use anyhow::{anyhow, Result};
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16/// Clustering algorithm types
17#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
18pub enum ClusteringAlgorithm {
19    /// K-means clustering
20    KMeans,
21    /// DBSCAN density-based clustering
22    DBSCAN,
23    /// Hierarchical agglomerative clustering
24    Hierarchical,
25    /// Spectral clustering
26    Spectral,
27    /// Community detection (for graph-based clustering)
28    Community,
29    /// Threshold-based similarity clustering
30    Similarity,
31}
32
33/// Clustering configuration
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ClusteringConfig {
36    /// Clustering algorithm to use
37    pub algorithm: ClusteringAlgorithm,
38    /// Number of clusters (for k-means, spectral)
39    pub num_clusters: Option<usize>,
40    /// Similarity threshold (for DBSCAN, similarity clustering)
41    pub similarity_threshold: f32,
42    /// Minimum cluster size (for DBSCAN)
43    pub min_cluster_size: usize,
44    /// Distance metric to use
45    pub distance_metric: SimilarityMetric,
46    /// Maximum iterations (for iterative algorithms)
47    pub max_iterations: usize,
48    /// Random seed for reproducibility
49    pub random_seed: Option<u64>,
50    /// Convergence tolerance
51    pub tolerance: f32,
52    /// Linkage criterion for hierarchical clustering
53    pub linkage: LinkageCriterion,
54}
55
56impl Default for ClusteringConfig {
57    fn default() -> Self {
58        Self {
59            algorithm: ClusteringAlgorithm::KMeans,
60            num_clusters: Some(3),
61            similarity_threshold: 0.7,
62            min_cluster_size: 3,
63            distance_metric: SimilarityMetric::Cosine,
64            max_iterations: 100,
65            random_seed: None,
66            tolerance: 1e-4,
67            linkage: LinkageCriterion::Average,
68        }
69    }
70}
71
72/// Linkage criteria for hierarchical clustering
73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub enum LinkageCriterion {
75    /// Single linkage (minimum distance)
76    Single,
77    /// Complete linkage (maximum distance)
78    Complete,
79    /// Average linkage (average distance)
80    Average,
81    /// Ward linkage (minimize within-cluster variance)
82    Ward,
83}
84
85/// Cluster result
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Cluster {
88    /// Cluster ID
89    pub id: usize,
90    /// Resource IDs in this cluster
91    pub members: Vec<String>,
92    /// Cluster centroid (if applicable)
93    pub centroid: Option<Vector>,
94    /// Cluster statistics
95    pub stats: ClusterStats,
96}
97
98/// Cluster statistics
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ClusterStats {
101    /// Number of members
102    pub size: usize,
103    /// Average intra-cluster similarity
104    pub avg_intra_similarity: f32,
105    /// Cluster density (for DBSCAN)
106    pub density: f32,
107    /// Silhouette score for this cluster
108    pub silhouette_score: f32,
109}
110
111/// Clustering result
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ClusteringResult {
114    /// All clusters found
115    pub clusters: Vec<Cluster>,
116    /// Noise points (for DBSCAN)
117    pub noise: Vec<String>,
118    /// Overall clustering quality metrics
119    pub quality_metrics: ClusteringQualityMetrics,
120    /// Algorithm used
121    pub algorithm: ClusteringAlgorithm,
122    /// Configuration used
123    pub config: ClusteringConfig,
124}
125
126/// Quality metrics for clustering results
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ClusteringQualityMetrics {
129    /// Silhouette score (-1 to 1, higher is better)
130    pub silhouette_score: f32,
131    /// Davies-Bouldin index (lower is better)
132    pub davies_bouldin_index: f32,
133    /// Calinski-Harabasz index (higher is better)
134    pub calinski_harabasz_index: f32,
135    /// Within-cluster sum of squares
136    pub within_cluster_ss: f32,
137    /// Between-cluster sum of squares
138    pub between_cluster_ss: f32,
139}
140
141/// Main clustering engine
142pub struct ClusteringEngine {
143    config: ClusteringConfig,
144}
145
146impl ClusteringEngine {
147    pub fn new(config: ClusteringConfig) -> Self {
148        Self { config }
149    }
150
151    /// Cluster a set of resources with their embeddings
152    pub fn cluster(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
153        if resources.is_empty() {
154            return Ok(ClusteringResult {
155                clusters: Vec::new(),
156                noise: Vec::new(),
157                quality_metrics: ClusteringQualityMetrics::default(),
158                algorithm: self.config.algorithm,
159                config: self.config.clone(),
160            });
161        }
162
163        match self.config.algorithm {
164            ClusteringAlgorithm::KMeans => self.kmeans_clustering(resources),
165            ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(resources),
166            ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(resources),
167            ClusteringAlgorithm::Spectral => self.spectral_clustering(resources),
168            ClusteringAlgorithm::Community => self.community_detection(resources),
169            ClusteringAlgorithm::Similarity => self.similarity_clustering(resources),
170        }
171    }
172
173    /// K-means clustering implementation
174    fn kmeans_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
175        let k = self.config.num_clusters.unwrap_or(3);
176        if k >= resources.len() {
177            return Err(anyhow!(
178                "Number of clusters must be less than number of resources"
179            ));
180        }
181
182        let mut rng = if let Some(seed) = self.config.random_seed {
183            Random::seed(seed)
184        } else {
185            Random::seed(42)
186        };
187
188        // Initialize centroids using k-means++
189        let mut centroids = self.initialize_centroids_kmeans_plus_plus(resources, k, &mut rng)?;
190        let mut assignments = vec![0; resources.len()];
191        let mut prev_assignments = vec![usize::MAX; resources.len()];
192
193        for iteration in 0..self.config.max_iterations {
194            // Assign points to closest centroids
195            for (i, (_, vector)) in resources.iter().enumerate() {
196                let mut best_cluster = 0;
197                let mut best_distance = f32::INFINITY;
198
199                for (cluster_id, centroid) in centroids.iter().enumerate() {
200                    let distance = self.calculate_distance(vector, centroid)?;
201                    if distance < best_distance {
202                        best_distance = distance;
203                        best_cluster = cluster_id;
204                    }
205                }
206                assignments[i] = best_cluster;
207            }
208
209            // Check for convergence
210            if assignments == prev_assignments {
211                break;
212            }
213
214            // Update centroids
215            for (cluster_id, centroid) in centroids.iter_mut().enumerate().take(k) {
216                let cluster_vectors: Vec<&Vector> = resources
217                    .iter()
218                    .enumerate()
219                    .filter(|(i, _)| assignments[*i] == cluster_id)
220                    .map(|(_, (_, vector))| vector)
221                    .collect();
222
223                if !cluster_vectors.is_empty() {
224                    *centroid = self.compute_centroid(&cluster_vectors)?;
225                }
226            }
227
228            prev_assignments = assignments.clone();
229
230            if iteration > 0 && iteration % 10 == 0 {
231                println!(
232                    "K-means iteration {}/{}",
233                    iteration, self.config.max_iterations
234                );
235            }
236        }
237
238        // Build clusters from assignments
239        let mut clusters = Vec::new();
240        for (cluster_id, centroid) in centroids.iter().enumerate().take(k) {
241            let members: Vec<String> = resources
242                .iter()
243                .enumerate()
244                .filter(|(i, _)| assignments[*i] == cluster_id)
245                .map(|(_, (resource_id, _))| resource_id.clone())
246                .collect();
247
248            if !members.is_empty() {
249                let cluster_vectors: Vec<&Vector> = resources
250                    .iter()
251                    .enumerate()
252                    .filter(|(i, _)| assignments[*i] == cluster_id)
253                    .map(|(_, (_, vector))| vector)
254                    .collect();
255
256                let stats = self.compute_cluster_stats(&cluster_vectors)?;
257
258                clusters.push(Cluster {
259                    id: cluster_id,
260                    members,
261                    centroid: Some(centroid.clone()),
262                    stats,
263                });
264            }
265        }
266
267        let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
268
269        Ok(ClusteringResult {
270            clusters,
271            noise: Vec::new(),
272            quality_metrics,
273            algorithm: ClusteringAlgorithm::KMeans,
274            config: self.config.clone(),
275        })
276    }
277
278    /// DBSCAN clustering implementation
279    fn dbscan_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
280        let eps = 1.0 - self.config.similarity_threshold; // Convert similarity to distance
281        let min_pts = self.config.min_cluster_size;
282
283        let mut visited = vec![false; resources.len()];
284        let mut cluster_assignments = vec![None; resources.len()];
285        let mut cluster_id = 0;
286        let mut noise_points = Vec::new();
287
288        for i in 0..resources.len() {
289            if visited[i] {
290                continue;
291            }
292            visited[i] = true;
293
294            let neighbors = self.find_neighbors(resources, i, eps)?;
295
296            if neighbors.len() < min_pts {
297                noise_points.push(resources[i].0.clone());
298            } else {
299                let mut cluster_queue = VecDeque::new();
300                cluster_queue.push_back(i);
301                cluster_assignments[i] = Some(cluster_id);
302
303                while let Some(point_idx) = cluster_queue.pop_front() {
304                    let point_neighbors = self.find_neighbors(resources, point_idx, eps)?;
305
306                    if point_neighbors.len() >= min_pts {
307                        for &neighbor_idx in &point_neighbors {
308                            if !visited[neighbor_idx] {
309                                visited[neighbor_idx] = true;
310                                cluster_queue.push_back(neighbor_idx);
311                            }
312                            if cluster_assignments[neighbor_idx].is_none() {
313                                cluster_assignments[neighbor_idx] = Some(cluster_id);
314                            }
315                        }
316                    }
317                }
318                cluster_id += 1;
319            }
320        }
321
322        // Build clusters from assignments
323        let mut clusters = Vec::new();
324        for cid in 0..cluster_id {
325            let members: Vec<String> = resources
326                .iter()
327                .enumerate()
328                .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
329                .map(|(_, (resource_id, _))| resource_id.clone())
330                .collect();
331
332            if !members.is_empty() {
333                let cluster_vectors: Vec<&Vector> = resources
334                    .iter()
335                    .enumerate()
336                    .filter(|(i, _)| cluster_assignments[*i] == Some(cid))
337                    .map(|(_, (_, vector))| vector)
338                    .collect();
339
340                let stats = self.compute_cluster_stats(&cluster_vectors)?;
341                let centroid = if !cluster_vectors.is_empty() {
342                    Some(self.compute_centroid(&cluster_vectors)?)
343                } else {
344                    None
345                };
346
347                clusters.push(Cluster {
348                    id: cid,
349                    members,
350                    centroid,
351                    stats,
352                });
353            }
354        }
355
356        let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
357
358        Ok(ClusteringResult {
359            clusters,
360            noise: noise_points,
361            quality_metrics,
362            algorithm: ClusteringAlgorithm::DBSCAN,
363            config: self.config.clone(),
364        })
365    }
366
367    /// Hierarchical clustering implementation (agglomerative)
368    fn hierarchical_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
369        let target_clusters = self.config.num_clusters.unwrap_or(3);
370
371        // Initialize each point as its own cluster
372        let mut clusters: Vec<Vec<usize>> = (0..resources.len()).map(|i| vec![i]).collect();
373
374        // Compute initial distance matrix
375        let mut distance_matrix = self.compute_distance_matrix(resources)?;
376
377        // Merge clusters until we reach the target number
378        while clusters.len() > target_clusters {
379            let (min_i, min_j) = self.find_closest_clusters(&clusters, &distance_matrix)?;
380
381            // Merge clusters
382            let cluster_j = clusters.remove(min_j.max(min_i));
383            clusters[min_i.min(min_j)].extend(cluster_j);
384
385            // Update distance matrix
386            self.update_distance_matrix(
387                &mut distance_matrix,
388                &clusters,
389                min_i.min(min_j),
390                resources,
391            )?;
392        }
393
394        // Build final cluster results
395        let mut result_clusters = Vec::new();
396        for (cluster_id, cluster_indices) in clusters.iter().enumerate() {
397            let members: Vec<String> = cluster_indices
398                .iter()
399                .map(|&idx| resources[idx].0.clone())
400                .collect();
401
402            let cluster_vectors: Vec<&Vector> = cluster_indices
403                .iter()
404                .map(|&idx| &resources[idx].1)
405                .collect();
406
407            let stats = self.compute_cluster_stats(&cluster_vectors)?;
408            let centroid = if !cluster_vectors.is_empty() {
409                Some(self.compute_centroid(&cluster_vectors)?)
410            } else {
411                None
412            };
413
414            result_clusters.push(Cluster {
415                id: cluster_id,
416                members,
417                centroid,
418                stats,
419            });
420        }
421
422        let quality_metrics = self.compute_quality_metrics(resources, &result_clusters)?;
423
424        Ok(ClusteringResult {
425            clusters: result_clusters,
426            noise: Vec::new(),
427            quality_metrics,
428            algorithm: ClusteringAlgorithm::Hierarchical,
429            config: self.config.clone(),
430        })
431    }
432
433    /// Placeholder for spectral clustering
434    fn spectral_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
435        // For now, fall back to k-means
436        // TODO: Implement proper spectral clustering with eigenvalue decomposition
437        println!("Spectral clustering not yet fully implemented, falling back to k-means");
438        self.kmeans_clustering(resources)
439    }
440
441    /// Placeholder for community detection
442    fn community_detection(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
443        // For now, fall back to similarity clustering
444        // TODO: Implement graph-based community detection algorithms like Louvain
445        println!(
446            "Community detection not yet fully implemented, falling back to similarity clustering"
447        );
448        self.similarity_clustering(resources)
449    }
450
451    /// Simple similarity-based clustering
452    fn similarity_clustering(&self, resources: &[(String, Vector)]) -> Result<ClusteringResult> {
453        let threshold = self.config.similarity_threshold;
454        let mut clusters = Vec::new();
455        let mut assigned = vec![false; resources.len()];
456        let mut cluster_id = 0;
457
458        for i in 0..resources.len() {
459            if assigned[i] {
460                continue;
461            }
462
463            let mut cluster_members = vec![i];
464            assigned[i] = true;
465
466            // Find all similar vectors
467            for j in (i + 1)..resources.len() {
468                if assigned[j] {
469                    continue;
470                }
471
472                let similarity = self.calculate_similarity(&resources[i].1, &resources[j].1)?;
473                if similarity >= threshold {
474                    cluster_members.push(j);
475                    assigned[j] = true;
476                }
477            }
478
479            let members: Vec<String> = cluster_members
480                .iter()
481                .map(|&idx| resources[idx].0.clone())
482                .collect();
483
484            let cluster_vectors: Vec<&Vector> = cluster_members
485                .iter()
486                .map(|&idx| &resources[idx].1)
487                .collect();
488
489            let stats = self.compute_cluster_stats(&cluster_vectors)?;
490            let centroid = if !cluster_vectors.is_empty() {
491                Some(self.compute_centroid(&cluster_vectors)?)
492            } else {
493                None
494            };
495
496            clusters.push(Cluster {
497                id: cluster_id,
498                members,
499                centroid,
500                stats,
501            });
502
503            cluster_id += 1;
504        }
505
506        let quality_metrics = self.compute_quality_metrics(resources, &clusters)?;
507
508        Ok(ClusteringResult {
509            clusters,
510            noise: Vec::new(),
511            quality_metrics,
512            algorithm: ClusteringAlgorithm::Similarity,
513            config: self.config.clone(),
514        })
515    }
516
517    // Helper methods
518
519    /// Initialize centroids using k-means++
520    fn initialize_centroids_kmeans_plus_plus(
521        &self,
522        resources: &[(String, Vector)],
523        k: usize,
524        rng: &mut impl Rng,
525    ) -> Result<Vec<Vector>> {
526        let mut centroids = Vec::new();
527
528        // Choose first centroid randomly
529        let first_idx = rng.gen_range(0..resources.len());
530        centroids.push(resources[first_idx].1.clone());
531
532        // Choose remaining centroids with probability proportional to squared distance
533        for _ in 1..k {
534            let mut distances = Vec::new();
535            let mut total_distance = 0.0;
536
537            for (_, vector) in resources {
538                let min_dist_sq = centroids
539                    .iter()
540                    .map(|centroid| {
541                        self.calculate_distance(vector, centroid)
542                            .unwrap_or(f32::INFINITY)
543                    })
544                    .fold(f32::INFINITY, f32::min)
545                    .powi(2);
546                distances.push(min_dist_sq);
547                total_distance += min_dist_sq;
548            }
549
550            let target = rng.gen::<f32>() * total_distance;
551            let mut cumulative = 0.0;
552
553            for (i, &dist) in distances.iter().enumerate() {
554                cumulative += dist;
555                if cumulative >= target {
556                    centroids.push(resources[i].1.clone());
557                    break;
558                }
559            }
560        }
561
562        Ok(centroids)
563    }
564
565    /// Calculate distance between two vectors
566    fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
567        match self.config.distance_metric {
568            SimilarityMetric::Cosine => Ok(1.0 - v1.cosine_similarity(v2)?),
569            SimilarityMetric::Euclidean => v1.euclidean_distance(v2),
570            SimilarityMetric::Manhattan => v1.manhattan_distance(v2),
571            _ => Ok(1.0 - v1.cosine_similarity(v2)?), // Default to cosine
572        }
573    }
574
575    /// Calculate similarity between two vectors
576    fn calculate_similarity(&self, v1: &Vector, v2: &Vector) -> Result<f32> {
577        match self.config.distance_metric {
578            SimilarityMetric::Cosine => v1.cosine_similarity(v2),
579            SimilarityMetric::Euclidean => {
580                let dist = v1.euclidean_distance(v2)?;
581                Ok(1.0 / (1.0 + dist))
582            }
583            SimilarityMetric::Manhattan => {
584                let dist = v1.manhattan_distance(v2)?;
585                Ok(1.0 / (1.0 + dist))
586            }
587            _ => v1.cosine_similarity(v2), // Default to cosine
588        }
589    }
590
591    /// Find neighbors within distance eps
592    fn find_neighbors(
593        &self,
594        resources: &[(String, Vector)],
595        point_idx: usize,
596        eps: f32,
597    ) -> Result<Vec<usize>> {
598        let mut neighbors = Vec::new();
599        let point = &resources[point_idx].1;
600
601        for (i, (_, vector)) in resources.iter().enumerate() {
602            if i != point_idx {
603                let distance = self.calculate_distance(point, vector)?;
604                if distance <= eps {
605                    neighbors.push(i);
606                }
607            }
608        }
609
610        Ok(neighbors)
611    }
612
613    /// Compute centroid of vectors
614    fn compute_centroid(&self, vectors: &[&Vector]) -> Result<Vector> {
615        if vectors.is_empty() {
616            return Err(anyhow!("Cannot compute centroid of empty vector set"));
617        }
618
619        let dim = vectors[0].dimensions;
620        let mut centroid_data = vec![0.0; dim];
621
622        for vector in vectors {
623            let data = vector.as_f32();
624            for (i, &value) in data.iter().enumerate() {
625                centroid_data[i] += value;
626            }
627        }
628
629        let count = vectors.len() as f32;
630        for value in &mut centroid_data {
631            *value /= count;
632        }
633
634        Ok(Vector::new(centroid_data))
635    }
636
637    /// Compute cluster statistics
638    fn compute_cluster_stats(&self, vectors: &[&Vector]) -> Result<ClusterStats> {
639        if vectors.is_empty() {
640            return Ok(ClusterStats {
641                size: 0,
642                avg_intra_similarity: 0.0,
643                density: 0.0,
644                silhouette_score: 0.0,
645            });
646        }
647
648        let size = vectors.len();
649        let mut total_similarity = 0.0;
650        let mut pair_count = 0;
651
652        // Calculate average intra-cluster similarity
653        for i in 0..vectors.len() {
654            for j in (i + 1)..vectors.len() {
655                let similarity = self.calculate_similarity(vectors[i], vectors[j])?;
656                total_similarity += similarity;
657                pair_count += 1;
658            }
659        }
660
661        let avg_intra_similarity = if pair_count > 0 {
662            total_similarity / pair_count as f32
663        } else {
664            1.0 // Single point cluster
665        };
666
667        Ok(ClusterStats {
668            size,
669            avg_intra_similarity,
670            density: avg_intra_similarity, // Simplified density measure
671            silhouette_score: 0.0, // Simplified for individual cluster stats - use quality metrics for full silhouette score
672        })
673    }
674
675    /// Compute distance matrix for hierarchical clustering
676    fn compute_distance_matrix(&self, resources: &[(String, Vector)]) -> Result<Vec<Vec<f32>>> {
677        let n = resources.len();
678        let mut matrix = vec![vec![0.0; n]; n];
679
680        for i in 0..n {
681            for j in (i + 1)..n {
682                let distance = self.calculate_distance(&resources[i].1, &resources[j].1)?;
683                matrix[i][j] = distance;
684                matrix[j][i] = distance;
685            }
686        }
687
688        Ok(matrix)
689    }
690
691    /// Find closest clusters for hierarchical clustering
692    fn find_closest_clusters(
693        &self,
694        clusters: &[Vec<usize>],
695        distance_matrix: &[Vec<f32>],
696    ) -> Result<(usize, usize)> {
697        let mut min_distance = f32::INFINITY;
698        let mut closest_pair = (0, 1);
699
700        for i in 0..clusters.len() {
701            for j in (i + 1)..clusters.len() {
702                let distance = self.cluster_distance(&clusters[i], &clusters[j], distance_matrix);
703                if distance < min_distance {
704                    min_distance = distance;
705                    closest_pair = (i, j);
706                }
707            }
708        }
709
710        Ok(closest_pair)
711    }
712
713    /// Calculate distance between clusters based on linkage criterion
714    fn cluster_distance(
715        &self,
716        cluster1: &[usize],
717        cluster2: &[usize],
718        distance_matrix: &[Vec<f32>],
719    ) -> f32 {
720        match self.config.linkage {
721            LinkageCriterion::Single => {
722                // Minimum distance
723                cluster1
724                    .iter()
725                    .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
726                    .fold(f32::INFINITY, f32::min)
727            }
728            LinkageCriterion::Complete => {
729                // Maximum distance
730                cluster1
731                    .iter()
732                    .flat_map(|&i| cluster2.iter().map(move |&j| distance_matrix[i][j]))
733                    .fold(0.0, f32::max)
734            }
735            LinkageCriterion::Average => {
736                // Average distance
737                let mut total = 0.0;
738                let mut count = 0;
739                for &i in cluster1 {
740                    for &j in cluster2 {
741                        total += distance_matrix[i][j];
742                        count += 1;
743                    }
744                }
745                if count > 0 {
746                    total / count as f32
747                } else {
748                    0.0
749                }
750            }
751            LinkageCriterion::Ward => {
752                // Simplified Ward linkage (should consider cluster variance)
753                self.cluster_distance(cluster1, cluster2, distance_matrix)
754            }
755        }
756    }
757
758    /// Update distance matrix after merging clusters
759    fn update_distance_matrix(
760        &self,
761        distance_matrix: &mut Vec<Vec<f32>>,
762        _clusters: &[Vec<usize>],
763        _merged_cluster: usize,
764        resources: &[(String, Vector)],
765    ) -> Result<()> {
766        // Simplified update - could be more efficient
767        let new_matrix = self.compute_distance_matrix(resources)?;
768        *distance_matrix = new_matrix;
769        Ok(())
770    }
771
772    /// Compute clustering quality metrics
773    fn compute_quality_metrics(
774        &self,
775        resources: &[(String, Vector)],
776        clusters: &[Cluster],
777    ) -> Result<ClusteringQualityMetrics> {
778        // Simplified quality metrics - in practice these would be more sophisticated
779        let mut within_cluster_ss = 0.0;
780        let mut silhouette_scores = Vec::new();
781
782        for cluster in clusters {
783            if cluster.members.len() > 1 {
784                let cluster_vectors: Vec<&Vector> = cluster
785                    .members
786                    .iter()
787                    .filter_map(|member| {
788                        resources
789                            .iter()
790                            .find(|(id, _)| id == member)
791                            .map(|(_, v)| v)
792                    })
793                    .collect();
794
795                if let Some(ref centroid) = cluster.centroid {
796                    for vector in &cluster_vectors {
797                        let dist = self.calculate_distance(vector, centroid)?;
798                        within_cluster_ss += dist * dist;
799                    }
800                }
801            }
802        }
803
804        // Calculate silhouette scores for all points
805        for (cluster_idx, cluster) in clusters.iter().enumerate() {
806            let cluster_vectors: Vec<(usize, &Vector)> = cluster
807                .members
808                .iter()
809                .filter_map(|member| {
810                    resources
811                        .iter()
812                        .enumerate()
813                        .find(|(_, (id, _))| id == member)
814                        .map(|(idx, (_, v))| (idx, v))
815                })
816                .collect();
817
818            // For each point in this cluster
819            for (point_idx, point_vector) in &cluster_vectors {
820                if cluster_vectors.len() <= 1 {
821                    // Single point cluster gets silhouette score of 0
822                    silhouette_scores.push(0.0);
823                    continue;
824                }
825
826                // Calculate average distance to other points in same cluster (a)
827                let mut intra_cluster_dist = 0.0;
828                let mut intra_count = 0;
829                for (other_idx, other_vector) in &cluster_vectors {
830                    if point_idx != other_idx {
831                        let dist = self.calculate_distance(point_vector, other_vector)?;
832                        intra_cluster_dist += dist;
833                        intra_count += 1;
834                    }
835                }
836                let a = if intra_count > 0 {
837                    intra_cluster_dist / intra_count as f32
838                } else {
839                    0.0
840                };
841
842                // Calculate minimum average distance to points in other clusters (b)
843                let mut min_inter_cluster_dist = f32::INFINITY;
844                for (other_cluster_idx, other_cluster) in clusters.iter().enumerate() {
845                    if cluster_idx != other_cluster_idx {
846                        let other_cluster_vectors: Vec<&Vector> = other_cluster
847                            .members
848                            .iter()
849                            .filter_map(|member| {
850                                resources
851                                    .iter()
852                                    .find(|(id, _)| id == member)
853                                    .map(|(_, v)| v)
854                            })
855                            .collect();
856
857                        if !other_cluster_vectors.is_empty() {
858                            let mut inter_cluster_dist = 0.0;
859                            for other_vector in &other_cluster_vectors {
860                                let dist = self.calculate_distance(point_vector, other_vector)?;
861                                inter_cluster_dist += dist;
862                            }
863                            let avg_dist = inter_cluster_dist / other_cluster_vectors.len() as f32;
864                            min_inter_cluster_dist = min_inter_cluster_dist.min(avg_dist);
865                        }
866                    }
867                }
868                let b = min_inter_cluster_dist;
869
870                // Calculate silhouette score for this point
871                let silhouette = if a.max(b) > 0.0 {
872                    (b - a) / a.max(b)
873                } else {
874                    0.0
875                };
876                silhouette_scores.push(silhouette);
877            }
878        }
879
880        let silhouette_score = if !silhouette_scores.is_empty() {
881            silhouette_scores.iter().sum::<f32>() / silhouette_scores.len() as f32
882        } else {
883            0.0
884        };
885
886        // Calculate Davies-Bouldin Index
887        let davies_bouldin_index = self.calculate_davies_bouldin_index(resources, clusters)?;
888
889        // Calculate Calinski-Harabasz Index
890        let calinski_harabasz_index =
891            self.calculate_calinski_harabasz_index(resources, clusters, within_cluster_ss)?;
892
893        // Calculate between-cluster sum of squares
894        let between_cluster_ss = self.calculate_between_cluster_ss(resources, clusters)?;
895
896        Ok(ClusteringQualityMetrics {
897            silhouette_score,
898            davies_bouldin_index,
899            calinski_harabasz_index,
900            within_cluster_ss,
901            between_cluster_ss,
902        })
903    }
904
905    /// Calculate Davies-Bouldin Index (lower is better)
906    fn calculate_davies_bouldin_index(
907        &self,
908        resources: &[(String, Vector)],
909        clusters: &[Cluster],
910    ) -> Result<f32> {
911        if clusters.len() <= 1 {
912            return Ok(0.0);
913        }
914
915        let mut db_sum = 0.0;
916        for i in 0..clusters.len() {
917            let mut max_ratio: f32 = 0.0;
918
919            // Get vectors for cluster i
920            let cluster_i_vectors: Vec<&Vector> = clusters[i]
921                .members
922                .iter()
923                .filter_map(|member| {
924                    resources
925                        .iter()
926                        .find(|(id, _)| id == member)
927                        .map(|(_, v)| v)
928                })
929                .collect();
930
931            if cluster_i_vectors.is_empty() {
932                continue;
933            }
934
935            // Calculate centroid for cluster i
936            let centroid_i = self.compute_centroid(&cluster_i_vectors)?;
937
938            // Calculate average distance to centroid for cluster i
939            let mut avg_dist_i = 0.0;
940            for vector in &cluster_i_vectors {
941                avg_dist_i += self.calculate_distance(vector, &centroid_i)?;
942            }
943            avg_dist_i /= cluster_i_vectors.len() as f32;
944
945            for (j, cluster_j) in clusters.iter().enumerate() {
946                if i == j {
947                    continue;
948                }
949
950                // Get vectors for cluster j
951                let cluster_j_vectors: Vec<&Vector> = cluster_j
952                    .members
953                    .iter()
954                    .filter_map(|member| {
955                        resources
956                            .iter()
957                            .find(|(id, _)| id == member)
958                            .map(|(_, v)| v)
959                    })
960                    .collect();
961
962                if cluster_j_vectors.is_empty() {
963                    continue;
964                }
965
966                // Calculate centroid for cluster j
967                let centroid_j = self.compute_centroid(&cluster_j_vectors)?;
968
969                // Calculate average distance to centroid for cluster j
970                let mut avg_dist_j = 0.0;
971                for vector in &cluster_j_vectors {
972                    avg_dist_j += self.calculate_distance(vector, &centroid_j)?;
973                }
974                avg_dist_j /= cluster_j_vectors.len() as f32;
975
976                // Calculate distance between centroids
977                let centroid_distance = self.calculate_distance(&centroid_i, &centroid_j)?;
978
979                // Calculate Davies-Bouldin ratio
980                if centroid_distance > 0.0 {
981                    let ratio: f32 = (avg_dist_i + avg_dist_j) / centroid_distance;
982                    max_ratio = max_ratio.max(ratio);
983                }
984            }
985            db_sum += max_ratio;
986        }
987
988        Ok(db_sum / clusters.len() as f32)
989    }
990
991    /// Calculate Calinski-Harabasz Index (higher is better)
992    fn calculate_calinski_harabasz_index(
993        &self,
994        resources: &[(String, Vector)],
995        clusters: &[Cluster],
996        within_cluster_ss: f32,
997    ) -> Result<f32> {
998        if clusters.len() <= 1 || resources.is_empty() {
999            return Ok(0.0);
1000        }
1001
1002        // Calculate overall centroid
1003        let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1004        let overall_centroid = self.compute_centroid(&all_vectors)?;
1005
1006        // Calculate between-cluster sum of squares
1007        let mut between_cluster_ss = 0.0;
1008        for cluster in clusters {
1009            let cluster_vectors: Vec<&Vector> = cluster
1010                .members
1011                .iter()
1012                .filter_map(|member| {
1013                    resources
1014                        .iter()
1015                        .find(|(id, _)| id == member)
1016                        .map(|(_, v)| v)
1017                })
1018                .collect();
1019
1020            if !cluster_vectors.is_empty() {
1021                let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1022                let distance_sq = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1023                between_cluster_ss += cluster_vectors.len() as f32 * distance_sq * distance_sq;
1024            }
1025        }
1026
1027        // Calinski-Harabasz = (between_cluster_ss / (k-1)) / (within_cluster_ss / (n-k))
1028        let k = clusters.len() as f32;
1029        let n = resources.len() as f32;
1030
1031        if k >= n || within_cluster_ss <= 0.0 {
1032            return Ok(0.0);
1033        }
1034
1035        let ch_index = (between_cluster_ss / (k - 1.0)) / (within_cluster_ss / (n - k));
1036        Ok(ch_index)
1037    }
1038
1039    /// Calculate between-cluster sum of squares
1040    fn calculate_between_cluster_ss(
1041        &self,
1042        resources: &[(String, Vector)],
1043        clusters: &[Cluster],
1044    ) -> Result<f32> {
1045        if clusters.is_empty() || resources.is_empty() {
1046            return Ok(0.0);
1047        }
1048
1049        // Calculate overall centroid
1050        let all_vectors: Vec<&Vector> = resources.iter().map(|(_, v)| v).collect();
1051        let overall_centroid = self.compute_centroid(&all_vectors)?;
1052
1053        let mut between_cluster_ss = 0.0;
1054        for cluster in clusters {
1055            let cluster_vectors: Vec<&Vector> = cluster
1056                .members
1057                .iter()
1058                .filter_map(|member| {
1059                    resources
1060                        .iter()
1061                        .find(|(id, _)| id == member)
1062                        .map(|(_, v)| v)
1063                })
1064                .collect();
1065
1066            if !cluster_vectors.is_empty() {
1067                let cluster_centroid = self.compute_centroid(&cluster_vectors)?;
1068                let distance = self.calculate_distance(&cluster_centroid, &overall_centroid)?;
1069                between_cluster_ss += cluster_vectors.len() as f32 * distance * distance;
1070            }
1071        }
1072
1073        Ok(between_cluster_ss)
1074    }
1075}
1076
1077impl Default for ClusteringQualityMetrics {
1078    fn default() -> Self {
1079        Self {
1080            silhouette_score: 0.0,
1081            davies_bouldin_index: 0.0,
1082            calinski_harabasz_index: 0.0,
1083            within_cluster_ss: 0.0,
1084            between_cluster_ss: 0.0,
1085        }
1086    }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092
1093    #[test]
1094    fn test_kmeans_clustering() {
1095        let config = ClusteringConfig {
1096            algorithm: ClusteringAlgorithm::KMeans,
1097            num_clusters: Some(2),
1098            random_seed: Some(42),
1099            distance_metric: SimilarityMetric::Euclidean, // Use Euclidean for proper distance calculation
1100            ..Default::default()
1101        };
1102
1103        let engine = ClusteringEngine::new(config);
1104
1105        let resources = vec![
1106            ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1107            ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1108            ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1109            ("res4".to_string(), Vector::new(vec![10.1, 10.1, 10.1])),
1110        ];
1111
1112        let result = engine.cluster(&resources).unwrap();
1113
1114        assert_eq!(result.clusters.len(), 2);
1115        assert!(result.noise.is_empty());
1116    }
1117
1118    #[test]
1119    fn test_dbscan_clustering() {
1120        let config = ClusteringConfig {
1121            algorithm: ClusteringAlgorithm::DBSCAN,
1122            similarity_threshold: 0.9,
1123            min_cluster_size: 2,
1124            ..Default::default()
1125        };
1126
1127        let engine = ClusteringEngine::new(config);
1128
1129        let resources = vec![
1130            ("res1".to_string(), Vector::new(vec![1.0, 1.0, 1.0])),
1131            ("res2".to_string(), Vector::new(vec![1.1, 1.1, 1.1])),
1132            ("res3".to_string(), Vector::new(vec![10.0, 10.0, 10.0])),
1133        ];
1134
1135        let result = engine.cluster(&resources).unwrap();
1136        assert!(result.clusters.len() <= 2);
1137    }
1138
1139    #[test]
1140    fn test_similarity_clustering() {
1141        let config = ClusteringConfig {
1142            algorithm: ClusteringAlgorithm::Similarity,
1143            similarity_threshold: 0.95,
1144            ..Default::default()
1145        };
1146
1147        let engine = ClusteringEngine::new(config);
1148
1149        let resources = vec![
1150            ("res1".to_string(), Vector::new(vec![1.0, 0.0, 0.0])),
1151            ("res2".to_string(), Vector::new(vec![0.0, 1.0, 0.0])),
1152            ("res3".to_string(), Vector::new(vec![0.0, 0.0, 1.0])),
1153        ];
1154
1155        let result = engine.cluster(&resources).unwrap();
1156        // Should have 3 clusters since vectors are orthogonal
1157        assert_eq!(result.clusters.len(), 3);
1158    }
1159}