1use super::{natural_gradient, quantum_fisher_information, QMLCircuit};
7use crate::{
8 error::{QuantRS2Error, QuantRS2Result},
9 gpu::GpuBackendFactory,
10};
11use ndarray::Array1;
12use std::collections::HashMap;
13#[derive(Debug, Clone, Copy)]
17pub enum LossFunction {
18 MSE,
20 CrossEntropy,
22 Fidelity,
24 Variational,
26 Custom,
28}
29
30#[derive(Debug, Clone)]
32pub enum Optimizer {
33 GradientDescent { learning_rate: f64 },
35 Adam {
37 learning_rate: f64,
38 beta1: f64,
39 beta2: f64,
40 epsilon: f64,
41 },
42 NaturalGradient {
44 learning_rate: f64,
45 regularization: f64,
46 },
47 BFGS,
49 QuantumNatural {
51 learning_rate: f64,
52 regularization: f64,
53 },
54}
55
56#[derive(Debug, Clone)]
58pub struct TrainingConfig {
59 pub max_epochs: usize,
61 pub batch_size: usize,
63 pub tolerance: f64,
65 pub use_gpu: bool,
67 pub validation_split: f64,
69 pub early_stopping_patience: Option<usize>,
71 pub gradient_clip: Option<f64>,
73}
74
75impl Default for TrainingConfig {
76 fn default() -> Self {
77 Self {
78 max_epochs: 100,
79 batch_size: 32,
80 tolerance: 1e-6,
81 use_gpu: true,
82 validation_split: 0.2,
83 early_stopping_patience: Some(10),
84 gradient_clip: Some(1.0),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default)]
91pub struct TrainingMetrics {
92 pub loss_history: Vec<f64>,
94 pub val_loss_history: Vec<f64>,
96 pub gradient_norms: Vec<f64>,
98 pub parameter_history: Vec<Vec<f64>>,
100 pub best_val_loss: f64,
102 pub best_parameters: Vec<f64>,
104}
105
106pub struct QMLTrainer {
108 circuit: QMLCircuit,
110 loss_fn: LossFunction,
112 optimizer: Optimizer,
114 config: TrainingConfig,
116 metrics: TrainingMetrics,
118 adam_state: Option<AdamState>,
120}
121
122#[derive(Debug, Clone)]
124struct AdamState {
125 m: Vec<f64>, v: Vec<f64>, t: usize, }
129
130impl QMLTrainer {
131 pub fn new(
133 circuit: QMLCircuit,
134 loss_fn: LossFunction,
135 optimizer: Optimizer,
136 config: TrainingConfig,
137 ) -> Self {
138 let num_params = circuit.num_parameters;
139 let adam_state = match &optimizer {
140 Optimizer::Adam { .. } => Some(AdamState {
141 m: vec![0.0; num_params],
142 v: vec![0.0; num_params],
143 t: 0,
144 }),
145 _ => None,
146 };
147
148 Self {
149 circuit,
150 loss_fn,
151 optimizer,
152 config,
153 metrics: TrainingMetrics::default(),
154 adam_state,
155 }
156 }
157
158 pub fn train(
160 &mut self,
161 train_data: &[(Vec<f64>, Vec<f64>)],
162 val_data: Option<&[(Vec<f64>, Vec<f64>)]>,
163 ) -> QuantRS2Result<TrainingMetrics> {
164 let gpu_backend = if self.config.use_gpu {
166 Some(GpuBackendFactory::create_best_available()?)
167 } else {
168 None
169 };
170
171 let mut best_val_loss = f64::INFINITY;
172 let mut patience_counter = 0;
173
174 for epoch in 0..self.config.max_epochs {
175 let train_loss = self.train_epoch(train_data, &gpu_backend)?;
177 self.metrics.loss_history.push(train_loss);
178
179 if let Some(val_data) = val_data {
181 let val_loss = self.evaluate(val_data, &gpu_backend)?;
182 self.metrics.val_loss_history.push(val_loss);
183
184 if val_loss < best_val_loss {
186 best_val_loss = val_loss;
187 self.metrics.best_val_loss = val_loss;
188 self.metrics.best_parameters = self.get_parameters();
189 patience_counter = 0;
190 } else if let Some(patience) = self.config.early_stopping_patience {
191 patience_counter += 1;
192 if patience_counter >= patience {
193 println!("Early stopping at epoch {}", epoch);
194 break;
195 }
196 }
197 }
198
199 if epoch > 0 {
201 let loss_change =
202 (self.metrics.loss_history[epoch] - self.metrics.loss_history[epoch - 1]).abs();
203 if loss_change < self.config.tolerance {
204 println!("Converged at epoch {}", epoch);
205 break;
206 }
207 }
208
209 if epoch % 10 == 0 {
211 println!("Epoch {}: train_loss = {:.6}", epoch, train_loss);
212 if let Some(val_loss) = self.metrics.val_loss_history.last() {
213 println!(" val_loss = {:.6}", val_loss);
214 }
215 }
216 }
217
218 Ok(self.metrics.clone())
219 }
220
221 fn train_epoch(
223 &mut self,
224 data: &[(Vec<f64>, Vec<f64>)],
225 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
226 ) -> QuantRS2Result<f64> {
227 let mut epoch_loss = 0.0;
228 let num_batches = (data.len() + self.config.batch_size - 1) / self.config.batch_size;
229
230 for batch_idx in 0..num_batches {
231 let start = batch_idx * self.config.batch_size;
232 let end = (start + self.config.batch_size).min(data.len());
233 let batch = &data[start..end];
234
235 let (loss, gradients) = self.compute_batch_gradients(batch, gpu_backend)?;
237 epoch_loss += loss;
238
239 let clipped_gradients = if let Some(clip_value) = self.config.gradient_clip {
241 self.clip_gradients(&gradients, clip_value)
242 } else {
243 gradients
244 };
245
246 self.update_parameters(&clipped_gradients)?;
248
249 let grad_norm = clipped_gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
251 self.metrics.gradient_norms.push(grad_norm);
252 }
253
254 Ok(epoch_loss / num_batches as f64)
255 }
256
257 fn compute_batch_gradients(
259 &self,
260 batch: &[(Vec<f64>, Vec<f64>)],
261 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
262 ) -> QuantRS2Result<(f64, Vec<f64>)> {
263 let mut total_loss = 0.0;
264 let mut total_gradients = vec![0.0; self.circuit.num_parameters];
265
266 for (input, target) in batch {
267 let output = self.forward(input, gpu_backend)?;
269
270 let loss = self.compute_loss(&output, target)?;
272 total_loss += loss;
273
274 let gradients = vec![0.0; self.circuit.num_parameters]; for (i, &grad) in gradients.iter().enumerate() {
279 total_gradients[i] += grad;
280 }
281 }
282
283 let batch_size = batch.len() as f64;
285 total_loss /= batch_size;
286 for grad in &mut total_gradients {
287 *grad /= batch_size;
288 }
289
290 Ok((total_loss, total_gradients))
291 }
292
293 fn forward(
295 &self,
296 input: &[f64],
297 _gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
298 ) -> QuantRS2Result<Vec<f64>> {
299 Ok(vec![0.5; input.len()])
307 }
308
309 fn compute_loss(&self, output: &[f64], target: &[f64]) -> QuantRS2Result<f64> {
311 if output.len() != target.len() {
312 return Err(QuantRS2Error::InvalidInput(
313 "Output and target dimensions mismatch".to_string(),
314 ));
315 }
316
317 match self.loss_fn {
318 LossFunction::MSE => {
319 let mse = output
320 .iter()
321 .zip(target.iter())
322 .map(|(o, t)| (o - t).powi(2))
323 .sum::<f64>()
324 / output.len() as f64;
325 Ok(mse)
326 }
327 LossFunction::CrossEntropy => {
328 let epsilon = 1e-10;
329 let ce = -output
330 .iter()
331 .zip(target.iter())
332 .map(|(o, t)| t * (o + epsilon).ln())
333 .sum::<f64>()
334 / output.len() as f64;
335 Ok(ce)
336 }
337 _ => Ok(0.0), }
339 }
340
341 fn update_parameters(&mut self, gradients: &[f64]) -> QuantRS2Result<()> {
343 let current_params = self.get_parameters();
344 let new_params = match &mut self.optimizer {
345 Optimizer::GradientDescent { learning_rate } => current_params
346 .iter()
347 .zip(gradients.iter())
348 .map(|(p, g)| p - *learning_rate * g)
349 .collect(),
350
351 Optimizer::Adam {
352 learning_rate,
353 beta1,
354 beta2,
355 epsilon,
356 } => {
357 if let Some(state) = &mut self.adam_state {
358 state.t += 1;
359 let t = state.t as f64;
360
361 let mut new_params = vec![0.0; current_params.len()];
362 for i in 0..current_params.len() {
363 state.m[i] = *beta1 * state.m[i] + (1.0 - *beta1) * gradients[i];
365
366 state.v[i] = *beta2 * state.v[i] + (1.0 - *beta2) * gradients[i].powi(2);
368
369 let m_hat = state.m[i] / (1.0 - beta1.powf(t));
371
372 let v_hat = state.v[i] / (1.0 - beta2.powf(t));
374
375 new_params[i] =
377 current_params[i] - *learning_rate * m_hat / (v_hat.sqrt() + *epsilon);
378 }
379 new_params
380 } else {
381 current_params
382 }
383 }
384
385 Optimizer::QuantumNatural {
386 learning_rate: _,
387 regularization,
388 } => {
389 let state = Array1::zeros(1 << self.circuit.config.num_qubits);
391 let fisher = quantum_fisher_information(&self.circuit, &state)?;
392
393 natural_gradient(gradients, &fisher, *regularization)?
395 }
396
397 _ => current_params, };
399
400 self.circuit.set_parameters(&new_params)?;
401 self.metrics.parameter_history.push(new_params);
402
403 Ok(())
404 }
405
406 fn clip_gradients(&self, gradients: &[f64], clip_value: f64) -> Vec<f64> {
408 let norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
409
410 if norm > clip_value {
411 gradients.iter().map(|g| g * clip_value / norm).collect()
412 } else {
413 gradients.to_vec()
414 }
415 }
416
417 fn evaluate(
419 &self,
420 data: &[(Vec<f64>, Vec<f64>)],
421 gpu_backend: &Option<std::sync::Arc<dyn crate::gpu::GpuBackend>>,
422 ) -> QuantRS2Result<f64> {
423 let mut total_loss = 0.0;
424
425 for (input, target) in data {
426 let output = self.forward(input, gpu_backend)?;
427 let loss = self.compute_loss(&output, target)?;
428 total_loss += loss;
429 }
430
431 Ok(total_loss / data.len() as f64)
432 }
433
434 fn get_parameters(&self) -> Vec<f64> {
436 self.circuit.parameters().iter().map(|p| p.value).collect()
437 }
438}
439
440pub struct HyperparameterOptimizer {
442 search_space: HashMap<String, (f64, f64)>,
444 num_trials: usize,
446 strategy: HPOStrategy,
448}
449
450#[derive(Debug, Clone, Copy)]
451pub enum HPOStrategy {
452 Random,
454 Grid,
456 Bayesian,
458}
459
460impl HyperparameterOptimizer {
461 pub fn new(
463 search_space: HashMap<String, (f64, f64)>,
464 num_trials: usize,
465 strategy: HPOStrategy,
466 ) -> Self {
467 Self {
468 search_space,
469 num_trials,
470 strategy,
471 }
472 }
473
474 pub fn optimize<F>(&self, _objective: F) -> QuantRS2Result<HashMap<String, f64>>
476 where
477 F: Fn(&HashMap<String, f64>) -> QuantRS2Result<f64>,
478 {
479 Ok(HashMap::new())
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::qml::QMLConfig;
489
490 #[test]
491 fn test_trainer_creation() {
492 let config = QMLConfig::default();
493 let circuit = QMLCircuit::new(config);
494
495 let trainer = QMLTrainer::new(
496 circuit,
497 LossFunction::MSE,
498 Optimizer::Adam {
499 learning_rate: 0.01,
500 beta1: 0.9,
501 beta2: 0.999,
502 epsilon: 1e-8,
503 },
504 TrainingConfig::default(),
505 );
506
507 assert_eq!(trainer.metrics.loss_history.len(), 0);
508 }
509
510 #[test]
511 fn test_gradient_clipping() {
512 let config = QMLConfig::default();
513 let circuit = QMLCircuit::new(config);
514 let trainer = QMLTrainer::new(
515 circuit,
516 LossFunction::MSE,
517 Optimizer::GradientDescent { learning_rate: 0.1 },
518 TrainingConfig::default(),
519 );
520
521 let gradients = vec![3.0, 4.0]; let clipped = trainer.clip_gradients(&gradients, 1.0);
523
524 let norm = clipped.iter().map(|g| g * g).sum::<f64>().sqrt();
525 assert!((norm - 1.0).abs() < 1e-10);
526 }
527
528 #[test]
529 fn test_loss_computation() {
530 let config = QMLConfig::default();
531 let circuit = QMLCircuit::new(config);
532 let trainer = QMLTrainer::new(
533 circuit,
534 LossFunction::MSE,
535 Optimizer::GradientDescent { learning_rate: 0.1 },
536 TrainingConfig::default(),
537 );
538
539 let output = vec![0.0, 0.5, 1.0];
540 let target = vec![0.0, 0.0, 1.0];
541
542 let loss = trainer.compute_loss(&output, &target).unwrap();
543 assert!((loss - 0.25 / 3.0).abs() < 1e-10); }
545}