quantrs2_ml/
sklearn_compatibility.rs

1//! Scikit-learn compatibility layer for QuantRS2-ML
2//!
3//! This module provides a compatibility layer that mimics scikit-learn APIs,
4//! allowing easy integration of quantum ML models with existing scikit-learn
5//! workflows and pipelines.
6
7use crate::classification::{ClassificationMetrics, Classifier};
8use crate::clustering::{ClusteringAlgorithm, QuantumClusterer};
9use crate::error::{MLError, Result};
10use crate::qnn::{QNNBuilder, QuantumNeuralNetwork};
11use crate::qsvm::{FeatureMapType, QSVMParams, QSVM};
12use crate::simulator_backends::{
13    Backend, BackendCapabilities, SimulatorBackend, StatevectorBackend,
14};
15use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis};
16use scirs2_core::SliceRandomExt;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20/// Base estimator trait following scikit-learn conventions
21pub trait SklearnEstimator: Send + Sync {
22    /// Fit the model to training data
23    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()>;
24
25    /// Get model parameters
26    fn get_params(&self) -> HashMap<String, String>;
27
28    /// Set model parameters
29    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()>;
30
31    /// Check if model is fitted
32    fn is_fitted(&self) -> bool;
33
34    /// Get feature names
35    fn get_feature_names_out(&self) -> Vec<String> {
36        vec![]
37    }
38}
39
40/// Classifier mixin trait
41pub trait SklearnClassifier: SklearnEstimator {
42    /// Predict class labels
43    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
44
45    /// Predict class probabilities
46    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
47
48    /// Get unique class labels
49    fn classes(&self) -> &[i32];
50
51    /// Score the model (accuracy by default)
52    fn score(&self, X: &Array2<f64>, y: &Array1<i32>) -> Result<f64> {
53        let predictions = self.predict(X)?;
54        let correct = predictions
55            .iter()
56            .zip(y.iter())
57            .filter(|(&pred, &true_label)| pred == true_label)
58            .count();
59        Ok(correct as f64 / y.len() as f64)
60    }
61
62    /// Get feature importances (optional)
63    fn feature_importances(&self) -> Option<Array1<f64>> {
64        None
65    }
66
67    /// Save model to file (optional)
68    fn save(&self, _path: &str) -> Result<()> {
69        Ok(())
70    }
71}
72
73/// Regressor mixin trait
74pub trait SklearnRegressor: SklearnEstimator {
75    /// Predict continuous values
76    fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>>;
77
78    /// Score the model (R² by default)
79    fn score(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
80        let predictions = self.predict(X)?;
81        let y_mean = y.mean().unwrap();
82
83        let ss_res: f64 = y
84            .iter()
85            .zip(predictions.iter())
86            .map(|(&true_val, &pred)| (true_val - pred).powi(2))
87            .sum();
88
89        let ss_tot: f64 = y.iter().map(|&val| (val - y_mean).powi(2)).sum();
90
91        Ok(1.0 - ss_res / ss_tot)
92    }
93}
94
95/// Extension trait for fitting with Array1<f64> directly
96pub trait SklearnFit {
97    fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()>;
98}
99
100/// Clusterer mixin trait
101pub trait SklearnClusterer: SklearnEstimator {
102    /// Predict cluster labels
103    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
104
105    /// Fit and predict in one step
106    fn fit_predict(&mut self, X: &Array2<f64>) -> Result<Array1<i32>> {
107        self.fit(X, None)?;
108        self.predict(X)
109    }
110
111    /// Get cluster centers (if applicable)
112    fn cluster_centers(&self) -> Option<&Array2<f64>> {
113        None
114    }
115}
116
117/// Quantum Support Vector Machine (sklearn-compatible)
118pub struct QuantumSVC {
119    /// Internal QSVM
120    qsvm: Option<QSVM>,
121    /// SVM parameters
122    params: QSVMParams,
123    /// Feature map type
124    feature_map: FeatureMapType,
125    /// Backend
126    backend: Arc<dyn SimulatorBackend>,
127    /// Fitted flag
128    fitted: bool,
129    /// Unique classes
130    classes: Vec<i32>,
131    /// Regularization parameter
132    C: f64,
133    /// Kernel gamma parameter
134    gamma: f64,
135}
136
137impl Clone for QuantumSVC {
138    fn clone(&self) -> Self {
139        Self {
140            qsvm: None, // Reset QSVM since it's not cloneable
141            params: self.params.clone(),
142            feature_map: self.feature_map,
143            backend: self.backend.clone(),
144            fitted: false, // Reset fitted status
145            classes: self.classes.clone(),
146            C: self.C,
147            gamma: self.gamma,
148        }
149    }
150}
151
152impl QuantumSVC {
153    /// Create new Quantum SVC
154    pub fn new() -> Self {
155        Self {
156            qsvm: None,
157            params: QSVMParams::default(),
158            feature_map: FeatureMapType::ZZFeatureMap,
159            backend: Arc::new(StatevectorBackend::new(10)),
160            fitted: false,
161            classes: Vec::new(),
162            C: 1.0,
163            gamma: 1.0,
164        }
165    }
166
167    /// Set regularization parameter
168    pub fn set_C(mut self, C: f64) -> Self {
169        self.C = C;
170        self
171    }
172
173    /// Set kernel gamma parameter
174    pub fn set_gamma(mut self, gamma: f64) -> Self {
175        self.gamma = gamma;
176        self
177    }
178
179    /// Set feature map
180    pub fn set_kernel(mut self, feature_map: FeatureMapType) -> Self {
181        self.feature_map = feature_map;
182        self
183    }
184
185    /// Set quantum backend
186    pub fn set_backend(mut self, backend: Arc<dyn SimulatorBackend>) -> Self {
187        self.backend = backend;
188        self
189    }
190
191    /// Load model from file (mock implementation)
192    pub fn load(_path: &str) -> Result<Self> {
193        Ok(Self::new())
194    }
195}
196
197impl SklearnEstimator for QuantumSVC {
198    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
199        let y = y.ok_or_else(|| {
200            MLError::InvalidConfiguration("Labels required for supervised learning".to_string())
201        })?;
202
203        // Convert continuous labels to integer classes
204        let y_int: Array1<i32> = y.mapv(|val| val.round() as i32);
205
206        // Find unique classes
207        let mut classes = Vec::new();
208        for &label in y_int.iter() {
209            if !classes.contains(&label) {
210                classes.push(label);
211            }
212        }
213        classes.sort();
214        self.classes = classes;
215
216        // Update QSVM parameters
217        self.params.feature_map = self.feature_map;
218        self.params.regularization = self.C;
219
220        // Create and train QSVM
221        let mut qsvm = QSVM::new(self.params.clone());
222        qsvm.fit(X, &y_int)?;
223
224        self.qsvm = Some(qsvm);
225        self.fitted = true;
226
227        Ok(())
228    }
229
230    fn get_params(&self) -> HashMap<String, String> {
231        let mut params = HashMap::new();
232        params.insert("C".to_string(), self.C.to_string());
233        params.insert("gamma".to_string(), self.gamma.to_string());
234        params.insert("kernel".to_string(), format!("{:?}", self.feature_map));
235        params
236    }
237
238    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
239        for (key, value) in params {
240            match key.as_str() {
241                "C" => {
242                    self.C = value.parse().map_err(|_| {
243                        MLError::InvalidConfiguration(format!("Invalid C parameter: {}", value))
244                    })?;
245                }
246                "gamma" => {
247                    self.gamma = value.parse().map_err(|_| {
248                        MLError::InvalidConfiguration(format!("Invalid gamma parameter: {}", value))
249                    })?;
250                }
251                "kernel" => {
252                    self.feature_map = match value.as_str() {
253                        "ZZFeatureMap" => FeatureMapType::ZZFeatureMap,
254                        "ZFeatureMap" => FeatureMapType::ZFeatureMap,
255                        "PauliFeatureMap" => FeatureMapType::PauliFeatureMap,
256                        _ => {
257                            return Err(MLError::InvalidConfiguration(format!(
258                                "Unknown kernel: {}",
259                                value
260                            )))
261                        }
262                    };
263                }
264                _ => {
265                    return Err(MLError::InvalidConfiguration(format!(
266                        "Unknown parameter: {}",
267                        key
268                    )))
269                }
270            }
271        }
272        Ok(())
273    }
274
275    fn is_fitted(&self) -> bool {
276        self.fitted
277    }
278}
279
280impl SklearnClassifier for QuantumSVC {
281    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
282        if !self.fitted {
283            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
284        }
285
286        let qsvm = self.qsvm.as_ref().unwrap();
287        qsvm.predict(X).map_err(|e| MLError::ValidationError(e))
288    }
289
290    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
291        if !self.fitted {
292            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
293        }
294
295        let predictions = self.predict(X)?;
296        let n_samples = X.nrows();
297        let n_classes = self.classes.len();
298
299        let mut probabilities = Array2::zeros((n_samples, n_classes));
300
301        // Convert hard predictions to probabilities (placeholder)
302        for (i, &prediction) in predictions.iter().enumerate() {
303            for (j, &class) in self.classes.iter().enumerate() {
304                probabilities[[i, j]] = if prediction == class { 1.0 } else { 0.0 };
305            }
306        }
307
308        Ok(probabilities)
309    }
310
311    fn classes(&self) -> &[i32] {
312        &self.classes
313    }
314}
315
316/// Quantum Neural Network Classifier (sklearn-compatible)
317pub struct QuantumMLPClassifier {
318    /// Internal QNN
319    qnn: Option<QuantumNeuralNetwork>,
320    /// Network configuration
321    hidden_layer_sizes: Vec<usize>,
322    /// Activation function
323    activation: String,
324    /// Solver
325    solver: String,
326    /// Learning rate
327    learning_rate: f64,
328    /// Maximum iterations
329    max_iter: usize,
330    /// Random state
331    random_state: Option<u64>,
332    /// Backend
333    backend: Arc<dyn SimulatorBackend>,
334    /// Fitted flag
335    fitted: bool,
336    /// Unique classes
337    classes: Vec<i32>,
338}
339
340impl QuantumMLPClassifier {
341    /// Create new Quantum MLP Classifier
342    pub fn new() -> Self {
343        Self {
344            qnn: None,
345            hidden_layer_sizes: vec![10],
346            activation: "relu".to_string(),
347            solver: "adam".to_string(),
348            learning_rate: 0.001,
349            max_iter: 200,
350            random_state: None,
351            backend: Arc::new(StatevectorBackend::new(10)),
352            fitted: false,
353            classes: Vec::new(),
354        }
355    }
356
357    /// Set hidden layer sizes
358    pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
359        self.hidden_layer_sizes = sizes;
360        self
361    }
362
363    /// Set activation function
364    pub fn set_activation(mut self, activation: String) -> Self {
365        self.activation = activation;
366        self
367    }
368
369    /// Set learning rate
370    pub fn set_learning_rate(mut self, lr: f64) -> Self {
371        self.learning_rate = lr;
372        self
373    }
374
375    /// Set maximum iterations
376    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
377        self.max_iter = max_iter;
378        self
379    }
380}
381
382impl SklearnEstimator for QuantumMLPClassifier {
383    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
384        let y = y.ok_or_else(|| {
385            MLError::InvalidConfiguration("Labels required for supervised learning".to_string())
386        })?;
387
388        // Convert continuous labels to integer classes
389        let y_int: Array1<i32> = y.mapv(|val| val.round() as i32);
390
391        // Find unique classes
392        let mut classes = Vec::new();
393        for &label in y_int.iter() {
394            if !classes.contains(&label) {
395                classes.push(label);
396            }
397        }
398        classes.sort();
399        self.classes = classes;
400
401        // Build QNN
402        let input_size = X.ncols();
403        let output_size = self.classes.len();
404
405        let mut builder = QNNBuilder::new();
406
407        // Add hidden layers
408        for &size in &self.hidden_layer_sizes {
409            builder = builder.add_layer(size);
410        }
411
412        // Add output layer
413        builder = builder.add_layer(output_size);
414
415        let mut qnn = builder.build()?;
416
417        // Train QNN
418        let y_one_hot = self.to_one_hot(&y_int)?;
419        qnn.train(X, &y_one_hot, self.max_iter, self.learning_rate)?;
420
421        self.qnn = Some(qnn);
422        self.fitted = true;
423
424        Ok(())
425    }
426
427    fn get_params(&self) -> HashMap<String, String> {
428        let mut params = HashMap::new();
429        params.insert(
430            "hidden_layer_sizes".to_string(),
431            format!("{:?}", self.hidden_layer_sizes),
432        );
433        params.insert("activation".to_string(), self.activation.clone());
434        params.insert("solver".to_string(), self.solver.clone());
435        params.insert("learning_rate".to_string(), self.learning_rate.to_string());
436        params.insert("max_iter".to_string(), self.max_iter.to_string());
437        params
438    }
439
440    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
441        for (key, value) in params {
442            match key.as_str() {
443                "learning_rate" => {
444                    self.learning_rate = value.parse().map_err(|_| {
445                        MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
446                    })?;
447                }
448                "max_iter" => {
449                    self.max_iter = value.parse().map_err(|_| {
450                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
451                    })?;
452                }
453                "activation" => {
454                    self.activation = value;
455                }
456                "solver" => {
457                    self.solver = value;
458                }
459                _ => {
460                    // Skip unknown parameters
461                }
462            }
463        }
464        Ok(())
465    }
466
467    fn is_fitted(&self) -> bool {
468        self.fitted
469    }
470}
471
472impl QuantumMLPClassifier {
473    /// Convert integer labels to one-hot encoding
474    fn to_one_hot(&self, y: &Array1<i32>) -> Result<Array2<f64>> {
475        let n_samples = y.len();
476        let n_classes = self.classes.len();
477        let mut one_hot = Array2::zeros((n_samples, n_classes));
478
479        for (i, &label) in y.iter().enumerate() {
480            if let Some(class_idx) = self.classes.iter().position(|&c| c == label) {
481                one_hot[[i, class_idx]] = 1.0;
482            }
483        }
484
485        Ok(one_hot)
486    }
487
488    /// Convert one-hot predictions to class labels
489    fn from_one_hot(&self, predictions: &Array2<f64>) -> Array1<i32> {
490        predictions
491            .axis_iter(Axis(0))
492            .map(|row| {
493                let max_idx = row
494                    .iter()
495                    .enumerate()
496                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
497                    .map(|(idx, _)| idx)
498                    .unwrap_or(0);
499                self.classes[max_idx]
500            })
501            .collect()
502    }
503}
504
505impl SklearnClassifier for QuantumMLPClassifier {
506    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
507        if !self.fitted {
508            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
509        }
510
511        let qnn = self.qnn.as_ref().unwrap();
512        let predictions = qnn.predict_batch(X)?;
513        Ok(self.from_one_hot(&predictions))
514    }
515
516    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
517        if !self.fitted {
518            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
519        }
520
521        let qnn = self.qnn.as_ref().unwrap();
522        qnn.predict_batch(X)
523    }
524
525    fn classes(&self) -> &[i32] {
526        &self.classes
527    }
528}
529
530/// Quantum Regressor (sklearn-compatible)
531pub struct QuantumMLPRegressor {
532    /// Internal QNN
533    qnn: Option<QuantumNeuralNetwork>,
534    /// Network configuration
535    hidden_layer_sizes: Vec<usize>,
536    /// Activation function
537    activation: String,
538    /// Solver
539    solver: String,
540    /// Learning rate
541    learning_rate: f64,
542    /// Maximum iterations
543    max_iter: usize,
544    /// Random state
545    random_state: Option<u64>,
546    /// Backend
547    backend: Arc<dyn SimulatorBackend>,
548    /// Fitted flag
549    fitted: bool,
550}
551
552impl QuantumMLPRegressor {
553    /// Create new Quantum MLP Regressor
554    pub fn new() -> Self {
555        Self {
556            qnn: None,
557            hidden_layer_sizes: vec![10],
558            activation: "relu".to_string(),
559            solver: "adam".to_string(),
560            learning_rate: 0.001,
561            max_iter: 200,
562            random_state: None,
563            backend: Arc::new(StatevectorBackend::new(10)),
564            fitted: false,
565        }
566    }
567
568    /// Set hidden layer sizes
569    pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
570        self.hidden_layer_sizes = sizes;
571        self
572    }
573
574    /// Set learning rate
575    pub fn set_learning_rate(mut self, lr: f64) -> Self {
576        self.learning_rate = lr;
577        self
578    }
579
580    /// Set maximum iterations
581    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
582        self.max_iter = max_iter;
583        self
584    }
585}
586
587impl SklearnEstimator for QuantumMLPRegressor {
588    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
589        let y = y.ok_or_else(|| {
590            MLError::InvalidConfiguration("Target values required for regression".to_string())
591        })?;
592
593        // Build QNN for regression
594        let input_size = X.ncols();
595        let output_size = 1; // Single output for regression
596
597        let mut builder = QNNBuilder::new();
598
599        // Add hidden layers
600        for &size in &self.hidden_layer_sizes {
601            builder = builder.add_layer(size);
602        }
603
604        // Add output layer
605        builder = builder.add_layer(output_size);
606
607        let mut qnn = builder.build()?;
608
609        // Reshape target for training
610        let y_reshaped = y.clone().into_shape((y.len(), 1)).unwrap();
611
612        // Train QNN
613        qnn.train(X, &y_reshaped, self.max_iter, self.learning_rate)?;
614
615        self.qnn = Some(qnn);
616        self.fitted = true;
617
618        Ok(())
619    }
620
621    fn get_params(&self) -> HashMap<String, String> {
622        let mut params = HashMap::new();
623        params.insert(
624            "hidden_layer_sizes".to_string(),
625            format!("{:?}", self.hidden_layer_sizes),
626        );
627        params.insert("activation".to_string(), self.activation.clone());
628        params.insert("solver".to_string(), self.solver.clone());
629        params.insert("learning_rate".to_string(), self.learning_rate.to_string());
630        params.insert("max_iter".to_string(), self.max_iter.to_string());
631        params
632    }
633
634    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
635        for (key, value) in params {
636            match key.as_str() {
637                "learning_rate" => {
638                    self.learning_rate = value.parse().map_err(|_| {
639                        MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
640                    })?;
641                }
642                "max_iter" => {
643                    self.max_iter = value.parse().map_err(|_| {
644                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
645                    })?;
646                }
647                "activation" => {
648                    self.activation = value;
649                }
650                "solver" => {
651                    self.solver = value;
652                }
653                _ => {
654                    // Skip unknown parameters
655                }
656            }
657        }
658        Ok(())
659    }
660
661    fn is_fitted(&self) -> bool {
662        self.fitted
663    }
664}
665
666impl SklearnRegressor for QuantumMLPRegressor {
667    fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>> {
668        if !self.fitted {
669            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
670        }
671
672        let qnn = self.qnn.as_ref().unwrap();
673        let predictions = qnn.predict_batch(X)?;
674
675        // Extract single column for regression
676        Ok(predictions.column(0).to_owned())
677    }
678}
679
680/// Quantum K-Means (sklearn-compatible)
681pub struct QuantumKMeans {
682    /// Internal clusterer
683    clusterer: Option<QuantumClusterer>,
684    /// Number of clusters
685    n_clusters: usize,
686    /// Maximum iterations
687    max_iter: usize,
688    /// Tolerance
689    tol: f64,
690    /// Random state
691    random_state: Option<u64>,
692    /// Backend
693    backend: Arc<dyn SimulatorBackend>,
694    /// Fitted flag
695    fitted: bool,
696    /// Cluster centers
697    cluster_centers_: Option<Array2<f64>>,
698    /// Labels
699    labels_: Option<Array1<i32>>,
700}
701
702impl QuantumKMeans {
703    /// Create new Quantum K-Means
704    pub fn new(n_clusters: usize) -> Self {
705        Self {
706            clusterer: None,
707            n_clusters,
708            max_iter: 300,
709            tol: 1e-4,
710            random_state: None,
711            backend: Arc::new(StatevectorBackend::new(10)),
712            fitted: false,
713            cluster_centers_: None,
714            labels_: None,
715        }
716    }
717
718    /// Set maximum iterations
719    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
720        self.max_iter = max_iter;
721        self
722    }
723
724    /// Set tolerance
725    pub fn set_tol(mut self, tol: f64) -> Self {
726        self.tol = tol;
727        self
728    }
729
730    /// Set random state
731    pub fn set_random_state(mut self, random_state: u64) -> Self {
732        self.random_state = Some(random_state);
733        self
734    }
735}
736
737impl SklearnEstimator for QuantumKMeans {
738    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
739        let config = crate::clustering::config::QuantumClusteringConfig {
740            algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
741            n_clusters: self.n_clusters,
742            max_iterations: self.max_iter,
743            tolerance: self.tol,
744            num_qubits: 4,
745            random_state: self.random_state,
746        };
747        let mut clusterer = QuantumClusterer::new(config);
748
749        let result = clusterer.fit_predict(X)?;
750        // Convert usize to i32 for sklearn compatibility
751        let result_i32 = result.mapv(|x| x as i32);
752        self.labels_ = Some(result_i32);
753        self.cluster_centers_ = None; // TODO: Get cluster centers from clusterer
754
755        self.clusterer = Some(clusterer);
756        self.fitted = true;
757
758        Ok(())
759    }
760
761    fn get_params(&self) -> HashMap<String, String> {
762        let mut params = HashMap::new();
763        params.insert("n_clusters".to_string(), self.n_clusters.to_string());
764        params.insert("max_iter".to_string(), self.max_iter.to_string());
765        params.insert("tol".to_string(), self.tol.to_string());
766        if let Some(rs) = self.random_state {
767            params.insert("random_state".to_string(), rs.to_string());
768        }
769        params
770    }
771
772    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
773        for (key, value) in params {
774            match key.as_str() {
775                "n_clusters" => {
776                    self.n_clusters = value.parse().map_err(|_| {
777                        MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
778                    })?;
779                }
780                "max_iter" => {
781                    self.max_iter = value.parse().map_err(|_| {
782                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
783                    })?;
784                }
785                "tol" => {
786                    self.tol = value.parse().map_err(|_| {
787                        MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
788                    })?;
789                }
790                "random_state" => {
791                    self.random_state = Some(value.parse().map_err(|_| {
792                        MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
793                    })?);
794                }
795                _ => {
796                    // Skip unknown parameters
797                }
798            }
799        }
800        Ok(())
801    }
802
803    fn is_fitted(&self) -> bool {
804        self.fitted
805    }
806}
807
808impl SklearnClusterer for QuantumKMeans {
809    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
810        if !self.fitted {
811            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
812        }
813
814        let clusterer = self.clusterer.as_ref().unwrap();
815        let result = clusterer.predict(X)?;
816        // Convert usize to i32 for sklearn compatibility
817        Ok(result.mapv(|x| x as i32))
818    }
819
820    fn cluster_centers(&self) -> Option<&Array2<f64>> {
821        self.cluster_centers_.as_ref()
822    }
823}
824
825/// Model selection utilities (sklearn-compatible)
826pub mod model_selection {
827    use super::*;
828    use scirs2_core::random::prelude::*;
829
830    /// Cross-validation score
831    pub fn cross_val_score<E>(
832        estimator: &mut E,
833        X: &Array2<f64>,
834        y: &Array1<f64>,
835        cv: usize,
836    ) -> Result<Array1<f64>>
837    where
838        E: SklearnClassifier,
839    {
840        let n_samples = X.nrows();
841        let fold_size = n_samples / cv;
842        let mut scores = Array1::zeros(cv);
843
844        // Create fold indices
845        let mut indices: Vec<usize> = (0..n_samples).collect();
846        indices.shuffle(&mut thread_rng());
847
848        for fold in 0..cv {
849            let start_test = fold * fold_size;
850            let end_test = if fold == cv - 1 {
851                n_samples
852            } else {
853                (fold + 1) * fold_size
854            };
855
856            // Create train/test splits
857            let test_indices = &indices[start_test..end_test];
858            let train_indices: Vec<usize> = indices
859                .iter()
860                .enumerate()
861                .filter(|(i, _)| *i < start_test || *i >= end_test)
862                .map(|(_, &idx)| idx)
863                .collect();
864
865            // Extract train/test data
866            let X_train = X.select(Axis(0), &train_indices);
867            let y_train = y.select(Axis(0), &train_indices);
868            let X_test = X.select(Axis(0), test_indices);
869            let y_test = y.select(Axis(0), test_indices);
870
871            // Convert to i32 for classification
872            let y_train_int = y_train.mapv(|x| x.round() as i32);
873            let y_test_int = y_test.mapv(|x| x.round() as i32);
874
875            // Train and evaluate
876            estimator.fit(&X_train, Some(&y_train))?;
877            scores[fold] = estimator.score(&X_test, &y_test_int)?;
878        }
879
880        Ok(scores)
881    }
882
883    /// Train-test split
884    pub fn train_test_split(
885        X: &Array2<f64>,
886        y: &Array1<f64>,
887        test_size: f64,
888        random_state: Option<u64>,
889    ) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>, Array1<f64>)> {
890        let n_samples = X.nrows();
891        let n_test = (n_samples as f64 * test_size).round() as usize;
892
893        // Create indices
894        let mut indices: Vec<usize> = (0..n_samples).collect();
895
896        if let Some(seed) = random_state {
897            use scirs2_core::random::prelude::*;
898            let mut rng = StdRng::seed_from_u64(seed);
899            indices.shuffle(&mut rng);
900        } else {
901            indices.shuffle(&mut thread_rng());
902        }
903
904        let test_indices = &indices[..n_test];
905        let train_indices = &indices[n_test..];
906
907        let X_train = X.select(Axis(0), train_indices);
908        let X_test = X.select(Axis(0), test_indices);
909        let y_train = y.select(Axis(0), train_indices);
910        let y_test = y.select(Axis(0), test_indices);
911
912        Ok((X_train, X_test, y_train, y_test))
913    }
914
915    /// Grid search for hyperparameter tuning
916    pub struct GridSearchCV<E> {
917        /// Base estimator
918        estimator: E,
919        /// Parameter grid
920        param_grid: HashMap<String, Vec<String>>,
921        /// Cross-validation folds
922        cv: usize,
923        /// Best parameters
924        pub best_params_: HashMap<String, String>,
925        /// Best score
926        pub best_score_: f64,
927        /// Best estimator
928        pub best_estimator_: E,
929        /// Fitted flag
930        fitted: bool,
931    }
932
933    impl<E> GridSearchCV<E>
934    where
935        E: SklearnClassifier + Clone,
936    {
937        /// Create new grid search
938        pub fn new(estimator: E, param_grid: HashMap<String, Vec<String>>, cv: usize) -> Self {
939            Self {
940                best_estimator_: estimator.clone(),
941                estimator,
942                param_grid,
943                cv,
944                best_params_: HashMap::new(),
945                best_score_: f64::NEG_INFINITY,
946                fitted: false,
947            }
948        }
949
950        /// Fit grid search
951        pub fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
952            let param_combinations = self.generate_param_combinations();
953
954            for params in param_combinations {
955                let mut estimator = self.estimator.clone();
956                estimator.set_params(params.clone())?;
957
958                let scores = cross_val_score(&mut estimator, X, y, self.cv)?;
959                let mean_score = scores.mean().unwrap();
960
961                if mean_score > self.best_score_ {
962                    self.best_score_ = mean_score;
963                    self.best_params_ = params.clone();
964                    self.best_estimator_ = estimator;
965                }
966            }
967
968            // Fit best estimator
969            if !self.best_params_.is_empty() {
970                self.best_estimator_.set_params(self.best_params_.clone())?;
971                self.best_estimator_.fit(X, Some(y))?;
972            }
973
974            self.fitted = true;
975            Ok(())
976        }
977
978        /// Generate all parameter combinations
979        fn generate_param_combinations(&self) -> Vec<HashMap<String, String>> {
980            let mut combinations = vec![HashMap::new()];
981
982            for (param_name, param_values) in &self.param_grid {
983                let mut new_combinations = Vec::new();
984
985                for combination in &combinations {
986                    for value in param_values {
987                        let mut new_combination = combination.clone();
988                        new_combination.insert(param_name.clone(), value.clone());
989                        new_combinations.push(new_combination);
990                    }
991                }
992
993                combinations = new_combinations;
994            }
995
996            combinations
997        }
998
999        /// Get best parameters
1000        pub fn best_params(&self) -> &HashMap<String, String> {
1001            &self.best_params_
1002        }
1003
1004        /// Get best score
1005        pub fn best_score(&self) -> f64 {
1006            self.best_score_
1007        }
1008
1009        /// Predict with best estimator
1010        pub fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
1011            if !self.fitted {
1012                return Err(MLError::ModelNotTrained("Model not trained".to_string()));
1013            }
1014            self.best_estimator_.predict(X)
1015        }
1016    }
1017}
1018
1019/// Standard Scaler (sklearn-compatible)
1020pub struct StandardScaler {
1021    mean_: Option<Array1<f64>>,
1022    scale_: Option<Array1<f64>>,
1023    fitted: bool,
1024}
1025
1026impl StandardScaler {
1027    pub fn new() -> Self {
1028        Self {
1029            mean_: None,
1030            scale_: None,
1031            fitted: false,
1032        }
1033    }
1034}
1035
1036impl SklearnEstimator for StandardScaler {
1037    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
1038        let mean = X.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
1039        let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
1040
1041        self.mean_ = Some(mean);
1042        self.scale_ = Some(std);
1043        self.fitted = true;
1044
1045        Ok(())
1046    }
1047
1048    fn get_params(&self) -> HashMap<String, String> {
1049        HashMap::new()
1050    }
1051
1052    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
1053        Ok(())
1054    }
1055
1056    fn is_fitted(&self) -> bool {
1057        self.fitted
1058    }
1059}
1060
1061/// Select K Best features (sklearn-compatible)
1062pub struct SelectKBest {
1063    score_func: String,
1064    k: usize,
1065    fitted: bool,
1066    selected_features_: Option<Vec<usize>>,
1067}
1068
1069impl SelectKBest {
1070    pub fn new(score_func: &str, k: usize) -> Self {
1071        Self {
1072            score_func: score_func.to_string(),
1073            k,
1074            fitted: false,
1075            selected_features_: None,
1076        }
1077    }
1078}
1079
1080impl SklearnEstimator for SelectKBest {
1081    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
1082        // Mock implementation - select first k features
1083        let features: Vec<usize> = (0..self.k.min(X.ncols())).collect();
1084        self.selected_features_ = Some(features);
1085        self.fitted = true;
1086        Ok(())
1087    }
1088
1089    fn get_params(&self) -> HashMap<String, String> {
1090        let mut params = HashMap::new();
1091        params.insert("score_func".to_string(), self.score_func.clone());
1092        params.insert("k".to_string(), self.k.to_string());
1093        params
1094    }
1095
1096    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
1097        for (key, value) in params {
1098            match key.as_str() {
1099                "k" => {
1100                    self.k = value.parse().map_err(|_| {
1101                        MLError::InvalidConfiguration(format!("Invalid k parameter: {}", value))
1102                    })?;
1103                }
1104                "score_func" => {
1105                    self.score_func = value;
1106                }
1107                _ => {}
1108            }
1109        }
1110        Ok(())
1111    }
1112
1113    fn is_fitted(&self) -> bool {
1114        self.fitted
1115    }
1116}
1117
1118/// Quantum Feature Encoder (sklearn-compatible)
1119pub struct QuantumFeatureEncoder {
1120    encoding_type: String,
1121    normalization: String,
1122    fitted: bool,
1123}
1124
1125impl QuantumFeatureEncoder {
1126    pub fn new(encoding_type: &str, normalization: &str) -> Self {
1127        Self {
1128            encoding_type: encoding_type.to_string(),
1129            normalization: normalization.to_string(),
1130            fitted: false,
1131        }
1132    }
1133}
1134
1135impl SklearnEstimator for QuantumFeatureEncoder {
1136    fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
1137        self.fitted = true;
1138        Ok(())
1139    }
1140
1141    fn get_params(&self) -> HashMap<String, String> {
1142        let mut params = HashMap::new();
1143        params.insert("encoding_type".to_string(), self.encoding_type.clone());
1144        params.insert("normalization".to_string(), self.normalization.clone());
1145        params
1146    }
1147
1148    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
1149        for (key, value) in params {
1150            match key.as_str() {
1151                "encoding_type" => {
1152                    self.encoding_type = value;
1153                }
1154                "normalization" => {
1155                    self.normalization = value;
1156                }
1157                _ => {}
1158            }
1159        }
1160        Ok(())
1161    }
1162
1163    fn is_fitted(&self) -> bool {
1164        self.fitted
1165    }
1166}
1167
1168/// Simple Pipeline implementation
1169pub struct Pipeline {
1170    steps: Vec<(String, Box<dyn SklearnEstimator>)>,
1171    fitted: bool,
1172}
1173
1174impl Pipeline {
1175    pub fn new(steps: Vec<(&str, Box<dyn SklearnEstimator>)>) -> Result<Self> {
1176        let steps = steps
1177            .into_iter()
1178            .map(|(name, estimator)| (name.to_string(), estimator))
1179            .collect();
1180        Ok(Self {
1181            steps,
1182            fitted: false,
1183        })
1184    }
1185
1186    pub fn named_steps(&self) -> Vec<&String> {
1187        self.steps.iter().map(|(name, _)| name).collect()
1188    }
1189}
1190
1191impl Clone for Pipeline {
1192    fn clone(&self) -> Self {
1193        // For demo purposes, create a new pipeline with default components
1194        Self {
1195            steps: Vec::new(),
1196            fitted: false,
1197        }
1198    }
1199}
1200
1201impl SklearnEstimator for Pipeline {
1202    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
1203        // Mock implementation
1204        self.fitted = true;
1205        Ok(())
1206    }
1207
1208    fn get_params(&self) -> HashMap<String, String> {
1209        HashMap::new()
1210    }
1211
1212    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
1213        Ok(())
1214    }
1215
1216    fn is_fitted(&self) -> bool {
1217        self.fitted
1218    }
1219}
1220
1221impl SklearnClassifier for Pipeline {
1222    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
1223        // Mock predictions
1224        Ok(Array1::from_shape_fn(X.nrows(), |i| {
1225            if i % 2 == 0 {
1226                1
1227            } else {
1228                0
1229            }
1230        }))
1231    }
1232
1233    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
1234        Ok(Array2::from_shape_fn((X.nrows(), 2), |(i, j)| {
1235            if j == 0 {
1236                0.4
1237            } else {
1238                0.6
1239            }
1240        }))
1241    }
1242
1243    fn classes(&self) -> &[i32] {
1244        &[0, 1]
1245    }
1246
1247    fn feature_importances(&self) -> Option<Array1<f64>> {
1248        Some(Array1::from_vec(vec![0.25, 0.35, 0.20, 0.20]))
1249    }
1250
1251    fn save(&self, _path: &str) -> Result<()> {
1252        Ok(())
1253    }
1254}
1255
1256impl Pipeline {
1257    pub fn load(_path: &str) -> Result<Self> {
1258        Ok(Self::new(vec![])?)
1259    }
1260}
1261
1262/// Pipeline utilities (sklearn-compatible)
1263pub mod pipeline {
1264    use super::*;
1265
1266    /// Transformer trait
1267    pub trait SklearnTransformer: Send + Sync {
1268        /// Fit transformer
1269        fn fit(&mut self, X: &Array2<f64>) -> Result<()>;
1270
1271        /// Transform data
1272        fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
1273
1274        /// Fit and transform
1275        fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
1276            self.fit(X)?;
1277            self.transform(X)
1278        }
1279    }
1280
1281    /// Quantum feature scaler
1282    pub struct QuantumStandardScaler {
1283        /// Feature means
1284        mean_: Option<Array1<f64>>,
1285        /// Feature standard deviations
1286        scale_: Option<Array1<f64>>,
1287        /// Fitted flag
1288        fitted: bool,
1289    }
1290
1291    impl QuantumStandardScaler {
1292        /// Create new scaler
1293        pub fn new() -> Self {
1294            Self {
1295                mean_: None,
1296                scale_: None,
1297                fitted: false,
1298            }
1299        }
1300    }
1301
1302    impl SklearnTransformer for QuantumStandardScaler {
1303        fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
1304            let mean = X.mean_axis(Axis(0)).unwrap();
1305            let std = X.std_axis(Axis(0), 0.0);
1306
1307            self.mean_ = Some(mean);
1308            self.scale_ = Some(std);
1309            self.fitted = true;
1310
1311            Ok(())
1312        }
1313
1314        fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
1315            if !self.fitted {
1316                return Err(MLError::ModelNotTrained("Model not trained".to_string()));
1317            }
1318
1319            let mean = self.mean_.as_ref().unwrap();
1320            let scale = self.scale_.as_ref().unwrap();
1321
1322            let mut X_scaled = X.clone();
1323            for mut row in X_scaled.axis_iter_mut(Axis(0)) {
1324                row -= mean;
1325                row /= scale;
1326            }
1327
1328            Ok(X_scaled)
1329        }
1330    }
1331
1332    /// Quantum pipeline
1333    pub struct QuantumPipeline {
1334        /// Pipeline steps
1335        steps: Vec<(String, PipelineStep)>,
1336        /// Fitted flag
1337        fitted: bool,
1338    }
1339
1340    /// Pipeline step enum
1341    pub enum PipelineStep {
1342        /// Transformer step
1343        Transformer(Box<dyn SklearnTransformer>),
1344        /// Classifier step
1345        Classifier(Box<dyn SklearnClassifier>),
1346        /// Regressor step
1347        Regressor(Box<dyn SklearnRegressor>),
1348        /// Clusterer step
1349        Clusterer(Box<dyn SklearnClusterer>),
1350    }
1351
1352    impl QuantumPipeline {
1353        /// Create new pipeline
1354        pub fn new() -> Self {
1355            Self {
1356                steps: Vec::new(),
1357                fitted: false,
1358            }
1359        }
1360
1361        /// Add transformer step
1362        pub fn add_transformer(
1363            mut self,
1364            name: String,
1365            transformer: Box<dyn SklearnTransformer>,
1366        ) -> Self {
1367            self.steps
1368                .push((name, PipelineStep::Transformer(transformer)));
1369            self
1370        }
1371
1372        /// Add classifier step
1373        pub fn add_classifier(
1374            mut self,
1375            name: String,
1376            classifier: Box<dyn SklearnClassifier>,
1377        ) -> Self {
1378            self.steps
1379                .push((name, PipelineStep::Classifier(classifier)));
1380            self
1381        }
1382
1383        /// Fit pipeline
1384        pub fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
1385            let mut current_X = X.clone();
1386
1387            for (_name, step) in &mut self.steps {
1388                match step {
1389                    PipelineStep::Transformer(transformer) => {
1390                        current_X = transformer.fit_transform(&current_X)?;
1391                    }
1392                    PipelineStep::Classifier(classifier) => {
1393                        classifier.fit(&current_X, y)?;
1394                    }
1395                    PipelineStep::Regressor(regressor) => {
1396                        regressor.fit(&current_X, y)?;
1397                    }
1398                    PipelineStep::Clusterer(clusterer) => {
1399                        clusterer.fit(&current_X, y)?;
1400                    }
1401                }
1402            }
1403
1404            self.fitted = true;
1405            Ok(())
1406        }
1407
1408        /// Predict with pipeline
1409        pub fn predict(&self, X: &Array2<f64>) -> Result<ArrayD<f64>> {
1410            if !self.fitted {
1411                return Err(MLError::ModelNotTrained("Model not trained".to_string()));
1412            }
1413
1414            let mut current_X = X.clone();
1415
1416            for (_name, step) in &self.steps {
1417                match step {
1418                    PipelineStep::Transformer(transformer) => {
1419                        current_X = transformer.transform(&current_X)?;
1420                    }
1421                    PipelineStep::Classifier(classifier) => {
1422                        let predictions = classifier.predict(&current_X)?;
1423                        let predictions_f64 = predictions.mapv(|x| x as f64);
1424                        return Ok(predictions_f64.into_dyn());
1425                    }
1426                    PipelineStep::Regressor(regressor) => {
1427                        let predictions = regressor.predict(&current_X)?;
1428                        return Ok(predictions.into_dyn());
1429                    }
1430                    PipelineStep::Clusterer(clusterer) => {
1431                        let predictions = clusterer.predict(&current_X)?;
1432                        let predictions_f64 = predictions.mapv(|x| x as f64);
1433                        return Ok(predictions_f64.into_dyn());
1434                    }
1435                }
1436            }
1437
1438            Ok(current_X.into_dyn())
1439        }
1440    }
1441}
1442
1443/// Metrics module (sklearn-compatible)
1444pub mod metrics {
1445    use super::*;
1446
1447    /// Calculate accuracy score
1448    pub fn accuracy_score(y_true: &Array1<i32>, y_pred: &Array1<i32>) -> f64 {
1449        let correct = y_true
1450            .iter()
1451            .zip(y_pred.iter())
1452            .filter(|(&true_val, &pred_val)| true_val == pred_val)
1453            .count();
1454        correct as f64 / y_true.len() as f64
1455    }
1456
1457    /// Calculate precision score
1458    pub fn precision_score(y_true: &Array1<i32>, y_pred: &Array1<i32>, _average: &str) -> f64 {
1459        // Mock implementation
1460        0.85
1461    }
1462
1463    /// Calculate recall score
1464    pub fn recall_score(y_true: &Array1<i32>, y_pred: &Array1<i32>, _average: &str) -> f64 {
1465        // Mock implementation
1466        0.82
1467    }
1468
1469    /// Calculate F1 score
1470    pub fn f1_score(y_true: &Array1<i32>, y_pred: &Array1<i32>, _average: &str) -> f64 {
1471        // Mock implementation
1472        0.83
1473    }
1474
1475    /// Generate classification report
1476    pub fn classification_report(
1477        y_true: &Array1<i32>,
1478        y_pred: &Array1<i32>,
1479        target_names: Vec<&str>,
1480        digits: usize,
1481    ) -> String {
1482        format!("Classification Report\n==================\n{:>10} {:>10} {:>10} {:>10} {:>10}\n{:>10} {:>10.digits$} {:>10.digits$} {:>10.digits$} {:>10}\n{:>10} {:>10.digits$} {:>10.digits$} {:>10.digits$} {:>10}\n",
1483            "", "precision", "recall", "f1-score", "support",
1484            target_names[0], 0.85, 0.82, 0.83, 50,
1485            target_names[1], 0.87, 0.85, 0.86, 50,
1486            digits = digits)
1487    }
1488
1489    /// Calculate silhouette score
1490    pub fn silhouette_score(X: &Array2<f64>, labels: &Array1<i32>, _metric: &str) -> f64 {
1491        // Mock implementation
1492        0.65
1493    }
1494
1495    /// Calculate Calinski-Harabasz score
1496    pub fn calinski_harabasz_score(X: &Array2<f64>, labels: &Array1<i32>) -> f64 {
1497        // Mock implementation
1498        150.0
1499    }
1500}
1501
1502#[cfg(test)]
1503mod tests {
1504    use super::*;
1505    use scirs2_core::ndarray::Array;
1506
1507    #[test]
1508    #[ignore]
1509    fn test_quantum_svc() {
1510        let mut svc = QuantumSVC::new().set_C(1.0).set_gamma(0.1);
1511
1512        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0])
1513            .unwrap();
1514        let y = Array::from_vec(vec![1.0, -1.0, -1.0, 1.0]);
1515
1516        assert!(svc.fit(&X, Some(&y)).is_ok());
1517        assert!(svc.is_fitted());
1518
1519        let predictions = svc.predict(&X);
1520        assert!(predictions.is_ok());
1521    }
1522
1523    #[test]
1524    #[ignore]
1525    fn test_quantum_mlp_classifier() {
1526        let mut mlp = QuantumMLPClassifier::new()
1527            .set_hidden_layer_sizes(vec![5])
1528            .set_learning_rate(0.01)
1529            .set_max_iter(10);
1530
1531        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0])
1532            .unwrap();
1533        let y = Array::from_vec(vec![1.0, 0.0, 0.0, 1.0]);
1534
1535        assert!(mlp.fit(&X, Some(&y)).is_ok());
1536        assert!(mlp.is_fitted());
1537
1538        let predictions = mlp.predict(&X);
1539        assert!(predictions.is_ok());
1540
1541        let probas = mlp.predict_proba(&X);
1542        assert!(probas.is_ok());
1543    }
1544
1545    #[test]
1546    #[ignore]
1547    fn test_quantum_kmeans() {
1548        let mut kmeans = QuantumKMeans::new(2).set_max_iter(50).set_tol(1e-4);
1549
1550        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 1.1, 1.1, -1.0, -1.0, -1.1, -1.1])
1551            .unwrap();
1552
1553        assert!(kmeans.fit(&X, None).is_ok());
1554        assert!(kmeans.is_fitted());
1555
1556        let predictions = kmeans.predict(&X);
1557        assert!(predictions.is_ok());
1558
1559        assert!(kmeans.cluster_centers().is_some());
1560    }
1561
1562    #[test]
1563    fn test_model_selection() {
1564        use model_selection::train_test_split;
1565
1566        let X = Array::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
1567        let y = Array::from_vec((0..10).map(|x| x as f64).collect());
1568
1569        let (X_train, X_test, y_train, y_test) = train_test_split(&X, &y, 0.3, Some(42)).unwrap();
1570
1571        assert_eq!(X_train.nrows() + X_test.nrows(), X.nrows());
1572        assert_eq!(y_train.len() + y_test.len(), y.len());
1573    }
1574
1575    #[test]
1576    fn test_pipeline() {
1577        use pipeline::{QuantumPipeline, QuantumStandardScaler};
1578
1579        let mut pipeline = QuantumPipeline::new()
1580            .add_transformer("scaler".to_string(), Box::new(QuantumStandardScaler::new()));
1581
1582        let X =
1583            Array::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1584
1585        assert!(pipeline.fit(&X, None).is_ok());
1586
1587        let transformed = pipeline.predict(&X);
1588        assert!(transformed.is_ok());
1589    }
1590}