quantrs2_ml/sklearn_compatibility/
clustering.rs

1//! Sklearn-compatible clustering algorithms
2
3use super::{SklearnClusterer, SklearnEstimator};
4use crate::clustering::core::QuantumClusterer;
5use crate::error::{MLError, Result};
6use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Quantum K-Means (sklearn-compatible)
12pub struct QuantumKMeans {
13    /// Internal clusterer
14    clusterer: Option<QuantumClusterer>,
15    /// Number of clusters
16    n_clusters: usize,
17    /// Maximum iterations
18    max_iter: usize,
19    /// Tolerance
20    tol: f64,
21    /// Random state
22    random_state: Option<u64>,
23    /// Backend
24    backend: Arc<dyn SimulatorBackend>,
25    /// Fitted flag
26    fitted: bool,
27    /// Cluster centers
28    cluster_centers_: Option<Array2<f64>>,
29    /// Labels
30    labels_: Option<Array1<i32>>,
31}
32
33impl QuantumKMeans {
34    /// Create new Quantum K-Means
35    pub fn new(n_clusters: usize) -> Self {
36        Self {
37            clusterer: None,
38            n_clusters,
39            max_iter: 300,
40            tol: 1e-4,
41            random_state: None,
42            backend: Arc::new(StatevectorBackend::new(10)),
43            fitted: false,
44            cluster_centers_: None,
45            labels_: None,
46        }
47    }
48
49    /// Set maximum iterations
50    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
51        self.max_iter = max_iter;
52        self
53    }
54
55    /// Set tolerance
56    pub fn set_tol(mut self, tol: f64) -> Self {
57        self.tol = tol;
58        self
59    }
60
61    /// Set random state
62    pub fn set_random_state(mut self, random_state: u64) -> Self {
63        self.random_state = Some(random_state);
64        self
65    }
66}
67
68impl SklearnEstimator for QuantumKMeans {
69    #[allow(non_snake_case)]
70    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
71        let config = crate::clustering::config::QuantumClusteringConfig {
72            algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
73            n_clusters: self.n_clusters,
74            max_iterations: self.max_iter,
75            tolerance: self.tol,
76            num_qubits: 4,
77            random_state: self.random_state,
78        };
79        let mut clusterer = QuantumClusterer::new(config);
80
81        let result = clusterer.fit_predict(X)?;
82        // Convert usize to i32 for sklearn compatibility
83        let result_i32 = result.mapv(|x| x as i32);
84        self.labels_ = Some(result_i32);
85        self.cluster_centers_ = None; // TODO: Get cluster centers from clusterer
86
87        self.clusterer = Some(clusterer);
88        self.fitted = true;
89
90        Ok(())
91    }
92
93    fn get_params(&self) -> HashMap<String, String> {
94        let mut params = HashMap::new();
95        params.insert("n_clusters".to_string(), self.n_clusters.to_string());
96        params.insert("max_iter".to_string(), self.max_iter.to_string());
97        params.insert("tol".to_string(), self.tol.to_string());
98        if let Some(rs) = self.random_state {
99            params.insert("random_state".to_string(), rs.to_string());
100        }
101        params
102    }
103
104    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
105        for (key, value) in params {
106            match key.as_str() {
107                "n_clusters" => {
108                    self.n_clusters = value.parse().map_err(|_| {
109                        MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
110                    })?;
111                }
112                "max_iter" => {
113                    self.max_iter = value.parse().map_err(|_| {
114                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
115                    })?;
116                }
117                "tol" => {
118                    self.tol = value.parse().map_err(|_| {
119                        MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
120                    })?;
121                }
122                "random_state" => {
123                    self.random_state = Some(value.parse().map_err(|_| {
124                        MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
125                    })?);
126                }
127                _ => {
128                    // Skip unknown parameters
129                }
130            }
131        }
132        Ok(())
133    }
134
135    fn is_fitted(&self) -> bool {
136        self.fitted
137    }
138}
139
140impl SklearnClusterer for QuantumKMeans {
141    #[allow(non_snake_case)]
142    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
143        if !self.fitted {
144            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
145        }
146
147        let clusterer = self
148            .clusterer
149            .as_ref()
150            .ok_or_else(|| MLError::ModelNotTrained("Clusterer not initialized".to_string()))?;
151        let result = clusterer.predict(X)?;
152        // Convert usize to i32 for sklearn compatibility
153        Ok(result.mapv(|x| x as i32))
154    }
155
156    fn cluster_centers(&self) -> Option<&Array2<f64>> {
157        self.cluster_centers_.as_ref()
158    }
159}
160
161/// DBSCAN clustering algorithm
162pub struct DBSCAN {
163    /// Epsilon - neighborhood radius
164    eps: f64,
165    /// Minimum samples for core points
166    min_samples: usize,
167    /// Fitted labels
168    labels: Option<Array1<i32>>,
169    /// Core sample indices
170    core_sample_indices: Vec<usize>,
171}
172
173impl DBSCAN {
174    /// Create new DBSCAN
175    pub fn new(eps: f64, min_samples: usize) -> Self {
176        Self {
177            eps,
178            min_samples,
179            labels: None,
180            core_sample_indices: Vec::new(),
181        }
182    }
183
184    /// Set eps
185    pub fn eps(mut self, eps: f64) -> Self {
186        self.eps = eps;
187        self
188    }
189
190    /// Set min_samples
191    pub fn min_samples(mut self, min_samples: usize) -> Self {
192        self.min_samples = min_samples;
193        self
194    }
195
196    /// Get labels
197    pub fn labels(&self) -> Option<&Array1<i32>> {
198        self.labels.as_ref()
199    }
200
201    /// Get core sample indices
202    pub fn core_sample_indices(&self) -> &[usize] {
203        &self.core_sample_indices
204    }
205
206    /// Compute distance matrix
207    #[allow(non_snake_case)]
208    fn compute_distances(&self, X: &Array2<f64>) -> Array2<f64> {
209        let n = X.nrows();
210        let mut distances = Array2::zeros((n, n));
211
212        for i in 0..n {
213            for j in i + 1..n {
214                let mut dist = 0.0;
215                for k in 0..X.ncols() {
216                    let diff = X[[i, k]] - X[[j, k]];
217                    dist += diff * diff;
218                }
219                let dist = dist.sqrt();
220                distances[[i, j]] = dist;
221                distances[[j, i]] = dist;
222            }
223        }
224
225        distances
226    }
227
228    /// Get number of clusters found
229    pub fn n_clusters(&self) -> Option<usize> {
230        self.labels.as_ref().map(|labels| {
231            let max_label = labels.iter().max().copied().unwrap_or(-1);
232            if max_label >= 0 {
233                (max_label + 1) as usize
234            } else {
235                0
236            }
237        })
238    }
239
240    /// Fit the model (internal)
241    #[allow(non_snake_case)]
242    fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
243        let n = X.nrows();
244        let distances = self.compute_distances(X);
245
246        // Find neighbors for each point
247        let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
248        for i in 0..n {
249            for j in 0..n {
250                if i != j && distances[[i, j]] <= self.eps {
251                    neighbors[i].push(j);
252                }
253            }
254        }
255
256        // Identify core points
257        self.core_sample_indices.clear();
258        for (i, n_neighbors) in neighbors.iter().enumerate() {
259            if n_neighbors.len() >= self.min_samples {
260                self.core_sample_indices.push(i);
261            }
262        }
263
264        // Label points
265        let mut labels = Array1::from_elem(n, -1_i32); // -1 = noise
266        let mut visited = vec![false; n];
267        let mut cluster_id = 0_i32;
268
269        for &core_idx in &self.core_sample_indices {
270            if visited[core_idx] {
271                continue;
272            }
273
274            // BFS to expand cluster
275            let mut stack = vec![core_idx];
276            while let Some(idx) = stack.pop() {
277                if visited[idx] {
278                    continue;
279                }
280                visited[idx] = true;
281                labels[idx] = cluster_id;
282
283                // If this is a core point, expand
284                if neighbors[idx].len() >= self.min_samples {
285                    for &neighbor in &neighbors[idx] {
286                        if !visited[neighbor] {
287                            stack.push(neighbor);
288                        }
289                    }
290                }
291            }
292            cluster_id += 1;
293        }
294
295        self.labels = Some(labels);
296        Ok(())
297    }
298}
299
300impl SklearnEstimator for DBSCAN {
301    #[allow(non_snake_case)]
302    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
303        self.fit_internal(X)
304    }
305
306    fn get_params(&self) -> HashMap<String, String> {
307        let mut params = HashMap::new();
308        params.insert("eps".to_string(), self.eps.to_string());
309        params.insert("min_samples".to_string(), self.min_samples.to_string());
310        params
311    }
312
313    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
314        for (key, value) in params {
315            match key.as_str() {
316                "eps" => {
317                    self.eps = value.parse().map_err(|_| {
318                        MLError::InvalidConfiguration(format!("Invalid eps: {}", value))
319                    })?;
320                }
321                "min_samples" => {
322                    self.min_samples = value.parse().map_err(|_| {
323                        MLError::InvalidConfiguration(format!("Invalid min_samples: {}", value))
324                    })?;
325                }
326                _ => {}
327            }
328        }
329        Ok(())
330    }
331
332    fn is_fitted(&self) -> bool {
333        self.labels.is_some()
334    }
335}
336
337impl SklearnClusterer for DBSCAN {
338    #[allow(non_snake_case)]
339    fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
340        // For DBSCAN, predict returns labels from fit
341        // New points would need special handling
342        self.labels
343            .clone()
344            .ok_or_else(|| MLError::ModelNotTrained("DBSCAN not fitted".to_string()))
345    }
346}
347
348/// Agglomerative Clustering
349pub struct AgglomerativeClustering {
350    /// Number of clusters
351    n_clusters: usize,
352    /// Linkage type
353    linkage: String,
354    /// Fitted labels
355    labels: Option<Array1<i32>>,
356}
357
358impl AgglomerativeClustering {
359    /// Create new AgglomerativeClustering
360    pub fn new(n_clusters: usize) -> Self {
361        Self {
362            n_clusters,
363            linkage: "ward".to_string(),
364            labels: None,
365        }
366    }
367
368    /// Set linkage
369    pub fn linkage(mut self, linkage: &str) -> Self {
370        self.linkage = linkage.to_string();
371        self
372    }
373
374    /// Get number of clusters
375    pub fn get_n_clusters(&self) -> Option<usize> {
376        if self.labels.is_some() {
377            Some(self.n_clusters)
378        } else {
379            None
380        }
381    }
382
383    /// Fit internal
384    #[allow(non_snake_case)]
385    fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
386        let n = X.nrows();
387
388        // Compute distance matrix
389        let mut distances = Array2::from_elem((n, n), f64::INFINITY);
390        for i in 0..n {
391            for j in i + 1..n {
392                let mut dist = 0.0;
393                for k in 0..X.ncols() {
394                    let diff = X[[i, k]] - X[[j, k]];
395                    dist += diff * diff;
396                }
397                distances[[i, j]] = dist.sqrt();
398                distances[[j, i]] = distances[[i, j]];
399            }
400            distances[[i, i]] = 0.0;
401        }
402
403        // Initialize clusters: each point is its own cluster
404        let mut cluster_assignment: Vec<usize> = (0..n).collect();
405        let mut active_clusters: Vec<bool> = vec![true; n];
406        let mut cluster_sizes: Vec<usize> = vec![1; n];
407
408        // Merge clusters until we have n_clusters
409        let mut num_clusters = n;
410        while num_clusters > self.n_clusters {
411            // Find closest pair of clusters
412            let mut min_dist = f64::INFINITY;
413            let mut merge_i = 0;
414            let mut merge_j = 0;
415
416            for i in 0..n {
417                if !active_clusters[i] {
418                    continue;
419                }
420                for j in i + 1..n {
421                    if !active_clusters[j] {
422                        continue;
423                    }
424                    if distances[[i, j]] < min_dist {
425                        min_dist = distances[[i, j]];
426                        merge_i = i;
427                        merge_j = j;
428                    }
429                }
430            }
431
432            // Merge j into i
433            for k in 0..n {
434                if cluster_assignment[k] == merge_j {
435                    cluster_assignment[k] = merge_i;
436                }
437            }
438            active_clusters[merge_j] = false;
439            cluster_sizes[merge_i] += cluster_sizes[merge_j];
440
441            // Update distances (using average linkage as default)
442            for k in 0..n {
443                if k != merge_i && active_clusters[k] {
444                    let new_dist = match self.linkage.as_str() {
445                        "single" => distances[[merge_i, k]].min(distances[[merge_j, k]]),
446                        "complete" => distances[[merge_i, k]].max(distances[[merge_j, k]]),
447                        "average" | _ => {
448                            let s_i = cluster_sizes[merge_i] as f64;
449                            let s_j = cluster_sizes[merge_j] as f64;
450                            (distances[[merge_i, k]] * (s_i - cluster_sizes[merge_j] as f64)
451                                + distances[[merge_j, k]] * s_j)
452                                / s_i
453                        }
454                    };
455                    distances[[merge_i, k]] = new_dist;
456                    distances[[k, merge_i]] = new_dist;
457                }
458            }
459
460            num_clusters -= 1;
461        }
462
463        // Remap cluster labels to 0..n_clusters-1
464        let unique_clusters: Vec<usize> = cluster_assignment
465            .iter()
466            .copied()
467            .collect::<std::collections::HashSet<_>>()
468            .into_iter()
469            .collect();
470        let label_map: std::collections::HashMap<usize, i32> = unique_clusters
471            .iter()
472            .enumerate()
473            .map(|(i, &c)| (c, i as i32))
474            .collect();
475
476        let labels = cluster_assignment
477            .iter()
478            .map(|&c| *label_map.get(&c).unwrap_or(&0))
479            .collect();
480        self.labels = Some(Array1::from_vec(labels));
481
482        Ok(())
483    }
484}
485
486impl SklearnEstimator for AgglomerativeClustering {
487    #[allow(non_snake_case)]
488    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
489        self.fit_internal(X)
490    }
491
492    fn get_params(&self) -> HashMap<String, String> {
493        let mut params = HashMap::new();
494        params.insert("n_clusters".to_string(), self.n_clusters.to_string());
495        params.insert("linkage".to_string(), self.linkage.clone());
496        params
497    }
498
499    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
500        for (key, value) in params {
501            match key.as_str() {
502                "n_clusters" => {
503                    self.n_clusters = value.parse().map_err(|_| {
504                        MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
505                    })?;
506                }
507                "linkage" => {
508                    self.linkage = value;
509                }
510                _ => {}
511            }
512        }
513        Ok(())
514    }
515
516    fn is_fitted(&self) -> bool {
517        self.labels.is_some()
518    }
519}
520
521impl SklearnClusterer for AgglomerativeClustering {
522    #[allow(non_snake_case)]
523    fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
524        self.labels
525            .clone()
526            .ok_or_else(|| MLError::ModelNotTrained("Not fitted".to_string()))
527    }
528}