1use super::*;
7use crate::{DeviceError, DeviceResult, QuantumDevice};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13
14pub struct QuantumTrainer {
16 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
17 config: QMLConfig,
18 model_type: QMLModelType,
19 optimizer: Box<dyn QuantumOptimizer>,
20 gradient_calculator: QuantumGradientCalculator,
21 loss_function: Box<dyn LossFunction + Send + Sync>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TrainingData {
27 pub features: Vec<Vec<f64>>,
28 pub labels: Vec<f64>,
29 pub metadata: HashMap<String, String>,
30}
31
32impl TrainingData {
33 pub fn new(features: Vec<Vec<f64>>, labels: Vec<f64>) -> Self {
34 Self {
35 features,
36 labels,
37 metadata: HashMap::new(),
38 }
39 }
40
41 pub fn len(&self) -> usize {
42 self.features.len()
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.features.is_empty()
47 }
48
49 #[must_use]
50 pub fn get_batch(&self, indices: &[usize]) -> Self {
51 let batch_features = indices
52 .iter()
53 .filter_map(|&i| self.features.get(i))
54 .cloned()
55 .collect();
56 let batch_labels = indices
57 .iter()
58 .filter_map(|&i| self.labels.get(i))
59 .copied()
60 .collect();
61
62 Self {
63 features: batch_features,
64 labels: batch_labels,
65 metadata: self.metadata.clone(),
66 }
67 }
68
69 pub fn shuffle(&mut self) {
70 let n = self.len();
71 for i in 0..n {
72 let j = fastrand::usize(i..n);
73 self.features.swap(i, j);
74 self.labels.swap(i, j);
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TrainingResult {
82 pub model_id: String,
83 pub model: QMLModel,
84 pub final_loss: f64,
85 pub final_accuracy: Option<f64>,
86 pub training_time: Duration,
87 pub convergence_achieved: bool,
88 pub optimal_parameters: Vec<f64>,
89 pub training_metrics: TrainingMetrics,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct TrainingMetrics {
95 pub loss_history: Vec<f64>,
96 pub accuracy_history: Vec<f64>,
97 pub validation_loss_history: Vec<f64>,
98 pub validation_accuracy_history: Vec<f64>,
99 pub gradient_norms: Vec<f64>,
100 pub learning_rates: Vec<f64>,
101 pub quantum_fidelities: Vec<f64>,
102 pub execution_times: Vec<Duration>,
103}
104
105impl Default for TrainingMetrics {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl TrainingMetrics {
112 pub const fn new() -> Self {
113 Self {
114 loss_history: Vec::new(),
115 accuracy_history: Vec::new(),
116 validation_loss_history: Vec::new(),
117 validation_accuracy_history: Vec::new(),
118 gradient_norms: Vec::new(),
119 learning_rates: Vec::new(),
120 quantum_fidelities: Vec::new(),
121 execution_times: Vec::new(),
122 }
123 }
124
125 pub fn add_epoch(
126 &mut self,
127 loss: f64,
128 accuracy: f64,
129 val_loss: Option<f64>,
130 val_accuracy: Option<f64>,
131 gradient_norm: f64,
132 learning_rate: f64,
133 quantum_fidelity: f64,
134 execution_time: Duration,
135 ) {
136 self.loss_history.push(loss);
137 self.accuracy_history.push(accuracy);
138 if let Some(vl) = val_loss {
139 self.validation_loss_history.push(vl);
140 }
141 if let Some(va) = val_accuracy {
142 self.validation_accuracy_history.push(va);
143 }
144 self.gradient_norms.push(gradient_norm);
145 self.learning_rates.push(learning_rate);
146 self.quantum_fidelities.push(quantum_fidelity);
147 self.execution_times.push(execution_time);
148 }
149}
150
151pub trait LossFunction: Send + Sync {
153 fn compute_loss(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64>;
155
156 fn compute_gradients(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<Vec<f64>>;
158
159 fn name(&self) -> &str;
161}
162
163pub struct MSELoss;
165
166impl LossFunction for MSELoss {
167 fn compute_loss(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64> {
168 if predictions.len() != targets.len() {
169 return Err(DeviceError::InvalidInput(
170 "Predictions and targets must have same length".to_string(),
171 ));
172 }
173
174 let mse = predictions
175 .iter()
176 .zip(targets.iter())
177 .map(|(p, t)| (p - t).powi(2))
178 .sum::<f64>()
179 / predictions.len() as f64;
180
181 Ok(mse)
182 }
183
184 fn compute_gradients(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<Vec<f64>> {
185 if predictions.len() != targets.len() {
186 return Err(DeviceError::InvalidInput(
187 "Predictions and targets must have same length".to_string(),
188 ));
189 }
190
191 let gradients = predictions
192 .iter()
193 .zip(targets.iter())
194 .map(|(p, t)| 2.0 * (p - t) / predictions.len() as f64)
195 .collect();
196
197 Ok(gradients)
198 }
199
200 fn name(&self) -> &'static str {
201 "MSE"
202 }
203}
204
205pub struct CrossEntropyLoss;
207
208impl LossFunction for CrossEntropyLoss {
209 fn compute_loss(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64> {
210 if predictions.len() != targets.len() {
211 return Err(DeviceError::InvalidInput(
212 "Predictions and targets must have same length".to_string(),
213 ));
214 }
215
216 let epsilon = 1e-15; let cross_entropy = -targets
218 .iter()
219 .zip(predictions.iter())
220 .map(|(t, p)| {
221 let p_clipped = p.clamp(epsilon, 1.0 - epsilon);
222 (1.0 - t).mul_add((1.0 - p_clipped).ln(), t * p_clipped.ln())
223 })
224 .sum::<f64>()
225 / predictions.len() as f64;
226
227 Ok(cross_entropy)
228 }
229
230 fn compute_gradients(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<Vec<f64>> {
231 if predictions.len() != targets.len() {
232 return Err(DeviceError::InvalidInput(
233 "Predictions and targets must have same length".to_string(),
234 ));
235 }
236
237 let epsilon = 1e-15;
238 let gradients = predictions
239 .iter()
240 .zip(targets.iter())
241 .map(|(p, t)| {
242 let p_clipped = p.clamp(epsilon, 1.0 - epsilon);
243 (p_clipped - t) / (p_clipped * (1.0 - p_clipped) * predictions.len() as f64)
244 })
245 .collect();
246
247 Ok(gradients)
248 }
249
250 fn name(&self) -> &'static str {
251 "CrossEntropy"
252 }
253}
254
255impl QuantumTrainer {
256 pub fn new(
258 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
259 config: &QMLConfig,
260 model_type: QMLModelType,
261 ) -> DeviceResult<Self> {
262 let optimizer = create_gradient_optimizer(
263 device.clone(),
264 config.optimizer.clone(),
265 config.learning_rate,
266 );
267
268 let gradient_config = GradientConfig {
269 method: config.gradient_method.clone(),
270 shots: 1024,
271 ..Default::default()
272 };
273
274 let gradient_calculator = QuantumGradientCalculator::new(device.clone(), gradient_config)?;
275
276 let loss_function: Box<dyn LossFunction + Send + Sync> = match model_type {
277 QMLModelType::VQC | QMLModelType::QNN => Box::new(CrossEntropyLoss),
278 _ => Box::new(MSELoss),
279 };
280
281 Ok(Self {
282 device,
283 config: config.clone(),
284 model_type,
285 optimizer,
286 gradient_calculator,
287 loss_function,
288 })
289 }
290
291 pub async fn train(
293 &mut self,
294 training_data: TrainingData,
295 validation_data: Option<TrainingData>,
296 training_history: &mut Vec<TrainingEpoch>,
297 ) -> DeviceResult<TrainingResult> {
298 let start_time = Instant::now();
299 let model_id = format!("qml_model_{}", uuid::Uuid::new_v4());
300
301 let mut parameters = self.initialize_parameters()?;
303 let mut metrics = TrainingMetrics::new();
304 let mut best_loss = f64::INFINITY;
305 let mut best_parameters = parameters.clone();
306 let mut patience_counter = 0;
307 let early_stopping_patience = 50;
308
309 for epoch in 0..self.config.max_epochs {
310 let epoch_start = Instant::now();
311
312 let mut epoch_data = training_data.clone();
314 epoch_data.shuffle();
315
316 let (epoch_loss, epoch_accuracy, gradient_norm) =
318 self.train_epoch(&mut parameters, &epoch_data).await?;
319
320 let (val_loss, val_accuracy) = if let Some(ref val_data) = validation_data {
322 let (vl, va) = self.validate_epoch(¶meters, val_data).await?;
323 (Some(vl), Some(va))
324 } else {
325 (None, None)
326 };
327
328 let execution_time = epoch_start.elapsed();
329 let quantum_fidelity = self.estimate_quantum_fidelity(¶meters).await?;
330
331 metrics.add_epoch(
333 epoch_loss,
334 epoch_accuracy,
335 val_loss,
336 val_accuracy,
337 gradient_norm,
338 self.config.learning_rate,
339 quantum_fidelity,
340 execution_time,
341 );
342
343 training_history.push(TrainingEpoch {
345 epoch,
346 loss: epoch_loss,
347 accuracy: Some(epoch_accuracy),
348 parameters: parameters.clone(),
349 gradient_norm,
350 learning_rate: self.config.learning_rate,
351 execution_time,
352 quantum_fidelity: Some(quantum_fidelity),
353 classical_preprocessing_time: Duration::from_millis(10),
354 quantum_execution_time: execution_time
355 .checked_sub(Duration::from_millis(10))
356 .unwrap_or(Duration::ZERO),
357 });
358
359 let current_loss = val_loss.unwrap_or(epoch_loss);
361 if current_loss < best_loss {
362 best_loss = current_loss;
363 best_parameters.clone_from(¶meters);
364 patience_counter = 0;
365 } else {
366 patience_counter += 1;
367 }
368
369 if patience_counter >= early_stopping_patience {
371 println!("Early stopping at epoch {epoch} due to no improvement");
372 break;
373 }
374
375 if epoch_loss < self.config.convergence_tolerance {
377 println!("Converged at epoch {epoch} with loss {epoch_loss:.6}");
378 break;
379 }
380
381 if epoch % 10 == 0 {
383 println!(
384 "Epoch {}: Loss={:.6}, Accuracy={:.4}, Val_Loss={:.6}, Fidelity={:.4}",
385 epoch,
386 epoch_loss,
387 epoch_accuracy,
388 val_loss.unwrap_or(0.0),
389 quantum_fidelity
390 );
391 }
392 }
393
394 let model = QMLModel {
396 model_type: self.model_type.clone(),
397 parameters: best_parameters.clone(),
398 circuit_structure: self.get_circuit_structure(),
399 training_metadata: self.get_training_metadata(),
400 performance_metrics: self.get_performance_metrics(&metrics),
401 };
402
403 Ok(TrainingResult {
404 model_id,
405 model,
406 final_loss: best_loss,
407 final_accuracy: metrics.accuracy_history.last().copied(),
408 training_time: start_time.elapsed(),
409 convergence_achieved: best_loss < self.config.convergence_tolerance,
410 optimal_parameters: best_parameters,
411 training_metrics: metrics,
412 })
413 }
414
415 async fn train_epoch(
417 &mut self,
418 parameters: &mut Vec<f64>,
419 training_data: &TrainingData,
420 ) -> DeviceResult<(f64, f64, f64)> {
421 let batch_size = self.config.batch_size.min(training_data.len());
422 let num_batches = training_data.len().div_ceil(batch_size);
423
424 let mut total_loss = 0.0;
425 let mut total_accuracy = 0.0;
426 let mut total_gradient_norm = 0.0;
427
428 for batch_idx in 0..num_batches {
429 let start_idx = batch_idx * batch_size;
430 let end_idx = (start_idx + batch_size).min(training_data.len());
431 let batch_indices: Vec<usize> = (start_idx..end_idx).collect();
432 let batch_data = training_data.get_batch(&batch_indices);
433
434 let predictions = self.forward_pass(parameters, &batch_data.features).await?;
436
437 let batch_loss = self
439 .loss_function
440 .compute_loss(&predictions, &batch_data.labels)?;
441 total_loss += batch_loss;
442
443 let batch_accuracy = self.compute_accuracy(&predictions, &batch_data.labels)?;
445 total_accuracy += batch_accuracy;
446
447 let gradients = self.backward_pass(parameters, &batch_data).await?;
449 let gradient_norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
450 total_gradient_norm += gradient_norm;
451
452 let loss_fn = Arc::new(MSELoss {}) as Arc<dyn LossFunction + Send + Sync>;
454 let objective_function = Box::new(BatchObjectiveFunction::new(
455 self.device.clone(),
456 batch_data,
457 loss_fn,
458 ));
459
460 let optimization_result = self
461 .optimizer
462 .optimize(parameters.clone(), objective_function)?;
463
464 *parameters = optimization_result.optimal_parameters;
465 }
466
467 Ok((
468 total_loss / num_batches as f64,
469 total_accuracy / num_batches as f64,
470 total_gradient_norm / num_batches as f64,
471 ))
472 }
473
474 async fn validate_epoch(
476 &self,
477 parameters: &[f64],
478 validation_data: &TrainingData,
479 ) -> DeviceResult<(f64, f64)> {
480 let predictions = self
481 .forward_pass(parameters, &validation_data.features)
482 .await?;
483 let loss = self
484 .loss_function
485 .compute_loss(&predictions, &validation_data.labels)?;
486 let accuracy = self.compute_accuracy(&predictions, &validation_data.labels)?;
487
488 Ok((loss, accuracy))
489 }
490
491 async fn forward_pass(
493 &self,
494 parameters: &[f64],
495 features: &[Vec<f64>],
496 ) -> DeviceResult<Vec<f64>> {
497 let mut predictions = Vec::new();
498
499 for feature_vector in features {
500 let prediction = self.evaluate_model(parameters, feature_vector).await?;
501 predictions.push(prediction);
502 }
503
504 Ok(predictions)
505 }
506
507 async fn backward_pass(
509 &self,
510 parameters: &[f64],
511 batch_data: &TrainingData,
512 ) -> DeviceResult<Vec<f64>> {
513 let circuit = self.build_training_circuit(parameters, &batch_data.features[0])?;
515
516 self.gradient_calculator
518 .compute_gradients(circuit, parameters.to_vec())
519 .await
520 }
521
522 async fn evaluate_model(&self, parameters: &[f64], features: &[f64]) -> DeviceResult<f64> {
524 let circuit = self.build_training_circuit(parameters, features)?;
525 let device = self.device.read().await;
526 let result = Self::execute_circuit_helper(&*device, &circuit, 1024).await?;
527
528 self.decode_quantum_output(&result)
530 }
531
532 fn build_training_circuit(
534 &self,
535 parameters: &[f64],
536 features: &[f64],
537 ) -> DeviceResult<ParameterizedQuantumCircuit> {
538 match self.model_type {
539 QMLModelType::VQC => self.build_vqc_circuit(parameters, features),
540 QMLModelType::QNN => self.build_qnn_circuit(parameters, features),
541 QMLModelType::QAOA => self.build_qaoa_circuit(parameters, features),
542 _ => Err(DeviceError::InvalidInput(format!(
543 "Model type {:?} not implemented",
544 self.model_type
545 ))),
546 }
547 }
548
549 fn build_vqc_circuit(
551 &self,
552 parameters: &[f64],
553 features: &[f64],
554 ) -> DeviceResult<ParameterizedQuantumCircuit> {
555 let num_qubits = (features.len() as f64).log2().ceil() as usize + 2;
556 let mut circuit = ParameterizedQuantumCircuit::new(num_qubits);
557
558 for (i, &feature) in features.iter().enumerate() {
560 if i < num_qubits {
561 circuit.add_ry_gate(i, feature)?;
562 }
563 }
564
565 let params_per_layer = num_qubits * 2; let num_layers = parameters.len() / params_per_layer;
568
569 let mut param_idx = 0;
570 for _layer in 0..num_layers {
571 for qubit in 0..num_qubits {
573 if param_idx < parameters.len() {
574 circuit.add_ry_gate(qubit, parameters[param_idx])?;
575 param_idx += 1;
576 }
577 if param_idx < parameters.len() {
578 circuit.add_rz_gate(qubit, parameters[param_idx])?;
579 param_idx += 1;
580 }
581 }
582
583 for qubit in 0..num_qubits - 1 {
585 circuit.add_cnot_gate(qubit, qubit + 1)?;
586 }
587 }
588
589 Ok(circuit)
590 }
591
592 fn build_qnn_circuit(
594 &self,
595 parameters: &[f64],
596 features: &[f64],
597 ) -> DeviceResult<ParameterizedQuantumCircuit> {
598 self.build_vqc_circuit(parameters, features)
600 }
601
602 fn build_qaoa_circuit(
604 &self,
605 _parameters: &[f64],
606 _features: &[f64],
607 ) -> DeviceResult<ParameterizedQuantumCircuit> {
608 Err(DeviceError::InvalidInput(
610 "QAOA circuit building not implemented".to_string(),
611 ))
612 }
613
614 fn decode_quantum_output(&self, result: &CircuitResult) -> DeviceResult<f64> {
616 let mut expectation = 0.0;
618 let total_shots = result.shots as f64;
619
620 for (bitstring, count) in &result.counts {
621 if let Some(first_bit) = bitstring.chars().next() {
622 let bit_value = if first_bit == '1' { 1.0 } else { 0.0 };
623 let probability = *count as f64 / total_shots;
624 expectation += bit_value * probability;
625 }
626 }
627
628 Ok(expectation)
629 }
630
631 fn compute_accuracy(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64> {
633 if predictions.len() != targets.len() {
634 return Err(DeviceError::InvalidInput(
635 "Predictions and targets must have same length".to_string(),
636 ));
637 }
638
639 let correct = predictions
640 .iter()
641 .zip(targets.iter())
642 .map(|(p, t)| {
643 let predicted_class = if *p > 0.5 { 1.0 } else { 0.0 };
644 if (predicted_class - t).abs() < 0.1 {
645 1.0
646 } else {
647 0.0
648 }
649 })
650 .sum::<f64>();
651
652 Ok(correct / predictions.len() as f64)
653 }
654
655 fn initialize_parameters(&self) -> DeviceResult<Vec<f64>> {
657 let param_count = match self.model_type {
658 QMLModelType::QNN => 30,
659 QMLModelType::QAOA => 10,
660 QMLModelType::VQC | _ => 20, };
662
663 let parameters = (0..param_count)
664 .map(|_| (fastrand::f64() * 2.0).mul_add(std::f64::consts::PI, -std::f64::consts::PI))
665 .collect();
666
667 Ok(parameters)
668 }
669
670 async fn estimate_quantum_fidelity(&self, _parameters: &[f64]) -> DeviceResult<f64> {
672 Ok(fastrand::f64().mul_add(0.05, 0.95))
674 }
675
676 fn get_circuit_structure(&self) -> CircuitStructure {
678 CircuitStructure {
679 num_qubits: 6, depth: 10,
681 gate_types: vec!["RY".to_string(), "RZ".to_string(), "CNOT".to_string()],
682 parameter_count: 20,
683 entangling_gates: 5,
684 }
685 }
686
687 fn get_training_metadata(&self) -> HashMap<String, String> {
689 let mut metadata = HashMap::new();
690 metadata.insert("trainer_type".to_string(), "quantum".to_string());
691 metadata.insert(
692 "optimizer".to_string(),
693 format!("{:?}", self.config.optimizer),
694 );
695 metadata.insert(
696 "gradient_method".to_string(),
697 format!("{:?}", self.config.gradient_method),
698 );
699 metadata.insert(
700 "learning_rate".to_string(),
701 self.config.learning_rate.to_string(),
702 );
703 metadata
704 }
705
706 fn get_performance_metrics(&self, metrics: &TrainingMetrics) -> HashMap<String, f64> {
708 let mut perf_metrics = HashMap::new();
709
710 if let Some(&final_loss) = metrics.loss_history.last() {
711 perf_metrics.insert("final_loss".to_string(), final_loss);
712 }
713
714 if let Some(&final_accuracy) = metrics.accuracy_history.last() {
715 perf_metrics.insert("final_accuracy".to_string(), final_accuracy);
716 }
717
718 if !metrics.loss_history.is_empty() {
719 let best_loss = metrics
720 .loss_history
721 .iter()
722 .fold(f64::INFINITY, |a, &b| a.min(b));
723 perf_metrics.insert("best_loss".to_string(), best_loss);
724 }
725
726 if !metrics.accuracy_history.is_empty() {
727 let best_accuracy = metrics
728 .accuracy_history
729 .iter()
730 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
731 perf_metrics.insert("best_accuracy".to_string(), best_accuracy);
732 }
733
734 perf_metrics
735 }
736
737 async fn execute_circuit_helper(
739 device: &(dyn QuantumDevice + Send + Sync),
740 circuit: &ParameterizedQuantumCircuit,
741 shots: usize,
742 ) -> DeviceResult<CircuitResult> {
743 let mut counts = std::collections::HashMap::new();
746 counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
747 counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
748
749 Ok(CircuitResult {
750 counts,
751 shots,
752 metadata: std::collections::HashMap::new(),
753 })
754 }
755}
756
757pub struct BatchObjectiveFunction {
759 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
760 batch_data: TrainingData,
761 loss_function: Arc<dyn LossFunction + Send + Sync>,
762}
763
764impl BatchObjectiveFunction {
765 pub fn new(
766 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
767 batch_data: TrainingData,
768 loss_function: Arc<dyn LossFunction + Send + Sync>,
769 ) -> Self {
770 Self {
771 device,
772 batch_data,
773 loss_function,
774 }
775 }
776}
777
778impl ObjectiveFunction for BatchObjectiveFunction {
779 fn evaluate(&self, parameters: &[f64]) -> DeviceResult<f64> {
780 let mut total_loss = 0.0;
783
784 for (features, target) in self
785 .batch_data
786 .features
787 .iter()
788 .zip(self.batch_data.labels.iter())
789 {
790 let prediction = parameters.iter().sum::<f64>() / parameters.len() as f64;
792 let loss = (prediction - target).powi(2);
793 total_loss += loss;
794 }
795
796 Ok(total_loss / self.batch_data.len() as f64)
797 }
798
799 fn gradient(&self, _parameters: &[f64]) -> DeviceResult<Option<Vec<f64>>> {
800 Ok(None)
802 }
803
804 fn metadata(&self) -> HashMap<String, String> {
805 let mut metadata = HashMap::new();
806 metadata.insert("objective_type".to_string(), "batch_training".to_string());
807 metadata.insert("batch_size".to_string(), self.batch_data.len().to_string());
808 metadata
809 }
810}
811
812pub fn create_training_data(features: Vec<Vec<f64>>, labels: Vec<f64>) -> TrainingData {
814 TrainingData::new(features, labels)
815}
816
817pub fn create_supervised_trainer(
819 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
820 model_type: QMLModelType,
821 config: QMLConfig,
822) -> DeviceResult<QuantumTrainer> {
823 QuantumTrainer::new(device, &config, model_type)
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829 use crate::test_utils::create_mock_quantum_device;
830
831 #[test]
832 fn test_training_data_creation() {
833 let features = vec![vec![0.1, 0.2], vec![0.3, 0.4], vec![0.5, 0.6]];
834 let labels = vec![0.0, 1.0, 0.0];
835
836 let training_data = TrainingData::new(features.clone(), labels.clone());
837
838 assert_eq!(training_data.len(), 3);
839 assert_eq!(training_data.features, features);
840 assert_eq!(training_data.labels, labels);
841 }
842
843 #[test]
844 fn test_training_data_batch() {
845 let features = vec![
846 vec![0.1, 0.2],
847 vec![0.3, 0.4],
848 vec![0.5, 0.6],
849 vec![0.7, 0.8],
850 ];
851 let labels = vec![0.0, 1.0, 0.0, 1.0];
852 let training_data = TrainingData::new(features, labels);
853
854 let batch_indices = vec![0, 2];
855 let batch = training_data.get_batch(&batch_indices);
856
857 assert_eq!(batch.len(), 2);
858 assert_eq!(batch.features[0], vec![0.1, 0.2]);
859 assert_eq!(batch.features[1], vec![0.5, 0.6]);
860 assert_eq!(batch.labels[0], 0.0);
861 assert_eq!(batch.labels[1], 0.0);
862 }
863
864 #[test]
865 fn test_mse_loss() {
866 let loss_fn = MSELoss;
867 let predictions = vec![0.8, 0.2, 0.9];
868 let targets = vec![1.0, 0.0, 1.0];
869
870 let loss = loss_fn
871 .compute_loss(&predictions, &targets)
872 .expect("MSE loss computation should succeed");
873 let expected_loss =
874 ((0.8_f64 - 1.0).powi(2) + (0.2_f64 - 0.0).powi(2) + (0.9_f64 - 1.0).powi(2)) / 3.0;
875 assert!((loss - expected_loss).abs() < 1e-10);
876
877 let gradients = loss_fn
878 .compute_gradients(&predictions, &targets)
879 .expect("MSE gradient computation should succeed");
880 assert_eq!(gradients.len(), 3);
881 }
882
883 #[test]
884 fn test_cross_entropy_loss() {
885 let loss_fn = CrossEntropyLoss;
886 let predictions = vec![0.8, 0.2, 0.9];
887 let targets = vec![1.0, 0.0, 1.0];
888
889 let loss = loss_fn
890 .compute_loss(&predictions, &targets)
891 .expect("CrossEntropy loss computation should succeed");
892 assert!(loss > 0.0); let gradients = loss_fn
895 .compute_gradients(&predictions, &targets)
896 .expect("CrossEntropy gradient computation should succeed");
897 assert_eq!(gradients.len(), 3);
898 }
899
900 #[tokio::test]
901 async fn test_quantum_trainer_creation() {
902 let device = create_mock_quantum_device();
903 let config = QMLConfig::default();
904
905 let trainer = QuantumTrainer::new(device, &config, QMLModelType::VQC)
906 .expect("QuantumTrainer creation should succeed");
907 assert_eq!(trainer.model_type, QMLModelType::VQC);
908 }
909
910 #[test]
911 fn test_training_metrics() {
912 let mut metrics = TrainingMetrics::new();
913
914 metrics.add_epoch(
915 0.5,
916 0.8,
917 Some(0.6),
918 Some(0.7),
919 0.1,
920 0.01,
921 0.95,
922 Duration::from_millis(100),
923 );
924
925 assert_eq!(metrics.loss_history.len(), 1);
926 assert_eq!(metrics.accuracy_history.len(), 1);
927 assert_eq!(metrics.validation_loss_history.len(), 1);
928 assert_eq!(metrics.validation_accuracy_history.len(), 1);
929 assert_eq!(metrics.loss_history[0], 0.5);
930 assert_eq!(metrics.accuracy_history[0], 0.8);
931 }
932}