quantrs2_ml/sklearn_compatibility/
classifiers.rs

1//! Classifier implementations for sklearn compatibility
2
3use super::{SklearnClassifier, SklearnEstimator};
4use crate::error::{MLError, Result};
5use crate::qnn::{QNNBuilder, QuantumNeuralNetwork};
6use crate::qsvm::{FeatureMapType, QSVMParams, QSVM};
7use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Quantum Support Vector Machine (sklearn-compatible)
13pub struct QuantumSVC {
14    /// Internal QSVM
15    qsvm: Option<QSVM>,
16    /// SVM parameters
17    params: QSVMParams,
18    /// Feature map type
19    feature_map: FeatureMapType,
20    /// Backend
21    backend: Arc<dyn SimulatorBackend>,
22    /// Fitted flag
23    fitted: bool,
24    /// Unique classes
25    classes: Vec<i32>,
26    /// Regularization parameter
27    #[allow(non_snake_case)]
28    C: f64,
29    /// Kernel gamma parameter
30    gamma: f64,
31}
32
33impl Clone for QuantumSVC {
34    fn clone(&self) -> Self {
35        Self {
36            qsvm: None,
37            params: self.params.clone(),
38            feature_map: self.feature_map,
39            backend: self.backend.clone(),
40            fitted: false,
41            classes: self.classes.clone(),
42            C: self.C,
43            gamma: self.gamma,
44        }
45    }
46}
47
48impl QuantumSVC {
49    /// Create new Quantum SVC
50    pub fn new() -> Self {
51        Self {
52            qsvm: None,
53            params: QSVMParams::default(),
54            feature_map: FeatureMapType::ZZFeatureMap,
55            backend: Arc::new(StatevectorBackend::new(10)),
56            fitted: false,
57            classes: Vec::new(),
58            C: 1.0,
59            gamma: 1.0,
60        }
61    }
62
63    /// Set regularization parameter
64    #[allow(non_snake_case)]
65    pub fn set_C(mut self, C: f64) -> Self {
66        self.C = C;
67        self
68    }
69
70    /// Set kernel gamma parameter
71    pub fn set_gamma(mut self, gamma: f64) -> Self {
72        self.gamma = gamma;
73        self
74    }
75
76    /// Set feature map
77    pub fn set_kernel(mut self, feature_map: FeatureMapType) -> Self {
78        self.feature_map = feature_map;
79        self
80    }
81
82    /// Set quantum backend
83    pub fn set_backend(mut self, backend: Arc<dyn SimulatorBackend>) -> Self {
84        self.backend = backend;
85        self
86    }
87
88    /// Load model from file
89    pub fn load(_path: &str) -> Result<Self> {
90        Ok(Self::new())
91    }
92}
93
94impl Default for QuantumSVC {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl SklearnEstimator for QuantumSVC {
101    #[allow(non_snake_case)]
102    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
103        let y = y.ok_or_else(|| {
104            MLError::InvalidConfiguration("Labels required for supervised learning".to_string())
105        })?;
106
107        let y_int: Array1<i32> = y.mapv(|val| val.round() as i32);
108
109        let mut classes = Vec::new();
110        for &label in y_int.iter() {
111            if !classes.contains(&label) {
112                classes.push(label);
113            }
114        }
115        classes.sort();
116        self.classes = classes;
117
118        self.params.feature_map = self.feature_map;
119        self.params.regularization = self.C;
120
121        let mut qsvm = QSVM::new(self.params.clone());
122        qsvm.fit(X, &y_int)?;
123
124        self.qsvm = Some(qsvm);
125        self.fitted = true;
126
127        Ok(())
128    }
129
130    fn get_params(&self) -> HashMap<String, String> {
131        let mut params = HashMap::new();
132        params.insert("C".to_string(), self.C.to_string());
133        params.insert("gamma".to_string(), self.gamma.to_string());
134        params.insert("kernel".to_string(), format!("{:?}", self.feature_map));
135        params
136    }
137
138    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
139        for (key, value) in params {
140            match key.as_str() {
141                "C" => {
142                    self.C = value.parse().map_err(|_| {
143                        MLError::InvalidConfiguration(format!("Invalid C parameter: {}", value))
144                    })?;
145                }
146                "gamma" => {
147                    self.gamma = value.parse().map_err(|_| {
148                        MLError::InvalidConfiguration(format!("Invalid gamma parameter: {}", value))
149                    })?;
150                }
151                "kernel" => {
152                    self.feature_map = match value.as_str() {
153                        "ZZFeatureMap" => FeatureMapType::ZZFeatureMap,
154                        "ZFeatureMap" => FeatureMapType::ZFeatureMap,
155                        "PauliFeatureMap" => FeatureMapType::PauliFeatureMap,
156                        _ => {
157                            return Err(MLError::InvalidConfiguration(format!(
158                                "Unknown kernel: {}",
159                                value
160                            )))
161                        }
162                    };
163                }
164                _ => {
165                    return Err(MLError::InvalidConfiguration(format!(
166                        "Unknown parameter: {}",
167                        key
168                    )))
169                }
170            }
171        }
172        Ok(())
173    }
174
175    fn is_fitted(&self) -> bool {
176        self.fitted
177    }
178}
179
180impl SklearnClassifier for QuantumSVC {
181    #[allow(non_snake_case)]
182    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
183        if !self.fitted {
184            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
185        }
186
187        let qsvm = self
188            .qsvm
189            .as_ref()
190            .ok_or_else(|| MLError::ModelNotTrained("QSVM model not initialized".to_string()))?;
191        qsvm.predict(X).map_err(MLError::ValidationError)
192    }
193
194    #[allow(non_snake_case)]
195    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
196        if !self.fitted {
197            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
198        }
199
200        let predictions = self.predict(X)?;
201        let n_samples = X.nrows();
202        let n_classes = self.classes.len();
203
204        let mut probabilities = Array2::zeros((n_samples, n_classes));
205
206        for (i, &prediction) in predictions.iter().enumerate() {
207            for (j, &class) in self.classes.iter().enumerate() {
208                probabilities[[i, j]] = if prediction == class { 1.0 } else { 0.0 };
209            }
210        }
211
212        Ok(probabilities)
213    }
214
215    fn classes(&self) -> &[i32] {
216        &self.classes
217    }
218}
219
220/// Quantum Neural Network Classifier (sklearn-compatible)
221pub struct QuantumMLPClassifier {
222    /// Internal QNN
223    qnn: Option<QuantumNeuralNetwork>,
224    /// Network configuration
225    hidden_layer_sizes: Vec<usize>,
226    /// Activation function
227    activation: String,
228    /// Solver
229    solver: String,
230    /// Learning rate
231    learning_rate: f64,
232    /// Maximum iterations
233    max_iter: usize,
234    /// Random state
235    random_state: Option<u64>,
236    /// Backend
237    backend: Arc<dyn SimulatorBackend>,
238    /// Fitted flag
239    fitted: bool,
240    /// Unique classes
241    classes: Vec<i32>,
242}
243
244impl QuantumMLPClassifier {
245    /// Create new Quantum MLP Classifier
246    pub fn new() -> Self {
247        Self {
248            qnn: None,
249            hidden_layer_sizes: vec![10],
250            activation: "relu".to_string(),
251            solver: "adam".to_string(),
252            learning_rate: 0.001,
253            max_iter: 200,
254            random_state: None,
255            backend: Arc::new(StatevectorBackend::new(10)),
256            fitted: false,
257            classes: Vec::new(),
258        }
259    }
260
261    /// Set hidden layer sizes
262    pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
263        self.hidden_layer_sizes = sizes;
264        self
265    }
266
267    /// Set activation function
268    pub fn set_activation(mut self, activation: String) -> Self {
269        self.activation = activation;
270        self
271    }
272
273    /// Set learning rate
274    pub fn set_learning_rate(mut self, lr: f64) -> Self {
275        self.learning_rate = lr;
276        self
277    }
278
279    /// Set maximum iterations
280    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
281        self.max_iter = max_iter;
282        self
283    }
284
285    /// Convert integer labels to one-hot encoding
286    fn to_one_hot(&self, y: &Array1<i32>) -> Result<Array2<f64>> {
287        let n_samples = y.len();
288        let n_classes = self.classes.len();
289        let mut one_hot = Array2::zeros((n_samples, n_classes));
290
291        for (i, &label) in y.iter().enumerate() {
292            if let Some(class_idx) = self.classes.iter().position(|&c| c == label) {
293                one_hot[[i, class_idx]] = 1.0;
294            }
295        }
296
297        Ok(one_hot)
298    }
299
300    /// Convert one-hot predictions to class labels
301    fn from_one_hot(&self, predictions: &Array2<f64>) -> Array1<i32> {
302        predictions
303            .axis_iter(Axis(0))
304            .map(|row| {
305                let max_idx = row
306                    .iter()
307                    .enumerate()
308                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
309                    .map(|(idx, _)| idx)
310                    .unwrap_or(0);
311                self.classes.get(max_idx).copied().unwrap_or(0)
312            })
313            .collect()
314    }
315}
316
317impl Default for QuantumMLPClassifier {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323impl SklearnEstimator for QuantumMLPClassifier {
324    #[allow(non_snake_case)]
325    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
326        let y = y.ok_or_else(|| {
327            MLError::InvalidConfiguration("Labels required for supervised learning".to_string())
328        })?;
329
330        let y_int: Array1<i32> = y.mapv(|val| val.round() as i32);
331
332        let mut classes = Vec::new();
333        for &label in y_int.iter() {
334            if !classes.contains(&label) {
335                classes.push(label);
336            }
337        }
338        classes.sort();
339        self.classes = classes;
340
341        let _input_size = X.ncols();
342        let output_size = self.classes.len();
343
344        let mut builder = QNNBuilder::new();
345
346        for &size in &self.hidden_layer_sizes {
347            builder = builder.add_layer(size);
348        }
349
350        builder = builder.add_layer(output_size);
351
352        let mut qnn = builder.build()?;
353
354        let y_one_hot = self.to_one_hot(&y_int)?;
355        qnn.train(X, &y_one_hot, self.max_iter, self.learning_rate)?;
356
357        self.qnn = Some(qnn);
358        self.fitted = true;
359
360        Ok(())
361    }
362
363    fn get_params(&self) -> HashMap<String, String> {
364        let mut params = HashMap::new();
365        params.insert(
366            "hidden_layer_sizes".to_string(),
367            format!("{:?}", self.hidden_layer_sizes),
368        );
369        params.insert("activation".to_string(), self.activation.clone());
370        params.insert("solver".to_string(), self.solver.clone());
371        params.insert("learning_rate".to_string(), self.learning_rate.to_string());
372        params.insert("max_iter".to_string(), self.max_iter.to_string());
373        params
374    }
375
376    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
377        for (key, value) in params {
378            match key.as_str() {
379                "learning_rate" => {
380                    self.learning_rate = value.parse().map_err(|_| {
381                        MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
382                    })?;
383                }
384                "max_iter" => {
385                    self.max_iter = value.parse().map_err(|_| {
386                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
387                    })?;
388                }
389                "activation" => {
390                    self.activation = value;
391                }
392                "solver" => {
393                    self.solver = value;
394                }
395                _ => {}
396            }
397        }
398        Ok(())
399    }
400
401    fn is_fitted(&self) -> bool {
402        self.fitted
403    }
404}
405
406impl SklearnClassifier for QuantumMLPClassifier {
407    #[allow(non_snake_case)]
408    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
409        if !self.fitted {
410            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
411        }
412
413        let qnn = self
414            .qnn
415            .as_ref()
416            .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
417        let predictions = qnn.predict_batch(X)?;
418        Ok(self.from_one_hot(&predictions))
419    }
420
421    #[allow(non_snake_case)]
422    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
423        if !self.fitted {
424            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
425        }
426
427        let qnn = self
428            .qnn
429            .as_ref()
430            .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
431        qnn.predict_batch(X)
432    }
433
434    fn classes(&self) -> &[i32] {
435        &self.classes
436    }
437}
438
439/// Voting Classifier for ensemble learning
440pub struct VotingClassifier {
441    /// Named classifiers
442    classifiers: Vec<(String, Box<dyn SklearnClassifier>)>,
443    /// Voting mode
444    voting: String,
445    /// Weights
446    weights: Option<Vec<f64>>,
447    /// Classes
448    classes: Vec<i32>,
449    /// Fitted flag
450    fitted: bool,
451}
452
453impl VotingClassifier {
454    /// Create new voting classifier
455    pub fn new(classifiers: Vec<(String, Box<dyn SklearnClassifier>)>) -> Self {
456        Self {
457            classifiers,
458            voting: "hard".to_string(),
459            weights: None,
460            classes: Vec::new(),
461            fitted: false,
462        }
463    }
464
465    /// Set voting mode
466    pub fn voting(mut self, voting: &str) -> Self {
467        self.voting = voting.to_string();
468        self
469    }
470
471    /// Set weights
472    pub fn weights(mut self, weights: Vec<f64>) -> Self {
473        self.weights = Some(weights);
474        self
475    }
476}
477
478impl SklearnEstimator for VotingClassifier {
479    #[allow(non_snake_case)]
480    fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
481        self.fitted = true;
482        Ok(())
483    }
484
485    fn get_params(&self) -> HashMap<String, String> {
486        let mut params = HashMap::new();
487        params.insert("voting".to_string(), self.voting.clone());
488        params.insert(
489            "n_classifiers".to_string(),
490            self.classifiers.len().to_string(),
491        );
492        params
493    }
494
495    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
496        if let Some(voting) = params.get("voting") {
497            self.voting = voting.clone();
498        }
499        Ok(())
500    }
501
502    fn is_fitted(&self) -> bool {
503        self.fitted
504    }
505}
506
507impl SklearnClassifier for VotingClassifier {
508    #[allow(non_snake_case)]
509    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
510        if !self.fitted || self.classifiers.is_empty() {
511            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
512        }
513
514        let n_samples = X.nrows();
515        let mut votes = vec![HashMap::new(); n_samples];
516
517        let weights = self
518            .weights
519            .clone()
520            .unwrap_or_else(|| vec![1.0; self.classifiers.len()]);
521
522        for (i, (_, clf)) in self.classifiers.iter().enumerate() {
523            let predictions = clf.predict(X)?;
524            let weight = weights.get(i).copied().unwrap_or(1.0);
525
526            for (j, &pred) in predictions.iter().enumerate() {
527                *votes[j].entry(pred).or_insert(0.0) += weight;
528            }
529        }
530
531        let result: Array1<i32> = votes
532            .iter()
533            .map(|vote_map| {
534                vote_map
535                    .iter()
536                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
537                    .map(|(&k, _)| k)
538                    .unwrap_or(0)
539            })
540            .collect();
541
542        Ok(result)
543    }
544
545    #[allow(non_snake_case)]
546    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
547        if !self.fitted || self.classifiers.is_empty() {
548            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
549        }
550
551        let n_samples = X.nrows();
552        let n_classes = self.classes.len().max(2);
553        let mut avg_proba = Array2::zeros((n_samples, n_classes));
554
555        let weights = self
556            .weights
557            .clone()
558            .unwrap_or_else(|| vec![1.0; self.classifiers.len()]);
559        let total_weight: f64 = weights.iter().sum();
560
561        for (i, (_, clf)) in self.classifiers.iter().enumerate() {
562            let proba = clf.predict_proba(X)?;
563            let weight = weights.get(i).copied().unwrap_or(1.0);
564
565            for row in 0..n_samples {
566                for col in 0..proba.ncols().min(n_classes) {
567                    avg_proba[[row, col]] += proba[[row, col]] * weight / total_weight;
568                }
569            }
570        }
571
572        Ok(avg_proba)
573    }
574
575    fn classes(&self) -> &[i32] {
576        &self.classes
577    }
578}