oxirs-embed 0.2.4

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
//! Meta-Learning and Few-Shot Learning for Advanced Embedding Adaptation
//!
//! This module implements state-of-the-art meta-learning algorithms that enable
//! embedding models to quickly adapt to new domains and tasks with minimal data.
//! Features include MAML, Reptile, Prototypical Networks, Model-Agnostic Meta-Learning,
//! and advanced few-shot learning techniques for knowledge graph embeddings.
//!
//! Meta-learning capabilities enable:
//! - Rapid adaptation to new knowledge domains
//! - Few-shot entity and relation learning
//! - Transfer learning across knowledge graphs
//! - Continual learning without catastrophic forgetting
//! - Cross-domain knowledge transfer and adaptation

pub mod types;
pub mod maml;

// Re-export main types and structures
pub use types::*;
pub use maml::MAML;

use crate::{EmbeddingModel, Vector};
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::RwLock;
use tracing::{debug, info, warn, error};

/// Meta-Learning Engine for few-shot adaptation
pub struct MetaLearningEngine {
    /// Configuration for meta-learning
    config: MetaLearningConfig,
    /// MAML (Model-Agnostic Meta-Learning) implementation
    maml: MAML,
    /// Task distribution and sampling
    task_sampler: TaskSampler,
    /// Meta-learning history and statistics
    meta_history: MetaLearningHistory,
    /// Performance metrics
    performance_metrics: MetaPerformanceMetrics,
}

impl MetaLearningEngine {
    /// Create new meta-learning engine
    pub fn new(config: MetaLearningConfig) -> Result<Self> {
        let maml = MAML::new(config.maml_config.clone());
        let task_sampler = TaskSampler::new(config.task_config.clone());
        
        let meta_history = MetaLearningHistory {
            episodes: Vec::new(),
            performance_history: Vec::new(),
            task_statistics: TaskStatistics {
                domain_distribution: HashMap::new(),
                difficulty_distribution: HashMap::new(),
                success_rate_by_domain: HashMap::new(),
                avg_adaptation_time: HashMap::new(),
            },
        };
        
        Ok(Self {
            config,
            maml,
            task_sampler,
            meta_history,
            performance_metrics: MetaPerformanceMetrics::default(),
        })
    }

    /// Run meta-learning training
    pub async fn meta_train(&mut self, num_episodes: usize) -> Result<()> {
        info!("Starting meta-learning training for {} episodes", num_episodes);
        
        for episode in 0..num_episodes {
            info!("Meta-learning episode {}/{}", episode + 1, num_episodes);
            
            // Sample task batch
            let tasks = self.task_sampler.sample_task_batch(8)?; // Batch size of 8
            
            // Train with MAML
            let maml_result = self.maml.meta_learn_episode(&tasks).await?;
            
            // Track results
            let episode_result = EpisodeResult {
                episode,
                avg_loss: maml_result.average_loss,
                avg_accuracy: 1.0 - maml_result.average_loss, // Simplified accuracy
                task_results: maml_result.adaptation_results.into_iter().map(|ar| TaskResult {
                    task_id: ar.task_id,
                    loss: ar.final_loss,
                    accuracy: 1.0 - ar.final_loss, // Simplified accuracy
                    adaptation_steps: ar.adaptation_steps,
                    metadata: ar.task_metadata,
                }).collect(),
                duration: std::time::Duration::from_millis(100), // Simplified
            };
            
            self.meta_history.episodes.push(episode_result);
            self.update_performance_trends(&maml_result);
            
            // Evaluation
            if episode % self.config.global_settings.eval_frequency == 0 {
                let eval_result = self.evaluate_meta_learning().await?;
                info!("Episode {} evaluation: few-shot accuracy = {:.4}", episode, eval_result.few_shot_accuracy);
            }
            
            // Early stopping check
            if self.check_early_stopping() {
                info!("Early stopping triggered at episode {}", episode);
                break;
            }
        }
        
        info!("Meta-learning training completed");
        Ok(())
    }

