quantrs2_ml/sklearn_compatibility/
regressors.rs

1//! Sklearn-compatible regressors
2
3use super::{SklearnEstimator, SklearnRegressor};
4use crate::error::{MLError, Result};
5use crate::qnn::{QNNBuilder, QuantumNeuralNetwork};
6use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Quantum MLP Regressor (sklearn-compatible)
12pub struct QuantumMLPRegressor {
13    /// Internal QNN
14    qnn: Option<QuantumNeuralNetwork>,
15    /// Network configuration
16    hidden_layer_sizes: Vec<usize>,
17    /// Activation function
18    activation: String,
19    /// Solver
20    solver: String,
21    /// Learning rate
22    learning_rate: f64,
23    /// Maximum iterations
24    max_iter: usize,
25    /// Random state
26    random_state: Option<u64>,
27    /// Backend
28    backend: Arc<dyn SimulatorBackend>,
29    /// Fitted flag
30    fitted: bool,
31}
32
33impl QuantumMLPRegressor {
34    /// Create new Quantum MLP Regressor
35    pub fn new() -> Self {
36        Self {
37            qnn: None,
38            hidden_layer_sizes: vec![10],
39            activation: "relu".to_string(),
40            solver: "adam".to_string(),
41            learning_rate: 0.001,
42            max_iter: 200,
43            random_state: None,
44            backend: Arc::new(StatevectorBackend::new(10)),
45            fitted: false,
46        }
47    }
48
49    /// Set hidden layer sizes
50    pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
51        self.hidden_layer_sizes = sizes;
52        self
53    }
54
55    /// Set learning rate
56    pub fn set_learning_rate(mut self, lr: f64) -> Self {
57        self.learning_rate = lr;
58        self
59    }
60
61    /// Set maximum iterations
62    pub fn set_max_iter(mut self, max_iter: usize) -> Self {
63        self.max_iter = max_iter;
64        self
65    }
66}
67
68impl Default for QuantumMLPRegressor {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl SklearnEstimator for QuantumMLPRegressor {
75    #[allow(non_snake_case)]
76    fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
77        let y = y.ok_or_else(|| {
78            MLError::InvalidConfiguration("Target values required for regression".to_string())
79        })?;
80
81        // Build QNN for regression
82        let _input_size = X.ncols();
83        let output_size = 1; // Single output for regression
84
85        let mut builder = QNNBuilder::new();
86
87        // Add hidden layers
88        for &size in &self.hidden_layer_sizes {
89            builder = builder.add_layer(size);
90        }
91
92        // Add output layer
93        builder = builder.add_layer(output_size);
94
95        let mut qnn = builder.build()?;
96
97        // Reshape target for training
98        let y_reshaped = y.clone().into_shape((y.len(), 1)).map_err(|e| {
99            MLError::InvalidConfiguration(format!("Failed to reshape target: {}", e))
100        })?;
101
102        // Train QNN
103        qnn.train(X, &y_reshaped, self.max_iter, self.learning_rate)?;
104
105        self.qnn = Some(qnn);
106        self.fitted = true;
107
108        Ok(())
109    }
110
111    fn get_params(&self) -> HashMap<String, String> {
112        let mut params = HashMap::new();
113        params.insert(
114            "hidden_layer_sizes".to_string(),
115            format!("{:?}", self.hidden_layer_sizes),
116        );
117        params.insert("activation".to_string(), self.activation.clone());
118        params.insert("solver".to_string(), self.solver.clone());
119        params.insert("learning_rate".to_string(), self.learning_rate.to_string());
120        params.insert("max_iter".to_string(), self.max_iter.to_string());
121        params
122    }
123
124    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
125        for (key, value) in params {
126            match key.as_str() {
127                "learning_rate" => {
128                    self.learning_rate = value.parse().map_err(|_| {
129                        MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
130                    })?;
131                }
132                "max_iter" => {
133                    self.max_iter = value.parse().map_err(|_| {
134                        MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
135                    })?;
136                }
137                "activation" => {
138                    self.activation = value;
139                }
140                "solver" => {
141                    self.solver = value;
142                }
143                _ => {
144                    // Skip unknown parameters
145                }
146            }
147        }
148        Ok(())
149    }
150
151    fn is_fitted(&self) -> bool {
152        self.fitted
153    }
154}
155
156impl SklearnRegressor for QuantumMLPRegressor {
157    #[allow(non_snake_case)]
158    fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>> {
159        if !self.fitted {
160            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
161        }
162
163        let qnn = self
164            .qnn
165            .as_ref()
166            .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
167        let predictions = qnn.predict_batch(X)?;
168
169        // Extract single column for regression
170        Ok(predictions.column(0).to_owned())
171    }
172}