1use super::{natural_gradient, quantum_fisher_information, QMLCircuit};
7use crate::{
8 error::{QuantRS2Error, QuantRS2Result},
9 gpu::{GpuBackendFactory, GpuStateVector},
10};
11use ndarray::{Array1, Array2};
12use num_complex::Complex64;
13use std::collections::HashMap;
15
16#[derive(Debug, Clone, Copy)]
18pub enum LossFunction {
19 MSE,
21 CrossEntropy,
23 Fidelity,
25 Variational,
27 Custom,
29}
30
31#[derive(Debug, Clone)]
33pub enum Optimizer {
34 GradientDescent { learning_rate: f64 },
36 Adam {
38 learning_rate: f64,
39 beta1: f64,
40 beta2: f64,
41 epsilon: f64,
42 },
43 NaturalGradient {
45 learning_rate: f64,
46 regularization: f64,
47 },
48 BFGS,
50 QuantumNatural {
52 learning_rate: f64,
53 regularization: f64,
54 },
55}
56
57#[derive(Debug, Clone)]
59pub struct TrainingConfig {
60 pub max_epochs: usize,
62 pub batch_size: usize,
64 pub tolerance: f64,
66 pub use_gpu: bool,
68 pub validation_split: f64,
70 pub early_stopping_patience: Option<usize>,
72 pub gradient_clip: Option<f64>,
74}
75
76impl Default for TrainingConfig {
77 fn default() -> Self {
78 Self {
79 max_epochs: 100,
80 batch_size: 32,
81 tolerance: 1e-6,
82 use_gpu: true,
83 validation_split: 0.2,
84 early_stopping_patience: Some(10),
85 gradient_clip: Some(1.0),
86 }
87 }
88}
89
90#[derive(Debug, Clone, Default)]
92pub struct TrainingMetrics {
93 pub loss_history: Vec<f64>,
95 pub val_loss_history: Vec<f64>,
97 pub gradient_norms: Vec<f64>,
99 pub parameter_history: Vec<Vec<f64>>,
101 pub best_val_loss: f64,
103 pub best_parameters: Vec<f64>,
105}
106
107pub struct QMLTrainer {
109 circuit: QMLCircuit,
111 loss_fn: LossFunction,
113 optimizer: Optimizer,
115 config: TrainingConfig,
117 metrics: TrainingMetrics,
119 adam_state: Option<AdamState>,
121}
122
123#[derive(Debug, Clone)]
125struct AdamState {
126 m: Vec<f64>, v: Vec<f64>, t: usize, }
130
131impl QMLTrainer {
132 pub fn new(
134 circuit: QMLCircuit,
135 loss_fn: LossFunction,
136 optimizer: Optimizer,
137 config: TrainingConfig,
138 ) -> Self {
139 let num_params = circuit.num_parameters;
140 let adam_state = match &optimizer {
141 Optimizer::Adam { .. } => Some(AdamState {
142 m: vec![0.0; num_params],
143 v: vec![0.0; num_params],
144 t: 0,
145 }),
146 _ => None,
147 };
148
149 Self {
150 circuit,
151 loss_fn,
152 optimizer,
153 config,
154 metrics: TrainingMetrics::default(),
155 adam_state,
156 }
157 }
158
159 pub fn train(
161 &mut self,
162 train_data: &[(Vec<f64>, Vec<f64>)],
163 val_data: Option<&[(Vec<f64>, Vec<f64>)]>,
164 ) -> QuantRS2Result<TrainingMetrics> {
165 let gpu_backend = if self.config.use_gpu {
167 Some(GpuBackendFactory::create_best_available()?)
168 } else {
169 None
170 };
171
172 let mut best_val_loss = f64::INFINITY;
173 let mut patience_counter = 0;
174
175 for epoch in 0..self.config.max_epochs {
176 let train_loss = self.train_epoch(train_data, &gpu_backend)?;
178 self.metrics.loss_history.push(train_loss);
179
180 if let Some(val_data) = val_data {
182 let val_loss = self.evaluate(val_data, &gpu_backend)?;
183 self.metrics.val_loss_history.push(val_loss);
184
185 if val_loss < best_val_loss {
187 best_val_loss = val_loss;
188 self.metrics.best_val_loss = val_loss;
189 self.metrics.best_parameters = self.get_parameters();
190 patience_counter = 0;
191 } else if let Some(patience) = self.config.early_stopping_patience {
192 patience_counter += 1;
193 if patience_counter >= patience {
194 println!("Early stopping at epoch {}", epoch);
195 break;
196 }
197 }
198 }
199
200 if epoch > 0 {
202 let loss_change =
203 (self.metrics.loss_history[epoch] - self.metrics.loss_history[epoch - 1]).abs();
204 if loss_change < self.config.tolerance {
205 println!("Converged at epoch {}", epoch);
206 break;
207 }
208 }
209
210 if epoch % 10 == 0 {
212 println!("Epoch {}: train_loss = {:.6}", epoch, train_loss);
213 if let Some(val_loss) = self.metrics.val_loss_history.last() {
214 println!(" val_loss = {:.6}", val_loss);
215 }
216 }
217 }
218
219 Ok(self.metrics.clone())
220 }
221
222 fn train_epoch(
224 &mut self,
225 data: &[(Vec<f64>, Vec<f64>)],
226 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
227 ) -> QuantRS2Result<f64> {
228 let mut epoch_loss = 0.0;
229 let num_batches = (data.len() + self.config.batch_size - 1) / self.config.batch_size;
230
231 for batch_idx in 0..num_batches {
232 let start = batch_idx * self.config.batch_size;
233 let end = (start + self.config.batch_size).min(data.len());
234 let batch = &data[start..end];
235
236 let (loss, gradients) = self.compute_batch_gradients(batch, gpu_backend)?;
238 epoch_loss += loss;
239
240 let clipped_gradients = if let Some(clip_value) = self.config.gradient_clip {
242 self.clip_gradients(&gradients, clip_value)
243 } else {
244 gradients
245 };
246
247 self.update_parameters(&clipped_gradients)?;
249
250 let grad_norm = clipped_gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
252 self.metrics.gradient_norms.push(grad_norm);
253 }
254
255 Ok(epoch_loss / num_batches as f64)
256 }
257
258 fn compute_batch_gradients(
260 &self,
261 batch: &[(Vec<f64>, Vec<f64>)],
262 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
263 ) -> QuantRS2Result<(f64, Vec<f64>)> {
264 let mut total_loss = 0.0;
265 let mut total_gradients = vec![0.0; self.circuit.num_parameters];
266
267 for (input, target) in batch {
268 let output = self.forward(input, gpu_backend)?;
270
271 let loss = self.compute_loss(&output, target)?;
273 total_loss += loss;
274
275 let gradients = vec![0.0; self.circuit.num_parameters]; for (i, &grad) in gradients.iter().enumerate() {
280 total_gradients[i] += grad;
281 }
282 }
283
284 let batch_size = batch.len() as f64;
286 total_loss /= batch_size;
287 for grad in &mut total_gradients {
288 *grad /= batch_size;
289 }
290
291 Ok((total_loss, total_gradients))
292 }
293
294 fn forward(
296 &self,
297 input: &[f64],
298 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
299 ) -> QuantRS2Result<Vec<f64>> {
300 Ok(vec![0.5; input.len()])
308 }
309
310 fn compute_loss(&self, output: &[f64], target: &[f64]) -> QuantRS2Result<f64> {
312 if output.len() != target.len() {
313 return Err(QuantRS2Error::InvalidInput(
314 "Output and target dimensions mismatch".to_string(),
315 ));
316 }
317
318 match self.loss_fn {
319 LossFunction::MSE => {
320 let mse = output
321 .iter()
322 .zip(target.iter())
323 .map(|(o, t)| (o - t).powi(2))
324 .sum::<f64>()
325 / output.len() as f64;
326 Ok(mse)
327 }
328 LossFunction::CrossEntropy => {
329 let epsilon = 1e-10;
330 let ce = -output
331 .iter()
332 .zip(target.iter())
333 .map(|(o, t)| t * (o + epsilon).ln())
334 .sum::<f64>()
335 / output.len() as f64;
336 Ok(ce)
337 }
338 _ => Ok(0.0), }
340 }
341
342 fn update_parameters(&mut self, gradients: &[f64]) -> QuantRS2Result<()> {
344 let current_params = self.get_parameters();
345 let new_params = match &mut self.optimizer {
346 Optimizer::GradientDescent { learning_rate } => current_params
347 .iter()
348 .zip(gradients.iter())
349 .map(|(p, g)| p - *learning_rate * g)
350 .collect(),
351
352 Optimizer::Adam {
353 learning_rate,
354 beta1,
355 beta2,
356 epsilon,
357 } => {
358 if let Some(state) = &mut self.adam_state {
359 state.t += 1;
360 let t = state.t as f64;
361
362 let mut new_params = vec![0.0; current_params.len()];
363 for i in 0..current_params.len() {
364 state.m[i] = *beta1 * state.m[i] + (1.0 - *beta1) * gradients[i];
366
367 state.v[i] = *beta2 * state.v[i] + (1.0 - *beta2) * gradients[i].powi(2);
369
370 let m_hat = state.m[i] / (1.0 - beta1.powf(t));
372
373 let v_hat = state.v[i] / (1.0 - beta2.powf(t));
375
376 new_params[i] =
378 current_params[i] - *learning_rate * m_hat / (v_hat.sqrt() + *epsilon);
379 }
380 new_params
381 } else {
382 current_params
383 }
384 }
385
386 Optimizer::QuantumNatural {
387 learning_rate,
388 regularization,
389 } => {
390 let state = Array1::zeros(1 << self.circuit.config.num_qubits);
392 let fisher = quantum_fisher_information(&self.circuit, &state)?;
393
394 natural_gradient(gradients, &fisher, *regularization)?
396 }
397
398 _ => current_params, };
400
401 self.circuit.set_parameters(&new_params)?;
402 self.metrics.parameter_history.push(new_params);
403
404 Ok(())
405 }
406
407 fn clip_gradients(&self, gradients: &[f64], clip_value: f64) -> Vec<f64> {
409 let norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
410
411 if norm > clip_value {
412 gradients.iter().map(|g| g * clip_value / norm).collect()
413 } else {
414 gradients.to_vec()
415 }
416 }
417
418 fn evaluate(
420 &self,
421 data: &[(Vec<f64>, Vec<f64>)],
422 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
423 ) -> QuantRS2Result<f64> {
424 let mut total_loss = 0.0;
425
426 for (input, target) in data {
427 let output = self.forward(input, gpu_backend)?;
428 let loss = self.compute_loss(&output, target)?;
429 total_loss += loss;
430 }
431
432 Ok(total_loss / data.len() as f64)
433 }
434
435 fn get_parameters(&self) -> Vec<f64> {
437 self.circuit.parameters().iter().map(|p| p.value).collect()
438 }
439}
440
441pub struct HyperparameterOptimizer {
443 search_space: HashMap<String, (f64, f64)>,
445 num_trials: usize,
447 strategy: HPOStrategy,
449}
450
451#[derive(Debug, Clone, Copy)]
452pub enum HPOStrategy {
453 Random,
455 Grid,
457 Bayesian,
459}
460
461impl HyperparameterOptimizer {
462 pub fn new(
464 search_space: HashMap<String, (f64, f64)>,
465 num_trials: usize,
466 strategy: HPOStrategy,
467 ) -> Self {
468 Self {
469 search_space,
470 num_trials,
471 strategy,
472 }
473 }
474
475 pub fn optimize<F>(&self, objective: F) -> QuantRS2Result<HashMap<String, f64>>
477 where
478 F: Fn(&HashMap<String, f64>) -> QuantRS2Result<f64>,
479 {
480 Ok(HashMap::new())
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::qml::QMLConfig;
490
491 #[test]
492 fn test_trainer_creation() {
493 let config = QMLConfig::default();
494 let circuit = QMLCircuit::new(config);
495
496 let trainer = QMLTrainer::new(
497 circuit,
498 LossFunction::MSE,
499 Optimizer::Adam {
500 learning_rate: 0.01,
501 beta1: 0.9,
502 beta2: 0.999,
503 epsilon: 1e-8,
504 },
505 TrainingConfig::default(),
506 );
507
508 assert_eq!(trainer.metrics.loss_history.len(), 0);
509 }
510
511 #[test]
512 fn test_gradient_clipping() {
513 let config = QMLConfig::default();
514 let circuit = QMLCircuit::new(config);
515 let trainer = QMLTrainer::new(
516 circuit,
517 LossFunction::MSE,
518 Optimizer::GradientDescent { learning_rate: 0.1 },
519 TrainingConfig::default(),
520 );
521
522 let gradients = vec![3.0, 4.0]; let clipped = trainer.clip_gradients(&gradients, 1.0);
524
525 let norm = clipped.iter().map(|g| g * g).sum::<f64>().sqrt();
526 assert!((norm - 1.0).abs() < 1e-10);
527 }
528
529 #[test]
530 fn test_loss_computation() {
531 let config = QMLConfig::default();
532 let circuit = QMLCircuit::new(config);
533 let trainer = QMLTrainer::new(
534 circuit,
535 LossFunction::MSE,
536 Optimizer::GradientDescent { learning_rate: 0.1 },
537 TrainingConfig::default(),
538 );
539
540 let output = vec![0.0, 0.5, 1.0];
541 let target = vec![0.0, 0.0, 1.0];
542
543 let loss = trainer.compute_loss(&output, &target).unwrap();
544 assert!((loss - 0.25 / 3.0).abs() < 1e-10); }
546}