quantrs2_ml/
sklearn_compatibility.rs

1//! Scikit-learn compatibility layer for QuantRS2-ML
2//!
3//! This module provides a compatibility layer that mimics scikit-learn APIs,
4//! allowing easy integration of quantum ML models with existing scikit-learn
5//! workflows and pipelines.
6
7use crate::classification::{ClassificationMetrics, Classifier};
8use crate::clustering::{ClusteringAlgorithm, QuantumClusterer};
9use crate::error::{MLError, Result};
10use crate::qnn::{QNNBuilder, QuantumNeuralNetwork};
11use crate::qsvm::{FeatureMapType, QSVMParams, QSVM};
12use crate::simulator_backends::{
13    Backend, BackendCapabilities, SimulatorBackend, StatevectorBackend,
14};
15use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis};
16use scirs2_core::SliceRandomExt;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20/// Base estimator trait following scikit-learn conventions
21pub trait SklearnEstimator: Send + Sync {
22    /// Fit the model to training data
23    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()>;
24
25    /// Get model parameters
26    fn get_params(&self) -> HashMap<String, String>;
27
28    /// Set model parameters
29    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()>;
30
31    /// Check if model is fitted
32    fn is_fitted(&self) -> bool;
33
34    /// Get feature names
35    fn get_feature_names_out(&self) -> Vec<String> {
36        vec![]
37    }
38}
39
40/// Classifier mixin trait
41pub trait SklearnClassifier: SklearnEstimator {
42    /// Predict class labels
43    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
44
45    /// Predict class probabilities
46    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
47
48    /// Get unique class labels
49    fn classes(&self) -> &[i32];
50
51    /// Score the model (accuracy by default)
52    fn score(&self, X: &Array2<f64>, y: &Array1<i32>) -> Result<f64> {
53        let predictions = self.predict(X)?;
54        let correct = predictions
55            .iter()
56            .zip(y.iter())
57            .filter(|(&pred, &true_label)| pred == true_label)
58            .count();
59        Ok(correct as f64 / y.len() as f64)
60    }
61
62    /// Get feature importances (optional)
63    fn feature_importances(&self) -> Option<Array1<f64>> {
64        None
65    }
66
67    /// Save model to file (optional)
68    fn save(&self, _path: &str) -> Result<()> {
69        Ok(())
70    }
71}
72
73/// Regressor mixin trait
74pub trait SklearnRegressor: SklearnEstimator {
75    /// Predict continuous values
76    fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>>;
77
78    /// Score the model (R² by default)
79    fn score(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
80        let predictions = self.predict(X)?;
81        let y_mean = y.mean().unwrap_or(0.0);
82
83        let ss_res: f64 = y
84            .iter()
85            .zip(predictions.iter())
86            .map(|(&true_val, &pred)| (true_val - pred).powi(2))
87            .sum();
88
89        let ss_tot: f64 = y.iter().map(|&val| (val - y_mean).powi(2)).sum();
90
91        Ok(1.0 - ss_res / ss_tot)
92    }
93}
94
95/// Extension trait for fitting with Array1<f64> directly
96pub trait SklearnFit {
97    fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()>;
98}
99
100/// Clusterer mixin trait
101pub trait SklearnClusterer: SklearnEstimator {
102    /// Predict cluster labels
103    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>>;
104
105    /// Fit and predict in one step
106    fn fit_predict(&mut self, X: &Array2<f64>) -> Result<Array1<i32>> {
107        self.fit(X, None)?;
108        self.predict(X)
109    }
110
111    /// Get cluster centers (if applicable)
112    fn cluster_centers(&self) -> Option<&Array2<f64>> {
113        None
114    }
115}
116
117/// Quantum Support Vector Machine (sklearn-compatible)
118pub struct QuantumSVC {
119    /// Internal QSVM
120    qsvm: Option<QSVM>,
121    /// SVM parameters
122    params: QSVMParams,
123    /// Feature map type
124    feature_map: FeatureMapType,
125    /// Backend
126    backend: Arc<dyn SimulatorBackend>,
127    /// Fitted flag
128    fitted: bool,
129    /// Unique classes
130    classes: Vec<i32>,
131    /// Regularization parameter
132    C: f64,
133    /// Kernel gamma parameter
134    gamma: f64,
135}
136
137impl Clone for QuantumSVC {
138    fn clone(&self) -> Self {
139        Self {
140            qsvm: None, // Reset QSVM since it's not cloneable
141            params: self.params.clone(),
142            feature_map: self.feature_map,
143            backend: self.backend.clone(),
144            fitted: false, // Reset fitted status
145            classes: self.classes.clone(),
146            C: self.C,
147            gamma: self.gamma,
148        }
149    }
150}
151
152impl QuantumSVC {
153    /// Create new Quantum SVC
154    pub fn new() -> Self {
155        Self {
156            qsvm: None,
157            params: QSVMParams::default(),
158            feature_map: FeatureMapType::ZZFeatureMap,
159            backend: Arc::new(StatevectorBackend::new(10)),
160            fitted: false,
161            classes: Vec::new(),
162            C: 1.0,
163            gamma: 1.0,
164        }
165    }
166
167    /// Set regularization parameter
168    pub fn set_C(mut self, C: f64) -> Self {
169        self.C = C;
170        self
171    }
172
173    /// Set kernel gamma parameter
174    pub fn set_gamma(mut self, gamma: f64) -> Self {
175        self.gamma = gamma;
176        self
177    }
178
179    /// Set feature map
180    pub fn set_kernel(mut self, feature_map: FeatureMapType) -> Self {
181        self.feature_map = feature_map;
182        self
183    }
184
185    /// Set quantum backend
186    pub fn set_backend(mut self, backend: Arc<dyn SimulatorBackend>) -> Self {
187        self.backend = backend;
188        self
189    }
190
191    /// Load model from file (mock implementation)
192    pub fn load(_path: &str) -> Result<Self> {
193        Ok(Self::new())
194    }
195}
196
197impl SklearnEstimator for QuantumSVC {
198    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
199        let y = y.ok_or_else(|| {
200            MLError::InvalidConfiguration("Labels required for supervised learning".to_string())
201        })?;
202
203        // Convert continuous labels to integer classes
204        let y_int: Array1<i32> = y.mapv(|val| val.round() as i32);
205
206        // Find unique classes
207        let mut classes = Vec::new();
208        for &label in y_int.iter() {
209            if !classes.contains(&label) {
210                classes.push(label);
211            }
212        }
213        classes.sort();
214        self.classes = classes;
215
216        // Update QSVM parameters
217        self.params.feature_map = self.feature_map;
218        self.params.regularization = self.C;
219
220        // Create and train QSVM
221        let mut qsvm = QSVM::new(self.params.clone());
222        qsvm.fit(X, &y_int)?;
223
224        self.qsvm = Some(qsvm);
225        self.fitted = true;
226
227        Ok(())
228    }
229
230    fn get_params(&self) -> HashMap<String, String> {
231        let mut params = HashMap::new();
232        params.insert("C".to_string(), self.C.to_string());
233        params.insert("gamma".to_string(), self.gamma.to_string());
234        params.insert("kernel".to_string(), format!("{:?}", self.feature_map));
235        params
236    }
237
238    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
239        for (key, value) in params {
240            match key.as_str() {
241                "C" => {
242                    self.C = value.parse().map_err(|_| {
243                        MLError::InvalidConfiguration(format!("Invalid C parameter: {}", value))
244                    })?;
245                }
246                "gamma" => {
247                    self.gamma = value.parse().map_err(|_| {
248                        MLError::InvalidConfiguration(format!("Invalid gamma parameter: {}", value))
249                    })?;
250                }
251                "kernel" => {
252                    self.feature_map = match value.as_str() {
253                        "ZZFeatureMap" => FeatureMapType::ZZFeatureMap,
254                        "ZFeatureMap" => FeatureMapType::ZFeatureMap,
255                        "PauliFeatureMap" => FeatureMapType::PauliFeatureMap,
256                        _ => {
257                            return Err(MLError::InvalidConfiguration(format!(
258                                "Unknown kernel: {}",
259                                value
260                            )))
261                        }
262                    };
263                }
264                _ => {
265                    return Err(MLError::InvalidConfiguration(format!(
266                        "Unknown parameter: {}",
267                        key
268                    )))
269                }
270            }
271        }
272        Ok(())
273    }
274
275    fn is_fitted(&self) -> bool {
276        self.fitted
277    }
278}
279
280impl SklearnClassifier for QuantumSVC {
281    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
282        if !self.fitted {
283            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
284        }
285
286        let qsvm = self
287            .qsvm
288            .as_ref()
289            .ok_or_else(|| MLError::ModelNotTrained("QSVM model not initialized".to_string()))?;
290        qsvm.predict(X).map_err(MLError::ValidationError)
291    }
292
293    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
294        if !self.fitted {
295            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
296        }
297
298        let predictions = self.predict(X)?;
299        let n_samples = X.nrows();
300        let n_classes = self.classes.len();
301
302        let mut probabilities = Array2::zeros((n_samples, n_classes));
303
304        // Convert hard predictions to probabilities (placeholder)
305        for (i, &prediction) in predictions.iter().enumerate() {
306            for (j, &class) in self.classes.iter().enumerate() {
307                probabilities[[i, j]] = if prediction == class { 1.0 } else { 0.0 };
308            }
309        }
310
311        Ok(probabilities)
312    }
313
314    fn classes(&self) -> &[i32] {
315        &self.classes
316    }
317}
318
319/// Quantum Neural Network Classifier (sklearn-compatible)
320pub struct QuantumMLPClassifier {
321    /// Internal QNN
322    qnn: Option<QuantumNeuralNetwork>,
323    /// Network configuration
324    hidden_layer_sizes: Vec<usize>,
325    /// Activation function
326    activation: String,
327    /// Solver
328    solver: String,
329    /// Learning rate
330    learning_rate: f64,
331    /// Maximum iterations
332    max_iter: usize,
333    /// Random state
334    random_state: Option<u64>,
335    /// Backend
336    backend: Arc<dyn SimulatorBackend>,
337    /// Fitted flag
338    fitted: bool,
339    /// Unique classes
340    classes: Vec<i32>,
341}
342
343impl QuantumMLPClassifier {
344    /// Create new Quantum MLP Classifier
345    pub fn new() -> Self {
346        Self {
347            qnn: None,
348            hidden_layer_sizes: vec![10],
349            activation: "relu".to_string(),
350            solver: "adam".to_string(),
351            learning_rate: 0.001,
352            max_iter: 200,
353            random_state: None,
354            backend: Arc::new(StatevectorBackend::new(10)),
355            fitted: false,
356            classes: Vec::new(),
357        }
358    }
359
360    /// Set hidden layer sizes
361    pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
362        self.hidden_layer_sizes = sizes;
363        self
364    }
365
366    /// Set activation function
367    pub fn set_activation(mut self, activation: String) -> Self {
368        self.activation = activation;
369        self
370    }
371
372    /// Set learning rate
373    pub fn set_learning_rate(mut self, lr: f64) -> Self {
374        self.learning_rate = lr;
375        self
376    }
377
378    /// Set maximum iterations
379    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
380        self.max_iter = max_iter;
381        self
382    }
383}
384
385impl SklearnEstimator for QuantumMLPClassifier {
386    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
387        let y = y.ok_or_else(|| {
388            MLError::InvalidConfiguration("Labels required for supervised learning".to_string())
389        })?;
390
391        // Convert continuous labels to integer classes
392        let y_int: Array1<i32> = y.mapv(|val| val.round() as i32);
393
394        // Find unique classes
395        let mut classes = Vec::new();
396        for &label in y_int.iter() {
397            if !classes.contains(&label) {
398                classes.push(label);
399            }
400        }
401        classes.sort();
402        self.classes = classes;
403
404        // Build QNN
405        let input_size = X.ncols();
406        let output_size = self.classes.len();
407
408        let mut builder = QNNBuilder::new();
409
410        // Add hidden layers
411        for &size in &self.hidden_layer_sizes {
412            builder = builder.add_layer(size);
413        }
414
415        // Add output layer
416        builder = builder.add_layer(output_size);
417
418        let mut qnn = builder.build()?;
419
420        // Train QNN
421        let y_one_hot = self.to_one_hot(&y_int)?;
422        qnn.train(X, &y_one_hot, self.max_iter, self.learning_rate)?;
423
424        self.qnn = Some(qnn);
425        self.fitted = true;
426
427        Ok(())
428    }
429
430    fn get_params(&self) -> HashMap<String, String> {
431        let mut params = HashMap::new();
432        params.insert(
433            "hidden_layer_sizes".to_string(),
434            format!("{:?}", self.hidden_layer_sizes),
435        );
436        params.insert("activation".to_string(), self.activation.clone());
437        params.insert("solver".to_string(), self.solver.clone());
438        params.insert("learning_rate".to_string(), self.learning_rate.to_string());
439        params.insert("max_iter".to_string(), self.max_iter.to_string());
440        params
441    }
442
443    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
444        for (key, value) in params {
445            match key.as_str() {
446                "learning_rate" => {
447                    self.learning_rate = value.parse().map_err(|_| {
448                        MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
449                    })?;
450                }
451                "max_iter" => {
452                    self.max_iter = value.parse().map_err(|_| {
453                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
454                    })?;
455                }
456                "activation" => {
457                    self.activation = value;
458                }
459                "solver" => {
460                    self.solver = value;
461                }
462                _ => {
463                    // Skip unknown parameters
464                }
465            }
466        }
467        Ok(())
468    }
469
470    fn is_fitted(&self) -> bool {
471        self.fitted
472    }
473}
474
475impl QuantumMLPClassifier {
476    /// Convert integer labels to one-hot encoding
477    fn to_one_hot(&self, y: &Array1<i32>) -> Result<Array2<f64>> {
478        let n_samples = y.len();
479        let n_classes = self.classes.len();
480        let mut one_hot = Array2::zeros((n_samples, n_classes));
481
482        for (i, &label) in y.iter().enumerate() {
483            if let Some(class_idx) = self.classes.iter().position(|&c| c == label) {
484                one_hot[[i, class_idx]] = 1.0;
485            }
486        }
487
488        Ok(one_hot)
489    }
490
491    /// Convert one-hot predictions to class labels
492    fn from_one_hot(&self, predictions: &Array2<f64>) -> Array1<i32> {
493        predictions
494            .axis_iter(Axis(0))
495            .map(|row| {
496                let max_idx = row
497                    .iter()
498                    .enumerate()
499                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
500                    .map(|(idx, _)| idx)
501                    .unwrap_or(0);
502                self.classes.get(max_idx).copied().unwrap_or(0)
503            })
504            .collect()
505    }
506}
507
508impl SklearnClassifier for QuantumMLPClassifier {
509    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
510        if !self.fitted {
511            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
512        }
513
514        let qnn = self
515            .qnn
516            .as_ref()
517            .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
518        let predictions = qnn.predict_batch(X)?;
519        Ok(self.from_one_hot(&predictions))
520    }
521
522    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
523        if !self.fitted {
524            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
525        }
526
527        let qnn = self
528            .qnn
529            .as_ref()
530            .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
531        qnn.predict_batch(X)
532    }
533
534    fn classes(&self) -> &[i32] {
535        &self.classes
536    }
537}
538
539/// Quantum Regressor (sklearn-compatible)
540pub struct QuantumMLPRegressor {
541    /// Internal QNN
542    qnn: Option<QuantumNeuralNetwork>,
543    /// Network configuration
544    hidden_layer_sizes: Vec<usize>,
545    /// Activation function
546    activation: String,
547    /// Solver
548    solver: String,
549    /// Learning rate
550    learning_rate: f64,
551    /// Maximum iterations
552    max_iter: usize,
553    /// Random state
554    random_state: Option<u64>,
555    /// Backend
556    backend: Arc<dyn SimulatorBackend>,
557    /// Fitted flag
558    fitted: bool,
559}
560
561impl QuantumMLPRegressor {
562    /// Create new Quantum MLP Regressor
563    pub fn new() -> Self {
564        Self {
565            qnn: None,
566            hidden_layer_sizes: vec![10],
567            activation: "relu".to_string(),
568            solver: "adam".to_string(),
569            learning_rate: 0.001,
570            max_iter: 200,
571            random_state: None,
572            backend: Arc::new(StatevectorBackend::new(10)),
573            fitted: false,
574        }
575    }
576
577    /// Set hidden layer sizes
578    pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
579        self.hidden_layer_sizes = sizes;
580        self
581    }
582
583    /// Set learning rate
584    pub fn set_learning_rate(mut self, lr: f64) -> Self {
585        self.learning_rate = lr;
586        self
587    }
588
589    /// Set maximum iterations
590    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
591        self.max_iter = max_iter;
592        self
593    }
594}
595
596impl SklearnEstimator for QuantumMLPRegressor {
597    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
598        let y = y.ok_or_else(|| {
599            MLError::InvalidConfiguration("Target values required for regression".to_string())
600        })?;
601
602        // Build QNN for regression
603        let input_size = X.ncols();
604        let output_size = 1; // Single output for regression
605
606        let mut builder = QNNBuilder::new();
607
608        // Add hidden layers
609        for &size in &self.hidden_layer_sizes {
610            builder = builder.add_layer(size);
611        }
612
613        // Add output layer
614        builder = builder.add_layer(output_size);
615
616        let mut qnn = builder.build()?;
617
618        // Reshape target for training
619        let y_reshaped = y.clone().into_shape((y.len(), 1)).map_err(|e| {
620            MLError::InvalidConfiguration(format!("Failed to reshape target: {}", e))
621        })?;
622
623        // Train QNN
624        qnn.train(X, &y_reshaped, self.max_iter, self.learning_rate)?;
625
626        self.qnn = Some(qnn);
627        self.fitted = true;
628
629        Ok(())
630    }
631
632    fn get_params(&self) -> HashMap<String, String> {
633        let mut params = HashMap::new();
634        params.insert(
635            "hidden_layer_sizes".to_string(),
636            format!("{:?}", self.hidden_layer_sizes),
637        );
638        params.insert("activation".to_string(), self.activation.clone());
639        params.insert("solver".to_string(), self.solver.clone());
640        params.insert("learning_rate".to_string(), self.learning_rate.to_string());
641        params.insert("max_iter".to_string(), self.max_iter.to_string());
642        params
643    }
644
645    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
646        for (key, value) in params {
647            match key.as_str() {
648                "learning_rate" => {
649                    self.learning_rate = value.parse().map_err(|_| {
650                        MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
651                    })?;
652                }
653                "max_iter" => {
654                    self.max_iter = value.parse().map_err(|_| {
655                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
656                    })?;
657                }
658                "activation" => {
659                    self.activation = value;
660                }
661                "solver" => {
662                    self.solver = value;
663                }
664                _ => {
665                    // Skip unknown parameters
666                }
667            }
668        }
669        Ok(())
670    }
671
672    fn is_fitted(&self) -> bool {
673        self.fitted
674    }
675}
676
677impl SklearnRegressor for QuantumMLPRegressor {
678    fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>> {
679        if !self.fitted {
680            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
681        }
682
683        let qnn = self
684            .qnn
685            .as_ref()
686            .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
687        let predictions = qnn.predict_batch(X)?;
688
689        // Extract single column for regression
690        Ok(predictions.column(0).to_owned())
691    }
692}
693
694/// Quantum K-Means (sklearn-compatible)
695pub struct QuantumKMeans {
696    /// Internal clusterer
697    clusterer: Option<QuantumClusterer>,
698    /// Number of clusters
699    n_clusters: usize,
700    /// Maximum iterations
701    max_iter: usize,
702    /// Tolerance
703    tol: f64,
704    /// Random state
705    random_state: Option<u64>,
706    /// Backend
707    backend: Arc<dyn SimulatorBackend>,
708    /// Fitted flag
709    fitted: bool,
710    /// Cluster centers
711    cluster_centers_: Option<Array2<f64>>,
712    /// Labels
713    labels_: Option<Array1<i32>>,
714}
715
716impl QuantumKMeans {
717    /// Create new Quantum K-Means
718    pub fn new(n_clusters: usize) -> Self {
719        Self {
720            clusterer: None,
721            n_clusters,
722            max_iter: 300,
723            tol: 1e-4,
724            random_state: None,
725            backend: Arc::new(StatevectorBackend::new(10)),
726            fitted: false,
727            cluster_centers_: None,
728            labels_: None,
729        }
730    }
731
732    /// Set maximum iterations
733    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
734        self.max_iter = max_iter;
735        self
736    }
737
738    /// Set tolerance
739    pub fn set_tol(mut self, tol: f64) -> Self {
740        self.tol = tol;
741        self
742    }
743
744    /// Set random state
745    pub fn set_random_state(mut self, random_state: u64) -> Self {
746        self.random_state = Some(random_state);
747        self
748    }
749}
750
751impl SklearnEstimator for QuantumKMeans {
752    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
753        let config = crate::clustering::config::QuantumClusteringConfig {
754            algorithm: crate::clustering::config::ClusteringAlgorithm::QuantumKMeans,
755            n_clusters: self.n_clusters,
756            max_iterations: self.max_iter,
757            tolerance: self.tol,
758            num_qubits: 4,
759            random_state: self.random_state,
760        };
761        let mut clusterer = QuantumClusterer::new(config);
762
763        let result = clusterer.fit_predict(X)?;
764        // Convert usize to i32 for sklearn compatibility
765        let result_i32 = result.mapv(|x| x as i32);
766        self.labels_ = Some(result_i32);
767        self.cluster_centers_ = None; // TODO: Get cluster centers from clusterer
768
769        self.clusterer = Some(clusterer);
770        self.fitted = true;
771
772        Ok(())
773    }
774
775    fn get_params(&self) -> HashMap<String, String> {
776        let mut params = HashMap::new();
777        params.insert("n_clusters".to_string(), self.n_clusters.to_string());
778        params.insert("max_iter".to_string(), self.max_iter.to_string());
779        params.insert("tol".to_string(), self.tol.to_string());
780        if let Some(rs) = self.random_state {
781            params.insert("random_state".to_string(), rs.to_string());
782        }
783        params
784    }
785
786    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
787        for (key, value) in params {
788            match key.as_str() {
789                "n_clusters" => {
790                    self.n_clusters = value.parse().map_err(|_| {
791                        MLError::InvalidConfiguration(format!("Invalid n_clusters: {}", value))
792                    })?;
793                }
794                "max_iter" => {
795                    self.max_iter = value.parse().map_err(|_| {
796                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
797                    })?;
798                }
799                "tol" => {
800                    self.tol = value.parse().map_err(|_| {
801                        MLError::InvalidConfiguration(format!("Invalid tol: {}", value))
802                    })?;
803                }
804                "random_state" => {
805                    self.random_state = Some(value.parse().map_err(|_| {
806                        MLError::InvalidConfiguration(format!("Invalid random_state: {}", value))
807                    })?);
808                }
809                _ => {
810                    // Skip unknown parameters
811                }
812            }
813        }
814        Ok(())
815    }
816
817    fn is_fitted(&self) -> bool {
818        self.fitted
819    }
820}
821
822impl SklearnClusterer for QuantumKMeans {
823    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
824        if !self.fitted {
825            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
826        }
827
828        let clusterer = self
829            .clusterer
830            .as_ref()
831            .ok_or_else(|| MLError::ModelNotTrained("Clusterer not initialized".to_string()))?;
832        let result = clusterer.predict(X)?;
833        // Convert usize to i32 for sklearn compatibility
834        Ok(result.mapv(|x| x as i32))
835    }
836
837    fn cluster_centers(&self) -> Option<&Array2<f64>> {
838        self.cluster_centers_.as_ref()
839    }
840}
841
842/// Model selection utilities (sklearn-compatible)
843pub mod model_selection {
844    use super::*;
845    use scirs2_core::random::prelude::*;
846
847    /// Cross-validation score
848    pub fn cross_val_score<E>(
849        estimator: &mut E,
850        X: &Array2<f64>,
851        y: &Array1<f64>,
852        cv: usize,
853    ) -> Result<Array1<f64>>
854    where
855        E: SklearnClassifier,
856    {
857        let n_samples = X.nrows();
858        let fold_size = n_samples / cv;
859        let mut scores = Array1::zeros(cv);
860
861        // Create fold indices
862        let mut indices: Vec<usize> = (0..n_samples).collect();
863        indices.shuffle(&mut thread_rng());
864
865        for fold in 0..cv {
866            let start_test = fold * fold_size;
867            let end_test = if fold == cv - 1 {
868                n_samples
869            } else {
870                (fold + 1) * fold_size
871            };
872
873            // Create train/test splits
874            let test_indices = &indices[start_test..end_test];
875            let train_indices: Vec<usize> = indices
876                .iter()
877                .enumerate()
878                .filter(|(i, _)| *i < start_test || *i >= end_test)
879                .map(|(_, &idx)| idx)
880                .collect();
881
882            // Extract train/test data
883            let X_train = X.select(Axis(0), &train_indices);
884            let y_train = y.select(Axis(0), &train_indices);
885            let X_test = X.select(Axis(0), test_indices);
886            let y_test = y.select(Axis(0), test_indices);
887
888            // Convert to i32 for classification
889            let y_train_int = y_train.mapv(|x| x.round() as i32);
890            let y_test_int = y_test.mapv(|x| x.round() as i32);
891
892            // Train and evaluate
893            estimator.fit(&X_train, Some(&y_train))?;
894            scores[fold] = estimator.score(&X_test, &y_test_int)?;
895        }
896
897        Ok(scores)
898    }
899
900    /// Train-test split
901    pub fn train_test_split(
902        X: &Array2<f64>,
903        y: &Array1<f64>,
904        test_size: f64,
905        random_state: Option<u64>,
906    ) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>, Array1<f64>)> {
907        let n_samples = X.nrows();
908        let n_test = (n_samples as f64 * test_size).round() as usize;
909
910        // Create indices
911        let mut indices: Vec<usize> = (0..n_samples).collect();
912
913        if let Some(seed) = random_state {
914            use scirs2_core::random::prelude::*;
915            let mut rng = StdRng::seed_from_u64(seed);
916            indices.shuffle(&mut rng);
917        } else {
918            indices.shuffle(&mut thread_rng());
919        }
920
921        let test_indices = &indices[..n_test];
922        let train_indices = &indices[n_test..];
923
924        let X_train = X.select(Axis(0), train_indices);
925        let X_test = X.select(Axis(0), test_indices);
926        let y_train = y.select(Axis(0), train_indices);
927        let y_test = y.select(Axis(0), test_indices);
928
929        Ok((X_train, X_test, y_train, y_test))
930    }
931
932    /// Grid search for hyperparameter tuning
933    pub struct GridSearchCV<E> {
934        /// Base estimator
935        estimator: E,
936        /// Parameter grid
937        param_grid: HashMap<String, Vec<String>>,
938        /// Cross-validation folds
939        cv: usize,
940        /// Best parameters
941        pub best_params_: HashMap<String, String>,
942        /// Best score
943        pub best_score_: f64,
944        /// Best estimator
945        pub best_estimator_: E,
946        /// Fitted flag
947        fitted: bool,
948    }
949
950    impl<E> GridSearchCV<E>
951    where
952        E: SklearnClassifier + Clone,
953    {
954        /// Create new grid search
955        pub fn new(estimator: E, param_grid: HashMap<String, Vec<String>>, cv: usize) -> Self {
956            Self {
957                best_estimator_: estimator.clone(),
958                estimator,
959                param_grid,
960                cv,
961                best_params_: HashMap::new(),
962                best_score_: f64::NEG_INFINITY,
963                fitted: false,
964            }
965        }
966
967        /// Fit grid search
968        pub fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
969            let param_combinations = self.generate_param_combinations();
970
971            for params in param_combinations {
972                let mut estimator = self.estimator.clone();
973                estimator.set_params(params.clone())?;
974
975                let scores = cross_val_score(&mut estimator, X, y, self.cv)?;
976                let mean_score = scores.mean().unwrap_or(0.0);
977
978                if mean_score > self.best_score_ {
979                    self.best_score_ = mean_score;
980                    self.best_params_ = params.clone();
981                    self.best_estimator_ = estimator;
982                }
983            }
984
985            // Fit best estimator
986            if !self.best_params_.is_empty() {
987                self.best_estimator_.set_params(self.best_params_.clone())?;
988                self.best_estimator_.fit(X, Some(y))?;
989            }
990
991            self.fitted = true;
992            Ok(())
993        }
994
995        /// Generate all parameter combinations
996        fn generate_param_combinations(&self) -> Vec<HashMap<String, String>> {
997            let mut combinations = vec![HashMap::new()];
998
999            for (param_name, param_values) in &self.param_grid {
1000                let mut new_combinations = Vec::new();
1001
1002                for combination in &combinations {
1003                    for value in param_values {
1004                        let mut new_combination = combination.clone();
1005                        new_combination.insert(param_name.clone(), value.clone());
1006                        new_combinations.push(new_combination);
1007                    }
1008                }
1009
1010                combinations = new_combinations;
1011            }
1012
1013            combinations
1014        }
1015
1016        /// Get best parameters
1017        pub fn best_params(&self) -> &HashMap<String, String> {
1018            &self.best_params_
1019        }
1020
1021        /// Get best score
1022        pub fn best_score(&self) -> f64 {
1023            self.best_score_
1024        }
1025
1026        /// Predict with best estimator
1027        pub fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
1028            if !self.fitted {
1029                return Err(MLError::ModelNotTrained("Model not trained".to_string()));
1030            }
1031            self.best_estimator_.predict(X)
1032        }
1033    }
1034}
1035
1036/// Standard Scaler (sklearn-compatible)
1037pub struct StandardScaler {
1038    mean_: Option<Array1<f64>>,
1039    scale_: Option<Array1<f64>>,
1040    fitted: bool,
1041}
1042
1043impl StandardScaler {
1044    pub fn new() -> Self {
1045        Self {
1046            mean_: None,
1047            scale_: None,
1048            fitted: false,
1049        }
1050    }
1051}
1052
1053impl SklearnEstimator for StandardScaler {
1054    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
1055        let mean = X.mean_axis(scirs2_core::ndarray::Axis(0)).ok_or_else(|| {
1056            MLError::InvalidInput("Cannot compute mean of empty array".to_string())
1057        })?;
1058        let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
1059
1060        self.mean_ = Some(mean);
1061        self.scale_ = Some(std);
1062        self.fitted = true;
1063
1064        Ok(())
1065    }
1066
1067    fn get_params(&self) -> HashMap<String, String> {
1068        HashMap::new()
1069    }
1070
1071    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
1072        Ok(())
1073    }
1074
1075    fn is_fitted(&self) -> bool {
1076        self.fitted
1077    }
1078}
1079
1080/// Select K Best features (sklearn-compatible)
1081pub struct SelectKBest {
1082    score_func: String,
1083    k: usize,
1084    fitted: bool,
1085    selected_features_: Option<Vec<usize>>,
1086}
1087
1088impl SelectKBest {
1089    pub fn new(score_func: &str, k: usize) -> Self {
1090        Self {
1091            score_func: score_func.to_string(),
1092            k,
1093            fitted: false,
1094            selected_features_: None,
1095        }
1096    }
1097}
1098
1099impl SklearnEstimator for SelectKBest {
1100    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
1101        // Mock implementation - select first k features
1102        let features: Vec<usize> = (0..self.k.min(X.ncols())).collect();
1103        self.selected_features_ = Some(features);
1104        self.fitted = true;
1105        Ok(())
1106    }
1107
1108    fn get_params(&self) -> HashMap<String, String> {
1109        let mut params = HashMap::new();
1110        params.insert("score_func".to_string(), self.score_func.clone());
1111        params.insert("k".to_string(), self.k.to_string());
1112        params
1113    }
1114
1115    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
1116        for (key, value) in params {
1117            match key.as_str() {
1118                "k" => {
1119                    self.k = value.parse().map_err(|_| {
1120                        MLError::InvalidConfiguration(format!("Invalid k parameter: {}", value))
1121                    })?;
1122                }
1123                "score_func" => {
1124                    self.score_func = value;
1125                }
1126                _ => {}
1127            }
1128        }
1129        Ok(())
1130    }
1131
1132    fn is_fitted(&self) -> bool {
1133        self.fitted
1134    }
1135}
1136
1137/// Quantum Feature Encoder (sklearn-compatible)
1138pub struct QuantumFeatureEncoder {
1139    encoding_type: String,
1140    normalization: String,
1141    fitted: bool,
1142}
1143
1144impl QuantumFeatureEncoder {
1145    pub fn new(encoding_type: &str, normalization: &str) -> Self {
1146        Self {
1147            encoding_type: encoding_type.to_string(),
1148            normalization: normalization.to_string(),
1149            fitted: false,
1150        }
1151    }
1152}
1153
1154impl SklearnEstimator for QuantumFeatureEncoder {
1155    fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
1156        self.fitted = true;
1157        Ok(())
1158    }
1159
1160    fn get_params(&self) -> HashMap<String, String> {
1161        let mut params = HashMap::new();
1162        params.insert("encoding_type".to_string(), self.encoding_type.clone());
1163        params.insert("normalization".to_string(), self.normalization.clone());
1164        params
1165    }
1166
1167    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
1168        for (key, value) in params {
1169            match key.as_str() {
1170                "encoding_type" => {
1171                    self.encoding_type = value;
1172                }
1173                "normalization" => {
1174                    self.normalization = value;
1175                }
1176                _ => {}
1177            }
1178        }
1179        Ok(())
1180    }
1181
1182    fn is_fitted(&self) -> bool {
1183        self.fitted
1184    }
1185}
1186
1187/// Simple Pipeline implementation
1188pub struct Pipeline {
1189    steps: Vec<(String, Box<dyn SklearnEstimator>)>,
1190    fitted: bool,
1191}
1192
1193impl Pipeline {
1194    pub fn new(steps: Vec<(&str, Box<dyn SklearnEstimator>)>) -> Result<Self> {
1195        let steps = steps
1196            .into_iter()
1197            .map(|(name, estimator)| (name.to_string(), estimator))
1198            .collect();
1199        Ok(Self {
1200            steps,
1201            fitted: false,
1202        })
1203    }
1204
1205    pub fn named_steps(&self) -> Vec<&String> {
1206        self.steps.iter().map(|(name, _)| name).collect()
1207    }
1208}
1209
1210impl Clone for Pipeline {
1211    fn clone(&self) -> Self {
1212        // For demo purposes, create a new pipeline with default components
1213        Self {
1214            steps: Vec::new(),
1215            fitted: false,
1216        }
1217    }
1218}
1219
1220impl SklearnEstimator for Pipeline {
1221    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
1222        // Mock implementation
1223        self.fitted = true;
1224        Ok(())
1225    }
1226
1227    fn get_params(&self) -> HashMap<String, String> {
1228        HashMap::new()
1229    }
1230
1231    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
1232        Ok(())
1233    }
1234
1235    fn is_fitted(&self) -> bool {
1236        self.fitted
1237    }
1238}
1239
1240impl SklearnClassifier for Pipeline {
1241    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
1242        // Mock predictions
1243        Ok(Array1::from_shape_fn(X.nrows(), |i| {
1244            if i % 2 == 0 {
1245                1
1246            } else {
1247                0
1248            }
1249        }))
1250    }
1251
1252    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
1253        Ok(Array2::from_shape_fn((X.nrows(), 2), |(i, j)| {
1254            if j == 0 {
1255                0.4
1256            } else {
1257                0.6
1258            }
1259        }))
1260    }
1261
1262    fn classes(&self) -> &[i32] {
1263        &[0, 1]
1264    }
1265
1266    fn feature_importances(&self) -> Option<Array1<f64>> {
1267        Some(Array1::from_vec(vec![0.25, 0.35, 0.20, 0.20]))
1268    }
1269
1270    fn save(&self, _path: &str) -> Result<()> {
1271        Ok(())
1272    }
1273}
1274
1275impl Pipeline {
1276    pub fn load(_path: &str) -> Result<Self> {
1277        Ok(Self::new(vec![])?)
1278    }
1279}
1280
1281/// Pipeline utilities (sklearn-compatible)
1282pub mod pipeline {
1283    use super::*;
1284
1285    /// Transformer trait
1286    pub trait SklearnTransformer: Send + Sync {
1287        /// Fit transformer
1288        fn fit(&mut self, X: &Array2<f64>) -> Result<()>;
1289
1290        /// Transform data
1291        fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>>;
1292
1293        /// Fit and transform
1294        fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
1295            self.fit(X)?;
1296            self.transform(X)
1297        }
1298    }
1299
1300    /// Quantum feature scaler
1301    pub struct QuantumStandardScaler {
1302        /// Feature means
1303        mean_: Option<Array1<f64>>,
1304        /// Feature standard deviations
1305        scale_: Option<Array1<f64>>,
1306        /// Fitted flag
1307        fitted: bool,
1308    }
1309
1310    impl QuantumStandardScaler {
1311        /// Create new scaler
1312        pub fn new() -> Self {
1313            Self {
1314                mean_: None,
1315                scale_: None,
1316                fitted: false,
1317            }
1318        }
1319    }
1320
1321    impl SklearnTransformer for QuantumStandardScaler {
1322        fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
1323            let mean = X.mean_axis(Axis(0)).ok_or_else(|| {
1324                MLError::InvalidInput("Cannot compute mean of empty array".to_string())
1325            })?;
1326            let std = X.std_axis(Axis(0), 0.0);
1327
1328            self.mean_ = Some(mean);
1329            self.scale_ = Some(std);
1330            self.fitted = true;
1331
1332            Ok(())
1333        }
1334
1335        fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
1336            if !self.fitted {
1337                return Err(MLError::ModelNotTrained("Model not trained".to_string()));
1338            }
1339
1340            let mean = self
1341                .mean_
1342                .as_ref()
1343                .ok_or_else(|| MLError::ModelNotTrained("Mean not initialized".to_string()))?;
1344            let scale = self
1345                .scale_
1346                .as_ref()
1347                .ok_or_else(|| MLError::ModelNotTrained("Scale not initialized".to_string()))?;
1348
1349            let mut X_scaled = X.clone();
1350            for mut row in X_scaled.axis_iter_mut(Axis(0)) {
1351                row -= mean;
1352                row /= scale;
1353            }
1354
1355            Ok(X_scaled)
1356        }
1357    }
1358
1359    /// Quantum pipeline
1360    pub struct QuantumPipeline {
1361        /// Pipeline steps
1362        steps: Vec<(String, PipelineStep)>,
1363        /// Fitted flag
1364        fitted: bool,
1365    }
1366
1367    /// Pipeline step enum
1368    pub enum PipelineStep {
1369        /// Transformer step
1370        Transformer(Box<dyn SklearnTransformer>),
1371        /// Classifier step
1372        Classifier(Box<dyn SklearnClassifier>),
1373        /// Regressor step
1374        Regressor(Box<dyn SklearnRegressor>),
1375        /// Clusterer step
1376        Clusterer(Box<dyn SklearnClusterer>),
1377    }
1378
1379    impl QuantumPipeline {
1380        /// Create new pipeline
1381        pub fn new() -> Self {
1382            Self {
1383                steps: Vec::new(),
1384                fitted: false,
1385            }
1386        }
1387
1388        /// Add transformer step
1389        pub fn add_transformer(
1390            mut self,
1391            name: String,
1392            transformer: Box<dyn SklearnTransformer>,
1393        ) -> Self {
1394            self.steps
1395                .push((name, PipelineStep::Transformer(transformer)));
1396            self
1397        }
1398
1399        /// Add classifier step
1400        pub fn add_classifier(
1401            mut self,
1402            name: String,
1403            classifier: Box<dyn SklearnClassifier>,
1404        ) -> Self {
1405            self.steps
1406                .push((name, PipelineStep::Classifier(classifier)));
1407            self
1408        }
1409
1410        /// Fit pipeline
1411        pub fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
1412            let mut current_X = X.clone();
1413
1414            for (_name, step) in &mut self.steps {
1415                match step {
1416                    PipelineStep::Transformer(transformer) => {
1417                        current_X = transformer.fit_transform(&current_X)?;
1418                    }
1419                    PipelineStep::Classifier(classifier) => {
1420                        classifier.fit(&current_X, y)?;
1421                    }
1422                    PipelineStep::Regressor(regressor) => {
1423                        regressor.fit(&current_X, y)?;
1424                    }
1425                    PipelineStep::Clusterer(clusterer) => {
1426                        clusterer.fit(&current_X, y)?;
1427                    }
1428                }
1429            }
1430
1431            self.fitted = true;
1432            Ok(())
1433        }
1434
1435        /// Predict with pipeline
1436        pub fn predict(&self, X: &Array2<f64>) -> Result<ArrayD<f64>> {
1437            if !self.fitted {
1438                return Err(MLError::ModelNotTrained("Model not trained".to_string()));
1439            }
1440
1441            let mut current_X = X.clone();
1442
1443            for (_name, step) in &self.steps {
1444                match step {
1445                    PipelineStep::Transformer(transformer) => {
1446                        current_X = transformer.transform(&current_X)?;
1447                    }
1448                    PipelineStep::Classifier(classifier) => {
1449                        let predictions = classifier.predict(&current_X)?;
1450                        let predictions_f64 = predictions.mapv(|x| x as f64);
1451                        return Ok(predictions_f64.into_dyn());
1452                    }
1453                    PipelineStep::Regressor(regressor) => {
1454                        let predictions = regressor.predict(&current_X)?;
1455                        return Ok(predictions.into_dyn());
1456                    }
1457                    PipelineStep::Clusterer(clusterer) => {
1458                        let predictions = clusterer.predict(&current_X)?;
1459                        let predictions_f64 = predictions.mapv(|x| x as f64);
1460                        return Ok(predictions_f64.into_dyn());
1461                    }
1462                }
1463            }
1464
1465            Ok(current_X.into_dyn())
1466        }
1467    }
1468}
1469
1470/// Metrics module (sklearn-compatible)
1471pub mod metrics {
1472    use super::*;
1473
1474    /// Calculate accuracy score
1475    pub fn accuracy_score(y_true: &Array1<i32>, y_pred: &Array1<i32>) -> f64 {
1476        let correct = y_true
1477            .iter()
1478            .zip(y_pred.iter())
1479            .filter(|(&true_val, &pred_val)| true_val == pred_val)
1480            .count();
1481        correct as f64 / y_true.len() as f64
1482    }
1483
1484    /// Calculate precision score
1485    pub fn precision_score(y_true: &Array1<i32>, y_pred: &Array1<i32>, _average: &str) -> f64 {
1486        // Mock implementation
1487        0.85
1488    }
1489
1490    /// Calculate recall score
1491    pub fn recall_score(y_true: &Array1<i32>, y_pred: &Array1<i32>, _average: &str) -> f64 {
1492        // Mock implementation
1493        0.82
1494    }
1495
1496    /// Calculate F1 score
1497    pub fn f1_score(y_true: &Array1<i32>, y_pred: &Array1<i32>, _average: &str) -> f64 {
1498        // Mock implementation
1499        0.83
1500    }
1501
1502    /// Generate classification report
1503    pub fn classification_report(
1504        y_true: &Array1<i32>,
1505        y_pred: &Array1<i32>,
1506        target_names: Vec<&str>,
1507        digits: usize,
1508    ) -> String {
1509        format!("Classification Report\n==================\n{:>10} {:>10} {:>10} {:>10} {:>10}\n{:>10} {:>10.digits$} {:>10.digits$} {:>10.digits$} {:>10}\n{:>10} {:>10.digits$} {:>10.digits$} {:>10.digits$} {:>10}\n",
1510            "", "precision", "recall", "f1-score", "support",
1511            target_names[0], 0.85, 0.82, 0.83, 50,
1512            target_names[1], 0.87, 0.85, 0.86, 50,
1513            digits = digits)
1514    }
1515
1516    /// Calculate silhouette score
1517    pub fn silhouette_score(X: &Array2<f64>, labels: &Array1<i32>, _metric: &str) -> f64 {
1518        // Mock implementation
1519        0.65
1520    }
1521
1522    /// Calculate Calinski-Harabasz score
1523    pub fn calinski_harabasz_score(X: &Array2<f64>, labels: &Array1<i32>) -> f64 {
1524        // Mock implementation
1525        150.0
1526    }
1527}
1528
1529#[cfg(test)]
1530mod tests {
1531    use super::*;
1532    use scirs2_core::ndarray::Array;
1533
1534    #[test]
1535    #[ignore]
1536    fn test_quantum_svc() {
1537        let mut svc = QuantumSVC::new().set_C(1.0).set_gamma(0.1);
1538
1539        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0])
1540            .expect("Failed to create input array X");
1541        let y = Array::from_vec(vec![1.0, -1.0, -1.0, 1.0]);
1542
1543        assert!(svc.fit(&X, Some(&y)).is_ok());
1544        assert!(svc.is_fitted());
1545
1546        let predictions = svc.predict(&X);
1547        assert!(predictions.is_ok());
1548    }
1549
1550    #[test]
1551    #[ignore]
1552    fn test_quantum_mlp_classifier() {
1553        let mut mlp = QuantumMLPClassifier::new()
1554            .set_hidden_layer_sizes(vec![5])
1555            .set_learning_rate(0.01)
1556            .set_max_iter(10);
1557
1558        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0])
1559            .expect("Failed to create input array X");
1560        let y = Array::from_vec(vec![1.0, 0.0, 0.0, 1.0]);
1561
1562        assert!(mlp.fit(&X, Some(&y)).is_ok());
1563        assert!(mlp.is_fitted());
1564
1565        let predictions = mlp.predict(&X);
1566        assert!(predictions.is_ok());
1567
1568        let probas = mlp.predict_proba(&X);
1569        assert!(probas.is_ok());
1570    }
1571
1572    #[test]
1573    #[ignore]
1574    fn test_quantum_kmeans() {
1575        let mut kmeans = QuantumKMeans::new(2).set_max_iter(50).set_tol(1e-4);
1576
1577        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 1.1, 1.1, -1.0, -1.0, -1.1, -1.1])
1578            .expect("Failed to create input array X");
1579
1580        assert!(kmeans.fit(&X, None).is_ok());
1581        assert!(kmeans.is_fitted());
1582
1583        let predictions = kmeans.predict(&X);
1584        assert!(predictions.is_ok());
1585
1586        assert!(kmeans.cluster_centers().is_some());
1587    }
1588
1589    #[test]
1590    fn test_model_selection() {
1591        use model_selection::train_test_split;
1592
1593        let X = Array::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect())
1594            .expect("Failed to create input array X");
1595        let y = Array::from_vec((0..10).map(|x| x as f64).collect());
1596
1597        let (X_train, X_test, y_train, y_test) =
1598            train_test_split(&X, &y, 0.3, Some(42)).expect("train_test_split should succeed");
1599
1600        assert_eq!(X_train.nrows() + X_test.nrows(), X.nrows());
1601        assert_eq!(y_train.len() + y_test.len(), y.len());
1602    }
1603
1604    #[test]
1605    fn test_pipeline() {
1606        use pipeline::{QuantumPipeline, QuantumStandardScaler};
1607
1608        let mut pipeline = QuantumPipeline::new()
1609            .add_transformer("scaler".to_string(), Box::new(QuantumStandardScaler::new()));
1610
1611        let X = Array::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0])
1612            .expect("Failed to create input array X");
1613
1614        assert!(pipeline.fit(&X, None).is_ok());
1615
1616        let transformed = pipeline.predict(&X);
1617        assert!(transformed.is_ok());
1618    }
1619}