use trustformers_core::errors::Result;
use super::config::{HybridTrainingStrategy, QuantumClassicalConfig};
#[derive(Debug)]
pub struct QuantumTrainingManager {
pub config: QuantumClassicalConfig,
pub training_strategy: HybridTrainingStrategy,
pub classical_lr: f64,
pub quantum_lr: f64,
pub training_metrics: QuantumTrainingMetrics,
pub current_epoch: usize,
pub training_history: Vec<QuantumTrainingMetrics>,
}
#[derive(Debug, Clone)]
pub struct QuantumTrainingMetrics {
pub classical_loss: f64,
pub quantum_loss: f64,
pub total_loss: f64,
pub quantum_fidelity: f64,
pub classical_accuracy: f64,
pub quantum_advantage: f64,
pub training_time: f64,
}
impl QuantumTrainingManager {
pub fn new(config: &QuantumClassicalConfig) -> Result<Self> {
let training_metrics = QuantumTrainingMetrics {
classical_loss: 0.0,
quantum_loss: 0.0,
total_loss: 0.0,
quantum_fidelity: 1.0,
classical_accuracy: 0.0,
quantum_advantage: 0.0,
training_time: 0.0,
};
Ok(Self {
config: config.clone(),
training_strategy: config.hybrid_training_strategy.clone(),
classical_lr: config.classical_learning_rate,
quantum_lr: config.quantum_learning_rate,
training_metrics,
current_epoch: 0,
training_history: Vec::new(),
})
}
pub fn train_epoch(
&mut self,
classical_gradients: &[f32],
quantum_gradients: &[f64],
) -> Result<QuantumTrainingMetrics> {
let start_time = std::time::Instant::now();
match self.training_strategy {
HybridTrainingStrategy::Sequential => {
self.train_sequential(classical_gradients, quantum_gradients)?;
},
HybridTrainingStrategy::Alternating => {
self.train_alternating(classical_gradients, quantum_gradients)?;
},
HybridTrainingStrategy::Joint => {
self.train_joint(classical_gradients, quantum_gradients)?;
},
HybridTrainingStrategy::Adaptive => {
self.train_adaptive(classical_gradients, quantum_gradients)?;
},
}
let training_time = start_time.elapsed().as_secs_f64();
self.training_metrics.training_time = training_time;
self.training_history.push(self.training_metrics.clone());
self.current_epoch += 1;
Ok(self.training_metrics.clone())
}
fn train_sequential(
&mut self,
classical_gradients: &[f32],
quantum_gradients: &[f64],
) -> Result<()> {
self.update_classical_parameters(classical_gradients)?;
self.update_quantum_parameters(quantum_gradients)?;
Ok(())
}
fn train_alternating(
&mut self,
classical_gradients: &[f32],
quantum_gradients: &[f64],
) -> Result<()> {
if self.current_epoch.is_multiple_of(2) {
self.update_classical_parameters(classical_gradients)?;
} else {
self.update_quantum_parameters(quantum_gradients)?;
}
Ok(())
}
fn train_joint(
&mut self,
classical_gradients: &[f32],
quantum_gradients: &[f64],
) -> Result<()> {
self.update_classical_parameters(classical_gradients)?;
self.update_quantum_parameters(quantum_gradients)?;
Ok(())
}
fn train_adaptive(
&mut self,
classical_gradients: &[f32],
quantum_gradients: &[f64],
) -> Result<()> {
let classical_grad_norm =
classical_gradients.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt();
let quantum_grad_norm = quantum_gradients.iter().map(|&x| x.powi(2)).sum::<f64>().sqrt();
if classical_grad_norm as f64 > quantum_grad_norm {
self.update_classical_parameters(classical_gradients)?;
} else {
self.update_quantum_parameters(quantum_gradients)?;
}
Ok(())
}
fn update_classical_parameters(&mut self, gradients: &[f32]) -> Result<()> {
let classical_loss = gradients.iter().map(|&x| x.powi(2)).sum::<f32>() as f64;
self.training_metrics.classical_loss = classical_loss;
Ok(())
}
fn update_quantum_parameters(&mut self, gradients: &[f64]) -> Result<()> {
let quantum_loss = gradients.iter().map(|&x| x.powi(2)).sum::<f64>();
self.training_metrics.quantum_loss = quantum_loss;
self.training_metrics.quantum_fidelity = 1.0 - self.config.quantum_noise_variance;
Ok(())
}
pub fn get_training_stats(&self) -> QuantumTrainingStats {
let avg_classical_loss =
self.training_history.iter().map(|m| m.classical_loss).sum::<f64>()
/ self.training_history.len() as f64;
let avg_quantum_loss = self.training_history.iter().map(|m| m.quantum_loss).sum::<f64>()
/ self.training_history.len() as f64;
let avg_quantum_fidelity =
self.training_history.iter().map(|m| m.quantum_fidelity).sum::<f64>()
/ self.training_history.len() as f64;
QuantumTrainingStats {
total_epochs: self.current_epoch,
avg_classical_loss,
avg_quantum_loss,
avg_quantum_fidelity,
training_strategy: self.training_strategy.clone(),
convergence_rate: self.compute_convergence_rate(),
}
}
fn compute_convergence_rate(&self) -> f64 {
if self.training_history.len() < 2 {
return 0.0;
}
let first_loss = self.training_history[0].total_loss;
let last_loss = self.training_history.last().expect("operation failed").total_loss;
if first_loss > 0.0 {
(first_loss - last_loss) / first_loss
} else {
0.0
}
}
pub fn reset(&mut self) {
self.current_epoch = 0;
self.training_history.clear();
self.training_metrics = QuantumTrainingMetrics {
classical_loss: 0.0,
quantum_loss: 0.0,
total_loss: 0.0,
quantum_fidelity: 1.0,
classical_accuracy: 0.0,
quantum_advantage: 0.0,
training_time: 0.0,
};
}
}
#[derive(Debug, Clone)]
pub struct QuantumTrainingStats {
pub total_epochs: usize,
pub avg_classical_loss: f64,
pub avg_quantum_loss: f64,
pub avg_quantum_fidelity: f64,
pub training_strategy: HybridTrainingStrategy,
pub convergence_rate: f64,
}
impl Default for QuantumTrainingMetrics {
fn default() -> Self {
Self {
classical_loss: 0.0,
quantum_loss: 0.0,
total_loss: 0.0,
quantum_fidelity: 1.0,
classical_accuracy: 0.0,
quantum_advantage: 0.0,
training_time: 0.0,
}
}
}