1use crate::{EmbeddingModel, ModelConfig, TrainingStats, Triple, Vector};
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use scirs2_core::ndarray_ext::{Array1, Array2};
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::{HashMap, VecDeque};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct ContinualLearningConfig {
20 pub base_config: ModelConfig,
21 pub memory_config: MemoryConfig,
23 pub regularization_config: RegularizationConfig,
25 pub architecture_config: ArchitectureConfig,
27 pub task_config: TaskConfig,
29 pub replay_config: ReplayConfig,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct MemoryConfig {
36 pub memory_type: MemoryType,
38 pub memory_capacity: usize,
40 pub update_strategy: MemoryUpdateStrategy,
42 pub consolidation: ConsolidationConfig,
44}
45
46impl Default for MemoryConfig {
47 fn default() -> Self {
48 Self {
49 memory_type: MemoryType::EpisodicMemory,
50 memory_capacity: 10000,
51 update_strategy: MemoryUpdateStrategy::ReservoirSampling,
52 consolidation: ConsolidationConfig::default(),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum MemoryType {
60 EpisodicMemory,
62 SemanticMemory,
64 WorkingMemory,
66 ProceduralMemory,
68 HybridMemory,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum MemoryUpdateStrategy {
75 FIFO,
77 Random,
79 ReservoirSampling,
81 ImportanceBased,
83 GradientBased,
85 ClusteringBased,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ConsolidationConfig {
92 pub enabled: bool,
94 pub frequency: usize,
96 pub strength: f32,
98 pub sleep_consolidation: bool,
100}
101
102impl Default for ConsolidationConfig {
103 fn default() -> Self {
104 Self {
105 enabled: true,
106 frequency: 1000,
107 strength: 0.1,
108 sleep_consolidation: false,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct RegularizationConfig {
116 pub methods: Vec<RegularizationMethod>,
118 pub ewc_config: EWCConfig,
120 pub si_config: SynapticIntelligenceConfig,
122 pub lwf_config: LwFConfig,
124}
125
126impl Default for RegularizationConfig {
127 fn default() -> Self {
128 Self {
129 methods: vec![
130 RegularizationMethod::EWC,
131 RegularizationMethod::SynapticIntelligence,
132 ],
133 ewc_config: EWCConfig::default(),
134 si_config: SynapticIntelligenceConfig::default(),
135 lwf_config: LwFConfig::default(),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
142pub enum RegularizationMethod {
143 EWC,
145 SynapticIntelligence,
147 LwF,
149 MAS,
151 RiemannianWalk,
153 PackNet,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct EWCConfig {
160 pub lambda: f32,
162 pub fisher_method: FisherMethod,
164 pub online: bool,
166 pub gamma: f32,
168}
169
170impl Default for EWCConfig {
171 fn default() -> Self {
172 Self {
173 lambda: 0.4,
174 fisher_method: FisherMethod::Empirical,
175 online: true,
176 gamma: 1.0,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub enum FisherMethod {
184 Empirical,
186 True,
188 Diagonal,
190 BlockDiagonal,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct SynapticIntelligenceConfig {
197 pub c: f32,
199 pub xi: f32,
201 pub damping: f32,
203}
204
205impl Default for SynapticIntelligenceConfig {
206 fn default() -> Self {
207 Self {
208 c: 0.1,
209 xi: 1.0,
210 damping: 0.1,
211 }
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct LwFConfig {
218 pub alpha: f32,
220 pub temperature: f32,
222 pub attention_transfer: bool,
224}
225
226impl Default for LwFConfig {
227 fn default() -> Self {
228 Self {
229 alpha: 1.0,
230 temperature: 4.0,
231 attention_transfer: false,
232 }
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct ArchitectureConfig {
239 pub adaptation_method: ArchitectureAdaptation,
241 pub progressive_config: ProgressiveConfig,
243 pub dynamic_config: DynamicConfig,
245}
246
247impl Default for ArchitectureConfig {
248 fn default() -> Self {
249 Self {
250 adaptation_method: ArchitectureAdaptation::Progressive,
251 progressive_config: ProgressiveConfig::default(),
252 dynamic_config: DynamicConfig::default(),
253 }
254 }
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub enum ArchitectureAdaptation {
260 Progressive,
262 Dynamic,
264 PackNet,
266 HAT,
268 Supermasks,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct ProgressiveConfig {
275 pub columns_per_task: usize,
277 pub lateral_strength: f32,
279 pub column_capacity: usize,
281}
282
283impl Default for ProgressiveConfig {
284 fn default() -> Self {
285 Self {
286 columns_per_task: 1,
287 lateral_strength: 0.5,
288 column_capacity: 1000,
289 }
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct DynamicConfig {
296 pub expansion_threshold: f32,
298 pub pruning_threshold: f32,
300 pub growth_rate: f32,
302 pub max_size: usize,
304}
305
306impl Default for DynamicConfig {
307 fn default() -> Self {
308 Self {
309 expansion_threshold: 0.9,
310 pruning_threshold: 0.1,
311 growth_rate: 0.1,
312 max_size: 100000,
313 }
314 }
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct TaskConfig {
320 pub detection_method: TaskDetection,
322 pub boundary_detection: BoundaryDetection,
324 pub switching_strategy: TaskSwitching,
326}
327
328impl Default for TaskConfig {
329 fn default() -> Self {
330 Self {
331 detection_method: TaskDetection::Automatic,
332 boundary_detection: BoundaryDetection::ChangePoint,
333 switching_strategy: TaskSwitching::Soft,
334 }
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub enum TaskDetection {
341 Manual,
343 Automatic,
345 Oracle,
347 Clustering,
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub enum BoundaryDetection {
354 ChangePoint,
356 DistributionShift,
358 LossBased,
360 GradientBased,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
366pub enum TaskSwitching {
367 Hard,
369 Soft,
371 Attention,
373 Gating,
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct ReplayConfig {
380 pub methods: Vec<ReplayMethod>,
382 pub buffer_size: usize,
384 pub replay_ratio: f32,
386 pub generative_config: GenerativeReplayConfig,
388}
389
390impl Default for ReplayConfig {
391 fn default() -> Self {
392 Self {
393 methods: vec![
394 ReplayMethod::ExperienceReplay,
395 ReplayMethod::GenerativeReplay,
396 ],
397 buffer_size: 5000,
398 replay_ratio: 0.5,
399 generative_config: GenerativeReplayConfig::default(),
400 }
401 }
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
406pub enum ReplayMethod {
407 ExperienceReplay,
409 GenerativeReplay,
411 PseudoRehearsal,
413 MetaReplay,
415 GradientEpisodicMemory,
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct GenerativeReplayConfig {
422 pub generator_type: GeneratorType,
424 pub quality_threshold: f32,
426 pub diversity_weight: f32,
428}
429
430impl Default for GenerativeReplayConfig {
431 fn default() -> Self {
432 Self {
433 generator_type: GeneratorType::VAE,
434 quality_threshold: 0.8,
435 diversity_weight: 0.1,
436 }
437 }
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
442pub enum GeneratorType {
443 VAE,
444 GAN,
445 Flow,
446 Diffusion,
447}
448
449#[derive(Debug, Clone)]
451pub struct TaskInfo {
452 pub task_id: String,
453 pub task_type: String,
454 pub start_time: DateTime<Utc>,
455 pub end_time: Option<DateTime<Utc>>,
456 pub examples_seen: usize,
457 pub performance: f32,
458 pub task_embedding: Option<Array1<f32>>,
459}
460
461impl TaskInfo {
462 pub fn new(task_id: String, task_type: String) -> Self {
463 Self {
464 task_id,
465 task_type,
466 start_time: Utc::now(),
467 end_time: None,
468 examples_seen: 0,
469 performance: 0.0,
470 task_embedding: None,
471 }
472 }
473}
474
475#[derive(Debug, Clone)]
477pub struct MemoryEntry {
478 pub data: Array1<f32>,
479 pub target: Array1<f32>,
480 pub task_id: String,
481 pub timestamp: DateTime<Utc>,
482 pub importance: f32,
483 pub access_count: usize,
484}
485
486impl MemoryEntry {
487 pub fn new(data: Array1<f32>, target: Array1<f32>, task_id: String) -> Self {
488 Self {
489 data,
490 target,
491 task_id,
492 timestamp: Utc::now(),
493 importance: 1.0,
494 access_count: 0,
495 }
496 }
497}
498
499#[derive(Debug, Clone)]
501pub struct EWCState {
502 pub fisher_information: Array2<f32>,
503 pub optimal_parameters: Array2<f32>,
504 pub task_id: String,
505 pub importance: f32,
506}
507
508#[derive(Debug)]
510pub struct ContinualLearningModel {
511 pub config: ContinualLearningConfig,
512 pub model_id: Uuid,
513
514 pub embeddings: Array2<f32>,
516 pub task_specific_embeddings: HashMap<String, Array2<f32>>,
517
518 pub episodic_memory: VecDeque<MemoryEntry>,
520 pub semantic_memory: HashMap<String, Array1<f32>>,
521
522 pub ewc_states: Vec<EWCState>,
524 pub synaptic_importance: Array2<f32>,
525 pub parameter_trajectory: Array2<f32>,
526
527 pub current_task: Option<TaskInfo>,
529 pub task_history: Vec<TaskInfo>,
530 pub task_boundaries: Vec<usize>,
531
532 pub network_columns: Vec<Array2<f32>>,
534 pub lateral_connections: Vec<Array2<f32>>,
535
536 pub generator: Option<Array2<f32>>,
538 pub discriminator: Option<Array2<f32>>,
539
540 pub entities: HashMap<String, usize>,
542 pub relations: HashMap<String, usize>,
543
544 pub examples_seen: usize,
546 pub training_stats: Option<TrainingStats>,
547 pub is_trained: bool,
548}
549
550impl ContinualLearningModel {
551 pub fn new(config: ContinualLearningConfig) -> Self {
553 let mut _random = Random::default();
554
555 let model_id = Uuid::new_v4();
556 let dimensions = config.base_config.dimensions;
557
558 Self {
559 config: config.clone(),
560 model_id,
561 embeddings: Array2::zeros((0, dimensions)),
562 task_specific_embeddings: HashMap::new(),
563 episodic_memory: VecDeque::with_capacity(config.memory_config.memory_capacity),
564 semantic_memory: HashMap::new(),
565 ewc_states: Vec::new(),
566 synaptic_importance: Array2::zeros((0, dimensions)),
567 parameter_trajectory: Array2::zeros((0, dimensions)),
568 current_task: None,
569 task_history: Vec::new(),
570 task_boundaries: Vec::new(),
571 network_columns: {
572 let mut random = Random::default();
573 vec![Array2::from_shape_fn((dimensions, dimensions), |_| {
574 random.random::<f64>() as f32 * 0.1
575 })]
576 },
577 lateral_connections: Vec::new(),
578 generator: Some({
579 let mut random = Random::default();
580 Array2::from_shape_fn((dimensions, dimensions), |_| {
581 random.random::<f64>() as f32 * 0.1
582 })
583 }),
584 discriminator: Some({
585 let mut random = Random::default();
586 Array2::from_shape_fn((dimensions, dimensions), |_| {
587 random.random::<f64>() as f32 * 0.1
588 })
589 }),
590 entities: HashMap::new(),
591 relations: HashMap::new(),
592 examples_seen: 0,
593 training_stats: None,
594 is_trained: false,
595 }
596 }
597
598 pub fn start_task(&mut self, task_id: String, task_type: String) -> Result<()> {
600 if let Some(ref mut current_task) = self.current_task {
602 current_task.end_time = Some(Utc::now());
603 self.task_history.push(current_task.clone());
604 self.task_boundaries.push(self.examples_seen);
605 }
606
607 if self.config.memory_config.consolidation.enabled {
609 self.consolidate_memory()?;
610 }
611
612 if self
614 .config
615 .regularization_config
616 .methods
617 .contains(&RegularizationMethod::EWC)
618 {
619 self.compute_ewc_state()?;
620 }
621
622 if matches!(
624 self.config.architecture_config.adaptation_method,
625 ArchitectureAdaptation::Progressive
626 ) {
627 self.add_network_column()?;
628 }
629
630 let mut new_task = TaskInfo::new(task_id.clone(), task_type);
632 new_task.task_embedding = Some(self.generate_task_embedding(&task_id)?);
633 self.current_task = Some(new_task);
634
635 Ok(())
636 }
637
638 pub async fn add_example(
640 &mut self,
641 data: Array1<f32>,
642 target: Array1<f32>,
643 task_id: Option<String>,
644 ) -> Result<()> {
645 let task_id = task_id.unwrap_or_else(|| {
646 self.current_task
647 .as_ref()
648 .map(|t| t.task_id.clone())
649 .unwrap_or_else(|| "default".to_string())
650 });
651
652 if matches!(
654 self.config.task_config.detection_method,
655 TaskDetection::Automatic
656 ) && self.detect_task_boundary(&data)?
657 {
658 let task_num = self.task_history.len() + 1;
659 let new_task_id = format!("task_{task_num}");
660 self.start_task(new_task_id.clone(), "automatic".to_string())?;
661 }
662
663 if self.embeddings.nrows() == 0 {
665 let input_dim = data.len();
666 let output_dim = target.len();
667 self.embeddings = Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
668 let mut random = Random::default();
669 (random.random::<f64>() as f32 - 0.5) * 0.1
670 });
671 self.synaptic_importance = Array2::zeros((output_dim, input_dim));
672 self.parameter_trajectory = Array2::zeros((output_dim, input_dim));
673 }
674
675 self.add_to_memory(data.clone(), target.clone(), task_id.clone())?;
677
678 if let Some(ref mut current_task) = self.current_task {
680 current_task.examples_seen += 1;
681 }
682
683 self.examples_seen += 1;
684
685 self.continual_update(data, target, task_id).await?;
687
688 Ok(())
689 }
690
691 fn add_to_memory(
693 &mut self,
694 data: Array1<f32>,
695 target: Array1<f32>,
696 task_id: String,
697 ) -> Result<()> {
698 let mut random = Random::default();
699 let entry = MemoryEntry::new(data, target, task_id);
700
701 match self.config.memory_config.update_strategy {
702 MemoryUpdateStrategy::FIFO => {
703 if self.episodic_memory.len() >= self.config.memory_config.memory_capacity {
704 self.episodic_memory.pop_front();
705 }
706 self.episodic_memory.push_back(entry);
707 }
708 MemoryUpdateStrategy::Random => {
709 if self.episodic_memory.len() >= self.config.memory_config.memory_capacity {
710 let idx = random.random_range(0..self.episodic_memory.len());
711 self.episodic_memory.remove(idx);
712 }
713 self.episodic_memory.push_back(entry);
714 }
715 MemoryUpdateStrategy::ReservoirSampling => {
716 if self.episodic_memory.len() < self.config.memory_config.memory_capacity {
717 self.episodic_memory.push_back(entry);
718 } else {
719 let k = self.episodic_memory.len();
720 let j = random.random_range(0..self.examples_seen + 1);
721 if j < k {
722 self.episodic_memory[j] = entry;
723 }
724 }
725 }
726 MemoryUpdateStrategy::ImportanceBased => {
727 self.add_by_importance(entry)?;
728 }
729 _ => {
730 self.episodic_memory.push_back(entry);
731 }
732 }
733
734 Ok(())
735 }
736
737 fn add_by_importance(&mut self, entry: MemoryEntry) -> Result<()> {
739 if self.episodic_memory.len() < self.config.memory_config.memory_capacity {
740 self.episodic_memory.push_back(entry);
741 } else {
742 let mut min_importance = f32::INFINITY;
744 let mut min_idx = 0;
745
746 for (i, existing_entry) in self.episodic_memory.iter().enumerate() {
747 if existing_entry.importance < min_importance {
748 min_importance = existing_entry.importance;
749 min_idx = i;
750 }
751 }
752
753 if entry.importance > min_importance {
755 self.episodic_memory[min_idx] = entry;
756 }
757 }
758
759 Ok(())
760 }
761
762 fn detect_task_boundary(&self, data: &Array1<f32>) -> Result<bool> {
764 match self.config.task_config.boundary_detection {
765 BoundaryDetection::ChangePoint => self.detect_change_point(data),
766 BoundaryDetection::DistributionShift => self.detect_distribution_shift(data),
767 BoundaryDetection::LossBased => self.detect_loss_change(data),
768 BoundaryDetection::GradientBased => self.detect_gradient_change(data),
769 }
770 }
771
772 fn detect_change_point(&self, _data: &Array1<f32>) -> Result<bool> {
774 if self.examples_seen % 1000 == 0 && self.examples_seen > 0 {
777 Ok(true)
778 } else {
779 Ok(false)
780 }
781 }
782
783 fn detect_distribution_shift(&self, data: &Array1<f32>) -> Result<bool> {
785 if self.episodic_memory.is_empty() {
786 return Ok(false);
787 }
788
789 let recent_count = 100.min(self.episodic_memory.len());
791 let mut total_distance = 0.0;
792
793 for i in 0..recent_count {
794 let idx = self.episodic_memory.len() - 1 - i;
795 let recent_data = &self.episodic_memory[idx].data;
796 let distance = self.euclidean_distance(data, recent_data);
797 total_distance += distance;
798 }
799
800 let average_distance = total_distance / recent_count as f32;
801 let threshold = 2.0; Ok(average_distance > threshold)
804 }
805
806 fn detect_loss_change(&self, _data: &Array1<f32>) -> Result<bool> {
808 Ok(false)
810 }
811
812 fn detect_gradient_change(&self, _data: &Array1<f32>) -> Result<bool> {
814 Ok(false)
816 }
817
818 async fn continual_update(
820 &mut self,
821 data: Array1<f32>,
822 target: Array1<f32>,
823 _task_id: String,
824 ) -> Result<()> {
825 let gradients = self.compute_gradients(&data, &target)?;
827
828 let regularized_gradients = self.apply_regularization(gradients)?;
830
831 self.update_parameters(regularized_gradients)?;
833
834 if self
836 .config
837 .regularization_config
838 .methods
839 .contains(&RegularizationMethod::SynapticIntelligence)
840 {
841 self.update_synaptic_importance(&data, &target)?;
842 }
843
844 if self
846 .config
847 .replay_config
848 .methods
849 .contains(&ReplayMethod::ExperienceReplay)
850 {
851 self.experience_replay().await?;
852 }
853
854 if self
856 .config
857 .replay_config
858 .methods
859 .contains(&ReplayMethod::GenerativeReplay)
860 {
861 self.generative_replay().await?;
862 }
863
864 Ok(())
865 }
866
867 fn compute_gradients(&self, data: &Array1<f32>, target: &Array1<f32>) -> Result<Array2<f32>> {
869 let dimensions = self.config.base_config.dimensions;
870 let mut gradients = Array2::zeros((1, dimensions));
871
872 if self.embeddings.nrows() == 0 {
874 return Ok(gradients);
877 }
878
879 let prediction = self.forward_pass(data)?;
881
882 let error = target - &prediction;
884
885 for i in 0..dimensions.min(data.len()) {
887 gradients[[0, i]] = error[i] * data[i];
888 }
889
890 Ok(gradients)
891 }
892
893 fn apply_regularization(&self, mut gradients: Array2<f32>) -> Result<Array2<f32>> {
895 for method in &self.config.regularization_config.methods {
896 match method {
897 RegularizationMethod::EWC => {
898 gradients = self.apply_ewc_regularization(gradients)?;
899 }
900 RegularizationMethod::SynapticIntelligence => {
901 gradients = self.apply_si_regularization(gradients)?;
902 }
903 RegularizationMethod::LwF => {
904 gradients = self.apply_lwf_regularization(gradients)?;
905 }
906 _ => {}
907 }
908 }
909
910 Ok(gradients)
911 }
912
913 fn apply_ewc_regularization(&self, mut gradients: Array2<f32>) -> Result<Array2<f32>> {
915 let lambda = self.config.regularization_config.ewc_config.lambda;
916
917 for ewc_state in &self.ewc_states {
918 let penalty = &ewc_state.fisher_information
919 * (&self.embeddings - &ewc_state.optimal_parameters)
920 * lambda
921 * ewc_state.importance;
922
923 let rows_to_update = gradients.nrows().min(penalty.nrows());
925 let cols_to_update = gradients.ncols().min(penalty.ncols());
926
927 for i in 0..rows_to_update {
928 for j in 0..cols_to_update {
929 gradients[[i, j]] -= penalty[[i, j]];
930 }
931 }
932 }
933
934 Ok(gradients)
935 }
936
937 fn apply_si_regularization(&self, mut gradients: Array2<f32>) -> Result<Array2<f32>> {
939 let c = self.config.regularization_config.si_config.c;
940
941 if !self.synaptic_importance.is_empty() {
942 let penalty = &self.synaptic_importance * c;
943
944 let rows_to_update = gradients.nrows().min(penalty.nrows());
945 let cols_to_update = gradients.ncols().min(penalty.ncols());
946
947 for i in 0..rows_to_update {
948 for j in 0..cols_to_update {
949 gradients[[i, j]] -= penalty[[i, j]];
950 }
951 }
952 }
953
954 Ok(gradients)
955 }
956
957 fn apply_lwf_regularization(&self, gradients: Array2<f32>) -> Result<Array2<f32>> {
959 Ok(gradients)
962 }
963
964 fn update_parameters(&mut self, gradients: Array2<f32>) -> Result<()> {
966 let learning_rate = 0.01; if self.embeddings.nrows() < gradients.nrows() {
970 let dimensions = self.config.base_config.dimensions;
971 let new_rows = gradients.nrows();
972 let mut random = Random::default();
973 self.embeddings =
974 Array2::from_shape_fn((new_rows, dimensions), |_| random.random::<f32>() * 0.1);
975 }
976
977 let rows_to_update = gradients.nrows().min(self.embeddings.nrows());
979 let cols_to_update = gradients.ncols().min(self.embeddings.ncols());
980
981 for i in 0..rows_to_update {
982 for j in 0..cols_to_update {
983 self.embeddings[[i, j]] += learning_rate * gradients[[i, j]];
984 }
985 }
986
987 Ok(())
988 }
989
990 fn update_synaptic_importance(
992 &mut self,
993 data: &Array1<f32>,
994 target: &Array1<f32>,
995 ) -> Result<()> {
996 let xi = self.config.regularization_config.si_config.xi;
997 let damping = self.config.regularization_config.si_config.damping;
998
999 let gradients = self.compute_gradients(data, target)?;
1001
1002 if self.synaptic_importance.is_empty() {
1004 self.synaptic_importance = Array2::zeros(gradients.dim());
1005 }
1006
1007 let rows_to_update = gradients.nrows().min(self.synaptic_importance.nrows());
1008 let cols_to_update = gradients.ncols().min(self.synaptic_importance.ncols());
1009
1010 for i in 0..rows_to_update {
1011 for j in 0..cols_to_update {
1012 self.synaptic_importance[[i, j]] =
1013 damping * self.synaptic_importance[[i, j]] + xi * gradients[[i, j]].abs();
1014 }
1015 }
1016
1017 Ok(())
1018 }
1019
1020 fn forward_pass(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
1022 if self.embeddings.is_empty() {
1023 return Ok(Array1::zeros(input.len()));
1024 }
1025
1026 let network = if matches!(
1028 self.config.architecture_config.adaptation_method,
1029 ArchitectureAdaptation::Progressive
1030 ) {
1031 &self.network_columns[self.network_columns.len() - 1]
1032 } else {
1033 &self.embeddings
1034 };
1035
1036 let input_len = input.len().min(network.ncols());
1038 let output_len = network.nrows();
1039 let mut output = Array1::zeros(output_len);
1040
1041 for i in 0..output_len {
1042 let mut sum = 0.0;
1043 for j in 0..input_len {
1044 sum += network[[i, j]] * input[j];
1045 }
1046 output[i] = sum.tanh(); }
1048
1049 Ok(output)
1050 }
1051
1052 async fn experience_replay(&mut self) -> Result<()> {
1054 if self.episodic_memory.is_empty() {
1055 return Ok(());
1056 }
1057
1058 let mut random = Random::default();
1059 let replay_batch_size = (self.config.replay_config.replay_ratio * 32.0) as usize;
1060 let batch_size = replay_batch_size.min(self.episodic_memory.len());
1061
1062 for _ in 0..batch_size {
1063 let idx = random.random_range(0..self.episodic_memory.len());
1064
1065 let (data, target) = {
1067 let entry = &self.episodic_memory[idx];
1068 (entry.data.clone(), entry.target.clone())
1069 };
1070
1071 self.episodic_memory[idx].access_count += 1;
1073
1074 let gradients = self.compute_gradients(&data, &target)?;
1076 let regularized_gradients = self.apply_regularization(gradients)?;
1077 self.update_parameters(regularized_gradients)?;
1078 }
1079
1080 Ok(())
1081 }
1082
1083 async fn generative_replay(&mut self) -> Result<()> {
1085 if let Some(ref generator) = self.generator {
1086 let _replay_batch_size = (self.config.replay_config.replay_ratio * 32.0) as usize;
1087 let _generator_clone = generator.clone();
1088
1089 }
1091
1092 if let Some(generator) = self.generator.clone() {
1093 let replay_batch_size = (self.config.replay_config.replay_ratio * 32.0) as usize;
1094
1095 for _ in 0..replay_batch_size {
1096 let mut random = Random::default();
1098 let noise = Array1::from_shape_fn(generator.ncols(), |_| random.random::<f32>());
1099 let generated_data = generator.dot(&noise);
1100
1101 let generated_target = generated_data.mapv(|x| x.tanh());
1103
1104 let gradients = self.compute_gradients(&generated_data, &generated_target)?;
1106 let regularized_gradients = self.apply_regularization(gradients)?;
1107 self.update_parameters(regularized_gradients)?;
1108 }
1109 }
1110
1111 Ok(())
1112 }
1113
1114 fn compute_ewc_state(&mut self) -> Result<()> {
1116 if let Some(ref current_task) = self.current_task {
1117 let _dimensions = self.config.base_config.dimensions;
1118 let mut fisher_information = Array2::zeros(self.embeddings.dim());
1119
1120 for entry in &self.episodic_memory {
1122 if entry.task_id == current_task.task_id {
1123 let gradients = self.compute_gradients(&entry.data, &entry.target)?;
1124
1125 let rows_to_update = gradients.nrows().min(fisher_information.nrows());
1126 let cols_to_update = gradients.ncols().min(fisher_information.ncols());
1127
1128 for i in 0..rows_to_update {
1129 for j in 0..cols_to_update {
1130 fisher_information[[i, j]] += gradients[[i, j]] * gradients[[i, j]];
1131 }
1132 }
1133 }
1134 }
1135
1136 let task_examples = self
1138 .episodic_memory
1139 .iter()
1140 .filter(|entry| entry.task_id == current_task.task_id)
1141 .count() as f32;
1142
1143 if task_examples > 0.0 {
1144 fisher_information /= task_examples;
1145 }
1146
1147 let ewc_state = EWCState {
1148 fisher_information,
1149 optimal_parameters: self.embeddings.clone(),
1150 task_id: current_task.task_id.clone(),
1151 importance: 1.0,
1152 };
1153
1154 self.ewc_states.push(ewc_state);
1155 }
1156
1157 Ok(())
1158 }
1159
1160 fn add_network_column(&mut self) -> Result<()> {
1162 let dimensions = self.config.base_config.dimensions;
1163 let mut random = Random::default();
1164 let new_column =
1165 Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1);
1166 self.network_columns.push(new_column);
1167
1168 if self.network_columns.len() > 1 {
1170 let lateral_connection = Array2::from_shape_fn((dimensions, dimensions), |_| {
1171 random.random::<f32>()
1172 * self
1173 .config
1174 .architecture_config
1175 .progressive_config
1176 .lateral_strength
1177 });
1178 self.lateral_connections.push(lateral_connection);
1179 }
1180
1181 Ok(())
1182 }
1183
1184 fn generate_task_embedding(&self, task_id: &str) -> Result<Array1<f32>> {
1186 let dimensions = self.config.base_config.dimensions;
1187 let mut task_embedding = Array1::zeros(dimensions);
1188
1189 for (i, byte) in task_id.bytes().enumerate() {
1191 if i >= dimensions {
1192 break;
1193 }
1194 task_embedding[i] = (byte as f32) / 255.0;
1195 }
1196
1197 Ok(task_embedding)
1198 }
1199
1200 fn consolidate_memory(&mut self) -> Result<()> {
1202 if !self.config.memory_config.consolidation.enabled {
1203 return Ok(());
1204 }
1205
1206 let mut random = Random::default();
1207 let strength = self.config.memory_config.consolidation.strength;
1208
1209 for entry in &mut self.episodic_memory {
1211 entry.importance *= 1.0 + strength * entry.access_count as f32;
1212 }
1213
1214 let consolidation_steps = 100;
1216 for _ in 0..consolidation_steps {
1217 if !self.episodic_memory.is_empty() {
1218 let idx = random.random_range(0..self.episodic_memory.len());
1219 let entry = &self.episodic_memory[idx];
1220
1221 let weak_gradients = self.compute_gradients(&entry.data, &entry.target)? * 0.1;
1223 self.update_parameters(weak_gradients)?;
1224 }
1225 }
1226
1227 Ok(())
1228 }
1229
1230 pub fn get_task_performance(&self) -> HashMap<String, f32> {
1232 let mut performance = HashMap::new();
1233
1234 for task in &self.task_history {
1235 performance.insert(task.task_id.clone(), task.performance);
1236 }
1237
1238 if let Some(ref current_task) = self.current_task {
1239 performance.insert(current_task.task_id.clone(), current_task.performance);
1240 }
1241
1242 performance
1243 }
1244
1245 pub fn evaluate_forgetting(&self) -> f32 {
1247 if self.task_history.len() < 2 {
1248 return 0.0;
1249 }
1250
1251 let mut total_forgetting = 0.0;
1252 let mut task_count = 0;
1253
1254 for (i, task) in self.task_history.iter().enumerate() {
1255 if i > 0 {
1256 let initial_performance = task.performance;
1257 let current_performance = self.evaluate_task_performance(&task.task_id);
1258 let forgetting = initial_performance - current_performance;
1259 total_forgetting += forgetting;
1260 task_count += 1;
1261 }
1262 }
1263
1264 if task_count > 0 {
1265 total_forgetting / task_count as f32
1266 } else {
1267 0.0
1268 }
1269 }
1270
1271 fn evaluate_task_performance(&self, _task_id: &str) -> f32 {
1273 let mut random = Random::default();
1275 random.random::<f32>() * 0.1 + 0.8
1276 }
1277
1278 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
1280 let min_len = a.len().min(b.len());
1281 let mut sum = 0.0;
1282
1283 for i in 0..min_len {
1284 let diff = a[i] - b[i];
1285 sum += diff * diff;
1286 }
1287
1288 sum.sqrt()
1289 }
1290}
1291
1292#[async_trait]
1293impl EmbeddingModel for ContinualLearningModel {
1294 fn config(&self) -> &ModelConfig {
1295 &self.config.base_config
1296 }
1297
1298 fn model_id(&self) -> &Uuid {
1299 &self.model_id
1300 }
1301
1302 fn model_type(&self) -> &'static str {
1303 "ContinualLearningModel"
1304 }
1305
1306 fn add_triple(&mut self, triple: Triple) -> Result<()> {
1307 let subject_str = triple.subject.iri.clone();
1308 let predicate_str = triple.predicate.iri.clone();
1309 let object_str = triple.object.iri.clone();
1310
1311 let next_entity_id = self.entities.len();
1313 self.entities.entry(subject_str).or_insert(next_entity_id);
1314 let next_entity_id = self.entities.len();
1315 self.entities.entry(object_str).or_insert(next_entity_id);
1316
1317 let next_relation_id = self.relations.len();
1319 self.relations
1320 .entry(predicate_str)
1321 .or_insert(next_relation_id);
1322
1323 Ok(())
1324 }
1325
1326 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
1327 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
1328 let start_time = std::time::Instant::now();
1329
1330 let mut loss_history = Vec::new();
1331
1332 for epoch in 0..epochs {
1333 let mut random = Random::default();
1335 let epoch_loss = 0.1 * random.random::<f64>();
1336 loss_history.push(epoch_loss);
1337
1338 if epoch % 5 == 0 && epoch > 0 {
1340 let task_num = epoch / 5;
1341 let task_id = format!("task_{task_num}");
1342 self.start_task(task_id, "training".to_string())?;
1343 }
1344
1345 if epoch > 10 && epoch_loss < 1e-6 {
1346 break;
1347 }
1348 }
1349
1350 let training_time = start_time.elapsed().as_secs_f64();
1351 let final_loss = loss_history.last().copied().unwrap_or(0.0);
1352
1353 let stats = TrainingStats {
1354 epochs_completed: loss_history.len(),
1355 final_loss,
1356 training_time_seconds: training_time,
1357 convergence_achieved: final_loss < 1e-4,
1358 loss_history,
1359 };
1360
1361 self.training_stats = Some(stats.clone());
1362 self.is_trained = true;
1363
1364 Ok(stats)
1365 }
1366
1367 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
1368 if let Some(&entity_id) = self.entities.get(entity) {
1369 if entity_id < self.embeddings.nrows() {
1370 let embedding = self.embeddings.row(entity_id);
1371 return Ok(Vector::new(embedding.to_vec()));
1372 }
1373 }
1374 Err(anyhow!("Entity not found: {}", entity))
1375 }
1376
1377 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
1378 if let Some(&relation_id) = self.relations.get(relation) {
1379 if relation_id < self.embeddings.nrows() {
1380 let embedding = self.embeddings.row(relation_id);
1381 return Ok(Vector::new(embedding.to_vec()));
1382 }
1383 }
1384 Err(anyhow!("Relation not found: {}", relation))
1385 }
1386
1387 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
1388 let subject_emb = self.get_entity_embedding(subject)?;
1389 let predicate_emb = self.get_relation_embedding(predicate)?;
1390 let object_emb = self.get_entity_embedding(object)?;
1391
1392 let subject_arr = Array1::from_vec(subject_emb.values);
1394 let predicate_arr = Array1::from_vec(predicate_emb.values);
1395 let object_arr = Array1::from_vec(object_emb.values);
1396
1397 let predicted = &subject_arr + &predicate_arr;
1398 let diff = &predicted - &object_arr;
1399 let distance = diff.dot(&diff).sqrt();
1400
1401 Ok(-distance as f64)
1402 }
1403
1404 fn predict_objects(
1405 &self,
1406 subject: &str,
1407 predicate: &str,
1408 k: usize,
1409 ) -> Result<Vec<(String, f64)>> {
1410 let mut scores = Vec::new();
1411
1412 for entity in self.entities.keys() {
1413 if entity != subject {
1414 let score = self.score_triple(subject, predicate, entity)?;
1415 scores.push((entity.clone(), score));
1416 }
1417 }
1418
1419 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1420 scores.truncate(k);
1421
1422 Ok(scores)
1423 }
1424
1425 fn predict_subjects(
1426 &self,
1427 predicate: &str,
1428 object: &str,
1429 k: usize,
1430 ) -> Result<Vec<(String, f64)>> {
1431 let mut scores = Vec::new();
1432
1433 for entity in self.entities.keys() {
1434 if entity != object {
1435 let score = self.score_triple(entity, predicate, object)?;
1436 scores.push((entity.clone(), score));
1437 }
1438 }
1439
1440 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1441 scores.truncate(k);
1442
1443 Ok(scores)
1444 }
1445
1446 fn predict_relations(
1447 &self,
1448 subject: &str,
1449 object: &str,
1450 k: usize,
1451 ) -> Result<Vec<(String, f64)>> {
1452 let mut scores = Vec::new();
1453
1454 for relation in self.relations.keys() {
1455 let score = self.score_triple(subject, relation, object)?;
1456 scores.push((relation.clone(), score));
1457 }
1458
1459 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1460 scores.truncate(k);
1461
1462 Ok(scores)
1463 }
1464
1465 fn get_entities(&self) -> Vec<String> {
1466 self.entities.keys().cloned().collect()
1467 }
1468
1469 fn get_relations(&self) -> Vec<String> {
1470 self.relations.keys().cloned().collect()
1471 }
1472
1473 fn get_stats(&self) -> crate::ModelStats {
1474 crate::ModelStats {
1475 num_entities: self.entities.len(),
1476 num_relations: self.relations.len(),
1477 num_triples: 0,
1478 dimensions: self.config.base_config.dimensions,
1479 is_trained: self.is_trained,
1480 model_type: self.model_type().to_string(),
1481 creation_time: Utc::now(),
1482 last_training_time: if self.is_trained {
1483 Some(Utc::now())
1484 } else {
1485 None
1486 },
1487 }
1488 }
1489
1490 fn save(&self, _path: &str) -> Result<()> {
1491 Ok(())
1492 }
1493
1494 fn load(&mut self, _path: &str) -> Result<()> {
1495 Ok(())
1496 }
1497
1498 fn clear(&mut self) {
1499 self.entities.clear();
1500 self.relations.clear();
1501 self.embeddings = Array2::zeros((0, self.config.base_config.dimensions));
1502 self.episodic_memory.clear();
1503 self.semantic_memory.clear();
1504 self.ewc_states.clear();
1505 self.task_history.clear();
1506 self.current_task = None;
1507 self.examples_seen = 0;
1508 self.is_trained = false;
1509 self.training_stats = None;
1510 }
1511
1512 fn is_trained(&self) -> bool {
1513 self.is_trained
1514 }
1515
1516 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1517 let mut results = Vec::new();
1518
1519 for text in texts {
1520 let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
1521 for (i, c) in text.chars().enumerate() {
1522 if i >= self.config.base_config.dimensions {
1523 break;
1524 }
1525 embedding[i] = (c as u8 as f32) / 255.0;
1526 }
1527 results.push(embedding);
1528 }
1529
1530 Ok(results)
1531 }
1532}
1533
1534#[cfg(test)]
1535mod tests {
1536 use super::*;
1537
1538 #[test]
1539 fn test_continual_learning_config_default() {
1540 let config = ContinualLearningConfig::default();
1541 assert!(matches!(
1542 config.memory_config.memory_type,
1543 MemoryType::EpisodicMemory
1544 ));
1545 assert_eq!(config.memory_config.memory_capacity, 10000);
1546 }
1547
1548 #[test]
1549 fn test_task_info_creation() {
1550 let task = TaskInfo::new("task1".to_string(), "classification".to_string());
1551 assert_eq!(task.task_id, "task1");
1552 assert_eq!(task.task_type, "classification");
1553 assert_eq!(task.examples_seen, 0);
1554 }
1555
1556 #[test]
1557 fn test_memory_entry_creation() {
1558 let data = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1559 let target = Array1::from_vec(vec![0.0, 1.0]);
1560 let entry = MemoryEntry::new(data, target, "task1".to_string());
1561
1562 assert_eq!(entry.task_id, "task1");
1563 assert_eq!(entry.importance, 1.0);
1564 assert_eq!(entry.access_count, 0);
1565 }
1566
1567 #[test]
1568 fn test_continual_learning_model_creation() {
1569 let config = ContinualLearningConfig::default();
1570 let model = ContinualLearningModel::new(config);
1571
1572 assert_eq!(model.entities.len(), 0);
1573 assert_eq!(model.examples_seen, 0);
1574 assert!(model.current_task.is_none());
1575 }
1576
1577 #[tokio::test]
1578 async fn test_task_management() {
1579 let config = ContinualLearningConfig::default();
1580 let mut model = ContinualLearningModel::new(config);
1581
1582 model
1583 .start_task("task1".to_string(), "test".to_string())
1584 .unwrap();
1585 assert!(model.current_task.is_some());
1586 assert_eq!(model.current_task.as_ref().unwrap().task_id, "task1");
1587
1588 model
1589 .start_task("task2".to_string(), "test".to_string())
1590 .unwrap();
1591 assert_eq!(model.task_history.len(), 1);
1592 assert_eq!(model.current_task.as_ref().unwrap().task_id, "task2");
1593 }
1594
1595 #[tokio::test]
1596 async fn test_add_example() {
1597 let config = ContinualLearningConfig {
1598 base_config: ModelConfig {
1599 dimensions: 3, ..Default::default()
1601 },
1602 ..Default::default()
1603 };
1604 let mut model = ContinualLearningModel::new(config);
1605
1606 model
1607 .start_task("task1".to_string(), "test".to_string())
1608 .unwrap();
1609
1610 let data = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1611 let target = Array1::from_vec(vec![1.0, 2.0, 3.0]); model
1614 .add_example(data, target, Some("task1".to_string()))
1615 .await
1616 .unwrap();
1617
1618 assert_eq!(model.examples_seen, 1);
1619 assert_eq!(model.episodic_memory.len(), 1);
1620 assert_eq!(model.current_task.as_ref().unwrap().examples_seen, 1);
1621 }
1622
1623 #[tokio::test]
1624 async fn test_memory_management() {
1625 let config = ContinualLearningConfig {
1626 memory_config: MemoryConfig {
1627 memory_capacity: 3,
1628 update_strategy: MemoryUpdateStrategy::FIFO,
1629 ..Default::default()
1630 },
1631 ..Default::default()
1632 };
1633
1634 let mut model = ContinualLearningModel::new(config);
1635 model
1636 .start_task("task1".to_string(), "test".to_string())
1637 .unwrap();
1638
1639 for i in 0..5 {
1641 let data = Array1::from_vec(vec![i as f32]);
1642 let target = Array1::from_vec(vec![i as f32]);
1643 model
1644 .add_example(data, target, Some("task1".to_string()))
1645 .await
1646 .unwrap();
1647 }
1648
1649 assert_eq!(model.episodic_memory.len(), 3); }
1651
1652 #[tokio::test]
1653 async fn test_continual_training() {
1654 let config = ContinualLearningConfig {
1655 base_config: ModelConfig {
1656 dimensions: 3, max_epochs: 10,
1658 ..Default::default()
1659 },
1660 ..Default::default()
1661 };
1662 let mut model = ContinualLearningModel::new(config);
1663
1664 model
1666 .start_task("initial_task".to_string(), "training".to_string())
1667 .unwrap();
1668
1669 let stats = model.train(Some(10)).await.unwrap();
1670 assert_eq!(stats.epochs_completed, 10);
1671 assert!(model.is_trained());
1672 assert!(!model.task_history.is_empty()); }
1674
1675 #[test]
1676 fn test_forgetting_evaluation() {
1677 let config = ContinualLearningConfig::default();
1678 let model = ContinualLearningModel::new(config);
1679
1680 let forgetting = model.evaluate_forgetting();
1681 assert_eq!(forgetting, 0.0); }
1683
1684 #[test]
1685 fn test_ewc_state_creation() {
1686 let mut random = Random::default();
1687 let fisher = Array2::from_shape_fn((5, 5), |_| random.random::<f32>());
1688 let params = Array2::from_shape_fn((5, 5), |_| random.random::<f32>());
1689
1690 let ewc_state = EWCState {
1691 fisher_information: fisher,
1692 optimal_parameters: params,
1693 task_id: "task1".to_string(),
1694 importance: 1.0,
1695 };
1696
1697 assert_eq!(ewc_state.task_id, "task1");
1698 assert_eq!(ewc_state.importance, 1.0);
1699 }
1700}