Skip to main content

quantrs2_ml/sklearn_compatibility/
clustering.rs

1//! Sklearn-compatible clustering algorithms
2
3use super::{SklearnClusterer, SklearnEstimator};
4use crate::clustering::core::QuantumClusterer;
5use crate::error::{MLError, Result};
6use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Quantum K-Means (sklearn-compatible)
12pub struct QuantumKMeans {
13    /// Internal clusterer
14    clusterer: Option<QuantumClusterer>,
15    /// Number of clusters
16    n_clusters: usize,
17    /// Maximum iterations
18    max_iter: usize,
19    /// Tolerance
20    tol: f64,
21    /// Random state
22    random_state: Option<u64>,
23    /// Backend
24    backend: Arc<dyn SimulatorBackend>,
25    /// Fitted flag
26    fitted: bool,
27    /// Cluster centers
28    cluster_centers_: Option<Array2<f64>>,
29    /// Labels
30    labels_: Option<Array1<i32>>,
31}
32
33impl QuantumKMeans {
34    /// Create new Quantum K-Means
35    pub fn new(n_clusters: usize) -> Self {
36        Self {
37            clusterer: None,
38            n_clusters,
39            max_iter: 300,
40            tol: 1e-4,
41            random_state: None,
42            backend: Arc::new(StatevectorBackend::new(10)),
43            fitted: false,
44            cluster_centers_: None,
45            labels_: None,
46        }
47    }
48
49    /// Set maximum iterations
50    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
51        self.max_iter = max_iter;
52        self
53    }
54
55    /// Set tolerance
56    pub fn set_tol(mut self, tol: f64) -> Self {
57        self.tol = tol;
58        self
59    }
60
61    /// Set random state
62    pub fn set_random_state(mut self, random_state: u64) -> Self {
63        self.random_state = Some(random_state);
64        self
65    }
66}
67
68impl SklearnEstimator for QuantumKMeans {
69    #[allow(non_snake_case)]
70    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
71        let config = crate::clustering::config::QuantumClusteringConfig {
72            algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
73            n_clusters: self.n_clusters,
74            max_iterations: self.max_iter,
75            tolerance: self.tol,
76            num_qubits: 4,
77            random_state: self.random_state,
78        };
79        let mut clusterer = QuantumClusterer::new(config);
80
81        let result = clusterer.fit_predict(X)?;
82        // Convert usize to i32 for sklearn compatibility
83        let result_i32 = result.mapv(|x| x as i32);
84        self.labels_ = Some(result_i32);
85
86        // Compute cluster centers as the centroid of each cluster.
87        // center_k[j] = mean(X[i, j] for all i where label[i] == k)
88        let n_features = X.ncols();
89        let n_clusters = self.n_clusters;
90        let mut centers = Array2::<f64>::zeros((n_clusters, n_features));
91        let mut counts = vec![0usize; n_clusters];
92        for (i, &label) in result.iter().enumerate() {
93            let k = label.min(n_clusters - 1);
94            counts[k] += 1;
95            for j in 0..n_features {
96                centers[[k, j]] += X[[i, j]];
97            }
98        }
99        for k in 0..n_clusters {
100            let count = counts[k];
101            if count > 0 {
102                for j in 0..n_features {
103                    centers[[k, j]] /= count as f64;
104                }
105            }
106        }
107        self.cluster_centers_ = Some(centers);
108
109        self.clusterer = Some(clusterer);
110        self.fitted = true;
111
112        Ok(())
113    }
114
115    fn get_params(&self) -> HashMap<String, String> {
116        let mut params = HashMap::new();
117        params.insert("n_clusters".to_string(), self.n_clusters.to_string());
118        params.insert("max_iter".to_string(), self.max_iter.to_string());
119        params.insert("tol".to_string(), self.tol.to_string());
120        if let Some(rs) = self.random_state {
121            params.insert("random_state".to_string(), rs.to_string());
122        }
123        params
124    }
125
126    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
127        for (key, value) in params {
128            match key.as_str() {
129                "n_clusters" => {
130                    self.n_clusters = value.parse().map_err(|_| {
131                        MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
132                    })?;
133                }
134                "max_iter" => {
135                    self.max_iter = value.parse().map_err(|_| {
136                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
137                    })?;
138                }
139                "tol" => {
140                    self.tol = value.parse().map_err(|_| {
141                        MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
142                    })?;
143                }
144                "random_state" => {
145                    self.random_state = Some(value.parse().map_err(|_| {
146                        MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
147                    })?);
148                }
149                _ => {
150                    // Skip unknown parameters
151                }
152            }
153        }
154        Ok(())
155    }
156
157    fn is_fitted(&self) -> bool {
158        self.fitted
159    }
160}
161
162impl SklearnClusterer for QuantumKMeans {
163    #[allow(non_snake_case)]
164    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
165        if !self.fitted {
166            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
167        }
168
169        let clusterer = self
170            .clusterer
171            .as_ref()
172            .ok_or_else(|| MLError::ModelNotTrained("Clusterer not initialized".to_string()))?;
173        let result = clusterer.predict(X)?;
174        // Convert usize to i32 for sklearn compatibility
175        Ok(result.mapv(|x| x as i32))
176    }
177
178    fn cluster_centers(&self) -> Option<&Array2<f64>> {
179        self.cluster_centers_.as_ref()
180    }
181}
182
183/// DBSCAN clustering algorithm
184pub struct DBSCAN {
185    /// Epsilon - neighborhood radius
186    eps: f64,
187    /// Minimum samples for core points
188    min_samples: usize,
189    /// Fitted labels
190    labels: Option<Array1<i32>>,
191    /// Core sample indices
192    core_sample_indices: Vec<usize>,
193}
194
195impl DBSCAN {
196    /// Create new DBSCAN
197    pub fn new(eps: f64, min_samples: usize) -> Self {
198        Self {
199            eps,
200            min_samples,
201            labels: None,
202            core_sample_indices: Vec::new(),
203        }
204    }
205
206    /// Set eps
207    pub fn eps(mut self, eps: f64) -> Self {
208        self.eps = eps;
209        self
210    }
211
212    /// Set min_samples
213    pub fn min_samples(mut self, min_samples: usize) -> Self {
214        self.min_samples = min_samples;
215        self
216    }
217
218    /// Get labels
219    pub fn labels(&self) -> Option<&Array1<i32>> {
220        self.labels.as_ref()
221    }
222
223    /// Get core sample indices
224    pub fn core_sample_indices(&self) -> &[usize] {
225        &self.core_sample_indices
226    }
227
228    /// Compute distance matrix
229    #[allow(non_snake_case)]
230    fn compute_distances(&self, X: &Array2<f64>) -> Array2<f64> {
231        let n = X.nrows();
232        let mut distances = Array2::zeros((n, n));
233
234        for i in 0..n {
235            for j in i + 1..n {
236                let mut dist = 0.0;
237                for k in 0..X.ncols() {
238                    let diff = X[[i, k]] - X[[j, k]];
239                    dist += diff * diff;
240                }
241                let dist = dist.sqrt();
242                distances[[i, j]] = dist;
243                distances[[j, i]] = dist;
244            }
245        }
246
247        distances
248    }
249
250    /// Get number of clusters found
251    pub fn n_clusters(&self) -> Option<usize> {
252        self.labels.as_ref().map(|labels| {
253            let max_label = labels.iter().max().copied().unwrap_or(-1);
254            if max_label >= 0 {
255                (max_label + 1) as usize
256            } else {
257                0
258            }
259        })
260    }
261
262    /// Fit the model (internal)
263    #[allow(non_snake_case)]
264    fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
265        let n = X.nrows();
266        let distances = self.compute_distances(X);
267
268        // Find neighbors for each point
269        let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
270        for i in 0..n {
271            for j in 0..n {
272                if i != j && distances[[i, j]] <= self.eps {
273                    neighbors[i].push(j);
274                }
275            }
276        }
277
278        // Identify core points
279        self.core_sample_indices.clear();
280        for (i, n_neighbors) in neighbors.iter().enumerate() {
281            if n_neighbors.len() >= self.min_samples {
282                self.core_sample_indices.push(i);
283            }
284        }
285
286        // Label points
287        let mut labels = Array1::from_elem(n, -1_i32); // -1 = noise
288        let mut visited = vec![false; n];
289        let mut cluster_id = 0_i32;
290
291        for &core_idx in &self.core_sample_indices {
292            if visited[core_idx] {
293                continue;
294            }
295
296            // BFS to expand cluster
297            let mut stack = vec![core_idx];
298            while let Some(idx) = stack.pop() {
299                if visited[idx] {
300                    continue;
301                }
302                visited[idx] = true;
303                labels[idx] = cluster_id;
304
305                // If this is a core point, expand
306                if neighbors[idx].len() >= self.min_samples {
307                    for &neighbor in &neighbors[idx] {
308                        if !visited[neighbor] {
309                            stack.push(neighbor);
310                        }
311                    }
312                }
313            }
314            cluster_id += 1;
315        }
316
317        self.labels = Some(labels);
318        Ok(())
319    }
320}
321
322impl SklearnEstimator for DBSCAN {
323    #[allow(non_snake_case)]
324    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
325        self.fit_internal(X)
326    }
327
328    fn get_params(&self) -> HashMap<String, String> {
329        let mut params = HashMap::new();
330        params.insert("eps".to_string(), self.eps.to_string());
331        params.insert("min_samples".to_string(), self.min_samples.to_string());
332        params
333    }
334
335    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
336        for (key, value) in params {
337            match key.as_str() {
338                "eps" => {
339                    self.eps = value.parse().map_err(|_| {
340                        MLError::InvalidConfiguration(format!("Invalid eps: {}", value))
341                    })?;
342                }
343                "min_samples" => {
344                    self.min_samples = value.parse().map_err(|_| {
345                        MLError::InvalidConfiguration(format!("Invalid min_samples: {}", value))
346                    })?;
347                }
348                _ => {}
349            }
350        }
351        Ok(())
352    }
353
354    fn is_fitted(&self) -> bool {
355        self.labels.is_some()
356    }
357}
358
359impl SklearnClusterer for DBSCAN {
360    #[allow(non_snake_case)]
361    fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
362        // For DBSCAN, predict returns labels from fit
363        // New points would need special handling
364        self.labels
365            .clone()
366            .ok_or_else(|| MLError::ModelNotTrained("DBSCAN not fitted".to_string()))
367    }
368}
369
370/// Agglomerative Clustering
371pub struct AgglomerativeClustering {
372    /// Number of clusters
373    n_clusters: usize,
374    /// Linkage type
375    linkage: String,
376    /// Fitted labels
377    labels: Option<Array1<i32>>,
378}
379
380impl AgglomerativeClustering {
381    /// Create new AgglomerativeClustering
382    pub fn new(n_clusters: usize) -> Self {
383        Self {
384            n_clusters,
385            linkage: "ward".to_string(),
386            labels: None,
387        }
388    }
389
390    /// Set linkage
391    pub fn linkage(mut self, linkage: &str) -> Self {
392        self.linkage = linkage.to_string();
393        self
394    }
395
396    /// Get number of clusters
397    pub fn get_n_clusters(&self) -> Option<usize> {
398        if self.labels.is_some() {
399            Some(self.n_clusters)
400        } else {
401            None
402        }
403    }
404
405    /// Fit internal
406    #[allow(non_snake_case)]
407    fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
408        let n = X.nrows();
409
410        // Compute distance matrix
411        let mut distances = Array2::from_elem((n, n), f64::INFINITY);
412        for i in 0..n {
413            for j in i + 1..n {
414                let mut dist = 0.0;
415                for k in 0..X.ncols() {
416                    let diff = X[[i, k]] - X[[j, k]];
417                    dist += diff * diff;
418                }
419                distances[[i, j]] = dist.sqrt();
420                distances[[j, i]] = distances[[i, j]];
421            }
422            distances[[i, i]] = 0.0;
423        }
424
425        // Initialize clusters: each point is its own cluster
426        let mut cluster_assignment: Vec<usize> = (0..n).collect();
427        let mut active_clusters: Vec<bool> = vec![true; n];
428        let mut cluster_sizes: Vec<usize> = vec![1; n];
429
430        // Merge clusters until we have n_clusters
431        let mut num_clusters = n;
432        while num_clusters > self.n_clusters {
433            // Find closest pair of clusters
434            let mut min_dist = f64::INFINITY;
435            let mut merge_i = 0;
436            let mut merge_j = 0;
437
438            for i in 0..n {
439                if !active_clusters[i] {
440                    continue;
441                }
442                for j in i + 1..n {
443                    if !active_clusters[j] {
444                        continue;
445                    }
446                    if distances[[i, j]] < min_dist {
447                        min_dist = distances[[i, j]];
448                        merge_i = i;
449                        merge_j = j;
450                    }
451                }
452            }
453
454            // Merge j into i
455            for k in 0..n {
456                if cluster_assignment[k] == merge_j {
457                    cluster_assignment[k] = merge_i;
458                }
459            }
460            active_clusters[merge_j] = false;
461            cluster_sizes[merge_i] += cluster_sizes[merge_j];
462
463            // Update distances (using average linkage as default)
464            for k in 0..n {
465                if k != merge_i && active_clusters[k] {
466                    let new_dist = match self.linkage.as_str() {
467                        "single" => distances[[merge_i, k]].min(distances[[merge_j, k]]),
468                        "complete" => distances[[merge_i, k]].max(distances[[merge_j, k]]),
469                        "average" | _ => {
470                            let s_i = cluster_sizes[merge_i] as f64;
471                            let s_j = cluster_sizes[merge_j] as f64;
472                            (distances[[merge_i, k]] * (s_i - cluster_sizes[merge_j] as f64)
473                                + distances[[merge_j, k]] * s_j)
474                                / s_i
475                        }
476                    };
477                    distances[[merge_i, k]] = new_dist;
478                    distances[[k, merge_i]] = new_dist;
479                }
480            }
481
482            num_clusters -= 1;
483        }
484
485        // Remap cluster labels to 0..n_clusters-1
486        let unique_clusters: Vec<usize> = cluster_assignment
487            .iter()
488            .copied()
489            .collect::<std::collections::HashSet<_>>()
490            .into_iter()
491            .collect();
492        let label_map: std::collections::HashMap<usize, i32> = unique_clusters
493            .iter()
494            .enumerate()
495            .map(|(i, &c)| (c, i as i32))
496            .collect();
497
498        let labels = cluster_assignment
499            .iter()
500            .map(|&c| *label_map.get(&c).unwrap_or(&0))
501            .collect();
502        self.labels = Some(Array1::from_vec(labels));
503
504        Ok(())
505    }
506}
507
508impl SklearnEstimator for AgglomerativeClustering {
509    #[allow(non_snake_case)]
510    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
511        self.fit_internal(X)
512    }
513
514    fn get_params(&self) -> HashMap<String, String> {
515        let mut params = HashMap::new();
516        params.insert("n_clusters".to_string(), self.n_clusters.to_string());
517        params.insert("linkage".to_string(), self.linkage.clone());
518        params
519    }
520
521    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
522        for (key, value) in params {
523            match key.as_str() {
524                "n_clusters" => {
525                    self.n_clusters = value.parse().map_err(|_| {
526                        MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
527                    })?;
528                }
529                "linkage" => {
530                    self.linkage = value;
531                }
532                _ => {}
533            }
534        }
535        Ok(())
536    }
537
538    fn is_fitted(&self) -> bool {
539        self.labels.is_some()
540    }
541}
542
543impl SklearnClusterer for AgglomerativeClustering {
544    #[allow(non_snake_case)]
545    fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
546        self.labels
547            .clone()
548            .ok_or_else(|| MLError::ModelNotTrained("Not fitted".to_string()))
549    }
550}