oxirs_embed/
real_time_fine_tuning.rs

1//! Real-time Fine-tuning System
2//!
3//! This module implements real-time fine-tuning capabilities for embedding models
4//! with incremental learning, online adaptation, and dynamic model updates.
5
6use crate::{EmbeddingModel, ModelConfig, TrainingStats, Triple, Vector};
7use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use scirs2_core::ndarray_ext::{Array1, Array2};
11use scirs2_core::random::{Random, Rng};
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, VecDeque};
14use uuid::Uuid;
15
16/// Configuration for real-time fine-tuning
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RealTimeFinetuningConfig {
19    pub base_config: ModelConfig,
20    /// Learning rate for online updates
21    pub online_learning_rate: f32,
22    /// Buffer size for experience replay
23    pub replay_buffer_size: usize,
24    /// Batch size for online updates
25    pub online_batch_size: usize,
26    /// Adaptation threshold for triggering updates
27    pub adaptation_threshold: f32,
28    /// Memory decay factor
29    pub memory_decay: f32,
30    /// Update frequency (every N examples)
31    pub update_frequency: usize,
32    /// Catastrophic forgetting prevention
33    pub forgetting_prevention: ForgettingPreventionConfig,
34    /// Online evaluation settings
35    pub online_evaluation: OnlineEvaluationConfig,
36}
37
38impl Default for RealTimeFinetuningConfig {
39    fn default() -> Self {
40        Self {
41            base_config: ModelConfig::default(),
42            online_learning_rate: 1e-4,
43            replay_buffer_size: 10000,
44            online_batch_size: 32,
45            adaptation_threshold: 0.1,
46            memory_decay: 0.99,
47            update_frequency: 10,
48            forgetting_prevention: ForgettingPreventionConfig::default(),
49            online_evaluation: OnlineEvaluationConfig::default(),
50        }
51    }
52}
53
54/// Catastrophic forgetting prevention configuration
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ForgettingPreventionConfig {
57    /// Use elastic weight consolidation
58    pub use_ewc: bool,
59    /// EWC regularization strength
60    pub ewc_lambda: f32,
61    /// Use progressive neural networks
62    pub use_progressive_nets: bool,
63    /// Use memory replay
64    pub use_memory_replay: bool,
65    /// Memory replay ratio
66    pub replay_ratio: f32,
67}
68
69impl Default for ForgettingPreventionConfig {
70    fn default() -> Self {
71        Self {
72            use_ewc: true,
73            ewc_lambda: 0.4,
74            use_progressive_nets: false,
75            use_memory_replay: true,
76            replay_ratio: 0.3,
77        }
78    }
79}
80
81/// Online evaluation configuration
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct OnlineEvaluationConfig {
84    /// Sliding window size for evaluation
85    pub window_size: usize,
86    /// Evaluation frequency
87    pub eval_frequency: usize,
88    /// Performance metrics to track
89    pub metrics: Vec<OnlineMetric>,
90    /// Early stopping criteria
91    pub early_stopping: EarlyStoppingConfig,
92}
93
94impl Default for OnlineEvaluationConfig {
95    fn default() -> Self {
96        Self {
97            window_size: 1000,
98            eval_frequency: 100,
99            metrics: vec![
100                OnlineMetric::Loss,
101                OnlineMetric::Accuracy,
102                OnlineMetric::Drift,
103                OnlineMetric::Forgetting,
104            ],
105            early_stopping: EarlyStoppingConfig::default(),
106        }
107    }
108}
109
110/// Online metrics to track
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum OnlineMetric {
113    Loss,
114    Accuracy,
115    Drift,
116    Forgetting,
117    Plasticity,
118    Stability,
119}
120
121/// Early stopping configuration
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct EarlyStoppingConfig {
124    /// Patience (number of evaluations without improvement)
125    pub patience: usize,
126    /// Minimum improvement threshold
127    pub min_improvement: f32,
128    /// Metric to monitor
129    pub monitor_metric: OnlineMetric,
130}
131
132impl Default for EarlyStoppingConfig {
133    fn default() -> Self {
134        Self {
135            patience: 10,
136            min_improvement: 1e-4,
137            monitor_metric: OnlineMetric::Loss,
138        }
139    }
140}
141
142/// Experience replay buffer entry
143#[derive(Debug, Clone)]
144pub struct ExperienceEntry {
145    pub input: Array1<f32>,
146    pub target: Array1<f32>,
147    pub timestamp: DateTime<Utc>,
148    pub importance: f32,
149    pub task_id: Option<String>,
150}
151
152/// Online performance tracking
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct OnlinePerformanceTracker {
155    pub recent_losses: VecDeque<f32>,
156    pub recent_accuracies: VecDeque<f32>,
157    pub drift_scores: VecDeque<f32>,
158    pub forgetting_scores: VecDeque<f32>,
159    pub update_count: usize,
160    pub last_evaluation: DateTime<Utc>,
161}
162
163impl OnlinePerformanceTracker {
164    pub fn new(window_size: usize) -> Self {
165        Self {
166            recent_losses: VecDeque::with_capacity(window_size),
167            recent_accuracies: VecDeque::with_capacity(window_size),
168            drift_scores: VecDeque::with_capacity(window_size),
169            forgetting_scores: VecDeque::with_capacity(window_size),
170            update_count: 0,
171            last_evaluation: Utc::now(),
172        }
173    }
174
175    pub fn update_metrics(&mut self, loss: f32, accuracy: f32, drift: f32, forgetting: f32) {
176        self.recent_losses.push_back(loss);
177        self.recent_accuracies.push_back(accuracy);
178        self.drift_scores.push_back(drift);
179        self.forgetting_scores.push_back(forgetting);
180
181        // Maintain window size
182        if self.recent_losses.len() > self.recent_losses.capacity() {
183            self.recent_losses.pop_front();
184        }
185        if self.recent_accuracies.len() > self.recent_accuracies.capacity() {
186            self.recent_accuracies.pop_front();
187        }
188        if self.drift_scores.len() > self.drift_scores.capacity() {
189            self.drift_scores.pop_front();
190        }
191        if self.forgetting_scores.len() > self.forgetting_scores.capacity() {
192            self.forgetting_scores.pop_front();
193        }
194
195        self.update_count += 1;
196        self.last_evaluation = Utc::now();
197    }
198
199    pub fn get_average_loss(&self) -> f32 {
200        if self.recent_losses.is_empty() {
201            0.0
202        } else {
203            self.recent_losses.iter().sum::<f32>() / self.recent_losses.len() as f32
204        }
205    }
206
207    pub fn get_average_accuracy(&self) -> f32 {
208        if self.recent_accuracies.is_empty() {
209            0.0
210        } else {
211            self.recent_accuracies.iter().sum::<f32>() / self.recent_accuracies.len() as f32
212        }
213    }
214
215    pub fn get_drift_score(&self) -> f32 {
216        if self.drift_scores.is_empty() {
217            0.0
218        } else {
219            self.drift_scores.iter().sum::<f32>() / self.drift_scores.len() as f32
220        }
221    }
222
223    pub fn get_forgetting_score(&self) -> f32 {
224        if self.forgetting_scores.is_empty() {
225            0.0
226        } else {
227            self.forgetting_scores.iter().sum::<f32>() / self.forgetting_scores.len() as f32
228        }
229    }
230}
231
232/// Real-time fine-tuning model
233#[derive(Debug)]
234pub struct RealTimeFinetuningModel {
235    pub config: RealTimeFinetuningConfig,
236    pub model_id: Uuid,
237
238    /// Core model parameters
239    pub embeddings: Array2<f32>,
240    pub fisher_information: Array2<f32>, // For EWC
241    pub optimal_parameters: Array2<f32>, // For EWC
242
243    /// Experience replay buffer
244    pub replay_buffer: VecDeque<ExperienceEntry>,
245
246    /// Online performance tracking
247    pub performance_tracker: OnlinePerformanceTracker,
248
249    /// Entity and relation mappings
250    pub entities: HashMap<String, usize>,
251    pub relations: HashMap<String, usize>,
252
253    /// Training state
254    pub examples_seen: usize,
255    pub last_update: DateTime<Utc>,
256    pub is_adapting: bool,
257
258    /// Task-specific memory
259    pub task_memory: HashMap<String, Array2<f32>>,
260    pub current_task: Option<String>,
261
262    /// Statistics
263    pub training_stats: Option<TrainingStats>,
264    pub is_trained: bool,
265}
266
267impl RealTimeFinetuningModel {
268    /// Create new real-time fine-tuning model
269    pub fn new(config: RealTimeFinetuningConfig) -> Self {
270        let model_id = Uuid::new_v4();
271        let dimensions = config.base_config.dimensions;
272
273        Self {
274            config: config.clone(),
275            model_id,
276            embeddings: Array2::zeros((0, dimensions)),
277            fisher_information: Array2::zeros((0, dimensions)),
278            optimal_parameters: Array2::zeros((0, dimensions)),
279            replay_buffer: VecDeque::with_capacity(config.replay_buffer_size),
280            performance_tracker: OnlinePerformanceTracker::new(
281                config.online_evaluation.window_size,
282            ),
283            entities: HashMap::new(),
284            relations: HashMap::new(),
285            examples_seen: 0,
286            last_update: Utc::now(),
287            is_adapting: false,
288            task_memory: HashMap::new(),
289            current_task: None,
290            training_stats: None,
291            is_trained: false,
292        }
293    }
294
295    /// Add new example for online learning
296    pub async fn add_example(
297        &mut self,
298        input: Array1<f32>,
299        target: Array1<f32>,
300        task_id: Option<String>,
301    ) -> Result<()> {
302        // Initialize network if needed
303        if self.embeddings.nrows() == 0 {
304            let input_dim = input.len();
305            let output_dim = target.len();
306            self.embeddings = Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
307                let mut random = Random::default();
308                (random.random::<f32>() - 0.5) * 0.1
309            });
310            self.fisher_information = Array2::zeros((output_dim, input_dim));
311            self.optimal_parameters = Array2::zeros((output_dim, input_dim));
312        }
313
314        // Add to replay buffer
315        let entry = ExperienceEntry {
316            input: input.clone(),
317            target: target.clone(),
318            timestamp: Utc::now(),
319            importance: 1.0, // Can be computed based on novelty/difficulty
320            task_id: task_id.clone(),
321        };
322
323        self.replay_buffer.push_back(entry);
324        if self.replay_buffer.len() > self.config.replay_buffer_size {
325            self.replay_buffer.pop_front();
326        }
327
328        self.examples_seen += 1;
329
330        // Trigger adaptation if threshold met
331        if self.should_adapt() {
332            self.adapt_online().await?;
333        }
334
335        Ok(())
336    }
337
338    /// Check if model should adapt
339    fn should_adapt(&self) -> bool {
340        // Adapt every N examples or if performance drops
341        if self.examples_seen % self.config.update_frequency == 0 {
342            return true;
343        }
344
345        // Adapt if performance drops below threshold
346        let current_loss = self.performance_tracker.get_average_loss();
347        if current_loss > self.config.adaptation_threshold {
348            return true;
349        }
350
351        false
352    }
353
354    /// Perform online adaptation
355    pub async fn adapt_online(&mut self) -> Result<()> {
356        if self.replay_buffer.is_empty() {
357            return Ok(());
358        }
359
360        self.is_adapting = true;
361
362        // Sample batch from replay buffer
363        let batch = self.sample_replay_batch();
364
365        // Compute gradients
366        let gradients = self.compute_gradients(&batch)?;
367
368        // Apply EWC regularization if enabled
369        let regularized_gradients = if self.config.forgetting_prevention.use_ewc {
370            self.apply_ewc_regularization(gradients)?
371        } else {
372            gradients
373        };
374
375        // Update parameters
376        self.update_parameters(regularized_gradients)?;
377
378        // Update Fisher information for EWC
379        if self.config.forgetting_prevention.use_ewc {
380            self.update_fisher_information(&batch)?;
381        }
382
383        // Evaluate performance
384        self.evaluate_online_performance().await?;
385
386        self.last_update = Utc::now();
387        self.is_adapting = false;
388
389        Ok(())
390    }
391
392    /// Sample batch from replay buffer
393    fn sample_replay_batch(&self) -> Vec<ExperienceEntry> {
394        let batch_size = self.config.online_batch_size.min(self.replay_buffer.len());
395        let mut batch = Vec::with_capacity(batch_size);
396
397        // Sample with importance-based probability
398        for _ in 0..batch_size {
399            let mut random = Random::default();
400            let idx = random.random_range(0..self.replay_buffer.len());
401            batch.push(self.replay_buffer[idx].clone());
402        }
403
404        batch
405    }
406
407    /// Compute gradients for batch
408    fn compute_gradients(&self, batch: &[ExperienceEntry]) -> Result<Array2<f32>> {
409        let dimensions = self.config.base_config.dimensions;
410        let mut gradients = Array2::zeros((batch.len(), dimensions));
411
412        for (i, entry) in batch.iter().enumerate() {
413            // Simplified gradient computation
414            // In practice, this would involve backpropagation through the model
415            let prediction = self.forward_pass(&entry.input)?;
416            let error = &entry.target - &prediction;
417
418            // Simple gradient: error * input
419            let gradient = &error * &entry.input;
420            gradients.row_mut(i).assign(&gradient);
421        }
422
423        Ok(gradients)
424    }
425
426    /// Apply EWC regularization to gradients
427    fn apply_ewc_regularization(&self, gradients: Array2<f32>) -> Result<Array2<f32>> {
428        let lambda = self.config.forgetting_prevention.ewc_lambda;
429
430        // EWC penalty: λ * F * (θ - θ*)
431        let ewc_penalty =
432            &self.fisher_information * (&self.embeddings - &self.optimal_parameters) * lambda;
433
434        // Regularized gradients
435        let mut regularized = gradients;
436        for i in 0..regularized.nrows().min(ewc_penalty.nrows()) {
437            for j in 0..regularized.ncols().min(ewc_penalty.ncols()) {
438                regularized[[i, j]] -= ewc_penalty[[i, j]];
439            }
440        }
441
442        Ok(regularized)
443    }
444
445    /// Update model parameters
446    fn update_parameters(&mut self, gradients: Array2<f32>) -> Result<()> {
447        let learning_rate = self.config.online_learning_rate;
448
449        // Apply gradients with learning rate
450        let update = &gradients * learning_rate;
451
452        // Ensure embeddings matrix has the right shape
453        if self.embeddings.nrows() < gradients.nrows() {
454            let dimensions = self.config.base_config.dimensions;
455            let new_rows = gradients.nrows();
456            self.embeddings = Array2::from_shape_fn((new_rows, dimensions), |_| {
457                let mut random = Random::default();
458                random.random::<f32>() * 0.1
459            });
460        }
461
462        // Update embeddings
463        let rows_to_update = update.nrows().min(self.embeddings.nrows());
464        let cols_to_update = update.ncols().min(self.embeddings.ncols());
465
466        for i in 0..rows_to_update {
467            for j in 0..cols_to_update {
468                self.embeddings[[i, j]] += update[[i, j]];
469            }
470        }
471
472        Ok(())
473    }
474
475    /// Update Fisher Information Matrix for EWC
476    fn update_fisher_information(&mut self, batch: &[ExperienceEntry]) -> Result<()> {
477        let dimensions = self.config.base_config.dimensions;
478        let mut fisher_update = Array2::zeros((batch.len(), dimensions));
479
480        for (i, entry) in batch.iter().enumerate() {
481            // Compute second-order derivatives (simplified)
482            let prediction = self.forward_pass(&entry.input)?;
483            let second_derivative = prediction.mapv(|x| x * (1.0 - x)); // Sigmoid derivative approximation
484            fisher_update.row_mut(i).assign(&second_derivative);
485        }
486
487        // Update Fisher information with exponential moving average
488        let decay = self.config.memory_decay;
489
490        // Resize Fisher information if needed
491        if self.fisher_information.nrows() < fisher_update.nrows() {
492            self.fisher_information = Array2::zeros((fisher_update.nrows(), dimensions));
493        }
494
495        let rows_to_update = fisher_update.nrows().min(self.fisher_information.nrows());
496        let cols_to_update = fisher_update.ncols().min(self.fisher_information.ncols());
497
498        for i in 0..rows_to_update {
499            for j in 0..cols_to_update {
500                self.fisher_information[[i, j]] =
501                    decay * self.fisher_information[[i, j]] + (1.0 - decay) * fisher_update[[i, j]];
502            }
503        }
504
505        Ok(())
506    }
507
508    /// Forward pass through the model
509    fn forward_pass(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
510        if self.embeddings.is_empty() {
511            return Ok(Array1::zeros(input.len()));
512        }
513
514        // Simple linear transformation
515        let input_len = input.len().min(self.embeddings.ncols());
516        let output_len = self.embeddings.nrows();
517        let mut output = Array1::zeros(output_len);
518
519        for i in 0..output_len {
520            let mut sum = 0.0;
521            for j in 0..input_len {
522                sum += self.embeddings[[i, j]] * input[j];
523            }
524            output[i] = sum.tanh(); // Apply activation
525        }
526
527        Ok(output)
528    }
529
530    /// Evaluate online performance
531    async fn evaluate_online_performance(&mut self) -> Result<()> {
532        if self.replay_buffer.is_empty() {
533            return Ok(());
534        }
535
536        let mut total_loss = 0.0;
537        let mut total_accuracy = 0.0;
538        let mut total_drift = 0.0;
539        let mut total_forgetting = 0.0;
540        let sample_size = self
541            .config
542            .online_evaluation
543            .window_size
544            .min(self.replay_buffer.len());
545
546        for i in 0..sample_size {
547            let idx = self.replay_buffer.len() - 1 - i; // Recent examples
548            let entry = &self.replay_buffer[idx];
549
550            let prediction = self.forward_pass(&entry.input)?;
551
552            // Compute loss (MSE)
553            let diff = &entry.target - &prediction;
554            let loss = diff.dot(&diff) / diff.len() as f32;
555            total_loss += loss;
556
557            // Compute accuracy (simplified)
558            let accuracy = 1.0 / (1.0 + loss);
559            total_accuracy += accuracy;
560
561            // Compute drift (change in prediction distribution)
562            let drift = self.compute_drift_score(&prediction)?;
563            total_drift += drift;
564
565            // Compute forgetting (performance on old tasks)
566            let forgetting = self.compute_forgetting_score(&entry.input, &entry.target)?;
567            total_forgetting += forgetting;
568        }
569
570        let avg_loss = total_loss / sample_size as f32;
571        let avg_accuracy = total_accuracy / sample_size as f32;
572        let avg_drift = total_drift / sample_size as f32;
573        let avg_forgetting = total_forgetting / sample_size as f32;
574
575        self.performance_tracker
576            .update_metrics(avg_loss, avg_accuracy, avg_drift, avg_forgetting);
577
578        Ok(())
579    }
580
581    /// Compute drift score
582    fn compute_drift_score(&self, prediction: &Array1<f32>) -> Result<f32> {
583        // Simplified drift detection based on prediction distribution
584        let mean = prediction.mean().unwrap_or(0.0);
585        let variance = prediction.var(0.0);
586        let drift_score = (mean.abs() + variance).min(1.0);
587        Ok(drift_score)
588    }
589
590    /// Compute forgetting score
591    fn compute_forgetting_score(&self, input: &Array1<f32>, target: &Array1<f32>) -> Result<f32> {
592        let prediction = self.forward_pass(input)?;
593        let diff = target - &prediction;
594        let forgetting_score = diff.dot(&diff).sqrt() / target.len() as f32;
595        Ok(forgetting_score.min(1.0))
596    }
597
598    /// Set current task context
599    pub fn set_current_task(&mut self, task_id: Option<String>) {
600        self.current_task = task_id;
601    }
602
603    /// Save task-specific parameters
604    pub fn save_task_parameters(&mut self, task_id: String) -> Result<()> {
605        self.task_memory.insert(task_id, self.embeddings.clone());
606        Ok(())
607    }
608
609    /// Load task-specific parameters
610    pub fn load_task_parameters(&mut self, task_id: &str) -> Result<()> {
611        if let Some(task_params) = self.task_memory.get(task_id) {
612            self.embeddings = task_params.clone();
613        }
614        Ok(())
615    }
616
617    /// Get online performance statistics
618    pub fn get_online_stats(&self) -> HashMap<String, f32> {
619        let mut stats = HashMap::new();
620
621        stats.insert(
622            "average_loss".to_string(),
623            self.performance_tracker.get_average_loss(),
624        );
625        stats.insert(
626            "average_accuracy".to_string(),
627            self.performance_tracker.get_average_accuracy(),
628        );
629        stats.insert(
630            "drift_score".to_string(),
631            self.performance_tracker.get_drift_score(),
632        );
633        stats.insert(
634            "forgetting_score".to_string(),
635            self.performance_tracker.get_forgetting_score(),
636        );
637        stats.insert("examples_seen".to_string(), self.examples_seen as f32);
638        stats.insert(
639            "update_count".to_string(),
640            self.performance_tracker.update_count as f32,
641        );
642        stats.insert(
643            "replay_buffer_size".to_string(),
644            self.replay_buffer.len() as f32,
645        );
646
647        stats
648    }
649}
650
651#[async_trait]
652impl EmbeddingModel for RealTimeFinetuningModel {
653    fn config(&self) -> &ModelConfig {
654        &self.config.base_config
655    }
656
657    fn model_id(&self) -> &Uuid {
658        &self.model_id
659    }
660
661    fn model_type(&self) -> &'static str {
662        "RealTimeFinetuningModel"
663    }
664
665    fn add_triple(&mut self, triple: Triple) -> Result<()> {
666        let subject_str = triple.subject.iri.clone();
667        let predicate_str = triple.predicate.iri.clone();
668        let object_str = triple.object.iri.clone();
669
670        // Add entities
671        let next_entity_id = self.entities.len();
672        self.entities.entry(subject_str).or_insert(next_entity_id);
673        let next_entity_id = self.entities.len();
674        self.entities.entry(object_str).or_insert(next_entity_id);
675
676        // Add relation
677        let next_relation_id = self.relations.len();
678        self.relations
679            .entry(predicate_str)
680            .or_insert(next_relation_id);
681
682        Ok(())
683    }
684
685    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
686        let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
687        let start_time = std::time::Instant::now();
688
689        let mut loss_history = Vec::new();
690
691        for epoch in 0..epochs {
692            // Simulate training with online adaptation
693            let epoch_loss = {
694                let mut random = Random::default();
695                0.1 * random.random::<f64>()
696            };
697            loss_history.push(epoch_loss);
698
699            // Simulate adding examples and adapting
700            if epoch % 10 == 0 && !self.replay_buffer.is_empty() {
701                self.adapt_online().await?;
702            }
703
704            if epoch > 10 && epoch_loss < 1e-6 {
705                break;
706            }
707        }
708
709        let training_time = start_time.elapsed().as_secs_f64();
710        let final_loss = loss_history.last().copied().unwrap_or(0.0);
711
712        let stats = TrainingStats {
713            epochs_completed: loss_history.len(),
714            final_loss,
715            training_time_seconds: training_time,
716            convergence_achieved: final_loss < 1e-4,
717            loss_history,
718        };
719
720        self.training_stats = Some(stats.clone());
721        self.is_trained = true;
722
723        Ok(stats)
724    }
725
726    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
727        if let Some(&entity_id) = self.entities.get(entity) {
728            if entity_id < self.embeddings.nrows() {
729                let embedding = self.embeddings.row(entity_id);
730                return Ok(Vector::new(embedding.to_vec()));
731            }
732        }
733        Err(anyhow!("Entity not found: {}", entity))
734    }
735
736    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
737        if let Some(&relation_id) = self.relations.get(relation) {
738            if relation_id < self.embeddings.nrows() {
739                let embedding = self.embeddings.row(relation_id);
740                return Ok(Vector::new(embedding.to_vec()));
741            }
742        }
743        Err(anyhow!("Relation not found: {}", relation))
744    }
745
746    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
747        let subject_emb = self.get_entity_embedding(subject)?;
748        let predicate_emb = self.get_relation_embedding(predicate)?;
749        let object_emb = self.get_entity_embedding(object)?;
750
751        // Simple TransE-style scoring
752        let subject_arr = Array1::from_vec(subject_emb.values);
753        let predicate_arr = Array1::from_vec(predicate_emb.values);
754        let object_arr = Array1::from_vec(object_emb.values);
755
756        let predicted = &subject_arr + &predicate_arr;
757        let diff = &predicted - &object_arr;
758        let distance = diff.dot(&diff).sqrt();
759
760        Ok(-distance as f64)
761    }
762
763    fn predict_objects(
764        &self,
765        subject: &str,
766        predicate: &str,
767        k: usize,
768    ) -> Result<Vec<(String, f64)>> {
769        let mut scores = Vec::new();
770
771        for entity in self.entities.keys() {
772            if entity != subject {
773                let score = self.score_triple(subject, predicate, entity)?;
774                scores.push((entity.clone(), score));
775            }
776        }
777
778        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779        scores.truncate(k);
780
781        Ok(scores)
782    }
783
784    fn predict_subjects(
785        &self,
786        predicate: &str,
787        object: &str,
788        k: usize,
789    ) -> Result<Vec<(String, f64)>> {
790        let mut scores = Vec::new();
791
792        for entity in self.entities.keys() {
793            if entity != object {
794                let score = self.score_triple(entity, predicate, object)?;
795                scores.push((entity.clone(), score));
796            }
797        }
798
799        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
800        scores.truncate(k);
801
802        Ok(scores)
803    }
804
805    fn predict_relations(
806        &self,
807        subject: &str,
808        object: &str,
809        k: usize,
810    ) -> Result<Vec<(String, f64)>> {
811        let mut scores = Vec::new();
812
813        for relation in self.relations.keys() {
814            let score = self.score_triple(subject, relation, object)?;
815            scores.push((relation.clone(), score));
816        }
817
818        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
819        scores.truncate(k);
820
821        Ok(scores)
822    }
823
824    fn get_entities(&self) -> Vec<String> {
825        self.entities.keys().cloned().collect()
826    }
827
828    fn get_relations(&self) -> Vec<String> {
829        self.relations.keys().cloned().collect()
830    }
831
832    fn get_stats(&self) -> crate::ModelStats {
833        crate::ModelStats {
834            num_entities: self.entities.len(),
835            num_relations: self.relations.len(),
836            num_triples: 0,
837            dimensions: self.config.base_config.dimensions,
838            is_trained: self.is_trained,
839            model_type: self.model_type().to_string(),
840            creation_time: Utc::now(),
841            last_training_time: if self.is_trained {
842                Some(Utc::now())
843            } else {
844                None
845            },
846        }
847    }
848
849    fn save(&self, _path: &str) -> Result<()> {
850        Ok(())
851    }
852
853    fn load(&mut self, _path: &str) -> Result<()> {
854        Ok(())
855    }
856
857    fn clear(&mut self) {
858        self.entities.clear();
859        self.relations.clear();
860        self.embeddings = Array2::zeros((0, self.config.base_config.dimensions));
861        self.replay_buffer.clear();
862        self.performance_tracker =
863            OnlinePerformanceTracker::new(self.config.online_evaluation.window_size);
864        self.examples_seen = 0;
865        self.is_trained = false;
866        self.training_stats = None;
867    }
868
869    fn is_trained(&self) -> bool {
870        self.is_trained
871    }
872
873    async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
874        let mut results = Vec::new();
875
876        for text in texts {
877            // Simple text encoding
878            let mut embedding = vec![0.0f32; self.config.base_config.dimensions];
879            for (i, c) in text.chars().enumerate() {
880                if i >= self.config.base_config.dimensions {
881                    break;
882                }
883                embedding[i] = (c as u8 as f32) / 255.0;
884            }
885            results.push(embedding);
886        }
887
888        Ok(results)
889    }
890}
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895
896    #[test]
897    fn test_real_time_finetuning_config_default() {
898        let config = RealTimeFinetuningConfig::default();
899        assert_eq!(config.online_learning_rate, 1e-4);
900        assert_eq!(config.replay_buffer_size, 10000);
901        assert_eq!(config.online_batch_size, 32);
902    }
903
904    #[test]
905    fn test_experience_entry_creation() {
906        let entry = ExperienceEntry {
907            input: Array1::from_vec(vec![1.0, 2.0, 3.0]),
908            target: Array1::from_vec(vec![4.0, 5.0, 6.0]),
909            timestamp: Utc::now(),
910            importance: 1.0,
911            task_id: Some("task1".to_string()),
912        };
913
914        assert_eq!(entry.input.len(), 3);
915        assert_eq!(entry.target.len(), 3);
916        assert!(entry.importance > 0.0);
917    }
918
919    #[test]
920    fn test_online_performance_tracker() {
921        let mut tracker = OnlinePerformanceTracker::new(10);
922        tracker.update_metrics(0.5, 0.8, 0.1, 0.2);
923
924        assert_eq!(tracker.get_average_loss(), 0.5);
925        assert_eq!(tracker.get_average_accuracy(), 0.8);
926        assert_eq!(tracker.update_count, 1);
927    }
928
929    #[test]
930    fn test_real_time_finetuning_model_creation() {
931        let config = RealTimeFinetuningConfig::default();
932        let model = RealTimeFinetuningModel::new(config);
933
934        assert_eq!(model.entities.len(), 0);
935        assert_eq!(model.examples_seen, 0);
936        assert!(!model.is_adapting);
937    }
938
939    #[tokio::test]
940    async fn test_add_example_and_adaptation() {
941        let config = RealTimeFinetuningConfig {
942            base_config: ModelConfig {
943                dimensions: 3, // Match array size
944                ..Default::default()
945            },
946            update_frequency: 1, // Adapt on every example
947            ..Default::default()
948        };
949        let mut model = RealTimeFinetuningModel::new(config);
950
951        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
952        let target = Array1::from_vec(vec![4.0, 5.0, 6.0]);
953
954        model
955            .add_example(input, target, Some("task1".to_string()))
956            .await
957            .unwrap();
958
959        assert_eq!(model.examples_seen, 1);
960        assert_eq!(model.replay_buffer.len(), 1);
961    }
962
963    #[tokio::test]
964    async fn test_task_memory_management() {
965        let config = RealTimeFinetuningConfig::default();
966        let mut model = RealTimeFinetuningModel::new(config);
967
968        // Initialize embeddings
969        model.embeddings = Array2::from_shape_fn((5, 10), |_| {
970            let mut random = Random::default();
971            random.random::<f32>()
972        });
973
974        // Save task parameters
975        model.save_task_parameters("task1".to_string()).unwrap();
976
977        // Modify embeddings
978        model.embeddings *= 2.0;
979
980        // Load task parameters
981        model.load_task_parameters("task1").unwrap();
982
983        assert!(model.task_memory.contains_key("task1"));
984    }
985
986    #[test]
987    fn test_online_stats() {
988        let mut config = RealTimeFinetuningConfig::default();
989        config.online_evaluation.window_size = 5;
990        let model = RealTimeFinetuningModel::new(config);
991
992        let stats = model.get_online_stats();
993
994        assert!(stats.contains_key("average_loss"));
995        assert!(stats.contains_key("examples_seen"));
996        assert!(stats.contains_key("replay_buffer_size"));
997        assert_eq!(stats["examples_seen"], 0.0);
998    }
999
1000    #[tokio::test]
1001    async fn test_real_time_training() {
1002        let config = RealTimeFinetuningConfig {
1003            base_config: ModelConfig {
1004                dimensions: 3, // Match array size
1005                ..Default::default()
1006            },
1007            ..Default::default()
1008        };
1009        let mut model = RealTimeFinetuningModel::new(config);
1010
1011        // Add some examples to replay buffer
1012        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1013        let target = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1014        model.add_example(input, target, None).await.unwrap();
1015
1016        let stats = model.train(Some(5)).await.unwrap();
1017        assert_eq!(stats.epochs_completed, 5);
1018        assert!(model.is_trained());
1019    }
1020}