Skip to main content

quantrs2_ml/clustering/
core.rs

1//! Core quantum clustering functionality
2
3use crate::dimensionality_reduction::QuantumDistanceMetric;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
6
7use super::config::*;
8
9/// Clustering result containing labels and metadata
10#[derive(Debug, Clone)]
11pub struct ClusteringResult {
12    /// Cluster labels for each data point
13    pub labels: Array1<usize>,
14    /// Number of clusters found
15    pub n_clusters: usize,
16    /// Cluster centers (if available)
17    pub cluster_centers: Option<Array2<f64>>,
18    /// Inertia/within-cluster sum of squares (if available)
19    pub inertia: Option<f64>,
20    /// Cluster probabilities (for soft clustering)
21    pub probabilities: Option<Array2<f64>>,
22}
23
24/// Main quantum clusterer
25#[derive(Debug)]
26pub struct QuantumClusterer {
27    config: QuantumClusteringConfig,
28    cluster_centers: Option<Array2<f64>>,
29    labels: Option<Array1<usize>>,
30    // Algorithm-specific configurations
31    pub kmeans_config: Option<QuantumKMeansConfig>,
32    pub dbscan_config: Option<QuantumDBSCANConfig>,
33    pub spectral_config: Option<QuantumSpectralConfig>,
34    pub fuzzy_config: Option<QuantumFuzzyCMeansConfig>,
35    pub gmm_config: Option<QuantumGMMConfig>,
36}
37
38impl QuantumClusterer {
39    /// Create new quantum clusterer
40    pub fn new(config: QuantumClusteringConfig) -> Self {
41        Self {
42            config,
43            cluster_centers: None,
44            labels: None,
45            kmeans_config: None,
46            dbscan_config: None,
47            spectral_config: None,
48            fuzzy_config: None,
49            gmm_config: None,
50        }
51    }
52
53    /// Create quantum K-means clusterer
54    pub fn kmeans(config: QuantumKMeansConfig) -> Self {
55        let mut clusterer = Self::new(QuantumClusteringConfig {
56            algorithm: ClusteringAlgorithm::QuantumKMeans,
57            n_clusters: config.n_clusters,
58            max_iterations: config.max_iterations,
59            tolerance: config.tolerance,
60            num_qubits: 4,
61            random_state: config.seed,
62        });
63        clusterer.kmeans_config = Some(config);
64        clusterer
65    }
66
67    /// Create quantum DBSCAN clusterer
68    pub fn dbscan(config: QuantumDBSCANConfig) -> Self {
69        let mut clusterer = Self::new(QuantumClusteringConfig {
70            algorithm: ClusteringAlgorithm::QuantumDBSCAN,
71            n_clusters: 0, // DBSCAN determines clusters automatically
72            max_iterations: 100,
73            tolerance: 1e-4,
74            num_qubits: 4,
75            random_state: config.seed,
76        });
77        clusterer.dbscan_config = Some(config);
78        clusterer
79    }
80
81    /// Create quantum spectral clusterer
82    pub fn spectral(config: QuantumSpectralConfig) -> Self {
83        let mut clusterer = Self::new(QuantumClusteringConfig {
84            algorithm: ClusteringAlgorithm::QuantumSpectral,
85            n_clusters: config.n_clusters,
86            max_iterations: 100,
87            tolerance: 1e-4,
88            num_qubits: 4,
89            random_state: config.seed,
90        });
91        clusterer.spectral_config = Some(config);
92        clusterer
93    }
94
95    /// Compute squared Euclidean distance between two array views
96    fn squared_dist(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
97        a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
98    }
99
100    /// Iterative union-find with path halving (no recursion)
101    fn uf_find(parent: &mut [usize], mut x: usize) -> usize {
102        while parent[x] != x {
103            // Path compression by halving
104            parent[x] = parent[parent[x]];
105            x = parent[x];
106        }
107        x
108    }
109
110    /// Run Lloyd's k-means algorithm with k-means++ initialization.
111    ///
112    /// Returns `(cluster_centers, labels, inertia)`.
113    fn run_kmeans(
114        &self,
115        data: &Array2<f64>,
116        k: usize,
117    ) -> Result<(Array2<f64>, Array1<usize>, f64)> {
118        let n_samples = data.nrows();
119        let n_features = data.ncols();
120        let max_iter = self.config.max_iterations;
121
122        // -----------------------------------------------------------------------
123        // k-means++ initialisation
124        // First center: deterministic – row 0, or seeded via random_state.
125        // Subsequent centers: greedy furthest-point (deterministic, avoids RNG).
126        // -----------------------------------------------------------------------
127        let mut centers = Array2::<f64>::zeros((k, n_features));
128
129        // Choose first center
130        let first_idx = self
131            .config
132            .random_state
133            .map(|s| (s as usize) % n_samples)
134            .unwrap_or(0);
135        centers.row_mut(0).assign(&data.row(first_idx));
136
137        // k-means++ subsequent centers
138        for c in 1..k {
139            // For each sample, compute minimum squared distance to any chosen center so far
140            let mut min_dists_sq = vec![f64::INFINITY; n_samples];
141            for i in 0..n_samples {
142                for prev_c in 0..c {
143                    let d = self.squared_dist(&data.row(i), &centers.row(prev_c));
144                    if d < min_dists_sq[i] {
145                        min_dists_sq[i] = d;
146                    }
147                }
148            }
149            // Greedy deterministic choice: the sample farthest from all current centers
150            let next_idx = min_dists_sq
151                .iter()
152                .enumerate()
153                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
154                .map(|(i, _)| i)
155                .unwrap_or(c % n_samples);
156            centers.row_mut(c).assign(&data.row(next_idx));
157        }
158
159        // -----------------------------------------------------------------------
160        // Lloyd's iterations
161        // -----------------------------------------------------------------------
162        let mut labels = vec![0usize; n_samples];
163
164        for _iter in 0..max_iter {
165            // ----- Assignment step -----
166            let mut changed = false;
167            for i in 0..n_samples {
168                let mut best_c = 0;
169                let mut best_d = f64::INFINITY;
170                for c in 0..k {
171                    let d = self.squared_dist(&data.row(i), &centers.row(c));
172                    if d < best_d {
173                        best_d = d;
174                        best_c = c;
175                    }
176                }
177                if labels[i] != best_c {
178                    changed = true;
179                    labels[i] = best_c;
180                }
181            }
182
183            // ----- Update step -----
184            let mut new_centers = Array2::<f64>::zeros((k, n_features));
185            let mut counts = vec![0usize; k];
186            for i in 0..n_samples {
187                let c = labels[i];
188                new_centers.row_mut(c).scaled_add(1.0, &data.row(i));
189                counts[c] += 1;
190            }
191            for c in 0..k {
192                if counts[c] > 0 {
193                    new_centers
194                        .row_mut(c)
195                        .mapv_inplace(|v| v / counts[c] as f64);
196                } else {
197                    // Empty cluster: reassign center to a guaranteed occupied data point
198                    new_centers.row_mut(c).assign(&data.row(c % n_samples));
199                }
200            }
201            centers = new_centers;
202
203            if !changed {
204                break;
205            }
206        }
207
208        // -----------------------------------------------------------------------
209        // Compute inertia (within-cluster sum of squared distances)
210        // -----------------------------------------------------------------------
211        let mut inertia = 0.0f64;
212        for i in 0..n_samples {
213            inertia += self.squared_dist(&data.row(i), &centers.row(labels[i]));
214        }
215
216        let labels_arr = Array1::from_iter(labels);
217        Ok((centers, labels_arr, inertia))
218    }
219
220    /// Density-based cluster counting using union-find over the epsilon neighbourhood.
221    ///
222    /// Uses `dbscan_config.eps` and `dbscan_config.min_samples` when available,
223    /// falling back to sensible defaults derived from the data spread.
224    fn fit_dbscan(&self, data: &Array2<f64>) -> Result<usize> {
225        let n = data.nrows();
226
227        let (eps, min_samples) = if let Some(cfg) = &self.dbscan_config {
228            (cfg.eps, cfg.min_samples)
229        } else {
230            // Estimate eps as ~10 % of the bounding-box diagonal
231            let mut max_sq = 0.0f64;
232            for i in 0..n {
233                for j in (i + 1)..n {
234                    let d = self.squared_dist(&data.row(i), &data.row(j));
235                    if d > max_sq {
236                        max_sq = d;
237                    }
238                }
239            }
240            (max_sq.sqrt() * 0.1, 2usize)
241        };
242
243        // Union-find initialisation
244        let mut parent: Vec<usize> = (0..n).collect();
245
246        for i in 0..n {
247            let mut neighbor_count = 0usize;
248            for j in 0..n {
249                if i == j {
250                    continue;
251                }
252                let d = self.squared_dist(&data.row(i), &data.row(j)).sqrt();
253                if d <= eps {
254                    neighbor_count += 1;
255                    // Union i and j
256                    let pi = Self::uf_find(&mut parent, i);
257                    let pj = Self::uf_find(&mut parent, j);
258                    if pi != pj {
259                        parent[pi] = pj;
260                    }
261                }
262            }
263            // Points with fewer than min_samples neighbours remain noise (own root)
264            let _ = neighbor_count;
265        }
266
267        // Count distinct roots – each root represents one cluster
268        let n_clusters = (0..n)
269            .filter(|&i| Self::uf_find(&mut parent, i) == i)
270            .count();
271
272        Ok(n_clusters.max(1))
273    }
274
275    /// Fit the clustering model using Lloyd's k-means with k-means++ initialization.
276    pub fn fit(&mut self, data: &Array2<f64>) -> Result<ClusteringResult> {
277        let n_samples = data.nrows();
278
279        if n_samples == 0 {
280            return Err(MLError::InvalidInput("Empty data".to_string()));
281        }
282
283        // Determine the target number of clusters
284        let n_clusters = match self.config.algorithm {
285            ClusteringAlgorithm::QuantumDBSCAN => {
286                // DBSCAN determines clusters from density
287                let auto_k = self.fit_dbscan(data)?;
288                auto_k
289            }
290            _ => {
291                // Use configured n_clusters, capped to available samples
292                self.config.n_clusters.min(n_samples).max(1)
293            }
294        };
295
296        // Run Lloyd's k-means (with k-means++ init) over the chosen k
297        let (cluster_centers, labels, inertia) = self.run_kmeans(data, n_clusters)?;
298
299        self.cluster_centers = Some(cluster_centers.clone());
300        self.labels = Some(labels.clone());
301
302        Ok(ClusteringResult {
303            labels,
304            n_clusters,
305            cluster_centers: Some(cluster_centers),
306            inertia: Some(inertia),
307            probabilities: None,
308        })
309    }
310
311    /// Predict cluster labels for new data by assigning to the nearest center.
312    pub fn predict(&self, data: &Array2<f64>) -> Result<Array1<usize>> {
313        let centers = self.cluster_centers.as_ref().ok_or_else(|| {
314            MLError::ModelNotTrained("Clusterer must be fitted before predict".to_string())
315        })?;
316
317        let k = centers.nrows();
318        let labels: Vec<usize> = (0..data.nrows())
319            .map(|i| {
320                let mut best_c = 0;
321                let mut best_d = f64::INFINITY;
322                for c in 0..k {
323                    let d = self.squared_dist(&data.row(i), &centers.row(c));
324                    if d < best_d {
325                        best_d = d;
326                        best_c = c;
327                    }
328                }
329                best_c
330            })
331            .collect();
332
333        Ok(Array1::from_iter(labels))
334    }
335
336    /// Predict cluster probabilities (for soft clustering)
337    pub fn predict_proba(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
338        if self.cluster_centers.is_none() {
339            return Err(MLError::ModelNotTrained(
340                "Clusterer must be fitted before predict_proba".to_string(),
341            ));
342        }
343
344        let n_samples = data.nrows();
345        let n_clusters = self.config.n_clusters;
346
347        // Return uniform probabilities as placeholder
348        Ok(Array2::from_elem(
349            (n_samples, n_clusters),
350            1.0 / n_clusters as f64,
351        ))
352    }
353
354    /// Compute quantum distance between two points
355    pub fn compute_quantum_distance(
356        &self,
357        point1: &Array1<f64>,
358        point2: &Array1<f64>,
359        metric: QuantumDistanceMetric,
360    ) -> Result<f64> {
361        // Placeholder implementation for quantum distance computation
362        match metric {
363            QuantumDistanceMetric::QuantumEuclidean => {
364                let diff = point1 - point2;
365                Ok(diff.dot(&diff).sqrt())
366            }
367            QuantumDistanceMetric::QuantumManhattan => {
368                Ok((point1 - point2).mapv(|x| x.abs()).sum())
369            }
370            QuantumDistanceMetric::QuantumCosine => {
371                let dot_product = point1.dot(point2);
372                let norm1 = point1.dot(point1).sqrt();
373                let norm2 = point2.dot(point2).sqrt();
374                Ok(1.0 - (dot_product / (norm1 * norm2)))
375            }
376            _ => {
377                // For other quantum metrics, return Euclidean as fallback
378                let diff = point1 - point2;
379                Ok(diff.dot(&diff).sqrt())
380            }
381        }
382    }
383
384    /// Fit and predict in one step
385    pub fn fit_predict(&mut self, data: &Array2<f64>) -> Result<Array1<usize>> {
386        let result = self.fit(data)?;
387        Ok(result.labels)
388    }
389
390    /// Get cluster centers
391    pub fn cluster_centers(&self) -> Option<&Array2<f64>> {
392        self.cluster_centers.as_ref()
393    }
394
395    /// Evaluate clustering performance
396    pub fn evaluate(
397        &self,
398        _data: &Array2<f64>,
399        _true_labels: Option<&Array1<usize>>,
400    ) -> Result<ClusteringMetrics> {
401        if self.cluster_centers.is_none() {
402            return Err(MLError::ModelNotTrained(
403                "Clusterer must be fitted before evaluation".to_string(),
404            ));
405        }
406
407        // Placeholder evaluation metrics
408        Ok(ClusteringMetrics {
409            silhouette_score: 0.5,
410            davies_bouldin_index: 1.0,
411            calinski_harabasz_index: 100.0,
412            inertia: 0.0,
413            adjusted_rand_index: None,
414            normalized_mutual_info: None,
415        })
416    }
417}
418
419/// Clustering evaluation metrics
420#[derive(Debug, Clone)]
421pub struct ClusteringMetrics {
422    /// Silhouette score
423    pub silhouette_score: f64,
424    /// Davies-Bouldin index
425    pub davies_bouldin_index: f64,
426    /// Calinski-Harabasz index
427    pub calinski_harabasz_index: f64,
428    /// Within-cluster sum of squares
429    pub inertia: f64,
430    /// Adjusted Rand Index (if true labels provided)
431    pub adjusted_rand_index: Option<f64>,
432    /// Normalized Mutual Information (if true labels provided)
433    pub normalized_mutual_info: Option<f64>,
434}
435
436/// Helper function to create default quantum K-means clusterer
437pub fn create_default_quantum_kmeans(n_clusters: usize) -> QuantumClusterer {
438    let config = QuantumKMeansConfig {
439        n_clusters,
440        ..Default::default()
441    };
442    QuantumClusterer::kmeans(config)
443}
444
445/// Helper function to create default quantum DBSCAN clusterer
446pub fn create_default_quantum_dbscan(eps: f64, min_samples: usize) -> QuantumClusterer {
447    let config = QuantumDBSCANConfig {
448        eps,
449        min_samples,
450        ..Default::default()
451    };
452    QuantumClusterer::dbscan(config)
453}