oxirs-embed 0.3.1

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
use crate::continual_learning_types::{
    ArchitectureAdaptation, BoundaryDetection, ContinualLearningModel, MemoryEntry,
    MemoryUpdateStrategy, RegularizationMethod, ReplayMethod, TaskDetection,
};
use anyhow::Result;
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_core::random::{Random, RngExt};

impl ContinualLearningModel {
    pub(crate) fn add_to_memory(
        &mut self,
        data: Array1<f32>,
        target: Array1<f32>,
        task_id: String,
    ) -> Result<()> {
        let mut random = Random::default();
        let entry = MemoryEntry::new(data, target, task_id);

        match self.config.memory_config.update_strategy {
            MemoryUpdateStrategy::FIFO => {
                if self.episodic_memory.len() >= self.config.memory_config.memory_capacity {
                    self.episodic_memory.pop_front();
                }
                self.episodic_memory.push_back(entry);
            }
            MemoryUpdateStrategy::Random => {
                if self.episodic_memory.len() >= self.config.memory_config.memory_capacity {
                    let idx = random.random_range(0..self.episodic_memory.len());
                    self.episodic_memory.remove(idx);
                }
                self.episodic_memory.push_back(entry);
            }
            MemoryUpdateStrategy::ReservoirSampling => {
                if self.episodic_memory.len() < self.config.memory_config.memory_capacity {
                    self.episodic_memory.push_back(entry);
                } else {
                    let k = self.episodic_memory.len();
                    let j = random.random_range(0..self.examples_seen + 1);
                    if j < k {
                        self.episodic_memory[j] = entry;
                    }
                }
            }
            MemoryUpdateStrategy::ImportanceBased => {
                self.add_by_importance(entry)?;
            }
            _ => {
                self.episodic_memory.push_back(entry);
            }
        }

        Ok(())
    }

    pub(crate) fn add_by_importance(&mut self, entry: MemoryEntry) -> Result<()> {
        if self.episodic_memory.len() < self.config.memory_config.memory_capacity {
            self.episodic_memory.push_back(entry);
        } else {
            let mut min_importance = f32::INFINITY;
            let mut min_idx = 0;

            for (i, existing_entry) in self.episodic_memory.iter().enumerate() {
                if existing_entry.importance < min_importance {
                    min_importance = existing_entry.importance;
                    min_idx = i;
                }
            }

            if entry.importance > min_importance {
                self.episodic_memory[min_idx] = entry;
            }
        }

        Ok(())
    }

    pub(crate) fn detect_task_boundary(&self, data: &Array1<f32>) -> Result<bool> {
        match self.config.task_config.boundary_detection {
            BoundaryDetection::ChangePoint => self.detect_change_point(data),
            BoundaryDetection::DistributionShift => self.detect_distribution_shift(data),
            BoundaryDetection::LossBased => self.detect_loss_change(data),
            BoundaryDetection::GradientBased => self.detect_gradient_change(data),
        }
    }

    fn detect_change_point(&self, _data: &Array1<f32>) -> Result<bool> {
        if self.examples_seen % 1000 == 0 && self.examples_seen > 0 {
            Ok(true)
        } else {
            Ok(false)
        }
    }

    fn detect_distribution_shift(&self, data: &Array1<f32>) -> Result<bool> {
        if self.episodic_memory.is_empty() {
            return Ok(false);
        }

        let recent_count = 100.min(self.episodic_memory.len());
        let mut total_distance = 0.0;

        for i in 0..recent_count {
            let idx = self.episodic_memory.len() - 1 - i;
            let recent_data = &self.episodic_memory[idx].data;
            let distance = self.euclidean_distance(data, recent_data);
            total_distance += distance;
        }

        let average_distance = total_distance / recent_count as f32;
        let threshold = 2.0;

        Ok(average_distance > threshold)
    }

    fn detect_loss_change(&self, _data: &Array1<f32>) -> Result<bool> {
        Ok(false)
    }

    fn detect_gradient_change(&self, _data: &Array1<f32>) -> Result<bool> {
        Ok(false)
    }

    pub(crate) fn apply_regularization(&self, mut gradients: Array2<f32>) -> Result<Array2<f32>> {
        for method in &self.config.regularization_config.methods {
            match method {
                RegularizationMethod::EWC => {
                    gradients = self.apply_ewc_regularization(gradients)?;
                }
                RegularizationMethod::SynapticIntelligence => {
                    gradients = self.apply_si_regularization(gradients)?;
                }
                RegularizationMethod::LwF => {
                    gradients = self.apply_lwf_regularization(gradients)?;
                }
                _ => {}
            }
        }

        Ok(gradients)
    }

    fn apply_ewc_regularization(&self, mut gradients: Array2<f32>) -> Result<Array2<f32>> {
        let lambda = self.config.regularization_config.ewc_config.lambda;

        for ewc_state in &self.ewc_states {
            let penalty = &ewc_state.fisher_information
                * (&self.embeddings - &ewc_state.optimal_parameters)
                * lambda
                * ewc_state.importance;

            let rows_to_update = gradients.nrows().min(penalty.nrows());
            let cols_to_update = gradients.ncols().min(penalty.ncols());

            for i in 0..rows_to_update {
                for j in 0..cols_to_update {
                    gradients[[i, j]] -= penalty[[i, j]];
                }
            }
        }

        Ok(gradients)
    }

