quantrs2_ml/sklearn_compatibility/
clustering.rs1use super::{SklearnClusterer, SklearnEstimator};
4use crate::clustering::core::QuantumClusterer;
5use crate::error::{MLError, Result};
6use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct QuantumKMeans {
13 clusterer: Option<QuantumClusterer>,
15 n_clusters: usize,
17 max_iter: usize,
19 tol: f64,
21 random_state: Option<u64>,
23 backend: Arc<dyn SimulatorBackend>,
25 fitted: bool,
27 cluster_centers_: Option<Array2<f64>>,
29 labels_: Option<Array1<i32>>,
31}
32
33impl QuantumKMeans {
34 pub fn new(n_clusters: usize) -> Self {
36 Self {
37 clusterer: None,
38 n_clusters,
39 max_iter: 300,
40 tol: 1e-4,
41 random_state: None,
42 backend: Arc::new(StatevectorBackend::new(10)),
43 fitted: false,
44 cluster_centers_: None,
45 labels_: None,
46 }
47 }
48
49 pub fn set_max_iter(mut self, max_iter: usize) -> Self {
51 self.max_iter = max_iter;
52 self
53 }
54
55 pub fn set_tol(mut self, tol: f64) -> Self {
57 self.tol = tol;
58 self
59 }
60
61 pub fn set_random_state(mut self, random_state: u64) -> Self {
63 self.random_state = Some(random_state);
64 self
65 }
66}
67
68impl SklearnEstimator for QuantumKMeans {
69 #[allow(non_snake_case)]
70 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
71 let config = crate::clustering::config::QuantumClusteringConfig {
72 algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
73 n_clusters: self.n_clusters,
74 max_iterations: self.max_iter,
75 tolerance: self.tol,
76 num_qubits: 4,
77 random_state: self.random_state,
78 };
79 let mut clusterer = QuantumClusterer::new(config);
80
81 let result = clusterer.fit_predict(X)?;
82 let result_i32 = result.mapv(|x| x as i32);
84 self.labels_ = Some(result_i32);
85
86 let n_features = X.ncols();
89 let n_clusters = self.n_clusters;
90 let mut centers = Array2::<f64>::zeros((n_clusters, n_features));
91 let mut counts = vec![0usize; n_clusters];
92 for (i, &label) in result.iter().enumerate() {
93 let k = label.min(n_clusters - 1);
94 counts[k] += 1;
95 for j in 0..n_features {
96 centers[[k, j]] += X[[i, j]];
97 }
98 }
99 for k in 0..n_clusters {
100 let count = counts[k];
101 if count > 0 {
102 for j in 0..n_features {
103 centers[[k, j]] /= count as f64;
104 }
105 }
106 }
107 self.cluster_centers_ = Some(centers);
108
109 self.clusterer = Some(clusterer);
110 self.fitted = true;
111
112 Ok(())
113 }
114
115 fn get_params(&self) -> HashMap<String, String> {
116 let mut params = HashMap::new();
117 params.insert("n_clusters".to_string(), self.n_clusters.to_string());
118 params.insert("max_iter".to_string(), self.max_iter.to_string());
119 params.insert("tol".to_string(), self.tol.to_string());
120 if let Some(rs) = self.random_state {
121 params.insert("random_state".to_string(), rs.to_string());
122 }
123 params
124 }
125
126 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
127 for (key, value) in params {
128 match key.as_str() {
129 "n_clusters" => {
130 self.n_clusters = value.parse().map_err(|_| {
131 MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
132 })?;
133 }
134 "max_iter" => {
135 self.max_iter = value.parse().map_err(|_| {
136 MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
137 })?;
138 }
139 "tol" => {
140 self.tol = value.parse().map_err(|_| {
141 MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
142 })?;
143 }
144 "random_state" => {
145 self.random_state = Some(value.parse().map_err(|_| {
146 MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
147 })?);
148 }
149 _ => {
150 }
152 }
153 }
154 Ok(())
155 }
156
157 fn is_fitted(&self) -> bool {
158 self.fitted
159 }
160}
161
162impl SklearnClusterer for QuantumKMeans {
163 #[allow(non_snake_case)]
164 fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
165 if !self.fitted {
166 return Err(MLError::ModelNotTrained("Model not trained".to_string()));
167 }
168
169 let clusterer = self
170 .clusterer
171 .as_ref()
172 .ok_or_else(|| MLError::ModelNotTrained("Clusterer not initialized".to_string()))?;
173 let result = clusterer.predict(X)?;
174 Ok(result.mapv(|x| x as i32))
176 }
177
178 fn cluster_centers(&self) -> Option<&Array2<f64>> {
179 self.cluster_centers_.as_ref()
180 }
181}
182
183pub struct DBSCAN {
185 eps: f64,
187 min_samples: usize,
189 labels: Option<Array1<i32>>,
191 core_sample_indices: Vec<usize>,
193}
194
195impl DBSCAN {
196 pub fn new(eps: f64, min_samples: usize) -> Self {
198 Self {
199 eps,
200 min_samples,
201 labels: None,
202 core_sample_indices: Vec::new(),
203 }
204 }
205
206 pub fn eps(mut self, eps: f64) -> Self {
208 self.eps = eps;
209 self
210 }
211
212 pub fn min_samples(mut self, min_samples: usize) -> Self {
214 self.min_samples = min_samples;
215 self
216 }
217
218 pub fn labels(&self) -> Option<&Array1<i32>> {
220 self.labels.as_ref()
221 }
222
223 pub fn core_sample_indices(&self) -> &[usize] {
225 &self.core_sample_indices
226 }
227
228 #[allow(non_snake_case)]
230 fn compute_distances(&self, X: &Array2<f64>) -> Array2<f64> {
231 let n = X.nrows();
232 let mut distances = Array2::zeros((n, n));
233
234 for i in 0..n {
235 for j in i + 1..n {
236 let mut dist = 0.0;
237 for k in 0..X.ncols() {
238 let diff = X[[i, k]] - X[[j, k]];
239 dist += diff * diff;
240 }
241 let dist = dist.sqrt();
242 distances[[i, j]] = dist;
243 distances[[j, i]] = dist;
244 }
245 }
246
247 distances
248 }
249
250 pub fn n_clusters(&self) -> Option<usize> {
252 self.labels.as_ref().map(|labels| {
253 let max_label = labels.iter().max().copied().unwrap_or(-1);
254 if max_label >= 0 {
255 (max_label + 1) as usize
256 } else {
257 0
258 }
259 })
260 }
261
262 #[allow(non_snake_case)]
264 fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
265 let n = X.nrows();
266 let distances = self.compute_distances(X);
267
268 let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
270 for i in 0..n {
271 for j in 0..n {
272 if i != j && distances[[i, j]] <= self.eps {
273 neighbors[i].push(j);
274 }
275 }
276 }
277
278 self.core_sample_indices.clear();
280 for (i, n_neighbors) in neighbors.iter().enumerate() {
281 if n_neighbors.len() >= self.min_samples {
282 self.core_sample_indices.push(i);
283 }
284 }
285
286 let mut labels = Array1::from_elem(n, -1_i32); let mut visited = vec![false; n];
289 let mut cluster_id = 0_i32;
290
291 for &core_idx in &self.core_sample_indices {
292 if visited[core_idx] {
293 continue;
294 }
295
296 let mut stack = vec![core_idx];
298 while let Some(idx) = stack.pop() {
299 if visited[idx] {
300 continue;
301 }
302 visited[idx] = true;
303 labels[idx] = cluster_id;
304
305 if neighbors[idx].len() >= self.min_samples {
307 for &neighbor in &neighbors[idx] {
308 if !visited[neighbor] {
309 stack.push(neighbor);
310 }
311 }
312 }
313 }
314 cluster_id += 1;
315 }
316
317 self.labels = Some(labels);
318 Ok(())
319 }
320}
321
322impl SklearnEstimator for DBSCAN {
323 #[allow(non_snake_case)]
324 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
325 self.fit_internal(X)
326 }
327
328 fn get_params(&self) -> HashMap<String, String> {
329 let mut params = HashMap::new();
330 params.insert("eps".to_string(), self.eps.to_string());
331 params.insert("min_samples".to_string(), self.min_samples.to_string());
332 params
333 }
334
335 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
336 for (key, value) in params {
337 match key.as_str() {
338 "eps" => {
339 self.eps = value.parse().map_err(|_| {
340 MLError::InvalidConfiguration(format!("Invalid eps: {}", value))
341 })?;
342 }
343 "min_samples" => {
344 self.min_samples = value.parse().map_err(|_| {
345 MLError::InvalidConfiguration(format!("Invalid min_samples: {}", value))
346 })?;
347 }
348 _ => {}
349 }
350 }
351 Ok(())
352 }
353
354 fn is_fitted(&self) -> bool {
355 self.labels.is_some()
356 }
357}
358
359impl SklearnClusterer for DBSCAN {
360 #[allow(non_snake_case)]
361 fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
362 self.labels
365 .clone()
366 .ok_or_else(|| MLError::ModelNotTrained("DBSCAN not fitted".to_string()))
367 }
368}
369
370pub struct AgglomerativeClustering {
372 n_clusters: usize,
374 linkage: String,
376 labels: Option<Array1<i32>>,
378}
379
380impl AgglomerativeClustering {
381 pub fn new(n_clusters: usize) -> Self {
383 Self {
384 n_clusters,
385 linkage: "ward".to_string(),
386 labels: None,
387 }
388 }
389
390 pub fn linkage(mut self, linkage: &str) -> Self {
392 self.linkage = linkage.to_string();
393 self
394 }
395
396 pub fn get_n_clusters(&self) -> Option<usize> {
398 if self.labels.is_some() {
399 Some(self.n_clusters)
400 } else {
401 None
402 }
403 }
404
405 #[allow(non_snake_case)]
407 fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
408 let n = X.nrows();
409
410 let mut distances = Array2::from_elem((n, n), f64::INFINITY);
412 for i in 0..n {
413 for j in i + 1..n {
414 let mut dist = 0.0;
415 for k in 0..X.ncols() {
416 let diff = X[[i, k]] - X[[j, k]];
417 dist += diff * diff;
418 }
419 distances[[i, j]] = dist.sqrt();
420 distances[[j, i]] = distances[[i, j]];
421 }
422 distances[[i, i]] = 0.0;
423 }
424
425 let mut cluster_assignment: Vec<usize> = (0..n).collect();
427 let mut active_clusters: Vec<bool> = vec![true; n];
428 let mut cluster_sizes: Vec<usize> = vec![1; n];
429
430 let mut num_clusters = n;
432 while num_clusters > self.n_clusters {
433 let mut min_dist = f64::INFINITY;
435 let mut merge_i = 0;
436 let mut merge_j = 0;
437
438 for i in 0..n {
439 if !active_clusters[i] {
440 continue;
441 }
442 for j in i + 1..n {
443 if !active_clusters[j] {
444 continue;
445 }
446 if distances[[i, j]] < min_dist {
447 min_dist = distances[[i, j]];
448 merge_i = i;
449 merge_j = j;
450 }
451 }
452 }
453
454 for k in 0..n {
456 if cluster_assignment[k] == merge_j {
457 cluster_assignment[k] = merge_i;
458 }
459 }
460 active_clusters[merge_j] = false;
461 cluster_sizes[merge_i] += cluster_sizes[merge_j];
462
463 for k in 0..n {
465 if k != merge_i && active_clusters[k] {
466 let new_dist = match self.linkage.as_str() {
467 "single" => distances[[merge_i, k]].min(distances[[merge_j, k]]),
468 "complete" => distances[[merge_i, k]].max(distances[[merge_j, k]]),
469 "average" | _ => {
470 let s_i = cluster_sizes[merge_i] as f64;
471 let s_j = cluster_sizes[merge_j] as f64;
472 (distances[[merge_i, k]] * (s_i - cluster_sizes[merge_j] as f64)
473 + distances[[merge_j, k]] * s_j)
474 / s_i
475 }
476 };
477 distances[[merge_i, k]] = new_dist;
478 distances[[k, merge_i]] = new_dist;
479 }
480 }
481
482 num_clusters -= 1;
483 }
484
485 let unique_clusters: Vec<usize> = cluster_assignment
487 .iter()
488 .copied()
489 .collect::<std::collections::HashSet<_>>()
490 .into_iter()
491 .collect();
492 let label_map: std::collections::HashMap<usize, i32> = unique_clusters
493 .iter()
494 .enumerate()
495 .map(|(i, &c)| (c, i as i32))
496 .collect();
497
498 let labels = cluster_assignment
499 .iter()
500 .map(|&c| *label_map.get(&c).unwrap_or(&0))
501 .collect();
502 self.labels = Some(Array1::from_vec(labels));
503
504 Ok(())
505 }
506}
507
508impl SklearnEstimator for AgglomerativeClustering {
509 #[allow(non_snake_case)]
510 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
511 self.fit_internal(X)
512 }
513
514 fn get_params(&self) -> HashMap<String, String> {
515 let mut params = HashMap::new();
516 params.insert("n_clusters".to_string(), self.n_clusters.to_string());
517 params.insert("linkage".to_string(), self.linkage.clone());
518 params
519 }
520
521 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
522 for (key, value) in params {
523 match key.as_str() {
524 "n_clusters" => {
525 self.n_clusters = value.parse().map_err(|_| {
526 MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
527 })?;
528 }
529 "linkage" => {
530 self.linkage = value;
531 }
532 _ => {}
533 }
534 }
535 Ok(())
536 }
537
538 fn is_fitted(&self) -> bool {
539 self.labels.is_some()
540 }
541}
542
543impl SklearnClusterer for AgglomerativeClustering {
544 #[allow(non_snake_case)]
545 fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
546 self.labels
547 .clone()
548 .ok_or_else(|| MLError::ModelNotTrained("Not fitted".to_string()))
549 }
550}