1use crate::error::{MLError, Result};
2use scirs2_core::ndarray::{Array1, ArrayView1};
3use std::collections::HashMap;
4use std::fmt;
5
6#[derive(Debug, Clone, Copy)]
8pub enum OptimizationMethod {
9 GradientDescent,
11
12 Adam,
14
15 SPSA,
17
18 LBFGS,
20
21 QuantumNaturalGradient,
23
24 SciRS2Adam,
26
27 SciRS2LBFGS,
29
30 SciRS2CG,
32}
33
34#[derive(Debug, Clone)]
36pub enum Optimizer {
37 GradientDescent {
39 learning_rate: f64,
41 },
42
43 Adam {
45 learning_rate: f64,
47
48 beta1: f64,
50
51 beta2: f64,
53
54 epsilon: f64,
56 },
57
58 SPSA {
60 learning_rate: f64,
62
63 perturbation: f64,
65 },
66
67 QuantumNaturalGradient {
75 learning_rate: f64,
77 regularization: f64,
79 },
80
81 SciRS2 {
83 method: String,
85 config: HashMap<String, f64>,
87 },
88}
89
90impl Optimizer {
91 pub fn new(method: OptimizationMethod) -> Self {
93 match method {
94 OptimizationMethod::GradientDescent => Optimizer::GradientDescent {
95 learning_rate: 0.01,
96 },
97 OptimizationMethod::Adam => Optimizer::Adam {
98 learning_rate: 0.01,
99 beta1: 0.9,
100 beta2: 0.999,
101 epsilon: 1e-8,
102 },
103 OptimizationMethod::SPSA => Optimizer::SPSA {
104 learning_rate: 0.01,
105 perturbation: 0.01,
106 },
107 OptimizationMethod::LBFGS => {
108 Optimizer::Adam {
110 learning_rate: 0.01,
111 beta1: 0.9,
112 beta2: 0.999,
113 epsilon: 1e-8,
114 }
115 }
116 OptimizationMethod::QuantumNaturalGradient => Optimizer::QuantumNaturalGradient {
117 learning_rate: 0.01,
118 regularization: 1e-3,
119 },
120 OptimizationMethod::SciRS2Adam => {
121 let mut config = HashMap::new();
122 config.insert("learning_rate".to_string(), 0.001);
123 config.insert("beta1".to_string(), 0.9);
124 config.insert("beta2".to_string(), 0.999);
125 config.insert("epsilon".to_string(), 1e-8);
126 Optimizer::SciRS2 {
127 method: "adam".to_string(),
128 config,
129 }
130 }
131 OptimizationMethod::SciRS2LBFGS => {
132 let mut config = HashMap::new();
133 config.insert("m".to_string(), 10.0); config.insert("c1".to_string(), 1e-4);
135 config.insert("c2".to_string(), 0.9);
136 Optimizer::SciRS2 {
137 method: "lbfgs".to_string(),
138 config,
139 }
140 }
141 OptimizationMethod::SciRS2CG => {
142 let mut config = HashMap::new();
143 config.insert("beta_method".to_string(), 0.0); config.insert("restart_threshold".to_string(), 100.0);
145 Optimizer::SciRS2 {
146 method: "cg".to_string(),
147 config,
148 }
149 }
150 }
151 }
152
153 pub fn update_parameters(
155 &self,
156 parameters: &mut Array1<f64>,
157 gradients: &ArrayView1<f64>,
158 iteration: usize,
159 ) -> Result<()> {
160 match self {
161 Optimizer::GradientDescent { learning_rate } => {
162 for i in 0..parameters.len() {
164 parameters[i] -= learning_rate * gradients[i];
165 }
166 Ok(())
167 }
168 Optimizer::Adam {
169 learning_rate,
170 beta1,
171 beta2,
172 epsilon,
173 } => {
174 for i in 0..parameters.len() {
177 parameters[i] -= learning_rate * gradients[i];
178 }
179 Ok(())
180 }
181 Optimizer::SPSA {
182 learning_rate,
183 perturbation,
184 } => {
185 for i in 0..parameters.len() {
187 parameters[i] -= learning_rate * gradients[i];
188 }
189 Ok(())
190 }
191 Optimizer::QuantumNaturalGradient {
192 learning_rate,
193 regularization,
194 } => {
195 let damp = 1.0 + regularization;
199 for i in 0..parameters.len() {
200 parameters[i] -= learning_rate * gradients[i] / damp;
201 }
202 Ok(())
203 }
204 Optimizer::SciRS2 { method, config } => {
205 let learning_rate = config.get("learning_rate").unwrap_or(&0.001);
207 match method.as_str() {
208 "adam" => {
209 for i in 0..parameters.len() {
211 parameters[i] -= learning_rate * gradients[i];
212 }
213 }
214 "lbfgs" => {
215 for i in 0..parameters.len() {
217 parameters[i] -= learning_rate * gradients[i];
218 }
219 }
220 "cg" => {
221 for i in 0..parameters.len() {
223 parameters[i] -= learning_rate * gradients[i];
224 }
225 }
226 _ => {
227 return Err(MLError::InvalidConfiguration(format!(
228 "Unknown SciRS2 optimizer method: {}",
229 method
230 )));
231 }
232 }
233 Ok(())
234 }
235 }
236 }
237}
238
239pub trait ObjectiveFunction {
241 fn evaluate(&self, parameters: &ArrayView1<f64>) -> Result<f64>;
243
244 fn gradient(&self, parameters: &ArrayView1<f64>) -> Result<Array1<f64>> {
246 let epsilon = 1e-6;
248 let n = parameters.len();
249 let mut gradient = Array1::zeros(n);
250
251 let f0 = self.evaluate(parameters)?;
252
253 for i in 0..n {
254 let mut params_plus = parameters.to_owned();
255 params_plus[i] += epsilon;
256
257 let f_plus = self.evaluate(¶ms_plus.view())?;
258
259 gradient[i] = (f_plus - f0) / epsilon;
260 }
261
262 Ok(gradient)
263 }
264}
265
266impl fmt::Display for OptimizationMethod {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 match self {
269 OptimizationMethod::GradientDescent => write!(f, "Gradient Descent"),
270 OptimizationMethod::Adam => write!(f, "Adam"),
271 OptimizationMethod::SPSA => write!(f, "SPSA"),
272 OptimizationMethod::LBFGS => write!(f, "L-BFGS"),
273 OptimizationMethod::QuantumNaturalGradient => write!(f, "Quantum Natural Gradient"),
274 OptimizationMethod::SciRS2Adam => write!(f, "SciRS2 Adam"),
275 OptimizationMethod::SciRS2LBFGS => write!(f, "SciRS2 L-BFGS"),
276 OptimizationMethod::SciRS2CG => write!(f, "SciRS2 Conjugate Gradient"),
277 }
278 }
279}
280
281impl fmt::Display for Optimizer {
282 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283 match self {
284 Optimizer::GradientDescent { learning_rate } => {
285 write!(f, "Gradient Descent (learning_rate: {})", learning_rate)
286 }
287 Optimizer::Adam {
288 learning_rate,
289 beta1,
290 beta2,
291 epsilon,
292 } => {
293 write!(
294 f,
295 "Adam (learning_rate: {}, beta1: {}, beta2: {}, epsilon: {})",
296 learning_rate, beta1, beta2, epsilon
297 )
298 }
299 Optimizer::SPSA {
300 learning_rate,
301 perturbation,
302 } => {
303 write!(
304 f,
305 "SPSA (learning_rate: {}, perturbation: {})",
306 learning_rate, perturbation
307 )
308 }
309 Optimizer::QuantumNaturalGradient {
310 learning_rate,
311 regularization,
312 } => {
313 write!(
314 f,
315 "Quantum Natural Gradient (learning_rate: {}, regularization: {})",
316 learning_rate, regularization
317 )
318 }
319 Optimizer::SciRS2 { method, config } => {
320 write!(f, "SciRS2 {} with config: {:?}", method, config)
321 }
322 }
323 }
324}