quantrs2_ml/
continual_learning.rs

1//! Quantum Continual Learning
2//!
3//! This module implements continual learning algorithms for quantum neural networks,
4//! enabling models to learn new tasks sequentially while preserving knowledge from
5//! previous tasks and avoiding catastrophic forgetting.
6
7use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13    single::{RotationX, RotationY, RotationZ},
14    GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::f64::consts::PI;
20
21/// Continual learning strategies for quantum models
22#[derive(Debug, Clone)]
23pub enum ContinualLearningStrategy {
24    /// Elastic Weight Consolidation (EWC) for quantum circuits
25    ElasticWeightConsolidation {
26        importance_weight: f64,
27        fisher_samples: usize,
28    },
29
30    /// Progressive Neural Networks with quantum modules
31    ProgressiveNetworks {
32        lateral_connections: bool,
33        adaptation_layers: usize,
34    },
35
36    /// Memory replay with episodic buffer
37    ExperienceReplay {
38        buffer_size: usize,
39        replay_ratio: f64,
40        memory_selection: MemorySelectionStrategy,
41    },
42
43    /// Parameter isolation and expansion
44    ParameterIsolation {
45        allocation_strategy: ParameterAllocationStrategy,
46        growth_threshold: f64,
47    },
48
49    /// Gradient episodic memory
50    GradientEpisodicMemory {
51        memory_strength: f64,
52        violation_threshold: f64,
53    },
54
55    /// Learning without forgetting (LwF)
56    LearningWithoutForgetting {
57        distillation_weight: f64,
58        temperature: f64,
59    },
60
61    /// Quantum-specific regularization
62    QuantumRegularization {
63        entanglement_preservation: f64,
64        parameter_drift_penalty: f64,
65    },
66}
67
68/// Memory selection strategies for experience replay
69#[derive(Debug, Clone)]
70pub enum MemorySelectionStrategy {
71    /// Random sampling
72    Random,
73    /// Gradient-based importance
74    GradientImportance,
75    /// Uncertainty-based selection
76    Uncertainty,
77    /// Diverse sampling
78    Diversity,
79    /// Quantum-specific metrics
80    QuantumMetrics,
81}
82
83/// Parameter allocation strategies
84#[derive(Debug, Clone)]
85pub enum ParameterAllocationStrategy {
86    /// Add new parameters for new tasks
87    Expansion,
88    /// Mask existing parameters for different tasks
89    Masking,
90    /// Hierarchical parameter sharing
91    Hierarchical,
92    /// Quantum-specific allocation
93    QuantumAware,
94}
95
96/// Task definition for continual learning
97#[derive(Debug, Clone)]
98pub struct ContinualTask {
99    /// Task identifier
100    pub task_id: String,
101
102    /// Task type/domain
103    pub task_type: TaskType,
104
105    /// Training data
106    pub train_data: Array2<f64>,
107
108    /// Training labels
109    pub train_labels: Array1<usize>,
110
111    /// Validation data
112    pub val_data: Array2<f64>,
113
114    /// Validation labels
115    pub val_labels: Array1<usize>,
116
117    /// Number of classes
118    pub num_classes: usize,
119
120    /// Task-specific metadata
121    pub metadata: HashMap<String, f64>,
122}
123
124/// Task types for continual learning
125#[derive(Debug, Clone, PartialEq)]
126pub enum TaskType {
127    /// Classification task
128    Classification { num_classes: usize },
129    /// Regression task
130    Regression { output_dim: usize },
131    /// Quantum state preparation
132    StatePreparation { target_states: usize },
133    /// Quantum optimization
134    Optimization { problem_type: String },
135}
136
137/// Memory buffer for experience replay
138#[derive(Debug, Clone)]
139pub struct MemoryBuffer {
140    /// Stored experiences
141    experiences: VecDeque<Experience>,
142
143    /// Maximum buffer size
144    max_size: usize,
145
146    /// Selection strategy
147    selection_strategy: MemorySelectionStrategy,
148
149    /// Task-wise organization
150    task_memories: HashMap<String, Vec<usize>>,
151}
152
153/// Individual experience/memory
154#[derive(Debug, Clone)]
155pub struct Experience {
156    /// Input data
157    pub input: Array1<f64>,
158
159    /// Target output
160    pub target: Array1<f64>,
161
162    /// Task identifier
163    pub task_id: String,
164
165    /// Importance score
166    pub importance: f64,
167
168    /// Gradient information (optional)
169    pub gradient_info: Option<Array1<f64>>,
170
171    /// Uncertainty measure
172    pub uncertainty: Option<f64>,
173}
174
175/// Quantum continual learner
176pub struct QuantumContinualLearner {
177    /// Base quantum model
178    model: QuantumNeuralNetwork,
179
180    /// Continual learning strategy
181    strategy: ContinualLearningStrategy,
182
183    /// Task sequence and history
184    task_history: Vec<ContinualTask>,
185
186    /// Current task index
187    current_task: Option<usize>,
188
189    /// Memory buffer
190    memory_buffer: Option<MemoryBuffer>,
191
192    /// Fisher information (for EWC)
193    fisher_information: Option<Array1<f64>>,
194
195    /// Previous task parameters (for EWC)
196    previous_parameters: Option<Array1<f64>>,
197
198    /// Progressive modules (for Progressive Networks)
199    progressive_modules: Vec<QuantumNeuralNetwork>,
200
201    /// Parameter masks (for Parameter Isolation)
202    parameter_masks: HashMap<String, Array1<bool>>,
203
204    /// Performance metrics per task
205    task_metrics: HashMap<String, TaskMetrics>,
206
207    /// Forgetting metrics
208    forgetting_metrics: ForgettingMetrics,
209}
210
211/// Metrics for individual tasks
212#[derive(Debug, Clone)]
213pub struct TaskMetrics {
214    /// Accuracy on current task
215    pub current_accuracy: f64,
216
217    /// Accuracy after learning subsequent tasks
218    pub retained_accuracy: f64,
219
220    /// Learning speed (epochs to convergence)
221    pub learning_speed: usize,
222
223    /// Backward transfer (improvement from future tasks)
224    pub backward_transfer: f64,
225
226    /// Forward transfer (help to future tasks)
227    pub forward_transfer: f64,
228}
229
230/// Overall forgetting and transfer metrics
231#[derive(Debug, Clone)]
232pub struct ForgettingMetrics {
233    /// Average accuracy across all seen tasks
234    pub average_accuracy: f64,
235
236    /// Catastrophic forgetting measure
237    pub forgetting_measure: f64,
238
239    /// Backward transfer coefficient
240    pub backward_transfer: f64,
241
242    /// Forward transfer coefficient
243    pub forward_transfer: f64,
244
245    /// Overall continual learning score
246    pub continual_learning_score: f64,
247
248    /// Per-task forgetting
249    pub per_task_forgetting: HashMap<String, f64>,
250}
251
252impl QuantumContinualLearner {
253    /// Create a new quantum continual learner
254    pub fn new(model: QuantumNeuralNetwork, strategy: ContinualLearningStrategy) -> Self {
255        let memory_buffer = match &strategy {
256            ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => Some(
257                MemoryBuffer::new(*buffer_size, MemorySelectionStrategy::Random),
258            ),
259            ContinualLearningStrategy::GradientEpisodicMemory { .. } => Some(MemoryBuffer::new(
260                1000,
261                MemorySelectionStrategy::GradientImportance,
262            )),
263            _ => None,
264        };
265
266        Self {
267            model,
268            strategy,
269            task_history: Vec::new(),
270            current_task: None,
271            memory_buffer,
272            fisher_information: None,
273            previous_parameters: None,
274            progressive_modules: Vec::new(),
275            parameter_masks: HashMap::new(),
276            task_metrics: HashMap::new(),
277            forgetting_metrics: ForgettingMetrics {
278                average_accuracy: 0.0,
279                forgetting_measure: 0.0,
280                backward_transfer: 0.0,
281                forward_transfer: 0.0,
282                continual_learning_score: 0.0,
283                per_task_forgetting: HashMap::new(),
284            },
285        }
286    }
287
288    /// Learn a new task
289    pub fn learn_task(
290        &mut self,
291        task: ContinualTask,
292        optimizer: &mut dyn Optimizer,
293        epochs: usize,
294    ) -> Result<TaskMetrics> {
295        println!("Learning task: {}", task.task_id);
296
297        // Store task in history
298        self.task_history.push(task.clone());
299        self.current_task = Some(self.task_history.len() - 1);
300
301        // Apply continual learning strategy before training
302        self.apply_pre_training_strategy(&task)?;
303
304        // Train on the new task
305        let start_time = std::time::Instant::now();
306        let learning_losses = self.train_on_task(&task, optimizer, epochs)?;
307        let learning_time = start_time.elapsed();
308
309        // Apply post-training strategy
310        self.apply_post_training_strategy(&task)?;
311
312        // Evaluate on current task
313        let current_accuracy = self.evaluate_task(&task)?;
314
315        // Update memory buffer if applicable
316        if self.memory_buffer.is_some() {
317            let mut buffer = self
318                .memory_buffer
319                .take()
320                .expect("memory_buffer verified to be Some above");
321            self.update_memory_buffer(&mut buffer, &task)?;
322            self.memory_buffer = Some(buffer);
323        }
324
325        // Compute task metrics
326        let task_metrics = TaskMetrics {
327            current_accuracy,
328            retained_accuracy: current_accuracy, // Will be updated later
329            learning_speed: epochs,              // Simplified - could track convergence
330            backward_transfer: 0.0,              // Will be computed later
331            forward_transfer: 0.0,               // Will be computed when future tasks are learned
332        };
333
334        self.task_metrics
335            .insert(task.task_id.clone(), task_metrics.clone());
336
337        // Update overall metrics
338        self.update_forgetting_metrics()?;
339
340        println!(
341            "Task {} learned with accuracy: {:.3}",
342            task.task_id, current_accuracy
343        );
344
345        Ok(task_metrics)
346    }
347
348    /// Train on a specific task
349    fn train_on_task(
350        &mut self,
351        task: &ContinualTask,
352        optimizer: &mut dyn Optimizer,
353        epochs: usize,
354    ) -> Result<Vec<f64>> {
355        let mut losses = Vec::new();
356        let batch_size = 32;
357
358        for epoch in 0..epochs {
359            let mut epoch_loss = 0.0;
360            let num_batches = (task.train_data.nrows() + batch_size - 1) / batch_size;
361
362            for batch_idx in 0..num_batches {
363                let batch_start = batch_idx * batch_size;
364                let batch_end = (batch_start + batch_size).min(task.train_data.nrows());
365
366                let batch_data = task
367                    .train_data
368                    .slice(s![batch_start..batch_end, ..])
369                    .to_owned();
370                let batch_labels = task
371                    .train_labels
372                    .slice(s![batch_start..batch_end])
373                    .to_owned();
374
375                // Create combined training batch with replay if applicable
376                let (final_data, final_labels) =
377                    self.create_training_batch(&batch_data, &batch_labels, task)?;
378
379                // Compute loss with continual learning regularization
380                let batch_loss = self.compute_continual_loss(&final_data, &final_labels, task)?;
381                epoch_loss += batch_loss;
382
383                // Update model parameters (simplified)
384                // In practice, this would use proper backpropagation
385            }
386
387            epoch_loss /= num_batches as f64;
388            losses.push(epoch_loss);
389
390            if epoch % 10 == 0 {
391                println!("  Epoch {}: Loss = {:.4}", epoch, epoch_loss);
392            }
393        }
394
395        Ok(losses)
396    }
397
398    /// Apply pre-training strategy
399    fn apply_pre_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
400        let strategy = self.strategy.clone();
401        match strategy {
402            ContinualLearningStrategy::ElasticWeightConsolidation { .. } => {
403                if !self.task_history.is_empty() {
404                    // Store current parameters and compute Fisher information
405                    self.previous_parameters = Some(self.model.parameters.clone());
406                    self.compute_fisher_information()?;
407                }
408            }
409
410            ContinualLearningStrategy::ProgressiveNetworks {
411                lateral_connections,
412                adaptation_layers,
413            } => {
414                // Create new column for the new task
415                self.create_progressive_column(adaptation_layers)?;
416            }
417
418            ContinualLearningStrategy::ParameterIsolation {
419                allocation_strategy,
420                ..
421            } => {
422                // Allocate parameters for the new task
423                self.allocate_parameters_for_task(task, &allocation_strategy)?;
424            }
425
426            _ => {}
427        }
428
429        Ok(())
430    }
431
432    /// Apply post-training strategy
433    fn apply_post_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
434        match &self.strategy {
435            ContinualLearningStrategy::ExperienceReplay { .. } => {
436                // Memory buffer already updated during training
437            }
438
439            ContinualLearningStrategy::GradientEpisodicMemory { .. } => {
440                // Compute and store gradient information
441                self.compute_gradient_memory(task)?;
442            }
443
444            _ => {}
445        }
446
447        Ok(())
448    }
449
450    /// Create training batch with replay if applicable
451    fn create_training_batch(
452        &self,
453        current_data: &Array2<f64>,
454        current_labels: &Array1<usize>,
455        task: &ContinualTask,
456    ) -> Result<(Array2<f64>, Array1<usize>)> {
457        match &self.strategy {
458            ContinualLearningStrategy::ExperienceReplay { replay_ratio, .. } => {
459                if let Some(ref buffer) = self.memory_buffer {
460                    let num_replay = (current_data.nrows() as f64 * replay_ratio) as usize;
461                    let replay_experiences = buffer.sample(num_replay);
462
463                    // Combine current and replay data
464                    let mut combined_data = current_data.clone();
465                    let mut combined_labels = current_labels.clone();
466
467                    for experience in replay_experiences {
468                        // Add replay data (simplified)
469                        // In practice, would properly combine arrays
470                    }
471
472                    Ok((combined_data, combined_labels))
473                } else {
474                    Ok((current_data.clone(), current_labels.clone()))
475                }
476            }
477            _ => Ok((current_data.clone(), current_labels.clone())),
478        }
479    }
480
481    /// Compute continual learning loss with regularization
482    fn compute_continual_loss(
483        &self,
484        data: &Array2<f64>,
485        labels: &Array1<usize>,
486        task: &ContinualTask,
487    ) -> Result<f64> {
488        // Base loss (simplified)
489        let mut total_loss = 0.0;
490
491        for (input, &label) in data.outer_iter().zip(labels.iter()) {
492            let output = self.model.forward(&input.to_owned())?;
493            total_loss += self.cross_entropy_loss(&output, label);
494        }
495
496        let base_loss = total_loss / data.nrows() as f64;
497
498        // Add continual learning regularization
499        let regularization = match &self.strategy {
500            ContinualLearningStrategy::ElasticWeightConsolidation {
501                importance_weight, ..
502            } => self.compute_ewc_regularization(*importance_weight),
503
504            ContinualLearningStrategy::LearningWithoutForgetting {
505                distillation_weight,
506                temperature,
507            } => self.compute_lwf_regularization(*distillation_weight, *temperature, data)?,
508
509            ContinualLearningStrategy::QuantumRegularization {
510                entanglement_preservation,
511                parameter_drift_penalty,
512            } => self.compute_quantum_regularization(
513                *entanglement_preservation,
514                *parameter_drift_penalty,
515            ),
516
517            _ => 0.0,
518        };
519
520        Ok(base_loss + regularization)
521    }
522
523    /// Compute EWC regularization term
524    fn compute_ewc_regularization(&self, importance_weight: f64) -> f64 {
525        if let (Some(ref fisher), Some(ref prev_params)) =
526            (&self.fisher_information, &self.previous_parameters)
527        {
528            let param_diff = &self.model.parameters - prev_params;
529            let ewc_term = fisher * &param_diff.mapv(|x| x.powi(2));
530            importance_weight * ewc_term.sum() / 2.0
531        } else {
532            0.0
533        }
534    }
535
536    /// Compute Learning without Forgetting regularization
537    fn compute_lwf_regularization(
538        &self,
539        distillation_weight: f64,
540        temperature: f64,
541        data: &Array2<f64>,
542    ) -> Result<f64> {
543        if self.task_history.len() <= 1 {
544            return Ok(0.0);
545        }
546
547        // Compute distillation loss (simplified)
548        let mut distillation_loss = 0.0;
549
550        for input in data.outer_iter() {
551            let current_output = self.model.forward(&input.to_owned())?;
552
553            // Get "teacher" output from previous model state (simplified)
554            // In practice, would store previous model or compute with masked parameters
555            let teacher_output = current_output.clone(); // Placeholder
556
557            // Compute KL divergence with temperature scaling
558            let student_probs = self.softmax_with_temperature(&current_output, temperature);
559            let teacher_probs = self.softmax_with_temperature(&teacher_output, temperature);
560
561            for (s, t) in student_probs.iter().zip(teacher_probs.iter()) {
562                if *t > 1e-10 {
563                    distillation_loss += t * (t / s).ln();
564                }
565            }
566        }
567
568        Ok(distillation_weight * distillation_loss / data.nrows() as f64)
569    }
570
571    /// Compute quantum-specific regularization
572    fn compute_quantum_regularization(
573        &self,
574        entanglement_preservation: f64,
575        parameter_drift_penalty: f64,
576    ) -> f64 {
577        let mut regularization = 0.0;
578
579        // Entanglement preservation penalty
580        if let Some(ref prev_params) = self.previous_parameters {
581            let param_diff = &self.model.parameters - prev_params;
582
583            // Penalize changes that might reduce entanglement capability
584            let entanglement_penalty = param_diff.mapv(|x| x.abs()).sum();
585            regularization += entanglement_preservation * entanglement_penalty;
586        }
587
588        // Parameter drift penalty (encourage small changes)
589        if let Some(ref prev_params) = self.previous_parameters {
590            let drift = (&self.model.parameters - prev_params)
591                .mapv(|x| x.powi(2))
592                .sum();
593            regularization += parameter_drift_penalty * drift;
594        }
595
596        regularization
597    }
598
599    /// Compute Fisher information matrix for EWC
600    fn compute_fisher_information(&mut self) -> Result<()> {
601        if let ContinualLearningStrategy::ElasticWeightConsolidation { fisher_samples, .. } =
602            &self.strategy
603        {
604            let mut fisher = Array1::zeros(self.model.parameters.len());
605
606            // Sample data from previous tasks for Fisher computation
607            if let Some(current_task_idx) = self.current_task {
608                if current_task_idx > 0 {
609                    // Use previous task data (simplified)
610                    let prev_task = &self.task_history[current_task_idx - 1];
611
612                    for i in 0..*fisher_samples {
613                        let idx = i % prev_task.train_data.nrows();
614                        let input = prev_task.train_data.row(idx).to_owned();
615                        let label = prev_task.train_labels[idx];
616
617                        // Compute gradient (simplified - would use automatic differentiation)
618                        let gradient = self.compute_parameter_gradient(&input, label)?;
619                        fisher = fisher + &gradient.mapv(|x| x.powi(2));
620                    }
621
622                    fisher = fisher / *fisher_samples as f64;
623                }
624            }
625
626            self.fisher_information = Some(fisher);
627        }
628
629        Ok(())
630    }
631
632    /// Create progressive network column
633    fn create_progressive_column(&mut self, adaptation_layers: usize) -> Result<()> {
634        // Create a new small network for the new task
635        let layers = vec![
636            QNNLayerType::EncodingLayer { num_features: 4 },
637            QNNLayerType::VariationalLayer { num_params: 6 },
638        ];
639
640        let progressive_module = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
641        self.progressive_modules.push(progressive_module);
642
643        Ok(())
644    }
645
646    /// Allocate parameters for new task
647    fn allocate_parameters_for_task(
648        &mut self,
649        task: &ContinualTask,
650        strategy: &ParameterAllocationStrategy,
651    ) -> Result<()> {
652        match strategy {
653            ParameterAllocationStrategy::Masking => {
654                // Create mask for this task
655                let mask = Array1::from_elem(self.model.parameters.len(), true);
656                // In practice, would compute optimal mask
657                self.parameter_masks.insert(task.task_id.clone(), mask);
658            }
659
660            ParameterAllocationStrategy::Expansion => {
661                // Expand model capacity if needed
662                // This would require modifying the model architecture
663            }
664
665            _ => {}
666        }
667
668        Ok(())
669    }
670
671    /// Compute gradient memory for GEM
672    fn compute_gradient_memory(&mut self, task: &ContinualTask) -> Result<()> {
673        if self.memory_buffer.is_some() {
674            let mut buffer = self
675                .memory_buffer
676                .take()
677                .expect("memory_buffer verified to be Some above");
678
679            // Store representative examples with their gradients
680            for i in 0..task.train_data.nrows().min(100) {
681                let input = task.train_data.row(i).to_owned();
682                let label = task.train_labels[i];
683
684                let gradient = self.compute_parameter_gradient(&input, label)?;
685
686                let experience = Experience {
687                    input,
688                    target: Array1::from_elem(task.num_classes, 0.0), // Simplified
689                    task_id: task.task_id.clone(),
690                    importance: 1.0,
691                    gradient_info: Some(gradient),
692                    uncertainty: None,
693                };
694
695                buffer.add_experience(experience);
696            }
697
698            self.memory_buffer = Some(buffer);
699        }
700
701        Ok(())
702    }
703
704    /// Update memory buffer with new experiences
705    fn update_memory_buffer(&self, buffer: &mut MemoryBuffer, task: &ContinualTask) -> Result<()> {
706        // Add experiences from the new task
707        for i in 0..task.train_data.nrows() {
708            let input = task.train_data.row(i).to_owned();
709            let target = Array1::from_elem(task.num_classes, 0.0); // Simplified encoding
710
711            let experience = Experience {
712                input,
713                target,
714                task_id: task.task_id.clone(),
715                importance: 1.0,
716                gradient_info: None,
717                uncertainty: None,
718            };
719
720            buffer.add_experience(experience);
721        }
722
723        Ok(())
724    }
725
726    /// Evaluate model on a specific task
727    fn evaluate_task(&self, task: &ContinualTask) -> Result<f64> {
728        let mut correct = 0;
729        let total = task.val_data.nrows();
730
731        for (input, &label) in task.val_data.outer_iter().zip(task.val_labels.iter()) {
732            let output = self.model.forward(&input.to_owned())?;
733            let predicted = output
734                .iter()
735                .enumerate()
736                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
737                .map(|(i, _)| i)
738                .unwrap_or(0);
739
740            if predicted == label {
741                correct += 1;
742            }
743        }
744
745        Ok(correct as f64 / total as f64)
746    }
747
748    /// Evaluate all previous tasks to measure forgetting
749    pub fn evaluate_all_tasks(&mut self) -> Result<HashMap<String, f64>> {
750        let mut accuracies = HashMap::new();
751
752        for task in &self.task_history {
753            let accuracy = self.evaluate_task(task)?;
754            accuracies.insert(task.task_id.clone(), accuracy);
755
756            // Update retained accuracy in task metrics
757            if let Some(metrics) = self.task_metrics.get_mut(&task.task_id) {
758                metrics.retained_accuracy = accuracy;
759            }
760        }
761
762        Ok(accuracies)
763    }
764
765    /// Update forgetting metrics
766    fn update_forgetting_metrics(&mut self) -> Result<()> {
767        if self.task_history.is_empty() {
768            return Ok(());
769        }
770
771        // Evaluate all tasks
772        let accuracies = self.evaluate_all_tasks()?;
773
774        // Compute average accuracy
775        let avg_accuracy = accuracies.values().sum::<f64>() / accuracies.len() as f64;
776        self.forgetting_metrics.average_accuracy = avg_accuracy;
777
778        // Compute forgetting measure
779        let mut total_forgetting = 0.0;
780        let mut num_comparisons = 0;
781
782        for (task_id, metrics) in &self.task_metrics {
783            let current_acc = accuracies.get(task_id).unwrap_or(&0.0);
784            let original_acc = metrics.current_accuracy;
785
786            if original_acc > 0.0 {
787                let forgetting = (original_acc - current_acc).max(0.0);
788                total_forgetting += forgetting;
789                num_comparisons += 1;
790
791                self.forgetting_metrics
792                    .per_task_forgetting
793                    .insert(task_id.clone(), forgetting);
794            }
795        }
796
797        if num_comparisons > 0 {
798            self.forgetting_metrics.forgetting_measure = total_forgetting / num_comparisons as f64;
799        }
800
801        // Compute continual learning score (simplified)
802        self.forgetting_metrics.continual_learning_score =
803            avg_accuracy - self.forgetting_metrics.forgetting_measure;
804
805        Ok(())
806    }
807
808    /// Compute parameter gradient (simplified)
809    fn compute_parameter_gradient(&self, input: &Array1<f64>, label: usize) -> Result<Array1<f64>> {
810        // Placeholder for gradient computation
811        // In practice, would use automatic differentiation
812        Ok(Array1::zeros(self.model.parameters.len()))
813    }
814
815    /// Cross-entropy loss
816    fn cross_entropy_loss(&self, output: &Array1<f64>, label: usize) -> f64 {
817        let predicted_prob = output[label].max(1e-10);
818        -predicted_prob.ln()
819    }
820
821    /// Softmax with temperature
822    fn softmax_with_temperature(&self, logits: &Array1<f64>, temperature: f64) -> Array1<f64> {
823        let scaled_logits = logits / temperature;
824        let max_logit = scaled_logits
825            .iter()
826            .cloned()
827            .fold(f64::NEG_INFINITY, f64::max);
828        let exp_logits = scaled_logits.mapv(|x| (x - max_logit).exp());
829        let sum_exp = exp_logits.sum();
830        exp_logits / sum_exp
831    }
832
833    /// Get forgetting metrics
834    pub fn get_forgetting_metrics(&self) -> &ForgettingMetrics {
835        &self.forgetting_metrics
836    }
837
838    /// Get task metrics
839    pub fn get_task_metrics(&self) -> &HashMap<String, TaskMetrics> {
840        &self.task_metrics
841    }
842
843    /// Get current model
844    pub fn get_model(&self) -> &QuantumNeuralNetwork {
845        &self.model
846    }
847
848    /// Reset for new task sequence
849    pub fn reset(&mut self) {
850        self.task_history.clear();
851        self.current_task = None;
852        self.fisher_information = None;
853        self.previous_parameters = None;
854        self.progressive_modules.clear();
855        self.parameter_masks.clear();
856        self.task_metrics.clear();
857
858        if let Some(ref mut buffer) = self.memory_buffer {
859            buffer.clear();
860        }
861    }
862}
863
864impl MemoryBuffer {
865    /// Create new memory buffer
866    pub fn new(max_size: usize, strategy: MemorySelectionStrategy) -> Self {
867        Self {
868            experiences: VecDeque::new(),
869            max_size,
870            selection_strategy: strategy,
871            task_memories: HashMap::new(),
872        }
873    }
874
875    /// Add experience to buffer
876    pub fn add_experience(&mut self, experience: Experience) {
877        // Add to main buffer
878        if self.experiences.len() >= self.max_size {
879            let removed = self
880                .experiences
881                .pop_front()
882                .expect("Buffer is non-empty when len >= max_size");
883            self.remove_from_task_index(&removed);
884        }
885
886        let experience_idx = self.experiences.len();
887        self.experiences.push_back(experience.clone());
888
889        // Update task index
890        self.task_memories
891            .entry(experience.task_id.clone())
892            .or_insert_with(Vec::new)
893            .push(experience_idx);
894    }
895
896    /// Sample experiences from buffer
897    pub fn sample(&self, num_samples: usize) -> Vec<Experience> {
898        let mut samples = Vec::new();
899
900        let available = self.experiences.len().min(num_samples);
901
902        match self.selection_strategy {
903            MemorySelectionStrategy::Random => {
904                for _ in 0..available {
905                    let idx = fastrand::usize(0..self.experiences.len());
906                    samples.push(self.experiences[idx].clone());
907                }
908            }
909
910            MemorySelectionStrategy::GradientImportance => {
911                // Sort by gradient importance and sample top experiences
912                let mut indexed_experiences: Vec<_> = self.experiences.iter().enumerate().collect();
913
914                indexed_experiences.sort_by(|a, b| {
915                    let importance_a = a.1.importance;
916                    let importance_b = b.1.importance;
917                    importance_b
918                        .partial_cmp(&importance_a)
919                        .unwrap_or(std::cmp::Ordering::Equal)
920                });
921
922                for (_, experience) in indexed_experiences.into_iter().take(available) {
923                    samples.push(experience.clone());
924                }
925            }
926
927            _ => {
928                // Fallback to random sampling
929                for _ in 0..available {
930                    let idx = fastrand::usize(0..self.experiences.len());
931                    samples.push(self.experiences[idx].clone());
932                }
933            }
934        }
935
936        samples
937    }
938
939    /// Remove experience from task index
940    fn remove_from_task_index(&mut self, experience: &Experience) {
941        if let Some(indices) = self.task_memories.get_mut(&experience.task_id) {
942            // This is simplified - in practice would need to update all indices
943            indices.clear();
944        }
945    }
946
947    /// Clear buffer
948    pub fn clear(&mut self) {
949        self.experiences.clear();
950        self.task_memories.clear();
951    }
952
953    /// Get buffer size
954    pub fn size(&self) -> usize {
955        self.experiences.len()
956    }
957}
958
959/// Helper function to create a simple continual task
960pub fn create_continual_task(
961    task_id: String,
962    task_type: TaskType,
963    data: Array2<f64>,
964    labels: Array1<usize>,
965    train_ratio: f64,
966) -> ContinualTask {
967    let train_size = (data.nrows() as f64 * train_ratio) as usize;
968
969    let train_data = data.slice(s![0..train_size, ..]).to_owned();
970    let train_labels = labels.slice(s![0..train_size]).to_owned();
971
972    let val_data = data.slice(s![train_size.., ..]).to_owned();
973    let val_labels = labels.slice(s![train_size..]).to_owned();
974
975    let num_classes = labels.iter().max().unwrap_or(&0) + 1;
976
977    ContinualTask {
978        task_id,
979        task_type,
980        train_data,
981        train_labels,
982        val_data,
983        val_labels,
984        num_classes,
985        metadata: HashMap::new(),
986    }
987}
988
989/// Helper function to generate synthetic task sequence
990pub fn generate_task_sequence(
991    num_tasks: usize,
992    samples_per_task: usize,
993    feature_dim: usize,
994) -> Vec<ContinualTask> {
995    let mut tasks = Vec::new();
996
997    for i in 0..num_tasks {
998        // Generate task-specific data with some variation
999        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
1000            let task_shift = i as f64 * 0.5;
1001            let base_value = row as f64 / samples_per_task as f64 + col as f64 / feature_dim as f64;
1002            0.5 + 0.3 * (base_value * 2.0 * PI + task_shift).sin() + 0.1 * (fastrand::f64() - 0.5)
1003        });
1004
1005        let labels = Array1::from_shape_fn(samples_per_task, |row| {
1006            // Binary classification based on sum of features
1007            let sum = data.row(row).sum();
1008            if sum > feature_dim as f64 * 0.5 {
1009                1
1010            } else {
1011                0
1012            }
1013        });
1014
1015        let task = create_continual_task(
1016            format!("task_{}", i),
1017            TaskType::Classification { num_classes: 2 },
1018            data,
1019            labels,
1020            0.8, // 80% training, 20% validation
1021        );
1022
1023        tasks.push(task);
1024    }
1025
1026    tasks
1027}
1028
1029#[cfg(test)]
1030mod tests {
1031    use super::*;
1032    use crate::autodiff::optimizers::Adam;
1033    use crate::qnn::QNNLayerType;
1034
1035    #[test]
1036    fn test_memory_buffer() {
1037        let mut buffer = MemoryBuffer::new(5, MemorySelectionStrategy::Random);
1038
1039        for i in 0..10 {
1040            let experience = Experience {
1041                input: Array1::from_vec(vec![i as f64]),
1042                target: Array1::from_vec(vec![(i % 2) as f64]),
1043                task_id: format!("task_{}", i / 3),
1044                importance: i as f64,
1045                gradient_info: None,
1046                uncertainty: None,
1047            };
1048
1049            buffer.add_experience(experience);
1050        }
1051
1052        assert_eq!(buffer.size(), 5);
1053
1054        let samples = buffer.sample(3);
1055        assert_eq!(samples.len(), 3);
1056    }
1057
1058    #[test]
1059    fn test_continual_task_creation() {
1060        let data = Array2::from_shape_fn((100, 4), |(i, j)| (i as f64 + j as f64) / 50.0);
1061        let labels = Array1::from_shape_fn(100, |i| i % 3);
1062
1063        let task = create_continual_task(
1064            "test_task".to_string(),
1065            TaskType::Classification { num_classes: 3 },
1066            data,
1067            labels,
1068            0.7,
1069        );
1070
1071        assert_eq!(task.task_id, "test_task");
1072        assert_eq!(task.train_data.nrows(), 70);
1073        assert_eq!(task.val_data.nrows(), 30);
1074        assert_eq!(task.num_classes, 3);
1075    }
1076
1077    #[test]
1078    fn test_continual_learner_creation() {
1079        let layers = vec![
1080            QNNLayerType::EncodingLayer { num_features: 4 },
1081            QNNLayerType::VariationalLayer { num_params: 8 },
1082            QNNLayerType::MeasurementLayer {
1083                measurement_basis: "computational".to_string(),
1084            },
1085        ];
1086
1087        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).expect("Failed to create model");
1088
1089        let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
1090            importance_weight: 1000.0,
1091            fisher_samples: 100,
1092        };
1093
1094        let learner = QuantumContinualLearner::new(model, strategy);
1095
1096        assert_eq!(learner.task_history.len(), 0);
1097        assert!(learner.current_task.is_none());
1098    }
1099
1100    #[test]
1101    fn test_task_sequence_generation() {
1102        let tasks = generate_task_sequence(3, 50, 4);
1103
1104        assert_eq!(tasks.len(), 3);
1105
1106        for (i, task) in tasks.iter().enumerate() {
1107            assert_eq!(task.task_id, format!("task_{}", i));
1108            assert_eq!(task.train_data.nrows(), 40); // 80% of 50
1109            assert_eq!(task.val_data.nrows(), 10); // 20% of 50
1110            assert_eq!(task.train_data.ncols(), 4);
1111        }
1112    }
1113
1114    #[test]
1115    fn test_forgetting_metrics() {
1116        let metrics = ForgettingMetrics {
1117            average_accuracy: 0.75,
1118            forgetting_measure: 0.15,
1119            backward_transfer: 0.05,
1120            forward_transfer: 0.1,
1121            continual_learning_score: 0.6,
1122            per_task_forgetting: HashMap::new(),
1123        };
1124
1125        assert_eq!(metrics.average_accuracy, 0.75);
1126        assert_eq!(metrics.forgetting_measure, 0.15);
1127        assert!(metrics.continual_learning_score > 0.5);
1128    }
1129}