    /// Evaluate meta-learning performance
    pub async fn evaluate_meta_learning(&mut self) -> Result<MetaPerformanceMetrics> {
        let eval_tasks = self.task_sampler.sample_task_batch(20)?;
        let mut total_accuracy = 0.0;
        let mut total_adaptation_time = 0.0;

        for task in &eval_tasks {
            let adaptation_result = self.maml.adapt_to_task(task).await?;
            let accuracy = 1.0 - adaptation_result.final_loss; // Simplified accuracy calculation
            total_accuracy += accuracy;
            total_adaptation_time += adaptation_result.duration.as_millis() as f32;
        }

        let avg_accuracy = total_accuracy / eval_tasks.len() as f32;
        let avg_adaptation_time = total_adaptation_time / eval_tasks.len() as f32;

        self.performance_metrics.few_shot_accuracy = avg_accuracy;
        self.performance_metrics.avg_adaptation_time_ms = avg_adaptation_time;

        Ok(self.performance_metrics.clone())
    }

    fn update_performance_trends(&mut self, result: &MetaLearningResult) {
        use std::time::Instant;
        
        let snapshot = PerformanceSnapshot {
            timestamp: Instant::now(),
            avg_loss: result.average_loss,
            avg_accuracy: 1.0 - result.average_loss, // Simplified
            learning_rate: self.config.global_settings.meta_learning_rate,
            memory_usage: 0, // Simplified
        };
        
        self.meta_history.performance_history.push(snapshot);
    }

    fn check_early_stopping(&self) -> bool {
        if !self.config.global_settings.enable_early_stopping {
            return false;
        }

        if self.meta_history.episodes.len() < self.config.global_settings.early_stopping_patience {
            return false;
        }

        // Check if performance hasn't improved in the last N episodes
        let recent_episodes = &self.meta_history.episodes[
            self.meta_history.episodes.len() - self.config.global_settings.early_stopping_patience..
        ];

        let recent_losses: Vec<f32> = recent_episodes.iter().map(|e| e.avg_loss).collect();
        let min_loss = recent_losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
        let max_loss = recent_losses.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));

        // If the difference between min and max loss in recent episodes is very small,
        // consider it as no improvement
        (max_loss - min_loss) < 0.001
    }

    /// Get meta-learning statistics
    pub fn get_statistics(&self) -> &MetaLearningHistory {
        &self.meta_history
    }

    /// Get current performance metrics
    pub fn get_performance_metrics(&self) -> &MetaPerformanceMetrics {
        &self.performance_metrics
    }

    /// Adapt to a new task
    pub async fn adapt_to_task(&mut self, task: &Task) -> Result<AdaptationResult> {
        self.maml.adapt_to_task(task).await
    }
}

/// Task sampler for meta-learning
pub struct TaskSampler {
    config: TaskSamplingConfig,
    sampling_statistics: SamplingStatistics,
}

impl TaskSampler {
    pub fn new(config: TaskSamplingConfig) -> Self {
        Self {
            config,
            sampling_statistics: SamplingStatistics {
                total_tasks: 0,
                domain_counts: HashMap::new(),
                difficulty_histogram: vec![0; 10], // 10 difficulty bins
                avg_generation_time_ms: 0.0,
            },
        }
    }

    /// Sample a batch of tasks
    pub fn sample_task_batch(&mut self, batch_size: usize) -> Result<Vec<Task>> {
        let mut tasks = Vec::new();
        
        for _ in 0..batch_size {
            let task = self.sample_single_task()?;
            tasks.push(task);
        }
        
        Ok(tasks)
    }

    /// Sample a single task
    pub fn sample_single_task(&mut self) -> Result<Task> {
        use std::time::Instant;
        use scirs2_core::random::{seq::SliceRandom, thread_rng, RngExt};
        use uuid::Uuid;
        use scirs2_core::ndarray_ext::Array1;
        
        let start_time = Instant::now();
        
        // Select domain
        let domain = self.select_domain()?;
        
        // Select difficulty
        let difficulty = self.sample_difficulty();
        
        // Generate task
        let task = self.generate_task(&domain, difficulty)?;
        
        // Update statistics
        let generation_time = start_time.elapsed().as_millis() as f32;
        self.update_statistics(&domain, difficulty, generation_time);
        
        Ok(task)
    }

    fn select_domain(&self) -> Result<String> {
        if self.config.domains.is_empty() {
            return Err(anyhow!("No domains configured"));
        }
        
        use scirs2_core::random::{seq::SliceRandom, thread_rng};
        let mut rng = rand::rng();
        Ok(self.config.domains.choose(&mut rng).expect("domains should not be empty").clone())
    }

