use crate::prelude::HardwareOptimizations;
use scirs2_core::ndarray::Array1;
use scirs2_core::parallel_ops::{IndexedParallelIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use crate::circuit_interfaces::{CircuitInterface, InterfaceCircuit};
use crate::device_noise_models::DeviceNoiseModel;
use crate::error::Result;
use super::circuit::ParameterizedQuantumCircuit;
use super::config::{GradientMethod, HardwareArchitecture, OptimizerType, QMLConfig};
pub struct QuantumMLTrainer {
config: QMLConfig,
pqc: ParameterizedQuantumCircuit,
optimizer_state: OptimizerState,
training_history: TrainingHistory,
noise_model: Option<Box<dyn DeviceNoiseModel>>,
circuit_interface: CircuitInterface,
hardware_compiler: HardwareAwareCompiler,
}
#[derive(Debug, Clone)]
pub struct OptimizerState {
pub parameters: Array1<f64>,
pub gradient: Array1<f64>,
pub momentum: Array1<f64>,
pub velocity: Array1<f64>,
pub learning_rate: f64,
pub iteration: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingHistory {
pub loss_history: Vec<f64>,
pub gradient_norms: Vec<f64>,
pub parameter_norms: Vec<f64>,
pub epoch_times: Vec<f64>,
pub hardware_metrics: Vec<HardwareMetrics>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HardwareMetrics {
pub compiled_depth: usize,
pub two_qubit_gates: usize,
pub execution_time: f64,
pub estimated_fidelity: f64,
pub shot_overhead: f64,
}
#[derive(Debug, Clone)]
pub struct HardwareAwareCompiler {
hardware_arch: HardwareArchitecture,
hardware_opts: HardwareOptimizations,
compilation_stats: CompilationStats,
}
#[derive(Debug, Clone, Default)]
pub struct CompilationStats {
pub original_depth: usize,
pub compiled_depth: usize,
pub swap_gates_added: usize,
pub compilation_time: f64,
pub estimated_execution_time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub final_parameters: Array1<f64>,
pub final_loss: f64,
pub epochs_completed: usize,
pub training_history: TrainingHistory,
pub converged: bool,
}
impl QuantumMLTrainer {
pub fn new(
config: QMLConfig,
pqc: ParameterizedQuantumCircuit,
noise_model: Option<Box<dyn DeviceNoiseModel>>,
) -> Result<Self> {
let num_params = pqc.num_parameters();
let optimizer_state = OptimizerState {
parameters: pqc.parameters.clone(),
gradient: Array1::zeros(num_params),
momentum: Array1::zeros(num_params),
velocity: Array1::zeros(num_params),
learning_rate: config.learning_rate,
iteration: 0,
};
let training_history = TrainingHistory::default();
let circuit_interface = CircuitInterface::new(Default::default())?;
let hardware_compiler = HardwareAwareCompiler::new(
config.hardware_architecture,
pqc.hardware_optimizations.clone(),
);
Ok(Self {
config,
pqc,
optimizer_state,
training_history,
noise_model,
circuit_interface,
hardware_compiler,
})
}
pub fn train<F>(&mut self, loss_function: F) -> Result<TrainingResult>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
let start_time = std::time::Instant::now();
for epoch in 0..self.config.max_epochs {
let epoch_start = std::time::Instant::now();
let gradient = self.compute_gradient(&loss_function)?;
self.optimizer_state.gradient = gradient;
self.update_parameters()?;
let current_loss = loss_function(&self.optimizer_state.parameters)?;
let epoch_time = epoch_start.elapsed().as_secs_f64();
self.training_history.loss_history.push(current_loss);
self.training_history.gradient_norms.push(
self.optimizer_state
.gradient
.iter()
.map(|x| x * x)
.sum::<f64>()
.sqrt(),
);
self.training_history.parameter_norms.push(
self.optimizer_state
.parameters
.iter()
.map(|x| x * x)
.sum::<f64>()
.sqrt(),
);
self.training_history.epoch_times.push(epoch_time);
if self.check_convergence(current_loss)? {
return Ok(TrainingResult {
final_parameters: self.optimizer_state.parameters.clone(),
final_loss: current_loss,
epochs_completed: epoch + 1,
training_history: self.training_history.clone(),
converged: true,
});
}
self.optimizer_state.iteration += 1;
}
let final_loss = loss_function(&self.optimizer_state.parameters)?;
Ok(TrainingResult {
final_parameters: self.optimizer_state.parameters.clone(),
final_loss,
epochs_completed: self.config.max_epochs,
training_history: self.training_history.clone(),
converged: false,
})
}
fn compute_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
match self.config.gradient_method {
GradientMethod::ParameterShift => self.compute_parameter_shift_gradient(loss_function),
GradientMethod::FiniteDifferences => {
self.compute_finite_difference_gradient(loss_function)
}
GradientMethod::AutomaticDifferentiation => {
self.compute_autodiff_gradient(loss_function)
}
GradientMethod::NaturalGradients => self.compute_natural_gradient(loss_function),
GradientMethod::StochasticParameterShift => {
self.compute_stochastic_parameter_shift_gradient(loss_function)
}
}
}
fn compute_parameter_shift_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
let num_params = self.optimizer_state.parameters.len();
let mut gradient = Array1::zeros(num_params);
let shift = std::f64::consts::PI / 2.0;
for i in 0..num_params {
let mut params_plus = self.optimizer_state.parameters.clone();
let mut params_minus = self.optimizer_state.parameters.clone();
params_plus[i] += shift;
params_minus[i] -= shift;
let loss_plus = loss_function(¶ms_plus)?;
let loss_minus = loss_function(¶ms_minus)?;
gradient[i] = (loss_plus - loss_minus) / 2.0;
}
Ok(gradient)
}
fn compute_finite_difference_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
let num_params = self.optimizer_state.parameters.len();
let mut gradient = Array1::zeros(num_params);
let eps = 1e-8;
for i in 0..num_params {
let mut params_plus = self.optimizer_state.parameters.clone();
params_plus[i] += eps;
let loss_plus = loss_function(¶ms_plus)?;
let loss_current = loss_function(&self.optimizer_state.parameters)?;
gradient[i] = (loss_plus - loss_current) / eps;
}
Ok(gradient)
}
fn compute_autodiff_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
self.compute_parameter_shift_gradient(loss_function)
}
fn compute_natural_gradient<F>(&self, loss_function: &F) -> Result<Array1<f64>>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
let gradient = self.compute_parameter_shift_gradient(loss_function)?;
Ok(gradient)
}
fn compute_stochastic_parameter_shift_gradient<F>(
&self,
loss_function: &F,
) -> Result<Array1<f64>>
where
F: Fn(&Array1<f64>) -> Result<f64> + Send + Sync,
{
self.compute_parameter_shift_gradient(loss_function)
}
fn update_parameters(&mut self) -> Result<()> {
match self.config.optimizer_type {
OptimizerType::Adam => self.update_parameters_adam(),
OptimizerType::SGD => self.update_parameters_sgd(),
OptimizerType::RMSprop => self.update_parameters_rmsprop(),
OptimizerType::LBFGS => self.update_parameters_lbfgs(),
OptimizerType::QuantumNaturalGradient => self.update_parameters_qng(),
OptimizerType::SPSA => self.update_parameters_spsa(),
}
}
fn update_parameters_adam(&mut self) -> Result<()> {
let beta1 = 0.9;
let beta2 = 0.999;
let eps = 1e-8;
for i in 0..self.optimizer_state.parameters.len() {
self.optimizer_state.momentum[i] = beta1 * self.optimizer_state.momentum[i]
+ (1.0 - beta1) * self.optimizer_state.gradient[i];
self.optimizer_state.velocity[i] = beta2 * self.optimizer_state.velocity[i]
+ (1.0 - beta2) * self.optimizer_state.gradient[i].powi(2);
let m_hat = self.optimizer_state.momentum[i]
/ (1.0 - beta1.powi(self.optimizer_state.iteration as i32 + 1));
let v_hat = self.optimizer_state.velocity[i]
/ (1.0 - beta2.powi(self.optimizer_state.iteration as i32 + 1));
self.optimizer_state.parameters[i] -=
self.optimizer_state.learning_rate * m_hat / (v_hat.sqrt() + eps);
}
Ok(())
}
fn update_parameters_sgd(&mut self) -> Result<()> {
for i in 0..self.optimizer_state.parameters.len() {
self.optimizer_state.parameters[i] -=
self.optimizer_state.learning_rate * self.optimizer_state.gradient[i];
}
Ok(())
}
fn update_parameters_rmsprop(&mut self) -> Result<()> {
let alpha = 0.99;
let eps = 1e-8;
for i in 0..self.optimizer_state.parameters.len() {
self.optimizer_state.velocity[i] = alpha * self.optimizer_state.velocity[i]
+ (1.0 - alpha) * self.optimizer_state.gradient[i].powi(2);
self.optimizer_state.parameters[i] -= self.optimizer_state.learning_rate
* self.optimizer_state.gradient[i]
/ (self.optimizer_state.velocity[i].sqrt() + eps);
}
Ok(())
}
fn update_parameters_lbfgs(&mut self) -> Result<()> {
self.update_parameters_sgd()
}
fn update_parameters_qng(&mut self) -> Result<()> {
self.update_parameters_sgd()
}
fn update_parameters_spsa(&mut self) -> Result<()> {
self.update_parameters_sgd()
}
fn check_convergence(&self, current_loss: f64) -> Result<bool> {
if self.training_history.loss_history.len() < 2 {
return Ok(false);
}
let prev_loss =
self.training_history.loss_history[self.training_history.loss_history.len() - 1];
let loss_change = (current_loss - prev_loss).abs();
Ok(loss_change < self.config.convergence_tolerance)
}
#[must_use]
pub const fn get_parameters(&self) -> &Array1<f64> {
&self.optimizer_state.parameters
}
#[must_use]
pub const fn get_training_history(&self) -> &TrainingHistory {
&self.training_history
}
pub const fn set_learning_rate(&mut self, lr: f64) {
self.optimizer_state.learning_rate = lr;
}
pub fn reset_optimizer(&mut self) {
let num_params = self.optimizer_state.parameters.len();
self.optimizer_state.gradient = Array1::zeros(num_params);
self.optimizer_state.momentum = Array1::zeros(num_params);
self.optimizer_state.velocity = Array1::zeros(num_params);
self.optimizer_state.iteration = 0;
self.training_history = TrainingHistory::default();
}
}
impl HardwareAwareCompiler {
#[must_use]
pub fn new(hardware_arch: HardwareArchitecture, hardware_opts: HardwareOptimizations) -> Self {
Self {
hardware_arch,
hardware_opts,
compilation_stats: CompilationStats::default(),
}
}
pub fn compile_circuit(&mut self, circuit: &InterfaceCircuit) -> Result<InterfaceCircuit> {
let start_time = std::time::Instant::now();
self.compilation_stats.original_depth = circuit.gates.len();
let compiled_circuit = circuit.clone();
self.compilation_stats.compiled_depth = compiled_circuit.gates.len();
self.compilation_stats.compilation_time = start_time.elapsed().as_secs_f64();
Ok(compiled_circuit)
}
#[must_use]
pub const fn get_stats(&self) -> &CompilationStats {
&self.compilation_stats
}
}
impl OptimizerState {
#[must_use]
pub fn new(num_parameters: usize, learning_rate: f64) -> Self {
Self {
parameters: Array1::zeros(num_parameters),
gradient: Array1::zeros(num_parameters),
momentum: Array1::zeros(num_parameters),
velocity: Array1::zeros(num_parameters),
learning_rate,
iteration: 0,
}
}
}
impl TrainingHistory {
#[must_use]
pub fn latest_loss(&self) -> Option<f64> {
self.loss_history.last().copied()
}
#[must_use]
pub fn best_loss(&self) -> Option<f64> {
self.loss_history
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.copied()
}
#[must_use]
pub fn average_epoch_time(&self) -> f64 {
if self.epoch_times.is_empty() {
0.0
} else {
self.epoch_times.iter().sum::<f64>() / self.epoch_times.len() as f64
}
}
}