quantrs2_device/quantum_ml/
training.rs

1//! Quantum Machine Learning Training
2//!
3//! This module provides training routines for quantum machine learning models,
4//! including supervised learning, unsupervised learning, and reinforcement learning.
5
6use super::*;
7use crate::{DeviceError, DeviceResult, QuantumDevice};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13
14/// Quantum trainer for ML models
15pub struct QuantumTrainer {
16    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
17    config: QMLConfig,
18    model_type: QMLModelType,
19    optimizer: Box<dyn QuantumOptimizer>,
20    gradient_calculator: QuantumGradientCalculator,
21    loss_function: Box<dyn LossFunction + Send + Sync>,
22}
23
24/// Training data structure
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TrainingData {
27    pub features: Vec<Vec<f64>>,
28    pub labels: Vec<f64>,
29    pub metadata: HashMap<String, String>,
30}
31
32impl TrainingData {
33    pub fn new(features: Vec<Vec<f64>>, labels: Vec<f64>) -> Self {
34        Self {
35            features,
36            labels,
37            metadata: HashMap::new(),
38        }
39    }
40
41    pub fn len(&self) -> usize {
42        self.features.len()
43    }
44
45    pub fn is_empty(&self) -> bool {
46        self.features.is_empty()
47    }
48
49    #[must_use]
50    pub fn get_batch(&self, indices: &[usize]) -> Self {
51        let batch_features = indices
52            .iter()
53            .filter_map(|&i| self.features.get(i))
54            .cloned()
55            .collect();
56        let batch_labels = indices
57            .iter()
58            .filter_map(|&i| self.labels.get(i))
59            .copied()
60            .collect();
61
62        Self {
63            features: batch_features,
64            labels: batch_labels,
65            metadata: self.metadata.clone(),
66        }
67    }
68
69    pub fn shuffle(&mut self) {
70        let n = self.len();
71        for i in 0..n {
72            let j = fastrand::usize(i..n);
73            self.features.swap(i, j);
74            self.labels.swap(i, j);
75        }
76    }
77}
78
79/// Training result
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TrainingResult {
82    pub model_id: String,
83    pub model: QMLModel,
84    pub final_loss: f64,
85    pub final_accuracy: Option<f64>,
86    pub training_time: Duration,
87    pub convergence_achieved: bool,
88    pub optimal_parameters: Vec<f64>,
89    pub training_metrics: TrainingMetrics,
90}
91
92/// Training metrics
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct TrainingMetrics {
95    pub loss_history: Vec<f64>,
96    pub accuracy_history: Vec<f64>,
97    pub validation_loss_history: Vec<f64>,
98    pub validation_accuracy_history: Vec<f64>,
99    pub gradient_norms: Vec<f64>,
100    pub learning_rates: Vec<f64>,
101    pub quantum_fidelities: Vec<f64>,
102    pub execution_times: Vec<Duration>,
103}
104
105impl Default for TrainingMetrics {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl TrainingMetrics {
112    pub const fn new() -> Self {
113        Self {
114            loss_history: Vec::new(),
115            accuracy_history: Vec::new(),
116            validation_loss_history: Vec::new(),
117            validation_accuracy_history: Vec::new(),
118            gradient_norms: Vec::new(),
119            learning_rates: Vec::new(),
120            quantum_fidelities: Vec::new(),
121            execution_times: Vec::new(),
122        }
123    }
124
125    pub fn add_epoch(
126        &mut self,
127        loss: f64,
128        accuracy: f64,
129        val_loss: Option<f64>,
130        val_accuracy: Option<f64>,
131        gradient_norm: f64,
132        learning_rate: f64,
133        quantum_fidelity: f64,
134        execution_time: Duration,
135    ) {
136        self.loss_history.push(loss);
137        self.accuracy_history.push(accuracy);
138        if let Some(vl) = val_loss {
139            self.validation_loss_history.push(vl);
140        }
141        if let Some(va) = val_accuracy {
142            self.validation_accuracy_history.push(va);
143        }
144        self.gradient_norms.push(gradient_norm);
145        self.learning_rates.push(learning_rate);
146        self.quantum_fidelities.push(quantum_fidelity);
147        self.execution_times.push(execution_time);
148    }
149}
150
151/// Loss function trait
152pub trait LossFunction: Send + Sync {
153    /// Compute loss value
154    fn compute_loss(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64>;
155
156    /// Compute loss gradients
157    fn compute_gradients(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<Vec<f64>>;
158
159    /// Get loss function name
160    fn name(&self) -> &str;
161}
162
163/// Mean squared error loss
164pub struct MSELoss;
165
166impl LossFunction for MSELoss {
167    fn compute_loss(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64> {
168        if predictions.len() != targets.len() {
169            return Err(DeviceError::InvalidInput(
170                "Predictions and targets must have same length".to_string(),
171            ));
172        }
173
174        let mse = predictions
175            .iter()
176            .zip(targets.iter())
177            .map(|(p, t)| (p - t).powi(2))
178            .sum::<f64>()
179            / predictions.len() as f64;
180
181        Ok(mse)
182    }
183
184    fn compute_gradients(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<Vec<f64>> {
185        if predictions.len() != targets.len() {
186            return Err(DeviceError::InvalidInput(
187                "Predictions and targets must have same length".to_string(),
188            ));
189        }
190
191        let gradients = predictions
192            .iter()
193            .zip(targets.iter())
194            .map(|(p, t)| 2.0 * (p - t) / predictions.len() as f64)
195            .collect();
196
197        Ok(gradients)
198    }
199
200    fn name(&self) -> &'static str {
201        "MSE"
202    }
203}
204
205/// Cross-entropy loss
206pub struct CrossEntropyLoss;
207
208impl LossFunction for CrossEntropyLoss {
209    fn compute_loss(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64> {
210        if predictions.len() != targets.len() {
211            return Err(DeviceError::InvalidInput(
212                "Predictions and targets must have same length".to_string(),
213            ));
214        }
215
216        let epsilon = 1e-15; // Prevent log(0)
217        let cross_entropy = -targets
218            .iter()
219            .zip(predictions.iter())
220            .map(|(t, p)| {
221                let p_clipped = p.clamp(epsilon, 1.0 - epsilon);
222                (1.0 - t).mul_add((1.0 - p_clipped).ln(), t * p_clipped.ln())
223            })
224            .sum::<f64>()
225            / predictions.len() as f64;
226
227        Ok(cross_entropy)
228    }
229
230    fn compute_gradients(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<Vec<f64>> {
231        if predictions.len() != targets.len() {
232            return Err(DeviceError::InvalidInput(
233                "Predictions and targets must have same length".to_string(),
234            ));
235        }
236
237        let epsilon = 1e-15;
238        let gradients = predictions
239            .iter()
240            .zip(targets.iter())
241            .map(|(p, t)| {
242                let p_clipped = p.clamp(epsilon, 1.0 - epsilon);
243                (p_clipped - t) / (p_clipped * (1.0 - p_clipped) * predictions.len() as f64)
244            })
245            .collect();
246
247        Ok(gradients)
248    }
249
250    fn name(&self) -> &'static str {
251        "CrossEntropy"
252    }
253}
254
255impl QuantumTrainer {
256    /// Create a new quantum trainer
257    pub fn new(
258        device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
259        config: &QMLConfig,
260        model_type: QMLModelType,
261    ) -> DeviceResult<Self> {
262        let optimizer = create_gradient_optimizer(
263            device.clone(),
264            config.optimizer.clone(),
265            config.learning_rate,
266        );
267
268        let gradient_config = GradientConfig {
269            method: config.gradient_method.clone(),
270            shots: 1024,
271            ..Default::default()
272        };
273
274        let gradient_calculator = QuantumGradientCalculator::new(device.clone(), gradient_config)?;
275
276        let loss_function: Box<dyn LossFunction + Send + Sync> = match model_type {
277            QMLModelType::VQC | QMLModelType::QNN => Box::new(CrossEntropyLoss),
278            _ => Box::new(MSELoss),
279        };
280
281        Ok(Self {
282            device,
283            config: config.clone(),
284            model_type,
285            optimizer,
286            gradient_calculator,
287            loss_function,
288        })
289    }
290
291    /// Train a quantum ML model
292    pub async fn train(
293        &mut self,
294        training_data: TrainingData,
295        validation_data: Option<TrainingData>,
296        training_history: &mut Vec<TrainingEpoch>,
297    ) -> DeviceResult<TrainingResult> {
298        let start_time = Instant::now();
299        let model_id = format!("qml_model_{}", uuid::Uuid::new_v4());
300
301        // Initialize model parameters
302        let mut parameters = self.initialize_parameters()?;
303        let mut metrics = TrainingMetrics::new();
304        let mut best_loss = f64::INFINITY;
305        let mut best_parameters = parameters.clone();
306        let mut patience_counter = 0;
307        let early_stopping_patience = 50;
308
309        for epoch in 0..self.config.max_epochs {
310            let epoch_start = Instant::now();
311
312            // Shuffle training data
313            let mut epoch_data = training_data.clone();
314            epoch_data.shuffle();
315
316            // Training step
317            let (epoch_loss, epoch_accuracy, gradient_norm) =
318                self.train_epoch(&mut parameters, &epoch_data).await?;
319
320            // Validation step
321            let (val_loss, val_accuracy) = if let Some(ref val_data) = validation_data {
322                let (vl, va) = self.validate_epoch(&parameters, val_data).await?;
323                (Some(vl), Some(va))
324            } else {
325                (None, None)
326            };
327
328            let execution_time = epoch_start.elapsed();
329            let quantum_fidelity = self.estimate_quantum_fidelity(&parameters).await?;
330
331            // Update metrics
332            metrics.add_epoch(
333                epoch_loss,
334                epoch_accuracy,
335                val_loss,
336                val_accuracy,
337                gradient_norm,
338                self.config.learning_rate,
339                quantum_fidelity,
340                execution_time,
341            );
342
343            // Add to training history
344            training_history.push(TrainingEpoch {
345                epoch,
346                loss: epoch_loss,
347                accuracy: Some(epoch_accuracy),
348                parameters: parameters.clone(),
349                gradient_norm,
350                learning_rate: self.config.learning_rate,
351                execution_time,
352                quantum_fidelity: Some(quantum_fidelity),
353                classical_preprocessing_time: Duration::from_millis(10),
354                quantum_execution_time: execution_time
355                    .checked_sub(Duration::from_millis(10))
356                    .unwrap_or(Duration::ZERO),
357            });
358
359            // Check for improvement
360            let current_loss = val_loss.unwrap_or(epoch_loss);
361            if current_loss < best_loss {
362                best_loss = current_loss;
363                best_parameters.clone_from(&parameters);
364                patience_counter = 0;
365            } else {
366                patience_counter += 1;
367            }
368
369            // Early stopping
370            if patience_counter >= early_stopping_patience {
371                println!("Early stopping at epoch {epoch} due to no improvement");
372                break;
373            }
374
375            // Convergence check
376            if epoch_loss < self.config.convergence_tolerance {
377                println!("Converged at epoch {epoch} with loss {epoch_loss:.6}");
378                break;
379            }
380
381            // Progress logging
382            if epoch % 10 == 0 {
383                println!(
384                    "Epoch {}: Loss={:.6}, Accuracy={:.4}, Val_Loss={:.6}, Fidelity={:.4}",
385                    epoch,
386                    epoch_loss,
387                    epoch_accuracy,
388                    val_loss.unwrap_or(0.0),
389                    quantum_fidelity
390                );
391            }
392        }
393
394        // Create final model
395        let model = QMLModel {
396            model_type: self.model_type.clone(),
397            parameters: best_parameters.clone(),
398            circuit_structure: self.get_circuit_structure(),
399            training_metadata: self.get_training_metadata(),
400            performance_metrics: self.get_performance_metrics(&metrics),
401        };
402
403        Ok(TrainingResult {
404            model_id,
405            model,
406            final_loss: best_loss,
407            final_accuracy: metrics.accuracy_history.last().copied(),
408            training_time: start_time.elapsed(),
409            convergence_achieved: best_loss < self.config.convergence_tolerance,
410            optimal_parameters: best_parameters,
411            training_metrics: metrics,
412        })
413    }
414
415    /// Train for one epoch
416    async fn train_epoch(
417        &mut self,
418        parameters: &mut Vec<f64>,
419        training_data: &TrainingData,
420    ) -> DeviceResult<(f64, f64, f64)> {
421        let batch_size = self.config.batch_size.min(training_data.len());
422        let num_batches = training_data.len().div_ceil(batch_size);
423
424        let mut total_loss = 0.0;
425        let mut total_accuracy = 0.0;
426        let mut total_gradient_norm = 0.0;
427
428        for batch_idx in 0..num_batches {
429            let start_idx = batch_idx * batch_size;
430            let end_idx = (start_idx + batch_size).min(training_data.len());
431            let batch_indices: Vec<usize> = (start_idx..end_idx).collect();
432            let batch_data = training_data.get_batch(&batch_indices);
433
434            // Forward pass
435            let predictions = self.forward_pass(parameters, &batch_data.features).await?;
436
437            // Compute loss
438            let batch_loss = self
439                .loss_function
440                .compute_loss(&predictions, &batch_data.labels)?;
441            total_loss += batch_loss;
442
443            // Compute accuracy
444            let batch_accuracy = self.compute_accuracy(&predictions, &batch_data.labels)?;
445            total_accuracy += batch_accuracy;
446
447            // Backward pass - compute gradients
448            let gradients = self.backward_pass(parameters, &batch_data).await?;
449            let gradient_norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
450            total_gradient_norm += gradient_norm;
451
452            // Update parameters
453            let loss_fn = Arc::new(MSELoss {}) as Arc<dyn LossFunction + Send + Sync>;
454            let objective_function = Box::new(BatchObjectiveFunction::new(
455                self.device.clone(),
456                batch_data,
457                loss_fn,
458            ));
459
460            let optimization_result = self
461                .optimizer
462                .optimize(parameters.clone(), objective_function)?;
463
464            *parameters = optimization_result.optimal_parameters;
465        }
466
467        Ok((
468            total_loss / num_batches as f64,
469            total_accuracy / num_batches as f64,
470            total_gradient_norm / num_batches as f64,
471        ))
472    }
473
474    /// Validate for one epoch
475    async fn validate_epoch(
476        &self,
477        parameters: &[f64],
478        validation_data: &TrainingData,
479    ) -> DeviceResult<(f64, f64)> {
480        let predictions = self
481            .forward_pass(parameters, &validation_data.features)
482            .await?;
483        let loss = self
484            .loss_function
485            .compute_loss(&predictions, &validation_data.labels)?;
486        let accuracy = self.compute_accuracy(&predictions, &validation_data.labels)?;
487
488        Ok((loss, accuracy))
489    }
490
491    /// Forward pass through the quantum model
492    async fn forward_pass(
493        &self,
494        parameters: &[f64],
495        features: &[Vec<f64>],
496    ) -> DeviceResult<Vec<f64>> {
497        let mut predictions = Vec::new();
498
499        for feature_vector in features {
500            let prediction = self.evaluate_model(parameters, feature_vector).await?;
501            predictions.push(prediction);
502        }
503
504        Ok(predictions)
505    }
506
507    /// Backward pass - compute gradients
508    async fn backward_pass(
509        &self,
510        parameters: &[f64],
511        batch_data: &TrainingData,
512    ) -> DeviceResult<Vec<f64>> {
513        // Create a circuit for this batch
514        let circuit = self.build_training_circuit(parameters, &batch_data.features[0])?;
515
516        // Compute gradients using the gradient calculator
517        self.gradient_calculator
518            .compute_gradients(circuit, parameters.to_vec())
519            .await
520    }
521
522    /// Evaluate the model for a single input
523    async fn evaluate_model(&self, parameters: &[f64], features: &[f64]) -> DeviceResult<f64> {
524        let circuit = self.build_training_circuit(parameters, features)?;
525        let device = self.device.read().await;
526        let result = Self::execute_circuit_helper(&*device, &circuit, 1024).await?;
527
528        // Convert quantum measurement to prediction
529        self.decode_quantum_output(&result)
530    }
531
532    /// Build training circuit
533    fn build_training_circuit(
534        &self,
535        parameters: &[f64],
536        features: &[f64],
537    ) -> DeviceResult<ParameterizedQuantumCircuit> {
538        match self.model_type {
539            QMLModelType::VQC => self.build_vqc_circuit(parameters, features),
540            QMLModelType::QNN => self.build_qnn_circuit(parameters, features),
541            QMLModelType::QAOA => self.build_qaoa_circuit(parameters, features),
542            _ => Err(DeviceError::InvalidInput(format!(
543                "Model type {:?} not implemented",
544                self.model_type
545            ))),
546        }
547    }
548
549    /// Build VQC circuit
550    fn build_vqc_circuit(
551        &self,
552        parameters: &[f64],
553        features: &[f64],
554    ) -> DeviceResult<ParameterizedQuantumCircuit> {
555        let num_qubits = (features.len() as f64).log2().ceil() as usize + 2;
556        let mut circuit = ParameterizedQuantumCircuit::new(num_qubits);
557
558        // Feature encoding
559        for (i, &feature) in features.iter().enumerate() {
560            if i < num_qubits {
561                circuit.add_ry_gate(i, feature)?;
562            }
563        }
564
565        // Parameterized layers
566        let params_per_layer = num_qubits * 2; // RY and RZ for each qubit
567        let num_layers = parameters.len() / params_per_layer;
568
569        let mut param_idx = 0;
570        for _layer in 0..num_layers {
571            // Rotation gates
572            for qubit in 0..num_qubits {
573                if param_idx < parameters.len() {
574                    circuit.add_ry_gate(qubit, parameters[param_idx])?;
575                    param_idx += 1;
576                }
577                if param_idx < parameters.len() {
578                    circuit.add_rz_gate(qubit, parameters[param_idx])?;
579                    param_idx += 1;
580                }
581            }
582
583            // Entangling gates
584            for qubit in 0..num_qubits - 1 {
585                circuit.add_cnot_gate(qubit, qubit + 1)?;
586            }
587        }
588
589        Ok(circuit)
590    }
591
592    /// Build QNN circuit (similar to VQC but different structure)
593    fn build_qnn_circuit(
594        &self,
595        parameters: &[f64],
596        features: &[f64],
597    ) -> DeviceResult<ParameterizedQuantumCircuit> {
598        // For now, use same structure as VQC
599        self.build_vqc_circuit(parameters, features)
600    }
601
602    /// Build QAOA circuit
603    fn build_qaoa_circuit(
604        &self,
605        _parameters: &[f64],
606        _features: &[f64],
607    ) -> DeviceResult<ParameterizedQuantumCircuit> {
608        // QAOA implementation would be more complex
609        Err(DeviceError::InvalidInput(
610            "QAOA circuit building not implemented".to_string(),
611        ))
612    }
613
614    /// Decode quantum output to classical prediction
615    fn decode_quantum_output(&self, result: &CircuitResult) -> DeviceResult<f64> {
616        // Simple decoding: expectation value of first qubit
617        let mut expectation = 0.0;
618        let total_shots = result.shots as f64;
619
620        for (bitstring, count) in &result.counts {
621            if let Some(first_bit) = bitstring.chars().next() {
622                let bit_value = if first_bit == '1' { 1.0 } else { 0.0 };
623                let probability = *count as f64 / total_shots;
624                expectation += bit_value * probability;
625            }
626        }
627
628        Ok(expectation)
629    }
630
631    /// Compute accuracy for classification
632    fn compute_accuracy(&self, predictions: &[f64], targets: &[f64]) -> DeviceResult<f64> {
633        if predictions.len() != targets.len() {
634            return Err(DeviceError::InvalidInput(
635                "Predictions and targets must have same length".to_string(),
636            ));
637        }
638
639        let correct = predictions
640            .iter()
641            .zip(targets.iter())
642            .map(|(p, t)| {
643                let predicted_class = if *p > 0.5 { 1.0 } else { 0.0 };
644                if (predicted_class - t).abs() < 0.1 {
645                    1.0
646                } else {
647                    0.0
648                }
649            })
650            .sum::<f64>();
651
652        Ok(correct / predictions.len() as f64)
653    }
654
655    /// Initialize model parameters
656    fn initialize_parameters(&self) -> DeviceResult<Vec<f64>> {
657        let param_count = match self.model_type {
658            QMLModelType::QNN => 30,
659            QMLModelType::QAOA => 10,
660            QMLModelType::VQC | _ => 20, // Default parameter count
661        };
662
663        let parameters = (0..param_count)
664            .map(|_| (fastrand::f64() * 2.0).mul_add(std::f64::consts::PI, -std::f64::consts::PI))
665            .collect();
666
667        Ok(parameters)
668    }
669
670    /// Estimate quantum fidelity
671    async fn estimate_quantum_fidelity(&self, _parameters: &[f64]) -> DeviceResult<f64> {
672        // Simplified fidelity estimate
673        Ok(fastrand::f64().mul_add(0.05, 0.95))
674    }
675
676    /// Get circuit structure description
677    fn get_circuit_structure(&self) -> CircuitStructure {
678        CircuitStructure {
679            num_qubits: 6, // Default
680            depth: 10,
681            gate_types: vec!["RY".to_string(), "RZ".to_string(), "CNOT".to_string()],
682            parameter_count: 20,
683            entangling_gates: 5,
684        }
685    }
686
687    /// Get training metadata
688    fn get_training_metadata(&self) -> HashMap<String, String> {
689        let mut metadata = HashMap::new();
690        metadata.insert("trainer_type".to_string(), "quantum".to_string());
691        metadata.insert(
692            "optimizer".to_string(),
693            format!("{:?}", self.config.optimizer),
694        );
695        metadata.insert(
696            "gradient_method".to_string(),
697            format!("{:?}", self.config.gradient_method),
698        );
699        metadata.insert(
700            "learning_rate".to_string(),
701            self.config.learning_rate.to_string(),
702        );
703        metadata
704    }
705
706    /// Get performance metrics
707    fn get_performance_metrics(&self, metrics: &TrainingMetrics) -> HashMap<String, f64> {
708        let mut perf_metrics = HashMap::new();
709
710        if let Some(&final_loss) = metrics.loss_history.last() {
711            perf_metrics.insert("final_loss".to_string(), final_loss);
712        }
713
714        if let Some(&final_accuracy) = metrics.accuracy_history.last() {
715            perf_metrics.insert("final_accuracy".to_string(), final_accuracy);
716        }
717
718        if !metrics.loss_history.is_empty() {
719            let best_loss = metrics
720                .loss_history
721                .iter()
722                .fold(f64::INFINITY, |a, &b| a.min(b));
723            perf_metrics.insert("best_loss".to_string(), best_loss);
724        }
725
726        if !metrics.accuracy_history.is_empty() {
727            let best_accuracy = metrics
728                .accuracy_history
729                .iter()
730                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
731            perf_metrics.insert("best_accuracy".to_string(), best_accuracy);
732        }
733
734        perf_metrics
735    }
736
737    /// Execute a circuit on the quantum device (helper function to work around trait object limitations)
738    async fn execute_circuit_helper(
739        device: &(dyn QuantumDevice + Send + Sync),
740        circuit: &ParameterizedQuantumCircuit,
741        shots: usize,
742    ) -> DeviceResult<CircuitResult> {
743        // For now, return a mock result since we can't execute circuits directly
744        // In a real implementation, this would need proper circuit execution
745        let mut counts = std::collections::HashMap::new();
746        counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
747        counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
748
749        Ok(CircuitResult {
750            counts,
751            shots,
752            metadata: std::collections::HashMap::new(),
753        })
754    }
755}
756
757/// Batch objective function for optimization
758pub struct BatchObjectiveFunction {
759    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
760    batch_data: TrainingData,
761    loss_function: Arc<dyn LossFunction + Send + Sync>,
762}
763
764impl BatchObjectiveFunction {
765    pub fn new(
766        device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
767        batch_data: TrainingData,
768        loss_function: Arc<dyn LossFunction + Send + Sync>,
769    ) -> Self {
770        Self {
771            device,
772            batch_data,
773            loss_function,
774        }
775    }
776}
777
778impl ObjectiveFunction for BatchObjectiveFunction {
779    fn evaluate(&self, parameters: &[f64]) -> DeviceResult<f64> {
780        // Simplified batch evaluation
781        // In practice, this would run the quantum circuit for the batch
782        let mut total_loss = 0.0;
783
784        for (features, target) in self
785            .batch_data
786            .features
787            .iter()
788            .zip(self.batch_data.labels.iter())
789        {
790            // Simplified prediction
791            let prediction = parameters.iter().sum::<f64>() / parameters.len() as f64;
792            let loss = (prediction - target).powi(2);
793            total_loss += loss;
794        }
795
796        Ok(total_loss / self.batch_data.len() as f64)
797    }
798
799    fn gradient(&self, _parameters: &[f64]) -> DeviceResult<Option<Vec<f64>>> {
800        // Gradients would be computed via parameter shift rule
801        Ok(None)
802    }
803
804    fn metadata(&self) -> HashMap<String, String> {
805        let mut metadata = HashMap::new();
806        metadata.insert("objective_type".to_string(), "batch_training".to_string());
807        metadata.insert("batch_size".to_string(), self.batch_data.len().to_string());
808        metadata
809    }
810}
811
812/// Create training data from vectors
813pub fn create_training_data(features: Vec<Vec<f64>>, labels: Vec<f64>) -> TrainingData {
814    TrainingData::new(features, labels)
815}
816
817/// Create a supervised learning trainer
818pub fn create_supervised_trainer(
819    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
820    model_type: QMLModelType,
821    config: QMLConfig,
822) -> DeviceResult<QuantumTrainer> {
823    QuantumTrainer::new(device, &config, model_type)
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use crate::test_utils::create_mock_quantum_device;
830
831    #[test]
832    fn test_training_data_creation() {
833        let features = vec![vec![0.1, 0.2], vec![0.3, 0.4], vec![0.5, 0.6]];
834        let labels = vec![0.0, 1.0, 0.0];
835
836        let training_data = TrainingData::new(features.clone(), labels.clone());
837
838        assert_eq!(training_data.len(), 3);
839        assert_eq!(training_data.features, features);
840        assert_eq!(training_data.labels, labels);
841    }
842
843    #[test]
844    fn test_training_data_batch() {
845        let features = vec![
846            vec![0.1, 0.2],
847            vec![0.3, 0.4],
848            vec![0.5, 0.6],
849            vec![0.7, 0.8],
850        ];
851        let labels = vec![0.0, 1.0, 0.0, 1.0];
852        let training_data = TrainingData::new(features, labels);
853
854        let batch_indices = vec![0, 2];
855        let batch = training_data.get_batch(&batch_indices);
856
857        assert_eq!(batch.len(), 2);
858        assert_eq!(batch.features[0], vec![0.1, 0.2]);
859        assert_eq!(batch.features[1], vec![0.5, 0.6]);
860        assert_eq!(batch.labels[0], 0.0);
861        assert_eq!(batch.labels[1], 0.0);
862    }
863
864    #[test]
865    fn test_mse_loss() {
866        let loss_fn = MSELoss;
867        let predictions = vec![0.8, 0.2, 0.9];
868        let targets = vec![1.0, 0.0, 1.0];
869
870        let loss = loss_fn
871            .compute_loss(&predictions, &targets)
872            .expect("MSE loss computation should succeed");
873        let expected_loss =
874            ((0.8_f64 - 1.0).powi(2) + (0.2_f64 - 0.0).powi(2) + (0.9_f64 - 1.0).powi(2)) / 3.0;
875        assert!((loss - expected_loss).abs() < 1e-10);
876
877        let gradients = loss_fn
878            .compute_gradients(&predictions, &targets)
879            .expect("MSE gradient computation should succeed");
880        assert_eq!(gradients.len(), 3);
881    }
882
883    #[test]
884    fn test_cross_entropy_loss() {
885        let loss_fn = CrossEntropyLoss;
886        let predictions = vec![0.8, 0.2, 0.9];
887        let targets = vec![1.0, 0.0, 1.0];
888
889        let loss = loss_fn
890            .compute_loss(&predictions, &targets)
891            .expect("CrossEntropy loss computation should succeed");
892        assert!(loss > 0.0); // Cross-entropy should be positive
893
894        let gradients = loss_fn
895            .compute_gradients(&predictions, &targets)
896            .expect("CrossEntropy gradient computation should succeed");
897        assert_eq!(gradients.len(), 3);
898    }
899
900    #[tokio::test]
901    async fn test_quantum_trainer_creation() {
902        let device = create_mock_quantum_device();
903        let config = QMLConfig::default();
904
905        let trainer = QuantumTrainer::new(device, &config, QMLModelType::VQC)
906            .expect("QuantumTrainer creation should succeed");
907        assert_eq!(trainer.model_type, QMLModelType::VQC);
908    }
909
910    #[test]
911    fn test_training_metrics() {
912        let mut metrics = TrainingMetrics::new();
913
914        metrics.add_epoch(
915            0.5,
916            0.8,
917            Some(0.6),
918            Some(0.7),
919            0.1,
920            0.01,
921            0.95,
922            Duration::from_millis(100),
923        );
924
925        assert_eq!(metrics.loss_history.len(), 1);
926        assert_eq!(metrics.accuracy_history.len(), 1);
927        assert_eq!(metrics.validation_loss_history.len(), 1);
928        assert_eq!(metrics.validation_accuracy_history.len(), 1);
929        assert_eq!(metrics.loss_history[0], 0.5);
930        assert_eq!(metrics.accuracy_history[0], 0.8);
931    }
932}