    fn sample_difficulty(&self) -> f32 {
        use scirs2_core::random::Random;
        let mut rng = Random::default();

        match &self.config.difficulty_sampling {
            s if s == "uniform" => rng.uniform(0.0, 1.0),
            s if s == "normal" => {
                // Simplified normal distribution
                let val: f32 = rng.random();
                (val - 0.5) * 2.0 + 0.5 // Center around 0.5
            }
            _ => rng.uniform(0.0, 1.0),
        }
    }

    fn generate_task(&self, domain: &str, difficulty: f32) -> Result<Task> {
        use scirs2_core::random::{thread_rng, RngExt};
        use uuid::Uuid;
        use scirs2_core::ndarray_ext::Array1;
        use std::time::Instant;
        
        let mut rng = rand::rng();
        
        // Generate synthetic data for now
        let mut support_set = Vec::new();
        let mut query_set = Vec::new();
        
        let n_way = 5; // Number of classes
        let k_shot = 1; // Support examples per class
        let n_query = 15; // Query examples per class
        
        for class_idx in 0..n_way {
            // Generate support examples
            for _ in 0..k_shot {
                let data_point = self.generate_data_point(class_idx, difficulty, n_way, &mut rng)?;
                support_set.push(data_point);
            }
            
            // Generate query examples
            for _ in 0..n_query {
                let data_point = self.generate_data_point(class_idx, difficulty, n_way, &mut rng)?;
                query_set.push(data_point);
            }
        }
        
        Ok(Task {
            id: Uuid::new_v4(),
            task_type: format!("{}_way_{}_shot", n_way, k_shot),
            support_set,
            query_set,
            metadata: TaskMetadata {
                domain: domain.to_string(),
                difficulty,
                support_size: support_set.len(),
                query_size: query_set.len(),
                created_at: Instant::now(),
            },
        })
    }

    fn generate_data_point(&self, class_idx: usize, difficulty: f32, n_way: usize, rng: &mut impl Rng) -> Result<DataPoint> {
        use scirs2_core::ndarray_ext::Array1;
        
        let input_dim = 128;
        
        // Generate synthetic input with class-specific pattern
        let mut input = Array1::zeros(input_dim);
        for i in 0..input_dim {
            let base_value = (class_idx as f32 * 2.0 + difficulty) * (i as f32 / input_dim as f32);
            let noise = rng.uniform(-0.1, 0.1) * difficulty;
            input[i] = base_value + noise;
        }
        
        // Generate one-hot target
        let mut target = Array1::zeros(n_way);
        target[class_idx] = 1.0;
        
        Ok(DataPoint {
            input,
            target,
            metadata: None,
        })
    }

    fn update_statistics(&mut self, domain: &str, difficulty: f32, generation_time: f32) {
        self.sampling_statistics.total_tasks += 1;
        
        *self.sampling_statistics.domain_counts.entry(domain.to_string()).or_insert(0) += 1;
        
        let difficulty_bin = (difficulty * 10.0) as usize;
        if difficulty_bin < self.sampling_statistics.difficulty_histogram.len() {
            self.sampling_statistics.difficulty_histogram[difficulty_bin] += 1;
        }
        
        let alpha = 0.1; // Exponential moving average
        self.sampling_statistics.avg_generation_time_ms = 
            alpha * generation_time + (1.0 - alpha) * self.sampling_statistics.avg_generation_time_ms;
    }

    /// Get sampling statistics
    pub fn get_statistics(&self) -> &SamplingStatistics {
        &self.sampling_statistics
    }
}

/// Sampling statistics
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SamplingStatistics {
    /// Total tasks sampled
    pub total_tasks: usize,
    /// Tasks per domain
    pub domain_counts: HashMap<String, usize>,
    /// Difficulty distribution
    pub difficulty_histogram: Vec<usize>,
    /// Average task generation time
    pub avg_generation_time_ms: f32,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_meta_learning_engine_creation() {
        let config = MetaLearningConfig::default();
        let engine = MetaLearningEngine::new(config);
        assert!(engine.is_ok());
    }

    #[test]
    fn test_task_sampler() {
        let config = TaskSamplingConfig::default();
        let mut sampler = TaskSampler::new(config);
        
        let task = sampler.sample_single_task();
        assert!(task.is_ok());
        
        let task = task.expect("should succeed");
        assert_eq!(task.support_set.len(), 5); // n_way * k_shot
        assert_eq!(task.query_set.len(), 75); // n_way * n_query
    }
}