1use crate::{ModelConfig, TrainingStats};
2use chrono::{DateTime, Utc};
3use scirs2_core::ndarray_ext::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct ContinualLearningConfig {
10 pub base_config: ModelConfig,
11 pub memory_config: MemoryConfig,
12 pub regularization_config: RegularizationConfig,
13 pub architecture_config: ArchitectureConfig,
14 pub task_config: TaskConfig,
15 pub replay_config: ReplayConfig,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct MemoryConfig {
20 pub memory_type: MemoryType,
21 pub memory_capacity: usize,
22 pub update_strategy: MemoryUpdateStrategy,
23 pub consolidation: ConsolidationConfig,
24}
25
26impl Default for MemoryConfig {
27 fn default() -> Self {
28 Self {
29 memory_type: MemoryType::EpisodicMemory,
30 memory_capacity: 10000,
31 update_strategy: MemoryUpdateStrategy::ReservoirSampling,
32 consolidation: ConsolidationConfig::default(),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum MemoryType {
39 EpisodicMemory,
40 SemanticMemory,
41 WorkingMemory,
42 ProceduralMemory,
43 HybridMemory,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum MemoryUpdateStrategy {
48 FIFO,
49 Random,
50 ReservoirSampling,
51 ImportanceBased,
52 GradientBased,
53 ClusteringBased,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ConsolidationConfig {
58 pub enabled: bool,
59 pub frequency: usize,
60 pub strength: f32,
61 pub sleep_consolidation: bool,
62}
63
64impl Default for ConsolidationConfig {
65 fn default() -> Self {
66 Self {
67 enabled: true,
68 frequency: 1000,
69 strength: 0.1,
70 sleep_consolidation: false,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RegularizationConfig {
77 pub methods: Vec<RegularizationMethod>,
78 pub ewc_config: EWCConfig,
79 pub si_config: SynapticIntelligenceConfig,
80 pub lwf_config: LwFConfig,
81}
82
83impl Default for RegularizationConfig {
84 fn default() -> Self {
85 Self {
86 methods: vec![
87 RegularizationMethod::EWC,
88 RegularizationMethod::SynapticIntelligence,
89 ],
90 ewc_config: EWCConfig::default(),
91 si_config: SynapticIntelligenceConfig::default(),
92 lwf_config: LwFConfig::default(),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
98pub enum RegularizationMethod {
99 EWC,
100 SynapticIntelligence,
101 LwF,
102 MAS,
103 RiemannianWalk,
104 PackNet,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct EWCConfig {
109 pub lambda: f32,
110 pub fisher_method: FisherMethod,
111 pub online: bool,
112 pub gamma: f32,
113}
114
115impl Default for EWCConfig {
116 fn default() -> Self {
117 Self {
118 lambda: 0.4,
119 fisher_method: FisherMethod::Empirical,
120 online: true,
121 gamma: 1.0,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub enum FisherMethod {
128 Empirical,
129 True,
130 Diagonal,
131 BlockDiagonal,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SynapticIntelligenceConfig {
136 pub c: f32,
137 pub xi: f32,
138 pub damping: f32,
139}
140
141impl Default for SynapticIntelligenceConfig {
142 fn default() -> Self {
143 Self {
144 c: 0.1,
145 xi: 1.0,
146 damping: 0.1,
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct LwFConfig {
153 pub alpha: f32,
154 pub temperature: f32,
155 pub attention_transfer: bool,
156}
157
158impl Default for LwFConfig {
159 fn default() -> Self {
160 Self {
161 alpha: 1.0,
162 temperature: 4.0,
163 attention_transfer: false,
164 }
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ArchitectureConfig {
170 pub adaptation_method: ArchitectureAdaptation,
171 pub progressive_config: ProgressiveConfig,
172 pub dynamic_config: DynamicConfig,
173}
174
175impl Default for ArchitectureConfig {
176 fn default() -> Self {
177 Self {
178 adaptation_method: ArchitectureAdaptation::Progressive,
179 progressive_config: ProgressiveConfig::default(),
180 dynamic_config: DynamicConfig::default(),
181 }
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub enum ArchitectureAdaptation {
187 Progressive,
188 Dynamic,
189 PackNet,
190 HAT,
191 Supermasks,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct ProgressiveConfig {
196 pub columns_per_task: usize,
197 pub lateral_strength: f32,
198 pub column_capacity: usize,
199}
200
201impl Default for ProgressiveConfig {
202 fn default() -> Self {
203 Self {
204 columns_per_task: 1,
205 lateral_strength: 0.5,
206 column_capacity: 1000,
207 }
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct DynamicConfig {
213 pub expansion_threshold: f32,
214 pub pruning_threshold: f32,
215 pub growth_rate: f32,
216 pub max_size: usize,
217}
218
219impl Default for DynamicConfig {
220 fn default() -> Self {
221 Self {
222 expansion_threshold: 0.9,
223 pruning_threshold: 0.1,
224 growth_rate: 0.1,
225 max_size: 100000,
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct TaskConfig {
232 pub detection_method: TaskDetection,
233 pub boundary_detection: BoundaryDetection,
234 pub switching_strategy: TaskSwitching,
235}
236
237impl Default for TaskConfig {
238 fn default() -> Self {
239 Self {
240 detection_method: TaskDetection::Automatic,
241 boundary_detection: BoundaryDetection::ChangePoint,
242 switching_strategy: TaskSwitching::Soft,
243 }
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub enum TaskDetection {
249 Manual,
250 Automatic,
251 Oracle,
252 Clustering,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub enum BoundaryDetection {
257 ChangePoint,
258 DistributionShift,
259 LossBased,
260 GradientBased,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub enum TaskSwitching {
265 Hard,
266 Soft,
267 Attention,
268 Gating,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ReplayConfig {
273 pub methods: Vec<ReplayMethod>,
274 pub buffer_size: usize,
275 pub replay_ratio: f32,
276 pub generative_config: GenerativeReplayConfig,
277}
278
279impl Default for ReplayConfig {
280 fn default() -> Self {
281 Self {
282 methods: vec![
283 ReplayMethod::ExperienceReplay,
284 ReplayMethod::GenerativeReplay,
285 ],
286 buffer_size: 5000,
287 replay_ratio: 0.5,
288 generative_config: GenerativeReplayConfig::default(),
289 }
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
294pub enum ReplayMethod {
295 ExperienceReplay,
296 GenerativeReplay,
297 PseudoRehearsal,
298 MetaReplay,
299 GradientEpisodicMemory,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct GenerativeReplayConfig {
304 pub generator_type: GeneratorType,
305 pub quality_threshold: f32,
306 pub diversity_weight: f32,
307}
308
309impl Default for GenerativeReplayConfig {
310 fn default() -> Self {
311 Self {
312 generator_type: GeneratorType::VAE,
313 quality_threshold: 0.8,
314 diversity_weight: 0.1,
315 }
316 }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub enum GeneratorType {
321 VAE,
322 GAN,
323 Flow,
324 Diffusion,
325}
326
327#[derive(Debug, Clone)]
328pub struct TaskInfo {
329 pub task_id: String,
330 pub task_type: String,
331 pub start_time: DateTime<Utc>,
332 pub end_time: Option<DateTime<Utc>>,
333 pub examples_seen: usize,
334 pub performance: f32,
335 pub task_embedding: Option<Array1<f32>>,
336}
337
338impl TaskInfo {
339 pub fn new(task_id: String, task_type: String) -> Self {
340 Self {
341 task_id,
342 task_type,
343 start_time: Utc::now(),
344 end_time: None,
345 examples_seen: 0,
346 performance: 0.0,
347 task_embedding: None,
348 }
349 }
350}
351
352#[derive(Debug, Clone)]
353pub struct MemoryEntry {
354 pub data: Array1<f32>,
355 pub target: Array1<f32>,
356 pub task_id: String,
357 pub timestamp: DateTime<Utc>,
358 pub importance: f32,
359 pub access_count: usize,
360}
361
362impl MemoryEntry {
363 pub fn new(data: Array1<f32>, target: Array1<f32>, task_id: String) -> Self {
364 Self {
365 data,
366 target,
367 task_id,
368 timestamp: Utc::now(),
369 importance: 1.0,
370 access_count: 0,
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
376pub struct EWCState {
377 pub fisher_information: Array2<f32>,
378 pub optimal_parameters: Array2<f32>,
379 pub task_id: String,
380 pub importance: f32,
381}
382
383#[derive(Debug)]
384pub struct ContinualLearningModel {
385 pub config: ContinualLearningConfig,
386 pub model_id: Uuid,
387 pub embeddings: Array2<f32>,
388 pub task_specific_embeddings: HashMap<String, Array2<f32>>,
389 pub episodic_memory: VecDeque<MemoryEntry>,
390 pub semantic_memory: HashMap<String, Array1<f32>>,
391 pub ewc_states: Vec<EWCState>,
392 pub synaptic_importance: Array2<f32>,
393 pub parameter_trajectory: Array2<f32>,
394 pub current_task: Option<TaskInfo>,
395 pub task_history: Vec<TaskInfo>,
396 pub task_boundaries: Vec<usize>,
397 pub network_columns: Vec<Array2<f32>>,
398 pub lateral_connections: Vec<Array2<f32>>,
399 pub generator: Option<Array2<f32>>,
400 pub discriminator: Option<Array2<f32>>,
401 pub entities: HashMap<String, usize>,
402 pub relations: HashMap<String, usize>,
403 pub examples_seen: usize,
404 pub training_stats: Option<TrainingStats>,
405 pub is_trained: bool,
406}