use crate::autodiff::optimizers::Optimizer;
use crate::error::{MLError, Result};
use crate::optimization::OptimizationMethod;
use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
use quantrs2_circuit::builder::{Circuit, Simulator};
use quantrs2_core::gate::{
single::{RotationX, RotationY, RotationZ},
GateOp,
};
use quantrs2_sim::statevector::StateVectorSimulator;
use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
use std::collections::{HashMap, HashSet, VecDeque};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub enum ContinualLearningStrategy {
ElasticWeightConsolidation {
importance_weight: f64,
fisher_samples: usize,
},
ProgressiveNetworks {
lateral_connections: bool,
adaptation_layers: usize,
},
ExperienceReplay {
buffer_size: usize,
replay_ratio: f64,
memory_selection: MemorySelectionStrategy,
},
ParameterIsolation {
allocation_strategy: ParameterAllocationStrategy,
growth_threshold: f64,
},
GradientEpisodicMemory {
memory_strength: f64,
violation_threshold: f64,
},
LearningWithoutForgetting {
distillation_weight: f64,
temperature: f64,
},
QuantumRegularization {
entanglement_preservation: f64,
parameter_drift_penalty: f64,
},
}
#[derive(Debug, Clone)]
pub enum MemorySelectionStrategy {
Random,
GradientImportance,
Uncertainty,
Diversity,
QuantumMetrics,
}
#[derive(Debug, Clone)]
pub enum ParameterAllocationStrategy {
Expansion,
Masking,
Hierarchical,
QuantumAware,
}
#[derive(Debug, Clone)]
pub struct ContinualTask {
pub task_id: String,
pub task_type: TaskType,
pub train_data: Array2<f64>,
pub train_labels: Array1<usize>,
pub val_data: Array2<f64>,
pub val_labels: Array1<usize>,
pub num_classes: usize,
pub metadata: HashMap<String, f64>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TaskType {
Classification { num_classes: usize },
Regression { output_dim: usize },
StatePreparation { target_states: usize },
Optimization { problem_type: String },
}
#[derive(Debug, Clone)]
pub struct MemoryBuffer {
experiences: VecDeque<Experience>,
max_size: usize,
selection_strategy: MemorySelectionStrategy,
task_memories: HashMap<String, Vec<usize>>,
}
#[derive(Debug, Clone)]
pub struct Experience {
pub input: Array1<f64>,
pub target: Array1<f64>,
pub task_id: String,
pub importance: f64,
pub gradient_info: Option<Array1<f64>>,
pub uncertainty: Option<f64>,
}
pub struct QuantumContinualLearner {
model: QuantumNeuralNetwork,
strategy: ContinualLearningStrategy,
task_history: Vec<ContinualTask>,
current_task: Option<usize>,
memory_buffer: Option<MemoryBuffer>,
fisher_information: Option<Array1<f64>>,
previous_parameters: Option<Array1<f64>>,
progressive_modules: Vec<QuantumNeuralNetwork>,
parameter_masks: HashMap<String, Array1<bool>>,
task_metrics: HashMap<String, TaskMetrics>,
forgetting_metrics: ForgettingMetrics,
}
#[derive(Debug, Clone)]
pub struct TaskMetrics {
pub current_accuracy: f64,
pub retained_accuracy: f64,
pub learning_speed: usize,
pub backward_transfer: f64,
pub forward_transfer: f64,
}
#[derive(Debug, Clone)]
pub struct ForgettingMetrics {
pub average_accuracy: f64,
pub forgetting_measure: f64,
pub backward_transfer: f64,
pub forward_transfer: f64,
pub continual_learning_score: f64,
pub per_task_forgetting: HashMap<String, f64>,
}
impl QuantumContinualLearner {
pub fn new(model: QuantumNeuralNetwork, strategy: ContinualLearningStrategy) -> Self {
let memory_buffer = match &strategy {
ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => Some(
MemoryBuffer::new(*buffer_size, MemorySelectionStrategy::Random),
),
ContinualLearningStrategy::GradientEpisodicMemory { .. } => Some(MemoryBuffer::new(
1000,
MemorySelectionStrategy::GradientImportance,
)),
_ => None,
};
Self {
model,
strategy,
task_history: Vec::new(),
current_task: None,
memory_buffer,
fisher_information: None,
previous_parameters: None,
progressive_modules: Vec::new(),
parameter_masks: HashMap::new(),
task_metrics: HashMap::new(),
forgetting_metrics: ForgettingMetrics {
average_accuracy: 0.0,
forgetting_measure: 0.0,
backward_transfer: 0.0,
forward_transfer: 0.0,
continual_learning_score: 0.0,
per_task_forgetting: HashMap::new(),
},
}
}
pub fn learn_task(
&mut self,
task: ContinualTask,
optimizer: &mut dyn Optimizer,
epochs: usize,
) -> Result<TaskMetrics> {
println!("Learning task: {}", task.task_id);
self.task_history.push(task.clone());
self.current_task = Some(self.task_history.len() - 1);
self.apply_pre_training_strategy(&task)?;
let start_time = std::time::Instant::now();
let learning_losses = self.train_on_task(&task, optimizer, epochs)?;
let learning_time = start_time.elapsed();
self.apply_post_training_strategy(&task)?;
let current_accuracy = self.evaluate_task(&task)?;
if self.memory_buffer.is_some() {
let mut buffer = self
.memory_buffer
.take()
.expect("memory_buffer verified to be Some above");
self.update_memory_buffer(&mut buffer, &task)?;
self.memory_buffer = Some(buffer);
}
let task_metrics = TaskMetrics {
current_accuracy,
retained_accuracy: current_accuracy, learning_speed: epochs, backward_transfer: 0.0, forward_transfer: 0.0, };
self.task_metrics
.insert(task.task_id.clone(), task_metrics.clone());
self.update_forgetting_metrics()?;
println!(
"Task {} learned with accuracy: {:.3}",
task.task_id, current_accuracy
);
Ok(task_metrics)
}
fn train_on_task(
&mut self,
task: &ContinualTask,
optimizer: &mut dyn Optimizer,
epochs: usize,
) -> Result<Vec<f64>> {
let mut losses = Vec::new();
let batch_size = 32;
for epoch in 0..epochs {
let mut epoch_loss = 0.0;
let num_batches = (task.train_data.nrows() + batch_size - 1) / batch_size;
for batch_idx in 0..num_batches {
let batch_start = batch_idx * batch_size;
let batch_end = (batch_start + batch_size).min(task.train_data.nrows());
let batch_data = task
.train_data
.slice(s![batch_start..batch_end, ..])
.to_owned();
let batch_labels = task
.train_labels
.slice(s![batch_start..batch_end])
.to_owned();
let (final_data, final_labels) =
self.create_training_batch(&batch_data, &batch_labels, task)?;
let batch_loss = self.compute_continual_loss(&final_data, &final_labels, task)?;
epoch_loss += batch_loss;
}
epoch_loss /= num_batches as f64;
losses.push(epoch_loss);
if epoch % 10 == 0 {
println!(" Epoch {}: Loss = {:.4}", epoch, epoch_loss);
}
}
Ok(losses)
}
fn apply_pre_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
let strategy = self.strategy.clone();
match strategy {
ContinualLearningStrategy::ElasticWeightConsolidation { .. } => {
if !self.task_history.is_empty() {
self.previous_parameters = Some(self.model.parameters.clone());
self.compute_fisher_information()?;
}
}
ContinualLearningStrategy::ProgressiveNetworks {
lateral_connections,
adaptation_layers,
} => {
self.create_progressive_column(adaptation_layers)?;
}
ContinualLearningStrategy::ParameterIsolation {
allocation_strategy,
..
} => {
self.allocate_parameters_for_task(task, &allocation_strategy)?;
}
_ => {}
}
Ok(())
}
fn apply_post_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
match &self.strategy {
ContinualLearningStrategy::ExperienceReplay { .. } => {
}
ContinualLearningStrategy::GradientEpisodicMemory { .. } => {
self.compute_gradient_memory(task)?;
}
_ => {}
}
Ok(())
}
fn create_training_batch(
&self,
current_data: &Array2<f64>,
current_labels: &Array1<usize>,
task: &ContinualTask,
) -> Result<(Array2<f64>, Array1<usize>)> {
match &self.strategy {
ContinualLearningStrategy::ExperienceReplay { replay_ratio, .. } => {
if let Some(ref buffer) = self.memory_buffer {
let num_replay = (current_data.nrows() as f64 * replay_ratio) as usize;
let replay_experiences = buffer.sample(num_replay);
let mut combined_data = current_data.clone();
let mut combined_labels = current_labels.clone();
for experience in replay_experiences {
}
Ok((combined_data, combined_labels))
} else {
Ok((current_data.clone(), current_labels.clone()))
}
}
_ => Ok((current_data.clone(), current_labels.clone())),
}
}
fn compute_continual_loss(
&self,
data: &Array2<f64>,
labels: &Array1<usize>,
task: &ContinualTask,
) -> Result<f64> {
let mut total_loss = 0.0;
for (input, &label) in data.outer_iter().zip(labels.iter()) {
let output = self.model.forward(&input.to_owned())?;
total_loss += self.cross_entropy_loss(&output, label);
}
let base_loss = total_loss / data.nrows() as f64;
let regularization = match &self.strategy {
ContinualLearningStrategy::ElasticWeightConsolidation {
importance_weight, ..
} => self.compute_ewc_regularization(*importance_weight),
ContinualLearningStrategy::LearningWithoutForgetting {
distillation_weight,
temperature,
} => self.compute_lwf_regularization(*distillation_weight, *temperature, data)?,
ContinualLearningStrategy::QuantumRegularization {
entanglement_preservation,
parameter_drift_penalty,
} => self.compute_quantum_regularization(
*entanglement_preservation,
*parameter_drift_penalty,
),
_ => 0.0,
};
Ok(base_loss + regularization)
}
fn compute_ewc_regularization(&self, importance_weight: f64) -> f64 {
if let (Some(ref fisher), Some(ref prev_params)) =
(&self.fisher_information, &self.previous_parameters)
{
let param_diff = &self.model.parameters - prev_params;
let ewc_term = fisher * ¶m_diff.mapv(|x| x.powi(2));
importance_weight * ewc_term.sum() / 2.0
} else {
0.0
}
}
fn compute_lwf_regularization(
&self,
distillation_weight: f64,
temperature: f64,
data: &Array2<f64>,
) -> Result<f64> {
if self.task_history.len() <= 1 {
return Ok(0.0);
}
let mut distillation_loss = 0.0;
for input in data.outer_iter() {
let current_output = self.model.forward(&input.to_owned())?;
let teacher_output = current_output.clone();
let student_probs = self.softmax_with_temperature(¤t_output, temperature);
let teacher_probs = self.softmax_with_temperature(&teacher_output, temperature);
for (s, t) in student_probs.iter().zip(teacher_probs.iter()) {
if *t > 1e-10 {
distillation_loss += t * (t / s).ln();
}
}
}
Ok(distillation_weight * distillation_loss / data.nrows() as f64)
}
fn compute_quantum_regularization(
&self,
entanglement_preservation: f64,
parameter_drift_penalty: f64,
) -> f64 {
let mut regularization = 0.0;
if let Some(ref prev_params) = self.previous_parameters {
let param_diff = &self.model.parameters - prev_params;
let entanglement_penalty = param_diff.mapv(|x| x.abs()).sum();
regularization += entanglement_preservation * entanglement_penalty;
}
if let Some(ref prev_params) = self.previous_parameters {
let drift = (&self.model.parameters - prev_params)
.mapv(|x| x.powi(2))
.sum();
regularization += parameter_drift_penalty * drift;
}
regularization
}
fn compute_fisher_information(&mut self) -> Result<()> {
if let ContinualLearningStrategy::ElasticWeightConsolidation { fisher_samples, .. } =
&self.strategy
{
let mut fisher = Array1::zeros(self.model.parameters.len());
if let Some(current_task_idx) = self.current_task {
if current_task_idx > 0 {
let prev_task = &self.task_history[current_task_idx - 1];
for i in 0..*fisher_samples {
let idx = i % prev_task.train_data.nrows();
let input = prev_task.train_data.row(idx).to_owned();
let label = prev_task.train_labels[idx];
let gradient = self.compute_parameter_gradient(&input, label)?;
fisher = fisher + &gradient.mapv(|x| x.powi(2));
}
fisher = fisher / *fisher_samples as f64;
}
}
self.fisher_information = Some(fisher);
}
Ok(())
}
fn create_progressive_column(&mut self, adaptation_layers: usize) -> Result<()> {
let layers = vec![
QNNLayerType::EncodingLayer { num_features: 4 },
QNNLayerType::VariationalLayer { num_params: 6 },
];
let progressive_module = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
self.progressive_modules.push(progressive_module);
Ok(())
}
fn allocate_parameters_for_task(
&mut self,
task: &ContinualTask,
strategy: &ParameterAllocationStrategy,
) -> Result<()> {
match strategy {
ParameterAllocationStrategy::Masking => {
let mask = Array1::from_elem(self.model.parameters.len(), true);
self.parameter_masks.insert(task.task_id.clone(), mask);
}
ParameterAllocationStrategy::Expansion => {
}
_ => {}
}
Ok(())
}
fn compute_gradient_memory(&mut self, task: &ContinualTask) -> Result<()> {
if self.memory_buffer.is_some() {
let mut buffer = self
.memory_buffer
.take()
.expect("memory_buffer verified to be Some above");
for i in 0..task.train_data.nrows().min(100) {
let input = task.train_data.row(i).to_owned();
let label = task.train_labels[i];
let gradient = self.compute_parameter_gradient(&input, label)?;
let experience = Experience {
input,
target: Array1::from_elem(task.num_classes, 0.0), task_id: task.task_id.clone(),
importance: 1.0,
gradient_info: Some(gradient),
uncertainty: None,
};
buffer.add_experience(experience);
}
self.memory_buffer = Some(buffer);
}
Ok(())
}
fn update_memory_buffer(&self, buffer: &mut MemoryBuffer, task: &ContinualTask) -> Result<()> {
for i in 0..task.train_data.nrows() {
let input = task.train_data.row(i).to_owned();
let target = Array1::from_elem(task.num_classes, 0.0);
let experience = Experience {
input,
target,
task_id: task.task_id.clone(),
importance: 1.0,
gradient_info: None,
uncertainty: None,
};
buffer.add_experience(experience);
}
Ok(())
}
fn evaluate_task(&self, task: &ContinualTask) -> Result<f64> {
let mut correct = 0;
let total = task.val_data.nrows();
for (input, &label) in task.val_data.outer_iter().zip(task.val_labels.iter()) {
let output = self.model.forward(&input.to_owned())?;
let predicted = output
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
if predicted == label {
correct += 1;
}
}
Ok(correct as f64 / total as f64)
}
pub fn evaluate_all_tasks(&mut self) -> Result<HashMap<String, f64>> {
let mut accuracies = HashMap::new();
for task in &self.task_history {
let accuracy = self.evaluate_task(task)?;
accuracies.insert(task.task_id.clone(), accuracy);
if let Some(metrics) = self.task_metrics.get_mut(&task.task_id) {
metrics.retained_accuracy = accuracy;
}
}
Ok(accuracies)
}
fn update_forgetting_metrics(&mut self) -> Result<()> {
if self.task_history.is_empty() {
return Ok(());
}
let accuracies = self.evaluate_all_tasks()?;
let avg_accuracy = accuracies.values().sum::<f64>() / accuracies.len() as f64;
self.forgetting_metrics.average_accuracy = avg_accuracy;
let mut total_forgetting = 0.0;
let mut num_comparisons = 0;
for (task_id, metrics) in &self.task_metrics {
let current_acc = accuracies.get(task_id).unwrap_or(&0.0);
let original_acc = metrics.current_accuracy;
if original_acc > 0.0 {
let forgetting = (original_acc - current_acc).max(0.0);
total_forgetting += forgetting;
num_comparisons += 1;
self.forgetting_metrics
.per_task_forgetting
.insert(task_id.clone(), forgetting);
}
}
if num_comparisons > 0 {
self.forgetting_metrics.forgetting_measure = total_forgetting / num_comparisons as f64;
}
self.forgetting_metrics.continual_learning_score =
avg_accuracy - self.forgetting_metrics.forgetting_measure;
Ok(())
}
fn compute_parameter_gradient(&self, input: &Array1<f64>, label: usize) -> Result<Array1<f64>> {
Ok(Array1::zeros(self.model.parameters.len()))
}
fn cross_entropy_loss(&self, output: &Array1<f64>, label: usize) -> f64 {
let predicted_prob = output[label].max(1e-10);
-predicted_prob.ln()
}
fn softmax_with_temperature(&self, logits: &Array1<f64>, temperature: f64) -> Array1<f64> {
let scaled_logits = logits / temperature;
let max_logit = scaled_logits
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let exp_logits = scaled_logits.mapv(|x| (x - max_logit).exp());
let sum_exp = exp_logits.sum();
exp_logits / sum_exp
}
pub fn get_forgetting_metrics(&self) -> &ForgettingMetrics {
&self.forgetting_metrics
}
pub fn get_task_metrics(&self) -> &HashMap<String, TaskMetrics> {
&self.task_metrics
}
pub fn get_model(&self) -> &QuantumNeuralNetwork {
&self.model
}
pub fn reset(&mut self) {
self.task_history.clear();
self.current_task = None;
self.fisher_information = None;
self.previous_parameters = None;
self.progressive_modules.clear();
self.parameter_masks.clear();
self.task_metrics.clear();
if let Some(ref mut buffer) = self.memory_buffer {
buffer.clear();
}
}
}
impl MemoryBuffer {
pub fn new(max_size: usize, strategy: MemorySelectionStrategy) -> Self {
Self {
experiences: VecDeque::new(),
max_size,
selection_strategy: strategy,
task_memories: HashMap::new(),
}
}
pub fn add_experience(&mut self, experience: Experience) {
if self.experiences.len() >= self.max_size {
let removed = self
.experiences
.pop_front()
.expect("Buffer is non-empty when len >= max_size");
self.remove_from_task_index(&removed);
}
let experience_idx = self.experiences.len();
self.experiences.push_back(experience.clone());
self.task_memories
.entry(experience.task_id.clone())
.or_insert_with(Vec::new)
.push(experience_idx);
}
pub fn sample(&self, num_samples: usize) -> Vec<Experience> {
let mut samples = Vec::new();
let available = self.experiences.len().min(num_samples);
match self.selection_strategy {
MemorySelectionStrategy::Random => {
for _ in 0..available {
let idx = fastrand::usize(0..self.experiences.len());
samples.push(self.experiences[idx].clone());
}
}
MemorySelectionStrategy::GradientImportance => {
let mut indexed_experiences: Vec<_> = self.experiences.iter().enumerate().collect();
indexed_experiences.sort_by(|a, b| {
let importance_a = a.1.importance;
let importance_b = b.1.importance;
importance_b
.partial_cmp(&importance_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (_, experience) in indexed_experiences.into_iter().take(available) {
samples.push(experience.clone());
}
}
_ => {
for _ in 0..available {
let idx = fastrand::usize(0..self.experiences.len());
samples.push(self.experiences[idx].clone());
}
}
}
samples
}
fn remove_from_task_index(&mut self, experience: &Experience) {
if let Some(indices) = self.task_memories.get_mut(&experience.task_id) {
indices.clear();
}
}
pub fn clear(&mut self) {
self.experiences.clear();
self.task_memories.clear();
}
pub fn size(&self) -> usize {
self.experiences.len()
}
}
pub fn create_continual_task(
task_id: String,
task_type: TaskType,
data: Array2<f64>,
labels: Array1<usize>,
train_ratio: f64,
) -> ContinualTask {
let train_size = (data.nrows() as f64 * train_ratio) as usize;
let train_data = data.slice(s![0..train_size, ..]).to_owned();
let train_labels = labels.slice(s![0..train_size]).to_owned();
let val_data = data.slice(s![train_size.., ..]).to_owned();
let val_labels = labels.slice(s![train_size..]).to_owned();
let num_classes = labels.iter().max().unwrap_or(&0) + 1;
ContinualTask {
task_id,
task_type,
train_data,
train_labels,
val_data,
val_labels,
num_classes,
metadata: HashMap::new(),
}
}
pub fn generate_task_sequence(
num_tasks: usize,
samples_per_task: usize,
feature_dim: usize,
) -> Vec<ContinualTask> {
let mut tasks = Vec::new();
for i in 0..num_tasks {
let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
let task_shift = i as f64 * 0.5;
let base_value = row as f64 / samples_per_task as f64 + col as f64 / feature_dim as f64;
0.5 + 0.3 * (base_value * 2.0 * PI + task_shift).sin() + 0.1 * (fastrand::f64() - 0.5)
});
let labels = Array1::from_shape_fn(samples_per_task, |row| {
let sum = data.row(row).sum();
if sum > feature_dim as f64 * 0.5 {
1
} else {
0
}
});
let task = create_continual_task(
format!("task_{}", i),
TaskType::Classification { num_classes: 2 },
data,
labels,
0.8, );
tasks.push(task);
}
tasks
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autodiff::optimizers::Adam;
use crate::qnn::QNNLayerType;
#[test]
fn test_memory_buffer() {
let mut buffer = MemoryBuffer::new(5, MemorySelectionStrategy::Random);
for i in 0..10 {
let experience = Experience {
input: Array1::from_vec(vec![i as f64]),
target: Array1::from_vec(vec![(i % 2) as f64]),
task_id: format!("task_{}", i / 3),
importance: i as f64,
gradient_info: None,
uncertainty: None,
};
buffer.add_experience(experience);
}
assert_eq!(buffer.size(), 5);
let samples = buffer.sample(3);
assert_eq!(samples.len(), 3);
}
#[test]
fn test_continual_task_creation() {
let data = Array2::from_shape_fn((100, 4), |(i, j)| (i as f64 + j as f64) / 50.0);
let labels = Array1::from_shape_fn(100, |i| i % 3);
let task = create_continual_task(
"test_task".to_string(),
TaskType::Classification { num_classes: 3 },
data,
labels,
0.7,
);
assert_eq!(task.task_id, "test_task");
assert_eq!(task.train_data.nrows(), 70);
assert_eq!(task.val_data.nrows(), 30);
assert_eq!(task.num_classes, 3);
}
#[test]
fn test_continual_learner_creation() {
let layers = vec![
QNNLayerType::EncodingLayer { num_features: 4 },
QNNLayerType::VariationalLayer { num_params: 8 },
QNNLayerType::MeasurementLayer {
measurement_basis: "computational".to_string(),
},
];
let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).expect("Failed to create model");
let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
importance_weight: 1000.0,
fisher_samples: 100,
};
let learner = QuantumContinualLearner::new(model, strategy);
assert_eq!(learner.task_history.len(), 0);
assert!(learner.current_task.is_none());
}
#[test]
fn test_task_sequence_generation() {
let tasks = generate_task_sequence(3, 50, 4);
assert_eq!(tasks.len(), 3);
for (i, task) in tasks.iter().enumerate() {
assert_eq!(task.task_id, format!("task_{}", i));
assert_eq!(task.train_data.nrows(), 40); assert_eq!(task.val_data.nrows(), 10); assert_eq!(task.train_data.ncols(), 4);
}
}
#[test]
fn test_forgetting_metrics() {
let metrics = ForgettingMetrics {
average_accuracy: 0.75,
forgetting_measure: 0.15,
backward_transfer: 0.05,
forward_transfer: 0.1,
continual_learning_score: 0.6,
per_task_forgetting: HashMap::new(),
};
assert_eq!(metrics.average_accuracy, 0.75);
assert_eq!(metrics.forgetting_measure, 0.15);
assert!(metrics.continual_learning_score > 0.5);
}
}