oxirs-embed 0.2.4

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
//! Real-time adaptation and fine-tuning components for multi-modal embeddings

use super::model::MultiModalEmbedding;
use anyhow::Result;
use scirs2_core::ndarray_ext::Array2;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Real-time fine-tuning capabilities
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealTimeFinetuning {
    /// Learning rate for real-time updates
    pub learning_rate: f32,
    /// Buffer size for online learning
    pub buffer_size: usize,
    /// Update frequency
    pub update_frequency: usize,
    /// Elastic weight consolidation parameters
    pub ewc_config: EWCConfig,
    /// Online learning buffer
    pub online_buffer: Vec<(String, String, String)>,
    /// Current update count
    pub update_count: usize,
}

/// Elastic Weight Consolidation configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EWCConfig {
    /// EWC lambda parameter
    pub lambda: f32,
    /// Fisher information matrix
    pub fisher_information: HashMap<String, Array2<f32>>,
    /// Optimal parameters from previous tasks
    pub optimal_params: HashMap<String, Array2<f32>>,
}

impl Default for RealTimeFinetuning {
    fn default() -> Self {
        Self {
            learning_rate: 0.001,
            buffer_size: 1000,
            update_frequency: 10,
            ewc_config: EWCConfig::default(),
            online_buffer: Vec::new(),
            update_count: 0,
        }
    }
}

impl Default for EWCConfig {
    fn default() -> Self {
        Self {
            lambda: 0.1,
            fisher_information: HashMap::new(),
            optimal_params: HashMap::new(),
        }
    }
}

impl RealTimeFinetuning {
    /// Add new training example for real-time learning
    pub fn add_example(&mut self, text: String, entity: String, label: String) {
        self.online_buffer.push((text, entity, label));

        // Keep buffer size limited
        if self.online_buffer.len() > self.buffer_size {
            self.online_buffer.remove(0);
        }

        self.update_count += 1;
    }

    /// Check if model needs updating
    pub fn should_update(&self) -> bool {
        self.update_count % self.update_frequency == 0 && !self.online_buffer.is_empty()
    }

    /// Perform real-time model update
    pub async fn update_model(&mut self, model: &mut MultiModalEmbedding) -> Result<f32> {
        if !self.should_update() {
            return Ok(0.0);
        }

        let mut total_loss = 0.0;
        let batch_size = self.update_frequency.min(self.online_buffer.len());

        // Take recent examples for update
        let update_batch = &self.online_buffer[self.online_buffer.len() - batch_size..];

        for (text, entity, _label) in update_batch {
            // Generate unified embedding
            let unified = model.generate_unified_embedding(text, entity).await?;

            // Compute reconstruction loss
            let loss = unified.iter().map(|&x| x * x).sum::<f32>() / unified.len() as f32;
            total_loss += loss;

            // Apply EWC regularization
            let ewc_loss = self.compute_ewc_loss(&model.text_encoder.parameters)?;
            total_loss += ewc_loss * self.ewc_config.lambda;
        }

        total_loss /= batch_size as f32;

        // Update Fisher information (simplified)
        self.update_fisher_information(model)?;

        Ok(total_loss)
    }

    /// Compute EWC regularization loss
    fn compute_ewc_loss(&self, current_params: &HashMap<String, Array2<f32>>) -> Result<f32> {
        let mut ewc_loss = 0.0;

        for (param_name, current_param) in current_params {
            if let (Some(fisher), Some(optimal)) = (
                self.ewc_config.fisher_information.get(param_name),
                self.ewc_config.optimal_params.get(param_name),
            ) {
                let diff = current_param - optimal;
                let weighted_diff = &diff * fisher;
                ewc_loss += (&diff * &weighted_diff).sum();
            }
        }

        Ok(ewc_loss)
    }

    /// Update Fisher information matrix
    fn update_fisher_information(&mut self, model: &MultiModalEmbedding) -> Result<()> {
        for (param_name, param) in &model.text_encoder.parameters {
            // Simplified Fisher information computation
            let fisher = Array2::from_shape_fn(param.dim(), |(_, _)| {
                use scirs2_core::random::{Random, RngExt};
                let mut random = Random::default();
                random.random::<f32>() * 0.01
            });
            self.ewc_config
                .fisher_information
                .insert(param_name.clone(), fisher);
            self.ewc_config
                .optimal_params
                .insert(param_name.clone(), param.clone());
        }

        Ok(())
    }
}