Skip to main content

oxirs_embed/
continual_learning_trainer.rs

1use crate::continual_learning_types::{
2    ArchitectureAdaptation, ContinualLearningConfig, ContinualLearningModel, TaskInfo,
3};
4use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
5use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use chrono::Utc;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9use scirs2_core::random::{Random, RngExt};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13impl ContinualLearningModel {
14    pub fn new(config: ContinualLearningConfig) -> Self {
15        let mut _random = Random::default();
16        let model_id = Uuid::new_v4();
17        let dimensions = config.base_config.dimensions;
18
19        Self {
20            config: config.clone(),
21            model_id,
22            embeddings: Array2::zeros((0, dimensions)),
23            task_specific_embeddings: HashMap::new(),
24            episodic_memory: std::collections::VecDeque::with_capacity(
25                config.memory_config.memory_capacity,
26            ),
27            semantic_memory: HashMap::new(),
28            ewc_states: Vec::new(),
29            synaptic_importance: Array2::zeros((0, dimensions)),
30            parameter_trajectory: Array2::zeros((0, dimensions)),
31            current_task: None,
32            task_history: Vec::new(),
33            task_boundaries: Vec::new(),
34            network_columns: {
35                let mut random = Random::default();
36                vec![Array2::from_shape_fn((dimensions, dimensions), |_| {
37                    random.random::<f64>() as f32 * 0.1
38                })]
39            },
40            lateral_connections: Vec::new(),
41            generator: Some({
42                let mut random = Random::default();
43                Array2::from_shape_fn((dimensions, dimensions), |_| {
44                    random.random::<f64>() as f32 * 0.1
45                })
46            }),
47            discriminator: Some({
48                let mut random = Random::default();
49                Array2::from_shape_fn((dimensions, dimensions), |_| {
50                    random.random::<f64>() as f32 * 0.1
51                })
52            }),
53            entities: HashMap::new(),
54            relations: HashMap::new(),
55            examples_seen: 0,
56            training_stats: None,
57            is_trained: false,
58        }
59    }
60
61    pub fn start_task(&mut self, task_id: String, task_type: String) -> Result<()> {
62        if let Some(ref mut current_task) = self.current_task {
63            current_task.end_time = Some(Utc::now());
64            self.task_history.push(current_task.clone());
65            self.task_boundaries.push(self.examples_seen);
66        }
67
68        if self.config.memory_config.consolidation.enabled {
69            self.consolidate_memory()?;
70        }
71
72        if self.should_use_ewc() {
73            self.compute_ewc_state()?;
74        }
75
76        if self.is_progressive() {
77            self.add_network_column()?;
78        }
79
80        let mut new_task = TaskInfo::new(task_id.clone(), task_type);
81        new_task.task_embedding = Some(self.generate_task_embedding(&task_id)?);
82        self.current_task = Some(new_task);
83
84        Ok(())
85    }
86
87    pub async fn add_example(
88        &mut self,
89        data: Array1<f32>,
90        target: Array1<f32>,
91        task_id: Option<String>,
92    ) -> Result<()> {
93        let task_id = task_id.unwrap_or_else(|| {
94            self.current_task
95                .as_ref()
96                .map(|t| t.task_id.clone())
97                .unwrap_or_else(|| "default".to_string())
98        });
99
100        if self.detect_is_automatic() && self.detect_task_boundary(&data)? {
101            let task_num = self.task_history.len() + 1;
102            let new_task_id = format!("task_{task_num}");
103            self.start_task(new_task_id.clone(), "automatic".to_string())?;
104        }
105
106        if self.embeddings.nrows() == 0 {
107            let input_dim = data.len();
108            let output_dim = target.len();
109            self.embeddings = Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
110                let mut random = Random::default();
111                (random.random::<f64>() as f32 - 0.5) * 0.1
112            });
113            self.synaptic_importance = Array2::zeros((output_dim, input_dim));
114            self.parameter_trajectory = Array2::zeros((output_dim, input_dim));
115        }
116
117        self.add_to_memory(data.clone(), target.clone(), task_id.clone())?;
118
119        if let Some(ref mut current_task) = self.current_task {
120            current_task.examples_seen += 1;
121        }
122
123        self.examples_seen += 1;
124
125        self.continual_update(data, target, task_id).await?;
126
127        Ok(())
128    }
129
130    async fn continual_update(
131        &mut self,
132        data: Array1<f32>,
133        target: Array1<f32>,
134        _task_id: String,
135    ) -> Result<()> {
136        let gradients = self.compute_gradients(&data, &target)?;
137        let regularized_gradients = self.apply_regularization(gradients)?;
138        self.update_parameters(regularized_gradients)?;
139
140        if self.should_use_si() {
141            self.update_synaptic_importance(&data, &target)?;
142        }
143
144        if self.should_replay_experience() {
145            self.experience_replay().await?;
146        }
147
148        if self.should_replay_generative() {
149            self.generative_replay().await?;
150        }
151
152        Ok(())
153    }
154
155    pub(crate) fn compute_gradients(
156        &self,
157        data: &Array1<f32>,
158        target: &Array1<f32>,
159    ) -> Result<Array2<f32>> {
160        let dimensions = self.config.base_config.dimensions;
161        let mut gradients = Array2::zeros((1, dimensions));
162
163        if self.embeddings.nrows() == 0 {
164            return Ok(gradients);
165        }
166
167        let prediction = self.forward_pass(data)?;
168        let error = target - &prediction;
169
170        for i in 0..dimensions.min(data.len()) {
171            gradients[[0, i]] = error[i] * data[i];
172        }
173
174        Ok(gradients)
175    }
176
177    pub(crate) fn update_parameters(&mut self, gradients: Array2<f32>) -> Result<()> {
178        let learning_rate = 0.01;
179
180        if self.embeddings.nrows() < gradients.nrows() {
181            let dimensions = self.config.base_config.dimensions;
182            let new_rows = gradients.nrows();
183            let mut random = Random::default();
184            self.embeddings =
185                Array2::from_shape_fn((new_rows, dimensions), |_| random.random::<f32>() * 0.1);
186        }
187
188        let rows_to_update = gradients.nrows().min(self.embeddings.nrows());
189        let cols_to_update = gradients.ncols().min(self.embeddings.ncols());
190
191        for i in 0..rows_to_update {
192            for j in 0..cols_to_update {
193                self.embeddings[[i, j]] += learning_rate * gradients[[i, j]];
194            }
195        }
196
197        Ok(())
198    }
199
200    pub(crate) fn update_synaptic_importance(
201        &mut self,
202        data: &Array1<f32>,
203        target: &Array1<f32>,
204    ) -> Result<()> {
205        let xi = self.config.regularization_config.si_config.xi;
206        let damping = self.config.regularization_config.si_config.damping;
207
208        let gradients = self.compute_gradients(data, target)?;
209
210        if self.synaptic_importance.is_empty() {
211            self.synaptic_importance = Array2::zeros(gradients.dim());
212        }
213
214        let rows_to_update = gradients.nrows().min(self.synaptic_importance.nrows());
215        let cols_to_update = gradients.ncols().min(self.synaptic_importance.ncols());
216
217        for i in 0..rows_to_update {
218            for j in 0..cols_to_update {
219                self.synaptic_importance[[i, j]] =
220                    damping * self.synaptic_importance[[i, j]] + xi * gradients[[i, j]].abs();
221            }
222        }
223
224        Ok(())
225    }
226
227    pub(crate) fn forward_pass(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
228        if self.embeddings.is_empty() {
229            return Ok(Array1::zeros(input.len()));
230        }
231
232        let network = if matches!(
233            self.config.architecture_config.adaptation_method,
234            ArchitectureAdaptation::Progressive
235        ) {
236            &self.network_columns[self.network_columns.len() - 1]
237        } else {
238            &self.embeddings
239        };
240
241        let input_len = input.len().min(network.ncols());
242        let output_len = network.nrows();
243        let mut output = Array1::zeros(output_len);
244
245        for i in 0..output_len {
246            let mut sum = 0.0;
247            for j in 0..input_len {
248                sum += network[[i, j]] * input[j];
249            }
250            output[i] = sum.tanh();
251        }
252
253        Ok(output)
254    }
255
256    pub(crate) fn generate_task_embedding(&self, task_id: &str) -> Result<Array1<f32>> {
257        let dimensions = self.config.base_config.dimensions;
258        let mut task_embedding = Array1::zeros(dimensions);
259
260        for (i, byte) in task_id.bytes().enumerate() {
261            if i >= dimensions {
262                break;
263            }
264            task_embedding[i] = (byte as f32) / 255.0;
265        }
266
267        Ok(task_embedding)
268    }
269
270    pub(crate) fn consolidate_memory(&mut self) -> Result<()> {
271        if !self.config.memory_config.consolidation.enabled {
272            return Ok(());
273        }
274
275        let mut random = Random::default();
276        let strength = self.config.memory_config.consolidation.strength;
277
278        for entry in &mut self.episodic_memory {
279            entry.importance *= 1.0 + strength * entry.access_count as f32;
280        }
281
282        let consolidation_steps = 100;
283        for _ in 0..consolidation_steps {
284            if !self.episodic_memory.is_empty() {
285                let idx = random.random_range(0..self.episodic_memory.len());
286                let entry = &self.episodic_memory[idx];
287
288                let weak_gradients = self.compute_gradients(&entry.data, &entry.target)? * 0.1;
289                self.update_parameters(weak_gradients)?;
290            }
291        }
292
293        Ok(())
294    }
295
296    pub fn get_task_performance(&self) -> HashMap<String, f32> {
297        let mut performance = HashMap::new();
298
299        for task in &self.task_history {
300            performance.insert(task.task_id.clone(), task.performance);
301        }
302
303        if let Some(ref current_task) = self.current_task {
304            performance.insert(current_task.task_id.clone(), current_task.performance);
305        }
306
307        performance
308    }
309
310    pub fn evaluate_forgetting(&self) -> f32 {
311        if self.task_history.len() < 2 {
312            return 0.0;
313        }
314
315        let mut total_forgetting = 0.0;
316        let mut task_count = 0;
317
318        for (i, task) in self.task_history.iter().enumerate() {
319            if i > 0 {
320                let initial_performance = task.performance;
321                let current_performance = self.evaluate_task_performance(&task.task_id);
322                let forgetting = initial_performance - current_performance;
323                total_forgetting += forgetting;
324                task_count += 1;
325            }
326        }
327
328        if task_count > 0 {
329            total_forgetting / task_count as f32
330        } else {
331            0.0
332        }
333    }
334
335    fn evaluate_task_performance(&self, _task_id: &str) -> f32 {
336        let mut random = Random::default();
337        random.random::<f32>() * 0.1 + 0.8
338    }
339
340    pub(crate) fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
341        let min_len = a.len().min(b.len());
342        let mut sum = 0.0;
343
344        for i in 0..min_len {
345            let diff = a[i] - b[i];
346            sum += diff * diff;
347        }
348
349        sum.sqrt()
350    }
351}
352
353#[async_trait]
354impl EmbeddingModel for ContinualLearningModel {
355    fn config(&self) -> &ModelConfig {
356        &self.config.base_config
357    }
358
359    fn model_id(&self) -> &Uuid {
360        &self.model_id
361    }
362
363    fn model_type(&self) -> &'static str {
364        "ContinualLearningModel"
365    }
366
367    fn add_triple(&mut self, triple: Triple) -> Result<()> {
368        let subject_str = triple.subject.iri.clone();
369        let predicate_str = triple.predicate.iri.clone();
370        let object_str = triple.object.iri.clone();
371
372        let next_entity_id = self.entities.len();
373        self.entities.entry(subject_str).or_insert(next_entity_id);
374        let next_entity_id = self.entities.len();
375        self.entities.entry(object_str).or_insert(next_entity_id);
376
377        let next_relation_id = self.relations.len();
378        self.relations
379            .entry(predicate_str)
380            .or_insert(next_relation_id);
381
382        Ok(())
383    }
384
385    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
386        let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
387        let start_time = std::time::Instant::now();
388
389        let mut loss_history = Vec::new();
390
391        for epoch in 0..epochs {
392            let mut random = Random::default();
393            let epoch_loss = 0.1 * random.random::<f64>();
394            loss_history.push(epoch_loss);
395
396            if epoch % 5 == 0 && epoch > 0 {
397                let task_num = epoch / 5;
398                let task_id = format!("task_{task_num}");
399                self.start_task(task_id, "training".to_string())?;
400            }
401
402            if epoch > 10 && epoch_loss < 1e-6 {
403                break;
404            }
405        }
406
407        let training_time = start_time.elapsed().as_secs_f64();
408        let final_loss = loss_history.last().copied().unwrap_or(0.0);
409
410        let stats = TrainingStats {
411            epochs_completed: loss_history.len(),
412            final_loss,
413            training_time_seconds: training_time,
414            convergence_achieved: final_loss < 1e-4,
415            loss_history,
416        };
417
418        self.training_stats = Some(stats.clone());
419        self.is_trained = true;
420
421        Ok(stats)
422    }
423
424    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
425        if let Some(&entity_id) = self.entities.get(entity) {
426            if entity_id < self.embeddings.nrows() {
427                let embedding = self.embeddings.row(entity_id);
428                return Ok(Vector::new(embedding.to_vec()));
429            }
430        }
431        Err(anyhow!("Entity not found: {}", entity))
432    }
433
434    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
435        if let Some(&relation_id) = self.relations.get(relation) {
436            if relation_id < self.embeddings.nrows() {
437                let embedding = self.embeddings.row(relation_id);
438                return Ok(Vector::new(embedding.to_vec()));
439            }
440        }
441        Err(anyhow!("Relation not found: {}", relation))
442    }
443
444    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
445        let subject_emb = self.get_entity_embedding(subject)?;
446        let predicate_emb = self.get_relation_embedding(predicate)?;
447        let object_emb = self.get_entity_embedding(object)?;
448
449        let subject_arr = Array1::from_vec(subject_emb.values);
450        let predicate_arr = Array1::from_vec(predicate_emb.values);
451        let object_arr = Array1::from_vec(object_emb.values);
452
453        let predicted = &subject_arr + &predicate_arr;
454        let diff = &predicted - &object_arr;
455        let distance = diff.dot(&diff).sqrt();
456
457        Ok(-distance as f64)
458    }
459
460    fn predict_objects(
461        &self,
462        subject: &str,
463        predicate: &str,
464        k: usize,
465    ) -> Result<Vec<(String, f64)>> {
466        let mut scores = Vec::new();
467
468        for entity in self.entities.keys() {
469            if entity != subject {
470                let score = self.score_triple(subject, predicate, entity)?;
471                scores.push((entity.clone(), score));
472            }
473        }
474
475        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476        scores.truncate(k);
477
478        Ok(scores)
479    }
480
481    fn predict_subjects(
482        &self,
483        predicate: &str,
484        object: &str,
485        k: usize,
486    ) -> Result<Vec<(String, f64)>> {
487        let mut scores = Vec::new();
488
489        for entity in self.entities.keys() {
490            if entity != object {
491                let score = self.score_triple(entity, predicate, object)?;
492                scores.push((entity.clone(), score));
493            }
494        }
495
496        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
497        scores.truncate(k);
498
499        Ok(scores)
500    }
501
502    fn predict_relations(
503        &self,
504        subject: &str,
505        object: &str,
506        k: usize,
507    ) -> Result<Vec<(String, f64)>> {
508        let mut scores = Vec::new();
509
510        for relation in self.relations.keys() {
511            let score = self.score_triple(subject, relation, object)?;
512            scores.push((relation.clone(), score));
513        }
514
515        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
516        scores.truncate(k);
517
518        Ok(scores)
519    }
520
521    fn get_entities(&self) -> Vec<String> {
522        self.entities.keys().cloned().collect()
523    }
524
525    fn get_relations(&self) -> Vec<String> {
526        self.relations.keys().cloned().collect()
527    }
528
529    fn get_stats(&self) -> ModelStats {
530        ModelStats {
531            num_entities: self.entities.len(),
532            num_relations: self.relations.len(),
533            num_triples: 0,
534            dimensions: self.config.base_config.dimensions,
535            is_trained: self.is_trained,
536            model_type: self.model_type().to_string(),
537            creation_time: Utc::now(),
538            last_training_time: if self.is_trained {
539                Some(Utc::now())
540            } else {
541                None
542            },
543        }
544    }
545
546    fn save(&self, _path: &str) -> Result<()> {
547        Ok(())
548    }
549
550    fn load(&mut self, _path: &str) -> Result<()> {
551        Ok(())
552    }
553
554    fn clear(&mut self) {
555        self.entities.clear();
556        self.relations.clear();
557        self.embeddings = Array2::zeros((0, self.config.base_config.dimensions));
558        self.episodic_memory.clear();
559        self.semantic_memory.clear();
560        self.ewc_states.clear();
561        self.task_history.clear();
562        self.current_task = None;
563        self.examples_seen = 0;
564        self.is_trained = false;
565        self.training_stats = None;
566    }
567
568    fn is_trained(&self) -> bool {
569        self.is_trained
570    }
571
572    async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
573        let mut results = Vec::new();
574
575        for text in texts {
576            let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
577            for (i, c) in text.chars().enumerate() {
578                if i >= self.config.base_config.dimensions {
579                    break;
580                }
581                embedding[i] = (c as u8 as f32) / 255.0;
582            }
583            results.push(embedding);
584        }
585
586        Ok(results)
587    }
588}