1use crate::continual_learning_types::{
2 ArchitectureAdaptation, ContinualLearningConfig, ContinualLearningModel, TaskInfo,
3};
4use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
5use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use chrono::Utc;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9use scirs2_core::random::{Random, RngExt};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13impl ContinualLearningModel {
14 pub fn new(config: ContinualLearningConfig) -> Self {
15 let mut _random = Random::default();
16 let model_id = Uuid::new_v4();
17 let dimensions = config.base_config.dimensions;
18
19 Self {
20 config: config.clone(),
21 model_id,
22 embeddings: Array2::zeros((0, dimensions)),
23 task_specific_embeddings: HashMap::new(),
24 episodic_memory: std::collections::VecDeque::with_capacity(
25 config.memory_config.memory_capacity,
26 ),
27 semantic_memory: HashMap::new(),
28 ewc_states: Vec::new(),
29 synaptic_importance: Array2::zeros((0, dimensions)),
30 parameter_trajectory: Array2::zeros((0, dimensions)),
31 current_task: None,
32 task_history: Vec::new(),
33 task_boundaries: Vec::new(),
34 network_columns: {
35 let mut random = Random::default();
36 vec![Array2::from_shape_fn((dimensions, dimensions), |_| {
37 random.random::<f64>() as f32 * 0.1
38 })]
39 },
40 lateral_connections: Vec::new(),
41 generator: Some({
42 let mut random = Random::default();
43 Array2::from_shape_fn((dimensions, dimensions), |_| {
44 random.random::<f64>() as f32 * 0.1
45 })
46 }),
47 discriminator: Some({
48 let mut random = Random::default();
49 Array2::from_shape_fn((dimensions, dimensions), |_| {
50 random.random::<f64>() as f32 * 0.1
51 })
52 }),
53 entities: HashMap::new(),
54 relations: HashMap::new(),
55 examples_seen: 0,
56 training_stats: None,
57 is_trained: false,
58 }
59 }
60
61 pub fn start_task(&mut self, task_id: String, task_type: String) -> Result<()> {
62 if let Some(ref mut current_task) = self.current_task {
63 current_task.end_time = Some(Utc::now());
64 self.task_history.push(current_task.clone());
65 self.task_boundaries.push(self.examples_seen);
66 }
67
68 if self.config.memory_config.consolidation.enabled {
69 self.consolidate_memory()?;
70 }
71
72 if self.should_use_ewc() {
73 self.compute_ewc_state()?;
74 }
75
76 if self.is_progressive() {
77 self.add_network_column()?;
78 }
79
80 let mut new_task = TaskInfo::new(task_id.clone(), task_type);
81 new_task.task_embedding = Some(self.generate_task_embedding(&task_id)?);
82 self.current_task = Some(new_task);
83
84 Ok(())
85 }
86
87 pub async fn add_example(
88 &mut self,
89 data: Array1<f32>,
90 target: Array1<f32>,
91 task_id: Option<String>,
92 ) -> Result<()> {
93 let task_id = task_id.unwrap_or_else(|| {
94 self.current_task
95 .as_ref()
96 .map(|t| t.task_id.clone())
97 .unwrap_or_else(|| "default".to_string())
98 });
99
100 if self.detect_is_automatic() && self.detect_task_boundary(&data)? {
101 let task_num = self.task_history.len() + 1;
102 let new_task_id = format!("task_{task_num}");
103 self.start_task(new_task_id.clone(), "automatic".to_string())?;
104 }
105
106 if self.embeddings.nrows() == 0 {
107 let input_dim = data.len();
108 let output_dim = target.len();
109 self.embeddings = Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
110 let mut random = Random::default();
111 (random.random::<f64>() as f32 - 0.5) * 0.1
112 });
113 self.synaptic_importance = Array2::zeros((output_dim, input_dim));
114 self.parameter_trajectory = Array2::zeros((output_dim, input_dim));
115 }
116
117 self.add_to_memory(data.clone(), target.clone(), task_id.clone())?;
118
119 if let Some(ref mut current_task) = self.current_task {
120 current_task.examples_seen += 1;
121 }
122
123 self.examples_seen += 1;
124
125 self.continual_update(data, target, task_id).await?;
126
127 Ok(())
128 }
129
130 async fn continual_update(
131 &mut self,
132 data: Array1<f32>,
133 target: Array1<f32>,
134 _task_id: String,
135 ) -> Result<()> {
136 let gradients = self.compute_gradients(&data, &target)?;
137 let regularized_gradients = self.apply_regularization(gradients)?;
138 self.update_parameters(regularized_gradients)?;
139
140 if self.should_use_si() {
141 self.update_synaptic_importance(&data, &target)?;
142 }
143
144 if self.should_replay_experience() {
145 self.experience_replay().await?;
146 }
147
148 if self.should_replay_generative() {
149 self.generative_replay().await?;
150 }
151
152 Ok(())
153 }
154
155 pub(crate) fn compute_gradients(
156 &self,
157 data: &Array1<f32>,
158 target: &Array1<f32>,
159 ) -> Result<Array2<f32>> {
160 let dimensions = self.config.base_config.dimensions;
161 let mut gradients = Array2::zeros((1, dimensions));
162
163 if self.embeddings.nrows() == 0 {
164 return Ok(gradients);
165 }
166
167 let prediction = self.forward_pass(data)?;
168 let error = target - &prediction;
169
170 for i in 0..dimensions.min(data.len()) {
171 gradients[[0, i]] = error[i] * data[i];
172 }
173
174 Ok(gradients)
175 }
176
177 pub(crate) fn update_parameters(&mut self, gradients: Array2<f32>) -> Result<()> {
178 let learning_rate = 0.01;
179
180 if self.embeddings.nrows() < gradients.nrows() {
181 let dimensions = self.config.base_config.dimensions;
182 let new_rows = gradients.nrows();
183 let mut random = Random::default();
184 self.embeddings =
185 Array2::from_shape_fn((new_rows, dimensions), |_| random.random::<f32>() * 0.1);
186 }
187
188 let rows_to_update = gradients.nrows().min(self.embeddings.nrows());
189 let cols_to_update = gradients.ncols().min(self.embeddings.ncols());
190
191 for i in 0..rows_to_update {
192 for j in 0..cols_to_update {
193 self.embeddings[[i, j]] += learning_rate * gradients[[i, j]];
194 }
195 }
196
197 Ok(())
198 }
199
200 pub(crate) fn update_synaptic_importance(
201 &mut self,
202 data: &Array1<f32>,
203 target: &Array1<f32>,
204 ) -> Result<()> {
205 let xi = self.config.regularization_config.si_config.xi;
206 let damping = self.config.regularization_config.si_config.damping;
207
208 let gradients = self.compute_gradients(data, target)?;
209
210 if self.synaptic_importance.is_empty() {
211 self.synaptic_importance = Array2::zeros(gradients.dim());
212 }
213
214 let rows_to_update = gradients.nrows().min(self.synaptic_importance.nrows());
215 let cols_to_update = gradients.ncols().min(self.synaptic_importance.ncols());
216
217 for i in 0..rows_to_update {
218 for j in 0..cols_to_update {
219 self.synaptic_importance[[i, j]] =
220 damping * self.synaptic_importance[[i, j]] + xi * gradients[[i, j]].abs();
221 }
222 }
223
224 Ok(())
225 }
226
227 pub(crate) fn forward_pass(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
228 if self.embeddings.is_empty() {
229 return Ok(Array1::zeros(input.len()));
230 }
231
232 let network = if matches!(
233 self.config.architecture_config.adaptation_method,
234 ArchitectureAdaptation::Progressive
235 ) {
236 &self.network_columns[self.network_columns.len() - 1]
237 } else {
238 &self.embeddings
239 };
240
241 let input_len = input.len().min(network.ncols());
242 let output_len = network.nrows();
243 let mut output = Array1::zeros(output_len);
244
245 for i in 0..output_len {
246 let mut sum = 0.0;
247 for j in 0..input_len {
248 sum += network[[i, j]] * input[j];
249 }
250 output[i] = sum.tanh();
251 }
252
253 Ok(output)
254 }
255
256 pub(crate) fn generate_task_embedding(&self, task_id: &str) -> Result<Array1<f32>> {
257 let dimensions = self.config.base_config.dimensions;
258 let mut task_embedding = Array1::zeros(dimensions);
259
260 for (i, byte) in task_id.bytes().enumerate() {
261 if i >= dimensions {
262 break;
263 }
264 task_embedding[i] = (byte as f32) / 255.0;
265 }
266
267 Ok(task_embedding)
268 }
269
270 pub(crate) fn consolidate_memory(&mut self) -> Result<()> {
271 if !self.config.memory_config.consolidation.enabled {
272 return Ok(());
273 }
274
275 let mut random = Random::default();
276 let strength = self.config.memory_config.consolidation.strength;
277
278 for entry in &mut self.episodic_memory {
279 entry.importance *= 1.0 + strength * entry.access_count as f32;
280 }
281
282 let consolidation_steps = 100;
283 for _ in 0..consolidation_steps {
284 if !self.episodic_memory.is_empty() {
285 let idx = random.random_range(0..self.episodic_memory.len());
286 let entry = &self.episodic_memory[idx];
287
288 let weak_gradients = self.compute_gradients(&entry.data, &entry.target)? * 0.1;
289 self.update_parameters(weak_gradients)?;
290 }
291 }
292
293 Ok(())
294 }
295
296 pub fn get_task_performance(&self) -> HashMap<String, f32> {
297 let mut performance = HashMap::new();
298
299 for task in &self.task_history {
300 performance.insert(task.task_id.clone(), task.performance);
301 }
302
303 if let Some(ref current_task) = self.current_task {
304 performance.insert(current_task.task_id.clone(), current_task.performance);
305 }
306
307 performance
308 }
309
310 pub fn evaluate_forgetting(&self) -> f32 {
311 if self.task_history.len() < 2 {
312 return 0.0;
313 }
314
315 let mut total_forgetting = 0.0;
316 let mut task_count = 0;
317
318 for (i, task) in self.task_history.iter().enumerate() {
319 if i > 0 {
320 let initial_performance = task.performance;
321 let current_performance = self.evaluate_task_performance(&task.task_id);
322 let forgetting = initial_performance - current_performance;
323 total_forgetting += forgetting;
324 task_count += 1;
325 }
326 }
327
328 if task_count > 0 {
329 total_forgetting / task_count as f32
330 } else {
331 0.0
332 }
333 }
334
335 fn evaluate_task_performance(&self, _task_id: &str) -> f32 {
336 let mut random = Random::default();
337 random.random::<f32>() * 0.1 + 0.8
338 }
339
340 pub(crate) fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
341 let min_len = a.len().min(b.len());
342 let mut sum = 0.0;
343
344 for i in 0..min_len {
345 let diff = a[i] - b[i];
346 sum += diff * diff;
347 }
348
349 sum.sqrt()
350 }
351}
352
353#[async_trait]
354impl EmbeddingModel for ContinualLearningModel {
355 fn config(&self) -> &ModelConfig {
356 &self.config.base_config
357 }
358
359 fn model_id(&self) -> &Uuid {
360 &self.model_id
361 }
362
363 fn model_type(&self) -> &'static str {
364 "ContinualLearningModel"
365 }
366
367 fn add_triple(&mut self, triple: Triple) -> Result<()> {
368 let subject_str = triple.subject.iri.clone();
369 let predicate_str = triple.predicate.iri.clone();
370 let object_str = triple.object.iri.clone();
371
372 let next_entity_id = self.entities.len();
373 self.entities.entry(subject_str).or_insert(next_entity_id);
374 let next_entity_id = self.entities.len();
375 self.entities.entry(object_str).or_insert(next_entity_id);
376
377 let next_relation_id = self.relations.len();
378 self.relations
379 .entry(predicate_str)
380 .or_insert(next_relation_id);
381
382 Ok(())
383 }
384
385 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
386 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
387 let start_time = std::time::Instant::now();
388
389 let mut loss_history = Vec::new();
390
391 for epoch in 0..epochs {
392 let mut random = Random::default();
393 let epoch_loss = 0.1 * random.random::<f64>();
394 loss_history.push(epoch_loss);
395
396 if epoch % 5 == 0 && epoch > 0 {
397 let task_num = epoch / 5;
398 let task_id = format!("task_{task_num}");
399 self.start_task(task_id, "training".to_string())?;
400 }
401
402 if epoch > 10 && epoch_loss < 1e-6 {
403 break;
404 }
405 }
406
407 let training_time = start_time.elapsed().as_secs_f64();
408 let final_loss = loss_history.last().copied().unwrap_or(0.0);
409
410 let stats = TrainingStats {
411 epochs_completed: loss_history.len(),
412 final_loss,
413 training_time_seconds: training_time,
414 convergence_achieved: final_loss < 1e-4,
415 loss_history,
416 };
417
418 self.training_stats = Some(stats.clone());
419 self.is_trained = true;
420
421 Ok(stats)
422 }
423
424 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
425 if let Some(&entity_id) = self.entities.get(entity) {
426 if entity_id < self.embeddings.nrows() {
427 let embedding = self.embeddings.row(entity_id);
428 return Ok(Vector::new(embedding.to_vec()));
429 }
430 }
431 Err(anyhow!("Entity not found: {}", entity))
432 }
433
434 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
435 if let Some(&relation_id) = self.relations.get(relation) {
436 if relation_id < self.embeddings.nrows() {
437 let embedding = self.embeddings.row(relation_id);
438 return Ok(Vector::new(embedding.to_vec()));
439 }
440 }
441 Err(anyhow!("Relation not found: {}", relation))
442 }
443
444 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
445 let subject_emb = self.get_entity_embedding(subject)?;
446 let predicate_emb = self.get_relation_embedding(predicate)?;
447 let object_emb = self.get_entity_embedding(object)?;
448
449 let subject_arr = Array1::from_vec(subject_emb.values);
450 let predicate_arr = Array1::from_vec(predicate_emb.values);
451 let object_arr = Array1::from_vec(object_emb.values);
452
453 let predicted = &subject_arr + &predicate_arr;
454 let diff = &predicted - &object_arr;
455 let distance = diff.dot(&diff).sqrt();
456
457 Ok(-distance as f64)
458 }
459
460 fn predict_objects(
461 &self,
462 subject: &str,
463 predicate: &str,
464 k: usize,
465 ) -> Result<Vec<(String, f64)>> {
466 let mut scores = Vec::new();
467
468 for entity in self.entities.keys() {
469 if entity != subject {
470 let score = self.score_triple(subject, predicate, entity)?;
471 scores.push((entity.clone(), score));
472 }
473 }
474
475 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476 scores.truncate(k);
477
478 Ok(scores)
479 }
480
481 fn predict_subjects(
482 &self,
483 predicate: &str,
484 object: &str,
485 k: usize,
486 ) -> Result<Vec<(String, f64)>> {
487 let mut scores = Vec::new();
488
489 for entity in self.entities.keys() {
490 if entity != object {
491 let score = self.score_triple(entity, predicate, object)?;
492 scores.push((entity.clone(), score));
493 }
494 }
495
496 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
497 scores.truncate(k);
498
499 Ok(scores)
500 }
501
502 fn predict_relations(
503 &self,
504 subject: &str,
505 object: &str,
506 k: usize,
507 ) -> Result<Vec<(String, f64)>> {
508 let mut scores = Vec::new();
509
510 for relation in self.relations.keys() {
511 let score = self.score_triple(subject, relation, object)?;
512 scores.push((relation.clone(), score));
513 }
514
515 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
516 scores.truncate(k);
517
518 Ok(scores)
519 }
520
521 fn get_entities(&self) -> Vec<String> {
522 self.entities.keys().cloned().collect()
523 }
524
525 fn get_relations(&self) -> Vec<String> {
526 self.relations.keys().cloned().collect()
527 }
528
529 fn get_stats(&self) -> ModelStats {
530 ModelStats {
531 num_entities: self.entities.len(),
532 num_relations: self.relations.len(),
533 num_triples: 0,
534 dimensions: self.config.base_config.dimensions,
535 is_trained: self.is_trained,
536 model_type: self.model_type().to_string(),
537 creation_time: Utc::now(),
538 last_training_time: if self.is_trained {
539 Some(Utc::now())
540 } else {
541 None
542 },
543 }
544 }
545
546 fn save(&self, _path: &str) -> Result<()> {
547 Ok(())
548 }
549
550 fn load(&mut self, _path: &str) -> Result<()> {
551 Ok(())
552 }
553
554 fn clear(&mut self) {
555 self.entities.clear();
556 self.relations.clear();
557 self.embeddings = Array2::zeros((0, self.config.base_config.dimensions));
558 self.episodic_memory.clear();
559 self.semantic_memory.clear();
560 self.ewc_states.clear();
561 self.task_history.clear();
562 self.current_task = None;
563 self.examples_seen = 0;
564 self.is_trained = false;
565 self.training_stats = None;
566 }
567
568 fn is_trained(&self) -> bool {
569 self.is_trained
570 }
571
572 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
573 let mut results = Vec::new();
574
575 for text in texts {
576 let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
577 for (i, c) in text.chars().enumerate() {
578 if i >= self.config.base_config.dimensions {
579 break;
580 }
581 embedding[i] = (c as u8 as f32) / 255.0;
582 }
583 results.push(embedding);
584 }
585
586 Ok(results)
587 }
588}