1use crate::autodiff::optimizers::Optimizer;
8use crate::error::{MLError, Result};
9use crate::optimization::OptimizationMethod;
10use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13 single::{RotationX, RotationY, RotationZ},
14 GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::f64::consts::PI;
20
21#[derive(Debug, Clone)]
23pub enum ContinualLearningStrategy {
24 ElasticWeightConsolidation {
26 importance_weight: f64,
27 fisher_samples: usize,
28 },
29
30 ProgressiveNetworks {
32 lateral_connections: bool,
33 adaptation_layers: usize,
34 },
35
36 ExperienceReplay {
38 buffer_size: usize,
39 replay_ratio: f64,
40 memory_selection: MemorySelectionStrategy,
41 },
42
43 ParameterIsolation {
45 allocation_strategy: ParameterAllocationStrategy,
46 growth_threshold: f64,
47 },
48
49 GradientEpisodicMemory {
51 memory_strength: f64,
52 violation_threshold: f64,
53 },
54
55 LearningWithoutForgetting {
57 distillation_weight: f64,
58 temperature: f64,
59 },
60
61 QuantumRegularization {
63 entanglement_preservation: f64,
64 parameter_drift_penalty: f64,
65 },
66}
67
68#[derive(Debug, Clone)]
70pub enum MemorySelectionStrategy {
71 Random,
73 GradientImportance,
75 Uncertainty,
77 Diversity,
79 QuantumMetrics,
81}
82
83#[derive(Debug, Clone)]
85pub enum ParameterAllocationStrategy {
86 Expansion,
88 Masking,
90 Hierarchical,
92 QuantumAware,
94}
95
96#[derive(Debug, Clone)]
98pub struct ContinualTask {
99 pub task_id: String,
101
102 pub task_type: TaskType,
104
105 pub train_data: Array2<f64>,
107
108 pub train_labels: Array1<usize>,
110
111 pub val_data: Array2<f64>,
113
114 pub val_labels: Array1<usize>,
116
117 pub num_classes: usize,
119
120 pub metadata: HashMap<String, f64>,
122}
123
124#[derive(Debug, Clone, PartialEq)]
126pub enum TaskType {
127 Classification { num_classes: usize },
129 Regression { output_dim: usize },
131 StatePreparation { target_states: usize },
133 Optimization { problem_type: String },
135}
136
137#[derive(Debug, Clone)]
139pub struct MemoryBuffer {
140 experiences: VecDeque<Experience>,
142
143 max_size: usize,
145
146 selection_strategy: MemorySelectionStrategy,
148
149 task_memories: HashMap<String, Vec<usize>>,
151}
152
153#[derive(Debug, Clone)]
155pub struct Experience {
156 pub input: Array1<f64>,
158
159 pub target: Array1<f64>,
161
162 pub task_id: String,
164
165 pub importance: f64,
167
168 pub gradient_info: Option<Array1<f64>>,
170
171 pub uncertainty: Option<f64>,
173}
174
175pub struct QuantumContinualLearner {
177 model: QuantumNeuralNetwork,
179
180 strategy: ContinualLearningStrategy,
182
183 task_history: Vec<ContinualTask>,
185
186 current_task: Option<usize>,
188
189 memory_buffer: Option<MemoryBuffer>,
191
192 fisher_information: Option<Array1<f64>>,
194
195 previous_parameters: Option<Array1<f64>>,
197
198 progressive_modules: Vec<QuantumNeuralNetwork>,
200
201 parameter_masks: HashMap<String, Array1<bool>>,
203
204 task_metrics: HashMap<String, TaskMetrics>,
206
207 forgetting_metrics: ForgettingMetrics,
209}
210
211#[derive(Debug, Clone)]
213pub struct TaskMetrics {
214 pub current_accuracy: f64,
216
217 pub retained_accuracy: f64,
219
220 pub learning_speed: usize,
222
223 pub backward_transfer: f64,
225
226 pub forward_transfer: f64,
228}
229
230#[derive(Debug, Clone)]
232pub struct ForgettingMetrics {
233 pub average_accuracy: f64,
235
236 pub forgetting_measure: f64,
238
239 pub backward_transfer: f64,
241
242 pub forward_transfer: f64,
244
245 pub continual_learning_score: f64,
247
248 pub per_task_forgetting: HashMap<String, f64>,
250}
251
252impl QuantumContinualLearner {
253 pub fn new(model: QuantumNeuralNetwork, strategy: ContinualLearningStrategy) -> Self {
255 let memory_buffer = match &strategy {
256 ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => Some(
257 MemoryBuffer::new(*buffer_size, MemorySelectionStrategy::Random),
258 ),
259 ContinualLearningStrategy::GradientEpisodicMemory { .. } => Some(MemoryBuffer::new(
260 1000,
261 MemorySelectionStrategy::GradientImportance,
262 )),
263 _ => None,
264 };
265
266 Self {
267 model,
268 strategy,
269 task_history: Vec::new(),
270 current_task: None,
271 memory_buffer,
272 fisher_information: None,
273 previous_parameters: None,
274 progressive_modules: Vec::new(),
275 parameter_masks: HashMap::new(),
276 task_metrics: HashMap::new(),
277 forgetting_metrics: ForgettingMetrics {
278 average_accuracy: 0.0,
279 forgetting_measure: 0.0,
280 backward_transfer: 0.0,
281 forward_transfer: 0.0,
282 continual_learning_score: 0.0,
283 per_task_forgetting: HashMap::new(),
284 },
285 }
286 }
287
288 pub fn learn_task(
290 &mut self,
291 task: ContinualTask,
292 optimizer: &mut dyn Optimizer,
293 epochs: usize,
294 ) -> Result<TaskMetrics> {
295 println!("Learning task: {}", task.task_id);
296
297 self.task_history.push(task.clone());
299 self.current_task = Some(self.task_history.len() - 1);
300
301 self.apply_pre_training_strategy(&task)?;
303
304 let start_time = std::time::Instant::now();
306 let learning_losses = self.train_on_task(&task, optimizer, epochs)?;
307 let learning_time = start_time.elapsed();
308
309 self.apply_post_training_strategy(&task)?;
311
312 let current_accuracy = self.evaluate_task(&task)?;
314
315 if self.memory_buffer.is_some() {
317 let mut buffer = self
318 .memory_buffer
319 .take()
320 .expect("memory_buffer verified to be Some above");
321 self.update_memory_buffer(&mut buffer, &task)?;
322 self.memory_buffer = Some(buffer);
323 }
324
325 let task_metrics = TaskMetrics {
327 current_accuracy,
328 retained_accuracy: current_accuracy, learning_speed: epochs, backward_transfer: 0.0, forward_transfer: 0.0, };
333
334 self.task_metrics
335 .insert(task.task_id.clone(), task_metrics.clone());
336
337 self.update_forgetting_metrics()?;
339
340 println!(
341 "Task {} learned with accuracy: {:.3}",
342 task.task_id, current_accuracy
343 );
344
345 Ok(task_metrics)
346 }
347
348 fn train_on_task(
350 &mut self,
351 task: &ContinualTask,
352 optimizer: &mut dyn Optimizer,
353 epochs: usize,
354 ) -> Result<Vec<f64>> {
355 let mut losses = Vec::new();
356 let batch_size = 32;
357
358 for epoch in 0..epochs {
359 let mut epoch_loss = 0.0;
360 let num_batches = (task.train_data.nrows() + batch_size - 1) / batch_size;
361
362 for batch_idx in 0..num_batches {
363 let batch_start = batch_idx * batch_size;
364 let batch_end = (batch_start + batch_size).min(task.train_data.nrows());
365
366 let batch_data = task
367 .train_data
368 .slice(s![batch_start..batch_end, ..])
369 .to_owned();
370 let batch_labels = task
371 .train_labels
372 .slice(s![batch_start..batch_end])
373 .to_owned();
374
375 let (final_data, final_labels) =
377 self.create_training_batch(&batch_data, &batch_labels, task)?;
378
379 let batch_loss = self.compute_continual_loss(&final_data, &final_labels, task)?;
381 epoch_loss += batch_loss;
382
383 }
386
387 epoch_loss /= num_batches as f64;
388 losses.push(epoch_loss);
389
390 if epoch % 10 == 0 {
391 println!(" Epoch {}: Loss = {:.4}", epoch, epoch_loss);
392 }
393 }
394
395 Ok(losses)
396 }
397
398 fn apply_pre_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
400 let strategy = self.strategy.clone();
401 match strategy {
402 ContinualLearningStrategy::ElasticWeightConsolidation { .. } => {
403 if !self.task_history.is_empty() {
404 self.previous_parameters = Some(self.model.parameters.clone());
406 self.compute_fisher_information()?;
407 }
408 }
409
410 ContinualLearningStrategy::ProgressiveNetworks {
411 lateral_connections,
412 adaptation_layers,
413 } => {
414 self.create_progressive_column(adaptation_layers)?;
416 }
417
418 ContinualLearningStrategy::ParameterIsolation {
419 allocation_strategy,
420 ..
421 } => {
422 self.allocate_parameters_for_task(task, &allocation_strategy)?;
424 }
425
426 _ => {}
427 }
428
429 Ok(())
430 }
431
432 fn apply_post_training_strategy(&mut self, task: &ContinualTask) -> Result<()> {
434 match &self.strategy {
435 ContinualLearningStrategy::ExperienceReplay { .. } => {
436 }
438
439 ContinualLearningStrategy::GradientEpisodicMemory { .. } => {
440 self.compute_gradient_memory(task)?;
442 }
443
444 _ => {}
445 }
446
447 Ok(())
448 }
449
450 fn create_training_batch(
452 &self,
453 current_data: &Array2<f64>,
454 current_labels: &Array1<usize>,
455 task: &ContinualTask,
456 ) -> Result<(Array2<f64>, Array1<usize>)> {
457 match &self.strategy {
458 ContinualLearningStrategy::ExperienceReplay { replay_ratio, .. } => {
459 if let Some(ref buffer) = self.memory_buffer {
460 let num_replay = (current_data.nrows() as f64 * replay_ratio) as usize;
461 let replay_experiences = buffer.sample(num_replay);
462
463 let mut combined_data = current_data.clone();
465 let mut combined_labels = current_labels.clone();
466
467 for experience in replay_experiences {
468 }
471
472 Ok((combined_data, combined_labels))
473 } else {
474 Ok((current_data.clone(), current_labels.clone()))
475 }
476 }
477 _ => Ok((current_data.clone(), current_labels.clone())),
478 }
479 }
480
481 fn compute_continual_loss(
483 &self,
484 data: &Array2<f64>,
485 labels: &Array1<usize>,
486 task: &ContinualTask,
487 ) -> Result<f64> {
488 let mut total_loss = 0.0;
490
491 for (input, &label) in data.outer_iter().zip(labels.iter()) {
492 let output = self.model.forward(&input.to_owned())?;
493 total_loss += self.cross_entropy_loss(&output, label);
494 }
495
496 let base_loss = total_loss / data.nrows() as f64;
497
498 let regularization = match &self.strategy {
500 ContinualLearningStrategy::ElasticWeightConsolidation {
501 importance_weight, ..
502 } => self.compute_ewc_regularization(*importance_weight),
503
504 ContinualLearningStrategy::LearningWithoutForgetting {
505 distillation_weight,
506 temperature,
507 } => self.compute_lwf_regularization(*distillation_weight, *temperature, data)?,
508
509 ContinualLearningStrategy::QuantumRegularization {
510 entanglement_preservation,
511 parameter_drift_penalty,
512 } => self.compute_quantum_regularization(
513 *entanglement_preservation,
514 *parameter_drift_penalty,
515 ),
516
517 _ => 0.0,
518 };
519
520 Ok(base_loss + regularization)
521 }
522
523 fn compute_ewc_regularization(&self, importance_weight: f64) -> f64 {
525 if let (Some(ref fisher), Some(ref prev_params)) =
526 (&self.fisher_information, &self.previous_parameters)
527 {
528 let param_diff = &self.model.parameters - prev_params;
529 let ewc_term = fisher * ¶m_diff.mapv(|x| x.powi(2));
530 importance_weight * ewc_term.sum() / 2.0
531 } else {
532 0.0
533 }
534 }
535
536 fn compute_lwf_regularization(
538 &self,
539 distillation_weight: f64,
540 temperature: f64,
541 data: &Array2<f64>,
542 ) -> Result<f64> {
543 if self.task_history.len() <= 1 {
544 return Ok(0.0);
545 }
546
547 let mut distillation_loss = 0.0;
549
550 for input in data.outer_iter() {
551 let current_output = self.model.forward(&input.to_owned())?;
552
553 let teacher_output = current_output.clone(); let student_probs = self.softmax_with_temperature(¤t_output, temperature);
559 let teacher_probs = self.softmax_with_temperature(&teacher_output, temperature);
560
561 for (s, t) in student_probs.iter().zip(teacher_probs.iter()) {
562 if *t > 1e-10 {
563 distillation_loss += t * (t / s).ln();
564 }
565 }
566 }
567
568 Ok(distillation_weight * distillation_loss / data.nrows() as f64)
569 }
570
571 fn compute_quantum_regularization(
573 &self,
574 entanglement_preservation: f64,
575 parameter_drift_penalty: f64,
576 ) -> f64 {
577 let mut regularization = 0.0;
578
579 if let Some(ref prev_params) = self.previous_parameters {
581 let param_diff = &self.model.parameters - prev_params;
582
583 let entanglement_penalty = param_diff.mapv(|x| x.abs()).sum();
585 regularization += entanglement_preservation * entanglement_penalty;
586 }
587
588 if let Some(ref prev_params) = self.previous_parameters {
590 let drift = (&self.model.parameters - prev_params)
591 .mapv(|x| x.powi(2))
592 .sum();
593 regularization += parameter_drift_penalty * drift;
594 }
595
596 regularization
597 }
598
599 fn compute_fisher_information(&mut self) -> Result<()> {
601 if let ContinualLearningStrategy::ElasticWeightConsolidation { fisher_samples, .. } =
602 &self.strategy
603 {
604 let mut fisher = Array1::zeros(self.model.parameters.len());
605
606 if let Some(current_task_idx) = self.current_task {
608 if current_task_idx > 0 {
609 let prev_task = &self.task_history[current_task_idx - 1];
611
612 for i in 0..*fisher_samples {
613 let idx = i % prev_task.train_data.nrows();
614 let input = prev_task.train_data.row(idx).to_owned();
615 let label = prev_task.train_labels[idx];
616
617 let gradient = self.compute_parameter_gradient(&input, label)?;
619 fisher = fisher + &gradient.mapv(|x| x.powi(2));
620 }
621
622 fisher = fisher / *fisher_samples as f64;
623 }
624 }
625
626 self.fisher_information = Some(fisher);
627 }
628
629 Ok(())
630 }
631
632 fn create_progressive_column(&mut self, adaptation_layers: usize) -> Result<()> {
634 let layers = vec![
636 QNNLayerType::EncodingLayer { num_features: 4 },
637 QNNLayerType::VariationalLayer { num_params: 6 },
638 ];
639
640 let progressive_module = QuantumNeuralNetwork::new(layers, 4, 4, 2)?;
641 self.progressive_modules.push(progressive_module);
642
643 Ok(())
644 }
645
646 fn allocate_parameters_for_task(
648 &mut self,
649 task: &ContinualTask,
650 strategy: &ParameterAllocationStrategy,
651 ) -> Result<()> {
652 match strategy {
653 ParameterAllocationStrategy::Masking => {
654 let mask = Array1::from_elem(self.model.parameters.len(), true);
656 self.parameter_masks.insert(task.task_id.clone(), mask);
658 }
659
660 ParameterAllocationStrategy::Expansion => {
661 }
664
665 _ => {}
666 }
667
668 Ok(())
669 }
670
671 fn compute_gradient_memory(&mut self, task: &ContinualTask) -> Result<()> {
673 if self.memory_buffer.is_some() {
674 let mut buffer = self
675 .memory_buffer
676 .take()
677 .expect("memory_buffer verified to be Some above");
678
679 for i in 0..task.train_data.nrows().min(100) {
681 let input = task.train_data.row(i).to_owned();
682 let label = task.train_labels[i];
683
684 let gradient = self.compute_parameter_gradient(&input, label)?;
685
686 let experience = Experience {
687 input,
688 target: Array1::from_elem(task.num_classes, 0.0), task_id: task.task_id.clone(),
690 importance: 1.0,
691 gradient_info: Some(gradient),
692 uncertainty: None,
693 };
694
695 buffer.add_experience(experience);
696 }
697
698 self.memory_buffer = Some(buffer);
699 }
700
701 Ok(())
702 }
703
704 fn update_memory_buffer(&self, buffer: &mut MemoryBuffer, task: &ContinualTask) -> Result<()> {
706 for i in 0..task.train_data.nrows() {
708 let input = task.train_data.row(i).to_owned();
709 let target = Array1::from_elem(task.num_classes, 0.0); let experience = Experience {
712 input,
713 target,
714 task_id: task.task_id.clone(),
715 importance: 1.0,
716 gradient_info: None,
717 uncertainty: None,
718 };
719
720 buffer.add_experience(experience);
721 }
722
723 Ok(())
724 }
725
726 fn evaluate_task(&self, task: &ContinualTask) -> Result<f64> {
728 let mut correct = 0;
729 let total = task.val_data.nrows();
730
731 for (input, &label) in task.val_data.outer_iter().zip(task.val_labels.iter()) {
732 let output = self.model.forward(&input.to_owned())?;
733 let predicted = output
734 .iter()
735 .enumerate()
736 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
737 .map(|(i, _)| i)
738 .unwrap_or(0);
739
740 if predicted == label {
741 correct += 1;
742 }
743 }
744
745 Ok(correct as f64 / total as f64)
746 }
747
748 pub fn evaluate_all_tasks(&mut self) -> Result<HashMap<String, f64>> {
750 let mut accuracies = HashMap::new();
751
752 for task in &self.task_history {
753 let accuracy = self.evaluate_task(task)?;
754 accuracies.insert(task.task_id.clone(), accuracy);
755
756 if let Some(metrics) = self.task_metrics.get_mut(&task.task_id) {
758 metrics.retained_accuracy = accuracy;
759 }
760 }
761
762 Ok(accuracies)
763 }
764
765 fn update_forgetting_metrics(&mut self) -> Result<()> {
767 if self.task_history.is_empty() {
768 return Ok(());
769 }
770
771 let accuracies = self.evaluate_all_tasks()?;
773
774 let avg_accuracy = accuracies.values().sum::<f64>() / accuracies.len() as f64;
776 self.forgetting_metrics.average_accuracy = avg_accuracy;
777
778 let mut total_forgetting = 0.0;
780 let mut num_comparisons = 0;
781
782 for (task_id, metrics) in &self.task_metrics {
783 let current_acc = accuracies.get(task_id).unwrap_or(&0.0);
784 let original_acc = metrics.current_accuracy;
785
786 if original_acc > 0.0 {
787 let forgetting = (original_acc - current_acc).max(0.0);
788 total_forgetting += forgetting;
789 num_comparisons += 1;
790
791 self.forgetting_metrics
792 .per_task_forgetting
793 .insert(task_id.clone(), forgetting);
794 }
795 }
796
797 if num_comparisons > 0 {
798 self.forgetting_metrics.forgetting_measure = total_forgetting / num_comparisons as f64;
799 }
800
801 self.forgetting_metrics.continual_learning_score =
803 avg_accuracy - self.forgetting_metrics.forgetting_measure;
804
805 Ok(())
806 }
807
808 fn compute_parameter_gradient(&self, input: &Array1<f64>, label: usize) -> Result<Array1<f64>> {
810 Ok(Array1::zeros(self.model.parameters.len()))
813 }
814
815 fn cross_entropy_loss(&self, output: &Array1<f64>, label: usize) -> f64 {
817 let predicted_prob = output[label].max(1e-10);
818 -predicted_prob.ln()
819 }
820
821 fn softmax_with_temperature(&self, logits: &Array1<f64>, temperature: f64) -> Array1<f64> {
823 let scaled_logits = logits / temperature;
824 let max_logit = scaled_logits
825 .iter()
826 .cloned()
827 .fold(f64::NEG_INFINITY, f64::max);
828 let exp_logits = scaled_logits.mapv(|x| (x - max_logit).exp());
829 let sum_exp = exp_logits.sum();
830 exp_logits / sum_exp
831 }
832
833 pub fn get_forgetting_metrics(&self) -> &ForgettingMetrics {
835 &self.forgetting_metrics
836 }
837
838 pub fn get_task_metrics(&self) -> &HashMap<String, TaskMetrics> {
840 &self.task_metrics
841 }
842
843 pub fn get_model(&self) -> &QuantumNeuralNetwork {
845 &self.model
846 }
847
848 pub fn reset(&mut self) {
850 self.task_history.clear();
851 self.current_task = None;
852 self.fisher_information = None;
853 self.previous_parameters = None;
854 self.progressive_modules.clear();
855 self.parameter_masks.clear();
856 self.task_metrics.clear();
857
858 if let Some(ref mut buffer) = self.memory_buffer {
859 buffer.clear();
860 }
861 }
862}
863
864impl MemoryBuffer {
865 pub fn new(max_size: usize, strategy: MemorySelectionStrategy) -> Self {
867 Self {
868 experiences: VecDeque::new(),
869 max_size,
870 selection_strategy: strategy,
871 task_memories: HashMap::new(),
872 }
873 }
874
875 pub fn add_experience(&mut self, experience: Experience) {
877 if self.experiences.len() >= self.max_size {
879 let removed = self
880 .experiences
881 .pop_front()
882 .expect("Buffer is non-empty when len >= max_size");
883 self.remove_from_task_index(&removed);
884 }
885
886 let experience_idx = self.experiences.len();
887 self.experiences.push_back(experience.clone());
888
889 self.task_memories
891 .entry(experience.task_id.clone())
892 .or_insert_with(Vec::new)
893 .push(experience_idx);
894 }
895
896 pub fn sample(&self, num_samples: usize) -> Vec<Experience> {
898 let mut samples = Vec::new();
899
900 let available = self.experiences.len().min(num_samples);
901
902 match self.selection_strategy {
903 MemorySelectionStrategy::Random => {
904 for _ in 0..available {
905 let idx = fastrand::usize(0..self.experiences.len());
906 samples.push(self.experiences[idx].clone());
907 }
908 }
909
910 MemorySelectionStrategy::GradientImportance => {
911 let mut indexed_experiences: Vec<_> = self.experiences.iter().enumerate().collect();
913
914 indexed_experiences.sort_by(|a, b| {
915 let importance_a = a.1.importance;
916 let importance_b = b.1.importance;
917 importance_b
918 .partial_cmp(&importance_a)
919 .unwrap_or(std::cmp::Ordering::Equal)
920 });
921
922 for (_, experience) in indexed_experiences.into_iter().take(available) {
923 samples.push(experience.clone());
924 }
925 }
926
927 _ => {
928 for _ in 0..available {
930 let idx = fastrand::usize(0..self.experiences.len());
931 samples.push(self.experiences[idx].clone());
932 }
933 }
934 }
935
936 samples
937 }
938
939 fn remove_from_task_index(&mut self, experience: &Experience) {
941 if let Some(indices) = self.task_memories.get_mut(&experience.task_id) {
942 indices.clear();
944 }
945 }
946
947 pub fn clear(&mut self) {
949 self.experiences.clear();
950 self.task_memories.clear();
951 }
952
953 pub fn size(&self) -> usize {
955 self.experiences.len()
956 }
957}
958
959pub fn create_continual_task(
961 task_id: String,
962 task_type: TaskType,
963 data: Array2<f64>,
964 labels: Array1<usize>,
965 train_ratio: f64,
966) -> ContinualTask {
967 let train_size = (data.nrows() as f64 * train_ratio) as usize;
968
969 let train_data = data.slice(s![0..train_size, ..]).to_owned();
970 let train_labels = labels.slice(s![0..train_size]).to_owned();
971
972 let val_data = data.slice(s![train_size.., ..]).to_owned();
973 let val_labels = labels.slice(s![train_size..]).to_owned();
974
975 let num_classes = labels.iter().max().unwrap_or(&0) + 1;
976
977 ContinualTask {
978 task_id,
979 task_type,
980 train_data,
981 train_labels,
982 val_data,
983 val_labels,
984 num_classes,
985 metadata: HashMap::new(),
986 }
987}
988
989pub fn generate_task_sequence(
991 num_tasks: usize,
992 samples_per_task: usize,
993 feature_dim: usize,
994) -> Vec<ContinualTask> {
995 let mut tasks = Vec::new();
996
997 for i in 0..num_tasks {
998 let data = Array2::from_shape_fn((samples_per_task, feature_dim), |(row, col)| {
1000 let task_shift = i as f64 * 0.5;
1001 let base_value = row as f64 / samples_per_task as f64 + col as f64 / feature_dim as f64;
1002 0.5 + 0.3 * (base_value * 2.0 * PI + task_shift).sin() + 0.1 * (fastrand::f64() - 0.5)
1003 });
1004
1005 let labels = Array1::from_shape_fn(samples_per_task, |row| {
1006 let sum = data.row(row).sum();
1008 if sum > feature_dim as f64 * 0.5 {
1009 1
1010 } else {
1011 0
1012 }
1013 });
1014
1015 let task = create_continual_task(
1016 format!("task_{}", i),
1017 TaskType::Classification { num_classes: 2 },
1018 data,
1019 labels,
1020 0.8, );
1022
1023 tasks.push(task);
1024 }
1025
1026 tasks
1027}
1028
1029#[cfg(test)]
1030mod tests {
1031 use super::*;
1032 use crate::autodiff::optimizers::Adam;
1033 use crate::qnn::QNNLayerType;
1034
1035 #[test]
1036 fn test_memory_buffer() {
1037 let mut buffer = MemoryBuffer::new(5, MemorySelectionStrategy::Random);
1038
1039 for i in 0..10 {
1040 let experience = Experience {
1041 input: Array1::from_vec(vec![i as f64]),
1042 target: Array1::from_vec(vec![(i % 2) as f64]),
1043 task_id: format!("task_{}", i / 3),
1044 importance: i as f64,
1045 gradient_info: None,
1046 uncertainty: None,
1047 };
1048
1049 buffer.add_experience(experience);
1050 }
1051
1052 assert_eq!(buffer.size(), 5);
1053
1054 let samples = buffer.sample(3);
1055 assert_eq!(samples.len(), 3);
1056 }
1057
1058 #[test]
1059 fn test_continual_task_creation() {
1060 let data = Array2::from_shape_fn((100, 4), |(i, j)| (i as f64 + j as f64) / 50.0);
1061 let labels = Array1::from_shape_fn(100, |i| i % 3);
1062
1063 let task = create_continual_task(
1064 "test_task".to_string(),
1065 TaskType::Classification { num_classes: 3 },
1066 data,
1067 labels,
1068 0.7,
1069 );
1070
1071 assert_eq!(task.task_id, "test_task");
1072 assert_eq!(task.train_data.nrows(), 70);
1073 assert_eq!(task.val_data.nrows(), 30);
1074 assert_eq!(task.num_classes, 3);
1075 }
1076
1077 #[test]
1078 fn test_continual_learner_creation() {
1079 let layers = vec![
1080 QNNLayerType::EncodingLayer { num_features: 4 },
1081 QNNLayerType::VariationalLayer { num_params: 8 },
1082 QNNLayerType::MeasurementLayer {
1083 measurement_basis: "computational".to_string(),
1084 },
1085 ];
1086
1087 let model = QuantumNeuralNetwork::new(layers, 4, 4, 2).expect("Failed to create model");
1088
1089 let strategy = ContinualLearningStrategy::ElasticWeightConsolidation {
1090 importance_weight: 1000.0,
1091 fisher_samples: 100,
1092 };
1093
1094 let learner = QuantumContinualLearner::new(model, strategy);
1095
1096 assert_eq!(learner.task_history.len(), 0);
1097 assert!(learner.current_task.is_none());
1098 }
1099
1100 #[test]
1101 fn test_task_sequence_generation() {
1102 let tasks = generate_task_sequence(3, 50, 4);
1103
1104 assert_eq!(tasks.len(), 3);
1105
1106 for (i, task) in tasks.iter().enumerate() {
1107 assert_eq!(task.task_id, format!("task_{}", i));
1108 assert_eq!(task.train_data.nrows(), 40); assert_eq!(task.val_data.nrows(), 10); assert_eq!(task.train_data.ncols(), 4);
1111 }
1112 }
1113
1114 #[test]
1115 fn test_forgetting_metrics() {
1116 let metrics = ForgettingMetrics {
1117 average_accuracy: 0.75,
1118 forgetting_measure: 0.15,
1119 backward_transfer: 0.05,
1120 forward_transfer: 0.1,
1121 continual_learning_score: 0.6,
1122 per_task_forgetting: HashMap::new(),
1123 };
1124
1125 assert_eq!(metrics.average_accuracy, 0.75);
1126 assert_eq!(metrics.forgetting_measure, 0.15);
1127 assert!(metrics.continual_learning_score > 0.5);
1128 }
1129}