quantrs2_ml/sklearn_compatibility/
regressors.rs1use super::{SklearnEstimator, SklearnRegressor};
4use crate::error::{MLError, Result};
5use crate::qnn::{QNNBuilder, QuantumNeuralNetwork};
6use crate::simulator_backends::{SimulatorBackend, StatevectorBackend};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct QuantumMLPRegressor {
13 qnn: Option<QuantumNeuralNetwork>,
15 hidden_layer_sizes: Vec<usize>,
17 activation: String,
19 solver: String,
21 learning_rate: f64,
23 max_iter: usize,
25 random_state: Option<u64>,
27 backend: Arc<dyn SimulatorBackend>,
29 fitted: bool,
31}
32
33impl QuantumMLPRegressor {
34 pub fn new() -> Self {
36 Self {
37 qnn: None,
38 hidden_layer_sizes: vec![10],
39 activation: "relu".to_string(),
40 solver: "adam".to_string(),
41 learning_rate: 0.001,
42 max_iter: 200,
43 random_state: None,
44 backend: Arc::new(StatevectorBackend::new(10)),
45 fitted: false,
46 }
47 }
48
49 pub fn set_hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
51 self.hidden_layer_sizes = sizes;
52 self
53 }
54
55 pub fn set_learning_rate(mut self, lr: f64) -> Self {
57 self.learning_rate = lr;
58 self
59 }
60
61 pub fn set_max_iter(mut self, max_iter: usize) -> Self {
63 self.max_iter = max_iter;
64 self
65 }
66}
67
68impl Default for QuantumMLPRegressor {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl SklearnEstimator for QuantumMLPRegressor {
75 #[allow(non_snake_case)]
76 fn fit(&mut self, X: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
77 let y = y.ok_or_else(|| {
78 MLError::InvalidConfiguration("Target values required for regression".to_string())
79 })?;
80
81 let _input_size = X.ncols();
83 let output_size = 1; let mut builder = QNNBuilder::new();
86
87 for &size in &self.hidden_layer_sizes {
89 builder = builder.add_layer(size);
90 }
91
92 builder = builder.add_layer(output_size);
94
95 let mut qnn = builder.build()?;
96
97 let y_reshaped = y.clone().into_shape((y.len(), 1)).map_err(|e| {
99 MLError::InvalidConfiguration(format!("Failed to reshape target: {}", e))
100 })?;
101
102 qnn.train(X, &y_reshaped, self.max_iter, self.learning_rate)?;
104
105 self.qnn = Some(qnn);
106 self.fitted = true;
107
108 Ok(())
109 }
110
111 fn get_params(&self) -> HashMap<String, String> {
112 let mut params = HashMap::new();
113 params.insert(
114 "hidden_layer_sizes".to_string(),
115 format!("{:?}", self.hidden_layer_sizes),
116 );
117 params.insert("activation".to_string(), self.activation.clone());
118 params.insert("solver".to_string(), self.solver.clone());
119 params.insert("learning_rate".to_string(), self.learning_rate.to_string());
120 params.insert("max_iter".to_string(), self.max_iter.to_string());
121 params
122 }
123
124 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
125 for (key, value) in params {
126 match key.as_str() {
127 "learning_rate" => {
128 self.learning_rate = value.parse().map_err(|_| {
129 MLError::InvalidConfiguration(format!("Invalid learning_rate: {}", value))
130 })?;
131 }
132 "max_iter" => {
133 self.max_iter = value.parse().map_err(|_| {
134 MLError::InvalidConfiguration(format!("Invalid max_iter: {}", value))
135 })?;
136 }
137 "activation" => {
138 self.activation = value;
139 }
140 "solver" => {
141 self.solver = value;
142 }
143 _ => {
144 }
146 }
147 }
148 Ok(())
149 }
150
151 fn is_fitted(&self) -> bool {
152 self.fitted
153 }
154}
155
156impl SklearnRegressor for QuantumMLPRegressor {
157 #[allow(non_snake_case)]
158 fn predict(&self, X: &Array2<f64>) -> Result<Array1<f64>> {
159 if !self.fitted {
160 return Err(MLError::ModelNotTrained("Model not trained".to_string()));
161 }
162
163 let qnn = self
164 .qnn
165 .as_ref()
166 .ok_or_else(|| MLError::ModelNotTrained("QNN model not initialized".to_string()))?;
167 let predictions = qnn.predict_batch(X)?;
168
169 Ok(predictions.column(0).to_owned())
171 }
172}