Skip to main content

oxirs_embed/
continual_learning_types.rs

1use crate::{ModelConfig, TrainingStats};
2use chrono::{DateTime, Utc};
3use scirs2_core::ndarray_ext::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct ContinualLearningConfig {
10    pub base_config: ModelConfig,
11    pub memory_config: MemoryConfig,
12    pub regularization_config: RegularizationConfig,
13    pub architecture_config: ArchitectureConfig,
14    pub task_config: TaskConfig,
15    pub replay_config: ReplayConfig,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct MemoryConfig {
20    pub memory_type: MemoryType,
21    pub memory_capacity: usize,
22    pub update_strategy: MemoryUpdateStrategy,
23    pub consolidation: ConsolidationConfig,
24}
25
26impl Default for MemoryConfig {
27    fn default() -> Self {
28        Self {
29            memory_type: MemoryType::EpisodicMemory,
30            memory_capacity: 10000,
31            update_strategy: MemoryUpdateStrategy::ReservoirSampling,
32            consolidation: ConsolidationConfig::default(),
33        }
34    }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum MemoryType {
39    EpisodicMemory,
40    SemanticMemory,
41    WorkingMemory,
42    ProceduralMemory,
43    HybridMemory,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum MemoryUpdateStrategy {
48    FIFO,
49    Random,
50    ReservoirSampling,
51    ImportanceBased,
52    GradientBased,
53    ClusteringBased,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ConsolidationConfig {
58    pub enabled: bool,
59    pub frequency: usize,
60    pub strength: f32,
61    pub sleep_consolidation: bool,
62}
63
64impl Default for ConsolidationConfig {
65    fn default() -> Self {
66        Self {
67            enabled: true,
68            frequency: 1000,
69            strength: 0.1,
70            sleep_consolidation: false,
71        }
72    }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RegularizationConfig {
77    pub methods: Vec<RegularizationMethod>,
78    pub ewc_config: EWCConfig,
79    pub si_config: SynapticIntelligenceConfig,
80    pub lwf_config: LwFConfig,
81}
82
83impl Default for RegularizationConfig {
84    fn default() -> Self {
85        Self {
86            methods: vec![
87                RegularizationMethod::EWC,
88                RegularizationMethod::SynapticIntelligence,
89            ],
90            ewc_config: EWCConfig::default(),
91            si_config: SynapticIntelligenceConfig::default(),
92            lwf_config: LwFConfig::default(),
93        }
94    }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
98pub enum RegularizationMethod {
99    EWC,
100    SynapticIntelligence,
101    LwF,
102    MAS,
103    RiemannianWalk,
104    PackNet,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct EWCConfig {
109    pub lambda: f32,
110    pub fisher_method: FisherMethod,
111    pub online: bool,
112    pub gamma: f32,
113}
114
115impl Default for EWCConfig {
116    fn default() -> Self {
117        Self {
118            lambda: 0.4,
119            fisher_method: FisherMethod::Empirical,
120            online: true,
121            gamma: 1.0,
122        }
123    }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub enum FisherMethod {
128    Empirical,
129    True,
130    Diagonal,
131    BlockDiagonal,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SynapticIntelligenceConfig {
136    pub c: f32,
137    pub xi: f32,
138    pub damping: f32,
139}
140
141impl Default for SynapticIntelligenceConfig {
142    fn default() -> Self {
143        Self {
144            c: 0.1,
145            xi: 1.0,
146            damping: 0.1,
147        }
148    }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct LwFConfig {
153    pub alpha: f32,
154    pub temperature: f32,
155    pub attention_transfer: bool,
156}
157
158impl Default for LwFConfig {
159    fn default() -> Self {
160        Self {
161            alpha: 1.0,
162            temperature: 4.0,
163            attention_transfer: false,
164        }
165    }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ArchitectureConfig {
170    pub adaptation_method: ArchitectureAdaptation,
171    pub progressive_config: ProgressiveConfig,
172    pub dynamic_config: DynamicConfig,
173}
174
175impl Default for ArchitectureConfig {
176    fn default() -> Self {
177        Self {
178            adaptation_method: ArchitectureAdaptation::Progressive,
179            progressive_config: ProgressiveConfig::default(),
180            dynamic_config: DynamicConfig::default(),
181        }
182    }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub enum ArchitectureAdaptation {
187    Progressive,
188    Dynamic,
189    PackNet,
190    HAT,
191    Supermasks,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct ProgressiveConfig {
196    pub columns_per_task: usize,
197    pub lateral_strength: f32,
198    pub column_capacity: usize,
199}
200
201impl Default for ProgressiveConfig {
202    fn default() -> Self {
203        Self {
204            columns_per_task: 1,
205            lateral_strength: 0.5,
206            column_capacity: 1000,
207        }
208    }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct DynamicConfig {
213    pub expansion_threshold: f32,
214    pub pruning_threshold: f32,
215    pub growth_rate: f32,
216    pub max_size: usize,
217}
218
219impl Default for DynamicConfig {
220    fn default() -> Self {
221        Self {
222            expansion_threshold: 0.9,
223            pruning_threshold: 0.1,
224            growth_rate: 0.1,
225            max_size: 100000,
226        }
227    }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct TaskConfig {
232    pub detection_method: TaskDetection,
233    pub boundary_detection: BoundaryDetection,
234    pub switching_strategy: TaskSwitching,
235}
236
237impl Default for TaskConfig {
238    fn default() -> Self {
239        Self {
240            detection_method: TaskDetection::Automatic,
241            boundary_detection: BoundaryDetection::ChangePoint,
242            switching_strategy: TaskSwitching::Soft,
243        }
244    }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub enum TaskDetection {
249    Manual,
250    Automatic,
251    Oracle,
252    Clustering,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub enum BoundaryDetection {
257    ChangePoint,
258    DistributionShift,
259    LossBased,
260    GradientBased,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub enum TaskSwitching {
265    Hard,
266    Soft,
267    Attention,
268    Gating,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ReplayConfig {
273    pub methods: Vec<ReplayMethod>,
274    pub buffer_size: usize,
275    pub replay_ratio: f32,
276    pub generative_config: GenerativeReplayConfig,
277}
278
279impl Default for ReplayConfig {
280    fn default() -> Self {
281        Self {
282            methods: vec![
283                ReplayMethod::ExperienceReplay,
284                ReplayMethod::GenerativeReplay,
285            ],
286            buffer_size: 5000,
287            replay_ratio: 0.5,
288            generative_config: GenerativeReplayConfig::default(),
289        }
290    }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
294pub enum ReplayMethod {
295    ExperienceReplay,
296    GenerativeReplay,
297    PseudoRehearsal,
298    MetaReplay,
299    GradientEpisodicMemory,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct GenerativeReplayConfig {
304    pub generator_type: GeneratorType,
305    pub quality_threshold: f32,
306    pub diversity_weight: f32,
307}
308
309impl Default for GenerativeReplayConfig {
310    fn default() -> Self {
311        Self {
312            generator_type: GeneratorType::VAE,
313            quality_threshold: 0.8,
314            diversity_weight: 0.1,
315        }
316    }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub enum GeneratorType {
321    VAE,
322    GAN,
323    Flow,
324    Diffusion,
325}
326
327#[derive(Debug, Clone)]
328pub struct TaskInfo {
329    pub task_id: String,
330    pub task_type: String,
331    pub start_time: DateTime<Utc>,
332    pub end_time: Option<DateTime<Utc>>,
333    pub examples_seen: usize,
334    pub performance: f32,
335    pub task_embedding: Option<Array1<f32>>,
336}
337
338impl TaskInfo {
339    pub fn new(task_id: String, task_type: String) -> Self {
340        Self {
341            task_id,
342            task_type,
343            start_time: Utc::now(),
344            end_time: None,
345            examples_seen: 0,
346            performance: 0.0,
347            task_embedding: None,
348        }
349    }
350}
351
352#[derive(Debug, Clone)]
353pub struct MemoryEntry {
354    pub data: Array1<f32>,
355    pub target: Array1<f32>,
356    pub task_id: String,
357    pub timestamp: DateTime<Utc>,
358    pub importance: f32,
359    pub access_count: usize,
360}
361
362impl MemoryEntry {
363    pub fn new(data: Array1<f32>, target: Array1<f32>, task_id: String) -> Self {
364        Self {
365            data,
366            target,
367            task_id,
368            timestamp: Utc::now(),
369            importance: 1.0,
370            access_count: 0,
371        }
372    }
373}
374
375#[derive(Debug, Clone)]
376pub struct EWCState {
377    pub fisher_information: Array2<f32>,
378    pub optimal_parameters: Array2<f32>,
379    pub task_id: String,
380    pub importance: f32,
381}
382
383#[derive(Debug)]
384pub struct ContinualLearningModel {
385    pub config: ContinualLearningConfig,
386    pub model_id: Uuid,
387    pub embeddings: Array2<f32>,
388    pub task_specific_embeddings: HashMap<String, Array2<f32>>,
389    pub episodic_memory: VecDeque<MemoryEntry>,
390    pub semantic_memory: HashMap<String, Array1<f32>>,
391    pub ewc_states: Vec<EWCState>,
392    pub synaptic_importance: Array2<f32>,
393    pub parameter_trajectory: Array2<f32>,
394    pub current_task: Option<TaskInfo>,
395    pub task_history: Vec<TaskInfo>,
396    pub task_boundaries: Vec<usize>,
397    pub network_columns: Vec<Array2<f32>>,
398    pub lateral_connections: Vec<Array2<f32>>,
399    pub generator: Option<Array2<f32>>,
400    pub discriminator: Option<Array2<f32>>,
401    pub entities: HashMap<String, usize>,
402    pub relations: HashMap<String, usize>,
403    pub examples_seen: usize,
404    pub training_stats: Option<TrainingStats>,
405    pub is_trained: bool,
406}