quantrs2_ml/
continual_learning.rs

1//! Quantum Continual Learning
2//!
3//! This module implements continual learning algorithms for quantum neural networks,
4//! enabling models to learn new tasks sequentially while preserving knowledge from
5//! previous tasks and avoiding catastrophic forgetting.
6
7use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
12use quantrs2_circuit::builder::{Circuit, Simulator};
13use quantrs2_core::gate::{
14    single::{RotationX, RotationY, RotationZ},
15    GateOp,
16};
17use quantrs2_sim::statevector::StateVectorSimulator;
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::f64::consts::PI;
20
21/// Continual learning strategies for quantum models
22#[derive(Debug, Clone)]
23pub enum ContinualLearningStrategy {
24    /// Elastic Weight Consolidation (EWC) for quantum circuits
25    ElasticWeightConsolidation {
26        importance_weight: f64,
27        fisher_samples: usize,
28    },
29
30    /// Progressive Neural Networks with quantum modules
31    ProgressiveNetworks {
32        lateral_connections: bool,
33        adaptation_layers: usize,
34    },
35
36    /// Memory replay with episodic buffer
37    ExperienceReplay {
38        buffer_size: usize,
39        replay_ratio: f64,
40        memory_selection: MemorySelectionStrategy,
41    },
42
43    /// Parameter isolation and expansion
44    ParameterIsolation {
45        allocation_strategy: ParameterAllocationStrategy,
46        growth_threshold: f64,
47    },
48
49    /// Gradient episodic memory
50    GradientEpisodicMemory {
51        memory_strength: f64,
52        violation_threshold: f64,
53    },
54
55    /// Learning without forgetting (LwF)
56    LearningWithoutForgetting {
57        distillation_weight: f64,
58        temperature: f64,
59    },
60
61    /// Quantum-specific regularization
62    QuantumRegularization {
63        entanglement_preservation: f64,
64        parameter_drift_penalty: f64,
65    },
66}
67
68/// Memory selection strategies for experience replay
69#[derive(Debug, Clone)]
70pub enum MemorySelectionStrategy {
71    /// Random sampling
72    Random,
73    /// Gradient-based importance
74    GradientImportance,
75    /// Uncertainty-based selection
76    Uncertainty,
77    /// Diverse sampling
78    Diversity,
79    /// Quantum-specific metrics
80    QuantumMetrics,
81}
82
83/// Parameter allocation strategies
84#[derive(Debug, Clone)]
85pub enum ParameterAllocationStrategy {
86    /// Add new parameters for new tasks
87    Expansion,
88    /// Mask existing parameters for different tasks
89    Masking,
90    /// Hierarchical parameter sharing
91    Hierarchical,
92    /// Quantum-specific allocation
93    QuantumAware,
94}
95
96/// Task definition for continual learning
97#[derive(Debug, Clone)]
98pub struct ContinualTask {
99    /// Task identifier
100    pub task_id: String,
101
102    /// Task type/domain
103    pub task_type: TaskType,
104
105    /// Training data
106    pub train_data: Array2<f64>,
107
108    /// Training labels
109    pub train_labels: Array1<usize>,
110
111    /// Validation data
112    pub val_data: Array2<f64>,
113
114    /// Validation labels
115    pub val_labels: Array1<usize>,
116
117    /// Number of classes
118    pub num_classes: usize,
119
120    /// Task-specific metadata
121    pub metadata: HashMap<String, f64>,
122}
123
124/// Task types for continual learning
125#[derive(Debug, Clone, PartialEq)]
126pub enum TaskType {
127    /// Classification task
128    Classification { num_classes: usize },
129    /// Regression task
130    Regression { output_dim: usize },
131    /// Quantum state preparation
132    StatePreparation { target_states: usize },
133    /// Quantum optimization
134    Optimization { problem_type: String },
135}
136
137/// Memory buffer for experience replay
138#[derive(Debug, Clone)]
139pub struct MemoryBuffer {
140    /// Stored experiences
141    experiences: VecDeque<Experience>,
142
143    /// Maximum buffer size
144    max_size: usize,
145
146    /// Selection strategy
147    selection_strategy: MemorySelectionStrategy,
148
149    /// Task-wise organization
150    task_memories: HashMap<String, Vec<usize>>,
151}
152
153/// Individual experience/memory
154#[derive(Debug, Clone)]
155pub struct Experience {
156    /// Input data
157    pub input: Array1<f64>,
158
159    /// Target output
160    pub target: Array1<f64>,
161
162    /// Task identifier
163    pub task_id: String,
164
165    /// Importance score
166    pub importance: f64,
167
168    /// Gradient information (optional)
169    pub gradient_info: Option<Array1<f64>>,
170
171    /// Uncertainty measure
172    pub uncertainty: Option<f64>,
173}
174
175/// Quantum continual learner
176pub struct QuantumContinualLearner {
177    /// Base quantum model
178    model: QuantumNeuralNetwork,
179
180    /// Continual learning strategy
181    strategy: ContinualLearningStrategy,
182
183    /// Task sequence and history
184    task_history: Vec<ContinualTask>,
185
186    /// Current task index
187    current_task: Option<usize>,
188
189    /// Memory buffer
190    memory_buffer: Option<MemoryBuffer>,
191
192    /// Fisher information (for EWC)
193    fisher_information: Option<Array1<f64>>,
194
195    /// Previous task parameters (for EWC)
196    previous_parameters: Option<Array1<f64>>,
197
198    /// Progressive modules (for Progressive Networks)
199    progressive_modules: Vec<QuantumNeuralNetwork>,
200
201    /// Parameter masks (for Parameter Isolation)
202    parameter_masks: HashMap<String, Array1<bool>>,
203
204    /// Performance metrics per task
205    task_metrics: HashMap<String, TaskMetrics>,
206
207    /// Forgetting metrics
208    forgetting_metrics: ForgettingMetrics,
209}
210
211/// Metrics for individual tasks
212#[derive(Debug, Clone)]
213pub struct TaskMetrics {
214    /// Accuracy on current task
215    pub current_accuracy: f64,
216
217    /// Accuracy after learning subsequent tasks
218    pub retained_accuracy: f64,
219
220    /// Learning speed (epochs to convergence)
221    pub learning_speed: usize,
222
223    /// Backward transfer (improvement from future tasks)
224    pub backward_transfer: f64,
225
226    /// Forward transfer (help to future tasks)
227    pub forward_transfer: f64,
228}
229
230/// Overall forgetting and transfer metrics
231#[derive(Debug, Clone)]
232pub struct ForgettingMetrics {
233    /// Average accuracy across all seen tasks
234    pub average_accuracy: f64,
235
236    /// Catastrophic forgetting measure
237    pub forgetting_measure: f64,
238
239    /// Backward transfer coefficient
240    pub backward_transfer: f64,
241
242    /// Forward transfer coefficient
243    pub forward_transfer: f64,
244
245    /// Overall continual learning score
246    pub continual_learning_score: f64,
247
248    /// Per-task forgetting
249    pub per_task_forgetting: HashMap<String, f64>,
250}
251
252impl QuantumContinualLearner {
253    /// Create a new quantum continual learner
254    pub fn new(model: QuantumNeuralNetwork, strategy: ContinualLearningStrategy) -> Self {
255        let memory_buffer = match &strategy {
256            ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => Some(
257                MemoryBuffer::new(*buffer_size, MemorySelectionStrategy::Random),
258            ),
259            ContinualLearningStrategy::GradientEpisodicMemory { .. } => Some(MemoryBuffer::new(
260                1000,
261                MemorySelectionStrategy::GradientImportance,
262            )),
263            _ => None,
264        };
265
266        Self {
267            model,
268            strategy,
269            task_history: Vec::new(),
270            current_task: None,
271            memory_buffer,
272            fisher_information: None,
273            previous_parameters: None,
274            progressive_modules: Vec::new(),
275            parameter_masks: HashMap::new(),
276            task_metrics: HashMap::new(),
277            forgetting_metrics: ForgettingMetrics {
278                average_accuracy: 0.0,
279                forgetting_measure: 0.0,
280                backward_transfer: 0.0,
281                forward_transfer: 0.0,
282                continual_learning_score: 0.0,
283                per_task_forgetting: HashMap::new(),
284            },
285        }
286    }
287
288    /// Learn a new task
289    pub fn learn_task(
290        &mut self,
291        task: ContinualTask,
292        optimizer: &mut dyn Optimizer,
293        epochs: usize,
294    ) -> Result<TaskMetrics> {
295        println!("Learning task: {}", task.task_id);
296
297        // Store task in history
298        self.task_history.push(task.clone());
299        self.current_task = Some(self.task_history.len() - 1);
300
301        // Apply continual learning strategy before training
302        self.apply_pre_training_strategy(&task)?;
303
304        // Train on the new task
305        let start_time = std::time::Instant::now();
306        let learning_losses = self.train_on_task(&task, optimizer, epochs)?;
307        let learning_time = start_time.elapsed();
308
309        // Apply post-training strategy
310        self.apply_post_training_strategy(&task)?;
311
312        // Evaluate on current task
313        let current_accuracy = self.evaluate_task(&task)?;
314
315        // Update memory buffer if applicable
316        if self.memory_buffer.is_some() {
317            let mut buffer = self.memory_buffer.take().unwrap();
318            self.update_memory_buffer(&mut buffer, &task)?;
319            self.memory_buffer = Some(buffer);
320        }
321
322        // Compute task metrics
323        let task_metrics = TaskMetrics {
324            current_accuracy,
325            retained_accuracy: current_accuracy, // Will be updated later
326            learning_speed: epochs,              // Simplified - could track convergence
327            backward_transfer: 0.0,              // Will be computed later
328            forward_transfer: 0.0,               // Will be computed when future tasks are learned
329        };
330
331        self.task_metrics
332            .insert(task.task_id.clone(), task_metrics.clone());
333
334        // Update overall metrics
335        self.update_forgetting_metrics()?;
336
337        println!(
338            "Task {} learned with accuracy: {:.3}",
339            task.task_id, current_accuracy
340        );
341
342        Ok(task_metrics)
343    }
344
345    /// Train on a specific task
346    fn train_on_task(
347        &mut self,
348        task: &ContinualTask,
349        optimizer: &mut dyn Optimizer,
350        epochs: usize,
351    ) -> Result<Vec<f64>> {
352        let mut losses = Vec::new();
353        let batch_size = 32;
354
355        for epoch in 0..epochs {
356            let mut epoch_loss = 0.0;
357            let num_batches = (task.train_data.nrows() + batch_size - 1) / batch_size;
358
359            for batch_idx in 0..num_batches {
360                let batch_start = batch_idx * batch_size;
361                let batch_end = (batch_start + batch_size).min(task.train_data.nrows());
362
363                let batch_data = task
364                    .train_data
365                    .slice(s![batch_start..batch_end, ..])
366                    .to_owned();
367                let batch_labels = task
368                    .train_labels
369                    .slice(s![batch_start..batch_end])
370                    .to_owned();
371
372                // Create combined training batch with replay if applicable
373                let (final_data, final_labels) =
374                    self.create_training_batch(&batch_data, &batch_labels, task)?;
375
376                // Compute loss with continual learning regularization
377                let batch_loss = self.compute_continual_loss(&final_data, &final_labels, task)?;
378                epoch_loss += batch_loss;
379
380                // Update model parameters (simplified)
381                // In practice, this would use proper backpropagation
382            }
383
384            epoch_loss /= num_batches as f64;
385            losses.push(epoch_loss);
386
387            if epoch % 10 == 0 {
388                println!("  Epoch {}: Loss = {:.4}", epoch, epoch_loss);
389            }
390        }
391
392        Ok(losses)
393    }
394
395    /// Apply pre-training strategy
396    fn apply_pre_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
397        let strategy = self.strategy.clone();
398        match strategy {
399            ContinualLearningStrategy::ElasticWeightConsolidation { .. } => {
400                if !self.task_history.is_empty() {
401                    // Store current parameters and compute Fisher information
402                    self.previous_parameters = Some(self.model.parameters.clone());
403                    self.compute_fisher_information()?;
404                }
405            }
406
407            ContinualLearningStrategy::ProgressiveNetworks {
408                lateral_connections,
409                adaptation_layers,
410            } => {
411                // Create new column for the new task
412                self.create_progressive_column(adaptation_layers)?;
413            }
414
415            ContinualLearningStrategy::ParameterIsolation {
416                allocation_strategy,
417                ..
418            } => {
419                // Allocate parameters for the new task
420                self.allocate_parameters_for_task(task, &allocation_strategy)?;
421            }
422
423            _ => {}
424        }
425
426        Ok(())
427    }
428
429    /// Apply post-training strategy
430    fn apply_post_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
431        match &self.strategy {
432            ContinualLearningStrategy::ExperienceReplay { .. } => {
433                // Memory buffer already updated during training
434            }
435
436            ContinualLearningStrategy::GradientEpisodicMemory { .. } => {
437                // Compute and store gradient information
438                self.compute_gradient_memory(task)?;
439            }
440
441            _ => {}
442        }
443
444        Ok(())
445    }
446
447    /// Create training batch with replay if applicable
448    fn create_training_batch(
449        &self,
450        current_data: &Array2<f64>,
451        current_labels: &Array1<usize>,
452        task: &ContinualTask,
453    ) -> Result<(Array2<f64>, Array1<usize>)> {
454        match &self.strategy {
455            ContinualLearningStrategy::ExperienceReplay { replay_ratio, .. } => {
456                if let Some(ref buffer) = self.memory_buffer {
457                    let num_replay = (current_data.nrows() as f64 * replay_ratio) as usize;
458                    let replay_experiences = buffer.sample(num_replay);
459
460                    // Combine current and replay data
461                    let mut combined_data = current_data.clone();
462                    let mut combined_labels = current_labels.clone();
463
464                    for experience in replay_experiences {
465                        // Add replay data (simplified)
466                        // In practice, would properly combine arrays
467                    }
468
469                    Ok((combined_data, combined_labels))
470                } else {
471                    Ok((current_data.clone(), current_labels.clone()))
472                }
473            }
474            _ => Ok((current_data.clone(), current_labels.clone())),
475        }
476    }
477
478    /// Compute continual learning loss with regularization
479    fn compute_continual_loss(
480        &self,
481        data: &Array2<f64>,
482        labels: &Array1<usize>,
483        task: &ContinualTask,
484    ) -> Result<f64> {
485        // Base loss (simplified)
486        let mut total_loss = 0.0;
487
488        for (input, &label) in data.outer_iter().zip(labels.iter()) {
489            let output = self.model.forward(&input.to_owned())?;
490            total_loss += self.cross_entropy_loss(&output, label);
491        }
492
493        let base_loss = total_loss / data.nrows() as f64;
494
495        // Add continual learning regularization
496        let regularization = match &self.strategy {
497            ContinualLearningStrategy::ElasticWeightConsolidation {
498                importance_weight, ..
499            } => self.compute_ewc_regularization(*importance_weight),
500
501            ContinualLearningStrategy::LearningWithoutForgetting {
502                distillation_weight,
503                temperature,
504            } => self.compute_lwf_regularization(*distillation_weight, *temperature, data)?,
505
506            ContinualLearningStrategy::QuantumRegularization {
507                entanglement_preservation,
508                parameter_drift_penalty,
509            } => self.compute_quantum_regularization(
510                *entanglement_preservation,
511                *parameter_drift_penalty,
512            ),
513
514            _ => 0.0,
515        };
516
517        Ok(base_loss + regularization)
518    }
519
520    /// Compute EWC regularization term
521    fn compute_ewc_regularization(&self, importance_weight: f64) -> f64 {
522        if let (Some(ref fisher), Some(ref prev_params)) =
523            (&self.fisher_information, &self.previous_parameters)
524        {
525            let param_diff = &self.model.parameters - prev_params;
526            let ewc_term = fisher * &param_diff.mapv(|x| x.powi(2));
527            importance_weight * ewc_term.sum() / 2.0
528        } else {
529            0.0
530        }
531    }
532
533    /// Compute Learning without Forgetting regularization
534    fn compute_lwf_regularization(
535        &self,
536        distillation_weight: f64,
537        temperature: f64,
538        data: &Array2<f64>,
539    ) -> Result<f64> {
540        if self.task_history.len() <= 1 {
541            return Ok(0.0);
542        }
543
544        // Compute distillation loss (simplified)
545        let mut distillation_loss = 0.0;
546
547        for input in data.outer_iter() {
548            let current_output = self.model.forward(&input.to_owned())?;
549
550            // Get "teacher" output from previous model state (simplified)
551            // In practice, would store previous model or compute with masked parameters
552            let teacher_output = current_output.clone(); // Placeholder
553
554            // Compute KL divergence with temperature scaling
555            let student_probs = self.softmax_with_temperature(&current_output, temperature);
556            let teacher_probs = self.softmax_with_temperature(&teacher_output, temperature);
557
558            for (s, t) in student_probs.iter().zip(teacher_probs.iter()) {
559                if *t > 1e-10 {
560                    distillation_loss += t * (t / s).ln();
561                }
562            }
563        }
564
565        Ok(distillation_weight * distillation_loss / data.nrows() as f64)
566    }
567
568    /// Compute quantum-specific regularization
569    fn compute_quantum_regularization(
570        &self,
571        entanglement_preservation: f64,
572        parameter_drift_penalty: f64,
573    ) -> f64 {
574        let mut regularization = 0.0;
575
576        // Entanglement preservation penalty
577        if let Some(ref prev_params) = self.previous_parameters {
578            let param_diff = &self.model.parameters - prev_params;
579
580            // Penalize changes that might reduce entanglement capability
581            let entanglement_penalty = param_diff.mapv(|x| x.abs()).sum();
582            regularization += entanglement_preservation * entanglement_penalty;
583        }
584
585        // Parameter drift penalty (encourage small changes)
586        if let Some(ref prev_params) = self.previous_parameters {
587            let drift = (&self.model.parameters - prev_params)
588                .mapv(|x| x.powi(2))
589                .sum();
590            regularization += parameter_drift_penalty * drift;
591        }
592
593        regularization
594    }
595
596    /// Compute Fisher information matrix for EWC
597    fn compute_fisher_information(&mut self) -> Result<()> {
598        if let ContinualLearningStrategy::ElasticWeightConsolidation { fisher_samples, .. } =
599            &self.strategy
600        {
601            let mut fisher = Array1::zeros(self.model.parameters.len());
602
603            // Sample data from previous tasks for Fisher computation
604            if let Some(current_task_idx) = self.current_task {
605                if current_task_idx > 0 {
606                    // Use previous task data (simplified)
607                    let prev_task = &self.task_history[current_task_idx - 1];
608
609                    for i in 0..*fisher_samples {
610                        let idx = i % prev_task.train_data.nrows();
611                        let input = prev_task.train_data.row(idx).to_owned();
612                        let label = prev_task.train_labels[idx];
613
614                        // Compute gradient (simplified - would use automatic differentiation)
615                        let gradient = self.compute_parameter_gradient(&input, label)?;
616                        fisher = fisher + &gradient.mapv(|x| x.powi(2));
617                    }
618
619                    fisher = fisher / *fisher_samples as f64;
620                }
621            }
622
623            self.fisher_information = Some(fisher);
624        }
625
626        Ok(())
627    }
628
629    /// Create progressive network column
630    fn create_progressive_column(&mut self, adaptation_layers: usize) -> Result<()> {
631        // Create a new small network for the new task
632        let layers = vec![
633            QNNLayerType::EncodingLayer { num_features: 4 },
634            QNNLayerType::VariationalLayer { num_params: 6 },
635        ];
636
637        let progressive_module = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
638        self.progressive_modules.push(progressive_module);
639
640        Ok(())
641    }
642
643    /// Allocate parameters for new task
644    fn allocate_parameters_for_task(
645        &mut self,
646        task: &ContinualTask,
647        strategy: &ParameterAllocationStrategy,
648    ) -> Result<()> {
649        match strategy {
650            ParameterAllocationStrategy::Masking => {
651                // Create mask for this task
652                let mask = Array1::from_elem(self.model.parameters.len(), true);
653                // In practice, would compute optimal mask
654                self.parameter_masks.insert(task.task_id.clone(), mask);
655            }
656
657            ParameterAllocationStrategy::Expansion => {
658                // Expand model capacity if needed
659                // This would require modifying the model architecture
660            }
661
662            _ => {}
663        }
664
665        Ok(())
666    }
667
668    /// Compute gradient memory for GEM
669    fn compute_gradient_memory(&mut self, task: &ContinualTask) -> Result<()> {
670        if self.memory_buffer.is_some() {
671            let mut buffer = self.memory_buffer.take().unwrap();
672
673            // Store representative examples with their gradients
674            for i in 0..task.train_data.nrows().min(100) {
675                let input = task.train_data.row(i).to_owned();
676                let label = task.train_labels[i];
677
678                let gradient = self.compute_parameter_gradient(&input, label)?;
679
680                let experience = Experience {
681                    input,
682                    target: Array1::from_elem(task.num_classes, 0.0), // Simplified
683                    task_id: task.task_id.clone(),
684                    importance: 1.0,
685                    gradient_info: Some(gradient),
686                    uncertainty: None,
687                };
688
689                buffer.add_experience(experience);
690            }
691
692            self.memory_buffer = Some(buffer);
693        }
694
695        Ok(())
696    }
697
698    /// Update memory buffer with new experiences
699    fn update_memory_buffer(&self, buffer: &mut MemoryBuffer, task: &ContinualTask) -> Result<()> {
700        // Add experiences from the new task
701        for i in 0..task.train_data.nrows() {
702            let input = task.train_data.row(i).to_owned();
703            let target = Array1::from_elem(task.num_classes, 0.0); // Simplified encoding
704
705            let experience = Experience {
706                input,
707                target,
708                task_id: task.task_id.clone(),
709                importance: 1.0,
710                gradient_info: None,
711                uncertainty: None,
712            };
713
714            buffer.add_experience(experience);
715        }
716
717        Ok(())
718    }
719
720    /// Evaluate model on a specific task
721    fn evaluate_task(&self, task: &ContinualTask) -> Result<f64> {
722        let mut correct = 0;
723        let total = task.val_data.nrows();
724
725        for (input, &label) in task.val_data.outer_iter().zip(task.val_labels.iter()) {
726            let output = self.model.forward(&input.to_owned())?;
727            let predicted = output
728                .iter()
729                .enumerate()
730                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
731                .map(|(i, _)| i)
732                .unwrap_or(0);
733
734            if predicted == label {
735                correct += 1;
736            }
737        }
738
739        Ok(correct as f64 / total as f64)
740    }
741
742    /// Evaluate all previous tasks to measure forgetting
743    pub fn evaluate_all_tasks(&mut self) -> Result<HashMap<String, f64>> {
744        let mut accuracies = HashMap::new();
745
746        for task in &self.task_history {
747            let accuracy = self.evaluate_task(task)?;
748            accuracies.insert(task.task_id.clone(), accuracy);
749
750            // Update retained accuracy in task metrics
751            if let Some(metrics) = self.task_metrics.get_mut(&task.task_id) {
752                metrics.retained_accuracy = accuracy;
753            }
754        }
755
756        Ok(accuracies)
757    }
758
759    /// Update forgetting metrics
760    fn update_forgetting_metrics(&mut self) -> Result<()> {
761        if self.task_history.is_empty() {
762            return Ok(());
763        }
764
765        // Evaluate all tasks
766        let accuracies = self.evaluate_all_tasks()?;
767
768        // Compute average accuracy
769        let avg_accuracy = accuracies.values().sum::<f64>() / accuracies.len() as f64;
770        self.forgetting_metrics.average_accuracy = avg_accuracy;
771
772        // Compute forgetting measure
773        let mut total_forgetting = 0.0;
774        let mut num_comparisons = 0;
775
776        for (task_id, metrics) in &self.task_metrics {
777            let current_acc = accuracies.get(task_id).unwrap_or(&0.0);
778            let original_acc = metrics.current_accuracy;
779
780            if original_acc > 0.0 {
781                let forgetting = (original_acc - current_acc).max(0.0);
782                total_forgetting += forgetting;
783                num_comparisons += 1;
784
785                self.forgetting_metrics
786                    .per_task_forgetting
787                    .insert(task_id.clone(), forgetting);
788            }
789        }
790
791        if num_comparisons > 0 {
792            self.forgetting_metrics.forgetting_measure = total_forgetting / num_comparisons as f64;
793        }
794
795        // Compute continual learning score (simplified)
796        self.forgetting_metrics.continual_learning_score =
797            avg_accuracy - self.forgetting_metrics.forgetting_measure;
798
799        Ok(())
800    }
801
802    /// Compute parameter gradient (simplified)
803    fn compute_parameter_gradient(&self, input: &Array1<f64>, label: usize) -> Result<Array1<f64>> {
804        // Placeholder for gradient computation
805        // In practice, would use automatic differentiation
806        Ok(Array1::zeros(self.model.parameters.len()))
807    }
808
809    /// Cross-entropy loss
810    fn cross_entropy_loss(&self, output: &Array1<f64>, label: usize) -> f64 {
811        let predicted_prob = output[label].max(1e-10);
812        -predicted_prob.ln()
813    }
814
815    /// Softmax with temperature
816    fn softmax_with_temperature(&self, logits: &Array1<f64>, temperature: f64) -> Array1<f64> {
817        let scaled_logits = logits / temperature;
818        let max_logit = scaled_logits
819            .iter()
820            .cloned()
821            .fold(f64::NEG_INFINITY, f64::max);
822        let exp_logits = scaled_logits.mapv(|x| (x - max_logit).exp());
823        let sum_exp = exp_logits.sum();
824        exp_logits / sum_exp
825    }
826
827    /// Get forgetting metrics
828    pub fn get_forgetting_metrics(&self) -> &ForgettingMetrics {
829        &self.forgetting_metrics
830    }
831
832    /// Get task metrics
833    pub fn get_task_metrics(&self) -> &HashMap<String, TaskMetrics> {
834        &self.task_metrics
835    }
836
837    /// Get current model
838    pub fn get_model(&self) -> &QuantumNeuralNetwork {
839        &self.model
840    }
841
842    /// Reset for new task sequence
843    pub fn reset(&mut self) {
844        self.task_history.clear();
845        self.current_task = None;
846        self.fisher_information = None;
847        self.previous_parameters = None;
848        self.progressive_modules.clear();
849        self.parameter_masks.clear();
850        self.task_metrics.clear();
851
852        if let Some(ref mut buffer) = self.memory_buffer {
853            buffer.clear();
854        }
855    }
856}
857
858impl MemoryBuffer {
859    /// Create new memory buffer
860    pub fn new(max_size: usize, strategy: MemorySelectionStrategy) -> Self {
861        Self {
862            experiences: VecDeque::new(),
863            max_size,
864            selection_strategy: strategy,
865            task_memories: HashMap::new(),
866        }
867    }
868
869    /// Add experience to buffer
870    pub fn add_experience(&mut self, experience: Experience) {
871        // Add to main buffer
872        if self.experiences.len() >= self.max_size {
873            let removed = self.experiences.pop_front().unwrap();
874            self.remove_from_task_index(&removed);
875        }
876
877        let experience_idx = self.experiences.len();
878        self.experiences.push_back(experience.clone());
879
880        // Update task index
881        self.task_memories
882            .entry(experience.task_id.clone())
883            .or_insert_with(Vec::new)
884            .push(experience_idx);
885    }
886
887    /// Sample experiences from buffer
888    pub fn sample(&self, num_samples: usize) -> Vec<Experience> {
889        let mut samples = Vec::new();
890
891        let available = self.experiences.len().min(num_samples);
892
893        match self.selection_strategy {
894            MemorySelectionStrategy::Random => {
895                for _ in 0..available {
896                    let idx = fastrand::usize(0..self.experiences.len());
897                    samples.push(self.experiences[idx].clone());
898                }
899            }
900
901            MemorySelectionStrategy::GradientImportance => {
902                // Sort by gradient importance and sample top experiences
903                let mut indexed_experiences: Vec<_> = self.experiences.iter().enumerate().collect();
904
905                indexed_experiences.sort_by(|a, b| {
906                    let importance_a = a.1.importance;
907                    let importance_b = b.1.importance;
908                    importance_b.partial_cmp(&importance_a).unwrap()
909                });
910
911                for (_, experience) in indexed_experiences.into_iter().take(available) {
912                    samples.push(experience.clone());
913                }
914            }
915
916            _ => {
917                // Fallback to random sampling
918                for _ in 0..available {
919                    let idx = fastrand::usize(0..self.experiences.len());
920                    samples.push(self.experiences[idx].clone());
921                }
922            }
923        }
924
925        samples
926    }
927
928    /// Remove experience from task index
929    fn remove_from_task_index(&mut self, experience: &Experience) {
930        if let Some(indices) = self.task_memories.get_mut(&experience.task_id) {
931            // This is simplified - in practice would need to update all indices
932            indices.clear();
933        }
934    }
935
936    /// Clear buffer
937    pub fn clear(&mut self) {
938        self.experiences.clear();
939        self.task_memories.clear();
940    }
941
942    /// Get buffer size
943    pub fn size(&self) -> usize {
944        self.experiences.len()
945    }
946}
947
948/// Helper function to create a simple continual task
949pub fn create_continual_task(
950    task_id: String,
951    task_type: TaskType,
952    data: Array2<f64>,
953    labels: Array1<usize>,
954    train_ratio: f64,
955) -> ContinualTask {
956    let train_size = (data.nrows() as f64 * train_ratio) as usize;
957
958    let train_data = data.slice(s![0..train_size, ..]).to_owned();
959    let train_labels = labels.slice(s![0..train_size]).to_owned();
960
961    let val_data = data.slice(s![train_size.., ..]).to_owned();
962    let val_labels = labels.slice(s![train_size..]).to_owned();
963
964    let num_classes = labels.iter().max().unwrap_or(&0) + 1;
965
966    ContinualTask {
967        task_id,
968        task_type,
969        train_data,
970        train_labels,
971        val_data,
972        val_labels,
973        num_classes,
974        metadata: HashMap::new(),
975    }
976}
977
978/// Helper function to generate synthetic task sequence
979pub fn generate_task_sequence(
980    num_tasks: usize,
981    samples_per_task: usize,
982    feature_dim: usize,
983) -> Vec<ContinualTask> {
984    let mut tasks = Vec::new();
985
986    for i in 0..num_tasks {
987        // Generate task-specific data with some variation
988        let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
989            let task_shift = i as f64 * 0.5;
990            let base_value = row as f64 / samples_per_task as f64 + col as f64 / feature_dim as f64;
991            0.5 + 0.3 * (base_value * 2.0 * PI + task_shift).sin() + 0.1 * (fastrand::f64() - 0.5)
992        });
993
994        let labels = Array1::from_shape_fn(samples_per_task, |row| {
995            // Binary classification based on sum of features
996            let sum = data.row(row).sum();
997            if sum > feature_dim as f64 * 0.5 {
998                1
999            } else {
1000                0
1001            }
1002        });
1003
1004        let task = create_continual_task(
1005            format!("task_{}", i),
1006            TaskType::Classification { num_classes: 2 },
1007            data,
1008            labels,
1009            0.8, // 80% training, 20% validation
1010        );
1011
1012        tasks.push(task);
1013    }
1014
1015    tasks
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020    use super::*;
1021    use crate::autodiff::optimizers::Adam;
1022    use crate::qnn::QNNLayerType;
1023
1024    #[test]
1025    fn test_memory_buffer() {
1026        let mut buffer = MemoryBuffer::new(5, MemorySelectionStrategy::Random);
1027
1028        for i in 0..10 {
1029            let experience = Experience {
1030                input: Array1::from_vec(vec![i as f64]),
1031                target: Array1::from_vec(vec![(i % 2) as f64]),
1032                task_id: format!("task_{}", i / 3),
1033                importance: i as f64,
1034                gradient_info: None,
1035                uncertainty: None,
1036            };
1037
1038            buffer.add_experience(experience);
1039        }
1040
1041        assert_eq!(buffer.size(), 5);
1042
1043        let samples = buffer.sample(3);
1044        assert_eq!(samples.len(), 3);
1045    }
1046
1047    #[test]
1048    fn test_continual_task_creation() {
1049        let data = Array2::from_shape_fn((100, 4), |(i, j)| (i as f64 + j as f64) / 50.0);
1050        let labels = Array1::from_shape_fn(100, |i| i % 3);
1051
1052        let task = create_continual_task(
1053            "test_task".to_string(),
1054            TaskType::Classification { num_classes: 3 },
1055            data,
1056            labels,
1057            0.7,
1058        );
1059
1060        assert_eq!(task.task_id, "test_task");
1061        assert_eq!(task.train_data.nrows(), 70);
1062        assert_eq!(task.val_data.nrows(), 30);
1063        assert_eq!(task.num_classes, 3);
1064    }
1065
1066    #[test]
1067    fn test_continual_learner_creation() {
1068        let layers = vec![
1069            QNNLayerType::EncodingLayer { num_features: 4 },
1070            QNNLayerType::VariationalLayer { num_params: 8 },
1071            QNNLayerType::MeasurementLayer {
1072                measurement_basis: "computational".to_string(),
1073            },
1074        ];
1075
1076        let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
1077
1078        let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
1079            importance_weight: 1000.0,
1080            fisher_samples: 100,
1081        };
1082
1083        let learner = QuantumContinualLearner::new(model, strategy);
1084
1085        assert_eq!(learner.task_history.len(), 0);
1086        assert!(learner.current_task.is_none());
1087    }
1088
1089    #[test]
1090    fn test_task_sequence_generation() {
1091        let tasks = generate_task_sequence(3, 50, 4);
1092
1093        assert_eq!(tasks.len(), 3);
1094
1095        for (i, task) in tasks.iter().enumerate() {
1096            assert_eq!(task.task_id, format!("task_{}", i));
1097            assert_eq!(task.train_data.nrows(), 40); // 80% of 50
1098            assert_eq!(task.val_data.nrows(), 10); // 20% of 50
1099            assert_eq!(task.train_data.ncols(), 4);
1100        }
1101    }
1102
1103    #[test]
1104    fn test_forgetting_metrics() {
1105        let metrics = ForgettingMetrics {
1106            average_accuracy: 0.75,
1107            forgetting_measure: 0.15,
1108            backward_transfer: 0.05,
1109            forward_transfer: 0.1,
1110            continual_learning_score: 0.6,
1111            per_task_forgetting: HashMap::new(),
1112        };
1113
1114        assert_eq!(metrics.average_accuracy, 0.75);
1115        assert_eq!(metrics.forgetting_measure, 0.15);
1116        assert!(metrics.continual_learning_score > 0.5);
1117    }
1118}