1use crate::prelude::HardwareOptimizations;
7use ndarray::Array1;
8use scirs2_core::parallel_ops::*;
9use serde::{Deserialize, Serialize};
10
11use crate::circuit_interfaces::{CircuitInterface, InterfaceCircuit};
12use crate::device_noise_models::DeviceNoiseModel;
13use crate::error::Result;
14
15use super::circuit::ParameterizedQuantumCircuit;
16use super::config::{GradientMethod, HardwareArchitecture, OptimizerType, QMLConfig};
17
18pub struct QuantumMLTrainer {
20 config: QMLConfig,
22 pqc: ParameterizedQuantumCircuit,
24 optimizer_state: OptimizerState,
26 training_history: TrainingHistory,
28 noise_model: Option<Box<dyn DeviceNoiseModel>>,
30 circuit_interface: CircuitInterface,
32 hardware_compiler: HardwareAwareCompiler,
34}
35
36#[derive(Debug, Clone)]
38pub struct OptimizerState {
39 pub parameters: Array1<f64>,
41 pub gradient: Array1<f64>,
43 pub momentum: Array1<f64>,
45 pub velocity: Array1<f64>,
47 pub learning_rate: f64,
49 pub iteration: usize,
51}
52
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct TrainingHistory {
56 pub loss_history: Vec<f64>,
58 pub gradient_norms: Vec<f64>,
60 pub parameter_norms: Vec<f64>,
62 pub epoch_times: Vec<f64>,
64 pub hardware_metrics: Vec<HardwareMetrics>,
66}
67
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
70pub struct HardwareMetrics {
71 pub compiled_depth: usize,
73 pub two_qubit_gates: usize,
75 pub execution_time: f64,
77 pub estimated_fidelity: f64,
79 pub shot_overhead: f64,
81}
82
83#[derive(Debug, Clone)]
85pub struct HardwareAwareCompiler {
86 hardware_arch: HardwareArchitecture,
88 hardware_opts: HardwareOptimizations,
90 compilation_stats: CompilationStats,
92}
93
94#[derive(Debug, Clone, Default)]
96pub struct CompilationStats {
97 pub original_depth: usize,
99 pub compiled_depth: usize,
101 pub swap_gates_added: usize,
103 pub compilation_time: f64,
105 pub estimated_execution_time: f64,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct TrainingResult {
112 pub final_parameters: Array1<f64>,
114 pub final_loss: f64,
116 pub epochs_completed: usize,
118 pub training_history: TrainingHistory,
120 pub converged: bool,
122}
123
124impl QuantumMLTrainer {
125 pub fn new(
127 config: QMLConfig,
128 pqc: ParameterizedQuantumCircuit,
129 noise_model: Option<Box<dyn DeviceNoiseModel>>,
130 ) -> Result<Self> {
131 let num_params = pqc.num_parameters();
132
133 let optimizer_state = OptimizerState {
134 parameters: pqc.parameters.clone(),
135 gradient: Array1::zeros(num_params),
136 momentum: Array1::zeros(num_params),
137 velocity: Array1::zeros(num_params),
138 learning_rate: config.learning_rate,
139 iteration: 0,
140 };
141
142 let training_history = TrainingHistory::default();
143 let circuit_interface = CircuitInterface::new(Default::default())?;
144 let hardware_compiler = HardwareAwareCompiler::new(
145 config.hardware_architecture,
146 pqc.hardware_optimizations.clone(),
147 );
148
149 Ok(Self {
150 config,
151 pqc,
152 optimizer_state,
153 training_history,
154 noise_model,
155 circuit_interface,
156 hardware_compiler,
157 })
158 }
159
160 pub fn train<F>(&mut self, loss_function: F) -> Result<TrainingResult>
162 where
163 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
164 {
165 let start_time = std::time::Instant::now();
166
167 for epoch in 0..self.config.max_epochs {
168 let epoch_start = std::time::Instant::now();
169
170 let gradient = self.compute_gradient(&loss_function)?;
172 self.optimizer_state.gradient = gradient;
173
174 self.update_parameters()?;
176
177 let current_loss = loss_function(&self.optimizer_state.parameters)?;
179
180 let epoch_time = epoch_start.elapsed().as_secs_f64();
182 self.training_history.loss_history.push(current_loss);
183 self.training_history.gradient_norms.push(
184 self.optimizer_state
185 .gradient
186 .iter()
187 .map(|x| x * x)
188 .sum::<f64>()
189 .sqrt(),
190 );
191 self.training_history.parameter_norms.push(
192 self.optimizer_state
193 .parameters
194 .iter()
195 .map(|x| x * x)
196 .sum::<f64>()
197 .sqrt(),
198 );
199 self.training_history.epoch_times.push(epoch_time);
200
201 if self.check_convergence(current_loss)? {
203 return Ok(TrainingResult {
204 final_parameters: self.optimizer_state.parameters.clone(),
205 final_loss: current_loss,
206 epochs_completed: epoch + 1,
207 training_history: self.training_history.clone(),
208 converged: true,
209 });
210 }
211
212 self.optimizer_state.iteration += 1;
213 }
214
215 let final_loss = loss_function(&self.optimizer_state.parameters)?;
217 Ok(TrainingResult {
218 final_parameters: self.optimizer_state.parameters.clone(),
219 final_loss,
220 epochs_completed: self.config.max_epochs,
221 training_history: self.training_history.clone(),
222 converged: false,
223 })
224 }
225
226 fn compute_gradient<F>(&mut self, loss_function: &F) -> Result<Array1<f64>>
228 where
229 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
230 {
231 match self.config.gradient_method {
232 GradientMethod::ParameterShift => self.compute_parameter_shift_gradient(loss_function),
233 GradientMethod::FiniteDifferences => {
234 self.compute_finite_difference_gradient(loss_function)
235 }
236 GradientMethod::AutomaticDifferentiation => {
237 self.compute_autodiff_gradient(loss_function)
238 }
239 GradientMethod::NaturalGradients => self.compute_natural_gradient(loss_function),
240 GradientMethod::StochasticParameterShift => {
241 self.compute_stochastic_parameter_shift_gradient(loss_function)
242 }
243 }
244 }
245
246 fn compute_parameter_shift_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
248 where
249 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
250 {
251 let num_params = self.optimizer_state.parameters.len();
252 let mut gradient = Array1::zeros(num_params);
253 let shift = std::f64::consts::PI / 2.0;
254
255 for i in 0..num_params {
256 let mut params_plus = self.optimizer_state.parameters.clone();
257 let mut params_minus = self.optimizer_state.parameters.clone();
258
259 params_plus[i] += shift;
260 params_minus[i] -= shift;
261
262 let loss_plus = loss_function(¶ms_plus)?;
263 let loss_minus = loss_function(¶ms_minus)?;
264
265 gradient[i] = (loss_plus - loss_minus) / 2.0;
266 }
267
268 Ok(gradient)
269 }
270
271 fn compute_finite_difference_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
273 where
274 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
275 {
276 let num_params = self.optimizer_state.parameters.len();
277 let mut gradient = Array1::zeros(num_params);
278 let eps = 1e-8;
279
280 for i in 0..num_params {
281 let mut params_plus = self.optimizer_state.parameters.clone();
282 params_plus[i] += eps;
283
284 let loss_plus = loss_function(¶ms_plus)?;
285 let loss_current = loss_function(&self.optimizer_state.parameters)?;
286
287 gradient[i] = (loss_plus - loss_current) / eps;
288 }
289
290 Ok(gradient)
291 }
292
293 fn compute_autodiff_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
295 where
296 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
297 {
298 self.compute_parameter_shift_gradient(loss_function)
301 }
302
303 fn compute_natural_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
305 where
306 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
307 {
308 let gradient = self.compute_parameter_shift_gradient(loss_function)?;
310
311 Ok(gradient)
314 }
315
316 fn compute_stochastic_parameter_shift_gradient<F>(
318 &self,
319 loss_function: &F,
320 ) -> Result<Array1<f64>>
321 where
322 F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
323 {
324 self.compute_parameter_shift_gradient(loss_function)
326 }
327
328 fn update_parameters(&mut self) -> Result<()> {
330 match self.config.optimizer_type {
331 OptimizerType::Adam => self.update_parameters_adam(),
332 OptimizerType::SGD => self.update_parameters_sgd(),
333 OptimizerType::RMSprop => self.update_parameters_rmsprop(),
334 OptimizerType::LBFGS => self.update_parameters_lbfgs(),
335 OptimizerType::QuantumNaturalGradient => self.update_parameters_qng(),
336 OptimizerType::SPSA => self.update_parameters_spsa(),
337 }
338 }
339
340 fn update_parameters_adam(&mut self) -> Result<()> {
342 let beta1 = 0.9;
343 let beta2 = 0.999;
344 let eps = 1e-8;
345
346 for i in 0..self.optimizer_state.parameters.len() {
348 self.optimizer_state.momentum[i] = beta1 * self.optimizer_state.momentum[i]
349 + (1.0 - beta1) * self.optimizer_state.gradient[i];
350 self.optimizer_state.velocity[i] = beta2 * self.optimizer_state.velocity[i]
351 + (1.0 - beta2) * self.optimizer_state.gradient[i].powi(2);
352
353 let m_hat = self.optimizer_state.momentum[i]
355 / (1.0 - beta1.powi(self.optimizer_state.iteration as i32 + 1));
356 let v_hat = self.optimizer_state.velocity[i]
357 / (1.0 - beta2.powi(self.optimizer_state.iteration as i32 + 1));
358
359 self.optimizer_state.parameters[i] -=
361 self.optimizer_state.learning_rate * m_hat / (v_hat.sqrt() + eps);
362 }
363
364 Ok(())
365 }
366
367 fn update_parameters_sgd(&mut self) -> Result<()> {
369 for i in 0..self.optimizer_state.parameters.len() {
370 self.optimizer_state.parameters[i] -=
371 self.optimizer_state.learning_rate * self.optimizer_state.gradient[i];
372 }
373 Ok(())
374 }
375
376 fn update_parameters_rmsprop(&mut self) -> Result<()> {
378 let alpha = 0.99;
379 let eps = 1e-8;
380
381 for i in 0..self.optimizer_state.parameters.len() {
382 self.optimizer_state.velocity[i] = alpha * self.optimizer_state.velocity[i]
383 + (1.0 - alpha) * self.optimizer_state.gradient[i].powi(2);
384 self.optimizer_state.parameters[i] -= self.optimizer_state.learning_rate
385 * self.optimizer_state.gradient[i]
386 / (self.optimizer_state.velocity[i].sqrt() + eps);
387 }
388
389 Ok(())
390 }
391
392 fn update_parameters_lbfgs(&mut self) -> Result<()> {
394 self.update_parameters_sgd()
396 }
397
398 fn update_parameters_qng(&mut self) -> Result<()> {
400 self.update_parameters_sgd()
402 }
403
404 fn update_parameters_spsa(&mut self) -> Result<()> {
406 self.update_parameters_sgd()
408 }
409
410 fn check_convergence(&self, current_loss: f64) -> Result<bool> {
412 if self.training_history.loss_history.len() < 2 {
413 return Ok(false);
414 }
415
416 let prev_loss =
417 self.training_history.loss_history[self.training_history.loss_history.len() - 1];
418 let loss_change = (current_loss - prev_loss).abs();
419
420 Ok(loss_change < self.config.convergence_tolerance)
421 }
422
423 pub fn get_parameters(&self) -> &Array1<f64> {
425 &self.optimizer_state.parameters
426 }
427
428 pub fn get_training_history(&self) -> &TrainingHistory {
430 &self.training_history
431 }
432
433 pub fn set_learning_rate(&mut self, lr: f64) {
435 self.optimizer_state.learning_rate = lr;
436 }
437
438 pub fn reset_optimizer(&mut self) {
440 let num_params = self.optimizer_state.parameters.len();
441 self.optimizer_state.gradient = Array1::zeros(num_params);
442 self.optimizer_state.momentum = Array1::zeros(num_params);
443 self.optimizer_state.velocity = Array1::zeros(num_params);
444 self.optimizer_state.iteration = 0;
445 self.training_history = TrainingHistory::default();
446 }
447}
448
449impl HardwareAwareCompiler {
450 pub fn new(hardware_arch: HardwareArchitecture, hardware_opts: HardwareOptimizations) -> Self {
452 Self {
453 hardware_arch,
454 hardware_opts,
455 compilation_stats: CompilationStats::default(),
456 }
457 }
458
459 pub fn compile_circuit(&mut self, circuit: &InterfaceCircuit) -> Result<InterfaceCircuit> {
461 let start_time = std::time::Instant::now();
462 self.compilation_stats.original_depth = circuit.gates.len();
463
464 let compiled_circuit = circuit.clone();
467
468 self.compilation_stats.compiled_depth = compiled_circuit.gates.len();
469 self.compilation_stats.compilation_time = start_time.elapsed().as_secs_f64();
470
471 Ok(compiled_circuit)
472 }
473
474 pub fn get_stats(&self) -> &CompilationStats {
476 &self.compilation_stats
477 }
478}
479
480impl OptimizerState {
481 pub fn new(num_parameters: usize, learning_rate: f64) -> Self {
483 Self {
484 parameters: Array1::zeros(num_parameters),
485 gradient: Array1::zeros(num_parameters),
486 momentum: Array1::zeros(num_parameters),
487 velocity: Array1::zeros(num_parameters),
488 learning_rate,
489 iteration: 0,
490 }
491 }
492}
493
494impl TrainingHistory {
495 pub fn latest_loss(&self) -> Option<f64> {
497 self.loss_history.last().copied()
498 }
499
500 pub fn best_loss(&self) -> Option<f64> {
502 self.loss_history
503 .iter()
504 .min_by(|a, b| a.partial_cmp(b).unwrap())
505 .copied()
506 }
507
508 pub fn average_epoch_time(&self) -> f64 {
510 if self.epoch_times.is_empty() {
511 0.0
512 } else {
513 self.epoch_times.iter().sum::<f64>() / self.epoch_times.len() as f64
514 }
515 }
516}