quantrs2_ml/sklearn_compatibility/
clustering.rs1use super::{SklearnClusterer, SklearnEstimator};
4use crate::clustering::core::QuantumClusterer;
5use crate::error::{MLError, Result};
6use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct QuantumKMeans {
13 clusterer: Option<QuantumClusterer>,
15 n_clusters: usize,
17 max_iter: usize,
19 tol: f64,
21 random_state: Option<u64>,
23 backend: Arc<dyn SimulatorBackend>,
25 fitted: bool,
27 cluster_centers_: Option<Array2<f64>>,
29 labels_: Option<Array1<i32>>,
31}
32
33impl QuantumKMeans {
34 pub fn new(n_clusters: usize) -> Self {
36 Self {
37 clusterer: None,
38 n_clusters,
39 max_iter: 300,
40 tol: 1e-4,
41 random_state: None,
42 backend: Arc::new(StatevectorBackend::new(10)),
43 fitted: false,
44 cluster_centers_: None,
45 labels_: None,
46 }
47 }
48
49 pub fn set_max_iter(mut self, max_iter: usize) -> Self {
51 self.max_iter = max_iter;
52 self
53 }
54
55 pub fn set_tol(mut self, tol: f64) -> Self {
57 self.tol = tol;
58 self
59 }
60
61 pub fn set_random_state(mut self, random_state: u64) -> Self {
63 self.random_state = Some(random_state);
64 self
65 }
66}
67
68impl SklearnEstimator for QuantumKMeans {
69 #[allow(non_snake_case)]
70 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
71 let config = crate::clustering::config::QuantumClusteringConfig {
72 algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
73 n_clusters: self.n_clusters,
74 max_iterations: self.max_iter,
75 tolerance: self.tol,
76 num_qubits: 4,
77 random_state: self.random_state,
78 };
79 let mut clusterer = QuantumClusterer::new(config);
80
81 let result = clusterer.fit_predict(X)?;
82 let result_i32 = result.mapv(|x| x as i32);
84 self.labels_ = Some(result_i32);
85 self.cluster_centers_ = None; self.clusterer = Some(clusterer);
88 self.fitted = true;
89
90 Ok(())
91 }
92
93 fn get_params(&self) -> HashMap<String, String> {
94 let mut params = HashMap::new();
95 params.insert("n_clusters".to_string(), self.n_clusters.to_string());
96 params.insert("max_iter".to_string(), self.max_iter.to_string());
97 params.insert("tol".to_string(), self.tol.to_string());
98 if let Some(rs) = self.random_state {
99 params.insert("random_state".to_string(), rs.to_string());
100 }
101 params
102 }
103
104 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
105 for (key, value) in params {
106 match key.as_str() {
107 "n_clusters" => {
108 self.n_clusters = value.parse().map_err(|_| {
109 MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
110 })?;
111 }
112 "max_iter" => {
113 self.max_iter = value.parse().map_err(|_| {
114 MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
115 })?;
116 }
117 "tol" => {
118 self.tol = value.parse().map_err(|_| {
119 MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
120 })?;
121 }
122 "random_state" => {
123 self.random_state = Some(value.parse().map_err(|_| {
124 MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
125 })?);
126 }
127 _ => {
128 }
130 }
131 }
132 Ok(())
133 }
134
135 fn is_fitted(&self) -> bool {
136 self.fitted
137 }
138}
139
140impl SklearnClusterer for QuantumKMeans {
141 #[allow(non_snake_case)]
142 fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
143 if !self.fitted {
144 return Err(MLError::ModelNotTrained("Model not trained".to_string()));
145 }
146
147 let clusterer = self
148 .clusterer
149 .as_ref()
150 .ok_or_else(|| MLError::ModelNotTrained("Clusterer not initialized".to_string()))?;
151 let result = clusterer.predict(X)?;
152 Ok(result.mapv(|x| x as i32))
154 }
155
156 fn cluster_centers(&self) -> Option<&Array2<f64>> {
157 self.cluster_centers_.as_ref()
158 }
159}
160
161pub struct DBSCAN {
163 eps: f64,
165 min_samples: usize,
167 labels: Option<Array1<i32>>,
169 core_sample_indices: Vec<usize>,
171}
172
173impl DBSCAN {
174 pub fn new(eps: f64, min_samples: usize) -> Self {
176 Self {
177 eps,
178 min_samples,
179 labels: None,
180 core_sample_indices: Vec::new(),
181 }
182 }
183
184 pub fn eps(mut self, eps: f64) -> Self {
186 self.eps = eps;
187 self
188 }
189
190 pub fn min_samples(mut self, min_samples: usize) -> Self {
192 self.min_samples = min_samples;
193 self
194 }
195
196 pub fn labels(&self) -> Option<&Array1<i32>> {
198 self.labels.as_ref()
199 }
200
201 pub fn core_sample_indices(&self) -> &[usize] {
203 &self.core_sample_indices
204 }
205
206 #[allow(non_snake_case)]
208 fn compute_distances(&self, X: &Array2<f64>) -> Array2<f64> {
209 let n = X.nrows();
210 let mut distances = Array2::zeros((n, n));
211
212 for i in 0..n {
213 for j in i + 1..n {
214 let mut dist = 0.0;
215 for k in 0..X.ncols() {
216 let diff = X[[i, k]] - X[[j, k]];
217 dist += diff * diff;
218 }
219 let dist = dist.sqrt();
220 distances[[i, j]] = dist;
221 distances[[j, i]] = dist;
222 }
223 }
224
225 distances
226 }
227
228 pub fn n_clusters(&self) -> Option<usize> {
230 self.labels.as_ref().map(|labels| {
231 let max_label = labels.iter().max().copied().unwrap_or(-1);
232 if max_label >= 0 {
233 (max_label + 1) as usize
234 } else {
235 0
236 }
237 })
238 }
239
240 #[allow(non_snake_case)]
242 fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
243 let n = X.nrows();
244 let distances = self.compute_distances(X);
245
246 let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
248 for i in 0..n {
249 for j in 0..n {
250 if i != j && distances[[i, j]] <= self.eps {
251 neighbors[i].push(j);
252 }
253 }
254 }
255
256 self.core_sample_indices.clear();
258 for (i, n_neighbors) in neighbors.iter().enumerate() {
259 if n_neighbors.len() >= self.min_samples {
260 self.core_sample_indices.push(i);
261 }
262 }
263
264 let mut labels = Array1::from_elem(n, -1_i32); let mut visited = vec![false; n];
267 let mut cluster_id = 0_i32;
268
269 for &core_idx in &self.core_sample_indices {
270 if visited[core_idx] {
271 continue;
272 }
273
274 let mut stack = vec![core_idx];
276 while let Some(idx) = stack.pop() {
277 if visited[idx] {
278 continue;
279 }
280 visited[idx] = true;
281 labels[idx] = cluster_id;
282
283 if neighbors[idx].len() >= self.min_samples {
285 for &neighbor in &neighbors[idx] {
286 if !visited[neighbor] {
287 stack.push(neighbor);
288 }
289 }
290 }
291 }
292 cluster_id += 1;
293 }
294
295 self.labels = Some(labels);
296 Ok(())
297 }
298}
299
300impl SklearnEstimator for DBSCAN {
301 #[allow(non_snake_case)]
302 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
303 self.fit_internal(X)
304 }
305
306 fn get_params(&self) -> HashMap<String, String> {
307 let mut params = HashMap::new();
308 params.insert("eps".to_string(), self.eps.to_string());
309 params.insert("min_samples".to_string(), self.min_samples.to_string());
310 params
311 }
312
313 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
314 for (key, value) in params {
315 match key.as_str() {
316 "eps" => {
317 self.eps = value.parse().map_err(|_| {
318 MLError::InvalidConfiguration(format!("Invalid eps: {}", value))
319 })?;
320 }
321 "min_samples" => {
322 self.min_samples = value.parse().map_err(|_| {
323 MLError::InvalidConfiguration(format!("Invalid min_samples: {}", value))
324 })?;
325 }
326 _ => {}
327 }
328 }
329 Ok(())
330 }
331
332 fn is_fitted(&self) -> bool {
333 self.labels.is_some()
334 }
335}
336
337impl SklearnClusterer for DBSCAN {
338 #[allow(non_snake_case)]
339 fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
340 self.labels
343 .clone()
344 .ok_or_else(|| MLError::ModelNotTrained("DBSCAN not fitted".to_string()))
345 }
346}
347
348pub struct AgglomerativeClustering {
350 n_clusters: usize,
352 linkage: String,
354 labels: Option<Array1<i32>>,
356}
357
358impl AgglomerativeClustering {
359 pub fn new(n_clusters: usize) -> Self {
361 Self {
362 n_clusters,
363 linkage: "ward".to_string(),
364 labels: None,
365 }
366 }
367
368 pub fn linkage(mut self, linkage: &str) -> Self {
370 self.linkage = linkage.to_string();
371 self
372 }
373
374 pub fn get_n_clusters(&self) -> Option<usize> {
376 if self.labels.is_some() {
377 Some(self.n_clusters)
378 } else {
379 None
380 }
381 }
382
383 #[allow(non_snake_case)]
385 fn fit_internal(&mut self, X: &Array2<f64>) -> Result<()> {
386 let n = X.nrows();
387
388 let mut distances = Array2::from_elem((n, n), f64::INFINITY);
390 for i in 0..n {
391 for j in i + 1..n {
392 let mut dist = 0.0;
393 for k in 0..X.ncols() {
394 let diff = X[[i, k]] - X[[j, k]];
395 dist += diff * diff;
396 }
397 distances[[i, j]] = dist.sqrt();
398 distances[[j, i]] = distances[[i, j]];
399 }
400 distances[[i, i]] = 0.0;
401 }
402
403 let mut cluster_assignment: Vec<usize> = (0..n).collect();
405 let mut active_clusters: Vec<bool> = vec![true; n];
406 let mut cluster_sizes: Vec<usize> = vec![1; n];
407
408 let mut num_clusters = n;
410 while num_clusters > self.n_clusters {
411 let mut min_dist = f64::INFINITY;
413 let mut merge_i = 0;
414 let mut merge_j = 0;
415
416 for i in 0..n {
417 if !active_clusters[i] {
418 continue;
419 }
420 for j in i + 1..n {
421 if !active_clusters[j] {
422 continue;
423 }
424 if distances[[i, j]] < min_dist {
425 min_dist = distances[[i, j]];
426 merge_i = i;
427 merge_j = j;
428 }
429 }
430 }
431
432 for k in 0..n {
434 if cluster_assignment[k] == merge_j {
435 cluster_assignment[k] = merge_i;
436 }
437 }
438 active_clusters[merge_j] = false;
439 cluster_sizes[merge_i] += cluster_sizes[merge_j];
440
441 for k in 0..n {
443 if k != merge_i && active_clusters[k] {
444 let new_dist = match self.linkage.as_str() {
445 "single" => distances[[merge_i, k]].min(distances[[merge_j, k]]),
446 "complete" => distances[[merge_i, k]].max(distances[[merge_j, k]]),
447 "average" | _ => {
448 let s_i = cluster_sizes[merge_i] as f64;
449 let s_j = cluster_sizes[merge_j] as f64;
450 (distances[[merge_i, k]] * (s_i - cluster_sizes[merge_j] as f64)
451 + distances[[merge_j, k]] * s_j)
452 / s_i
453 }
454 };
455 distances[[merge_i, k]] = new_dist;
456 distances[[k, merge_i]] = new_dist;
457 }
458 }
459
460 num_clusters -= 1;
461 }
462
463 let unique_clusters: Vec<usize> = cluster_assignment
465 .iter()
466 .copied()
467 .collect::<std::collections::HashSet<_>>()
468 .into_iter()
469 .collect();
470 let label_map: std::collections::HashMap<usize, i32> = unique_clusters
471 .iter()
472 .enumerate()
473 .map(|(i, &c)| (c, i as i32))
474 .collect();
475
476 let labels = cluster_assignment
477 .iter()
478 .map(|&c| *label_map.get(&c).unwrap_or(&0))
479 .collect();
480 self.labels = Some(Array1::from_vec(labels));
481
482 Ok(())
483 }
484}
485
486impl SklearnEstimator for AgglomerativeClustering {
487 #[allow(non_snake_case)]
488 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
489 self.fit_internal(X)
490 }
491
492 fn get_params(&self) -> HashMap<String, String> {
493 let mut params = HashMap::new();
494 params.insert("n_clusters".to_string(), self.n_clusters.to_string());
495 params.insert("linkage".to_string(), self.linkage.clone());
496 params
497 }
498
499 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
500 for (key, value) in params {
501 match key.as_str() {
502 "n_clusters" => {
503 self.n_clusters = value.parse().map_err(|_| {
504 MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
505 })?;
506 }
507 "linkage" => {
508 self.linkage = value;
509 }
510 _ => {}
511 }
512 }
513 Ok(())
514 }
515
516 fn is_fitted(&self) -> bool {
517 self.labels.is_some()
518 }
519}
520
521impl SklearnClusterer for AgglomerativeClustering {
522 #[allow(non_snake_case)]
523 fn predict(&self, _X: &Array2<f64>) -> Result<Array1<i32>> {
524 self.labels
525 .clone()
526 .ok_or_else(|| MLError::ModelNotTrained("Not fitted".to_string()))
527 }
528}