    fn apply_si_regularization(&self, mut gradients: Array2<f32>) -> Result<Array2<f32>> {
        let c = self.config.regularization_config.si_config.c;

        if !self.synaptic_importance.is_empty() {
            let penalty = &self.synaptic_importance * c;

            let rows_to_update = gradients.nrows().min(penalty.nrows());
            let cols_to_update = gradients.ncols().min(penalty.ncols());

            for i in 0..rows_to_update {
                for j in 0..cols_to_update {
                    gradients[[i, j]] -= penalty[[i, j]];
                }
            }
        }

        Ok(gradients)
    }

    fn apply_lwf_regularization(&self, gradients: Array2<f32>) -> Result<Array2<f32>> {
        Ok(gradients)
    }

    pub(crate) fn compute_ewc_state(&mut self) -> Result<()> {
        use crate::continual_learning_types::EWCState;

        if let Some(ref current_task) = self.current_task {
            let mut fisher_information = Array2::zeros(self.embeddings.dim());

            for entry in &self.episodic_memory {
                if entry.task_id == current_task.task_id {
                    let gradients = self.compute_gradients(&entry.data, &entry.target)?;

                    let rows_to_update = gradients.nrows().min(fisher_information.nrows());
                    let cols_to_update = gradients.ncols().min(fisher_information.ncols());

                    for i in 0..rows_to_update {
                        for j in 0..cols_to_update {
                            fisher_information[[i, j]] += gradients[[i, j]] * gradients[[i, j]];
                        }
                    }
                }
            }

            let task_examples = self
                .episodic_memory
                .iter()
                .filter(|entry| entry.task_id == current_task.task_id)
                .count() as f32;

            if task_examples > 0.0 {
                fisher_information /= task_examples;
            }

            let ewc_state = EWCState {
                fisher_information,
                optimal_parameters: self.embeddings.clone(),
                task_id: current_task.task_id.clone(),
                importance: 1.0,
            };

            self.ewc_states.push(ewc_state);
        }

        Ok(())
    }

    pub(crate) fn add_network_column(&mut self) -> Result<()> {
        let dimensions = self.config.base_config.dimensions;
        let mut random = Random::default();
        let new_column =
            Array2::from_shape_fn((dimensions, dimensions), |_| random.random::<f32>() * 0.1);
        self.network_columns.push(new_column);

        if self.network_columns.len() > 1 {
            let lateral_connection = Array2::from_shape_fn((dimensions, dimensions), |_| {
                random.random::<f32>()
                    * self
                        .config
                        .architecture_config
                        .progressive_config
                        .lateral_strength
            });
            self.lateral_connections.push(lateral_connection);
        }

        Ok(())
    }

    pub(crate) async fn experience_replay(&mut self) -> Result<()> {
        if self.episodic_memory.is_empty() {
            return Ok(());
        }

        let mut random = Random::default();
        let replay_batch_size = (self.config.replay_config.replay_ratio * 32.0) as usize;
        let batch_size = replay_batch_size.min(self.episodic_memory.len());

        for _ in 0..batch_size {
            let idx = random.random_range(0..self.episodic_memory.len());

            let (data, target) = {
                let entry = &self.episodic_memory[idx];
                (entry.data.clone(), entry.target.clone())
            };

            self.episodic_memory[idx].access_count += 1;

            let gradients = self.compute_gradients(&data, &target)?;
            let regularized_gradients = self.apply_regularization(gradients)?;
            self.update_parameters(regularized_gradients)?;
        }

        Ok(())
    }

    pub(crate) async fn generative_replay(&mut self) -> Result<()> {
        if let Some(ref generator) = self.generator {
            let _replay_batch_size = (self.config.replay_config.replay_ratio * 32.0) as usize;
            let _generator_clone = generator.clone();
        }

        if let Some(generator) = self.generator.clone() {
            let replay_batch_size = (self.config.replay_config.replay_ratio * 32.0) as usize;

            for _ in 0..replay_batch_size {
                let mut random = Random::default();
                let noise = Array1::from_shape_fn(generator.ncols(), |_| random.random::<f32>());
                let generated_data = generator.dot(&noise);
                let generated_target = generated_data.mapv(|x| x.tanh());

                let gradients = self.compute_gradients(&generated_data, &generated_target)?;
                let regularized_gradients = self.apply_regularization(gradients)?;
                self.update_parameters(regularized_gradients)?;
            }
        }

        Ok(())
    }

    pub(crate) fn detect_is_automatic(&self) -> bool {
        matches!(
            self.config.task_config.detection_method,
            TaskDetection::Automatic
        )
    }

    pub(crate) fn is_progressive(&self) -> bool {
        matches!(
            self.config.architecture_config.adaptation_method,
            ArchitectureAdaptation::Progressive
        )
    }

    pub(crate) fn should_use_ewc(&self) -> bool {
        self.config
            .regularization_config
            .methods
            .contains(&RegularizationMethod::EWC)
    }

    pub(crate) fn should_use_si(&self) -> bool {
        self.config
            .regularization_config
            .methods
            .contains(&RegularizationMethod::SynapticIntelligence)
    }

    pub(crate) fn should_replay_experience(&self) -> bool {
        self.config
            .replay_config
            .methods
            .contains(&ReplayMethod::ExperienceReplay)
    }

    pub(crate) fn should_replay_generative(&self) -> bool {
        self.config
            .replay_config
            .methods
            .contains(&ReplayMethod::GenerativeReplay)
    }
}