pub mod types;
pub mod maml;
pub use types::*;
pub use maml::MAML;
use crate::{EmbeddingModel, Vector};
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::RwLock;
use tracing::{debug, info, warn, error};
pub struct MetaLearningEngine {
config: MetaLearningConfig,
maml: MAML,
task_sampler: TaskSampler,
meta_history: MetaLearningHistory,
performance_metrics: MetaPerformanceMetrics,
}
impl MetaLearningEngine {
pub fn new(config: MetaLearningConfig) -> Result<Self> {
let maml = MAML::new(config.maml_config.clone());
let task_sampler = TaskSampler::new(config.task_config.clone());
let meta_history = MetaLearningHistory {
episodes: Vec::new(),
performance_history: Vec::new(),
task_statistics: TaskStatistics {
domain_distribution: HashMap::new(),
difficulty_distribution: HashMap::new(),
success_rate_by_domain: HashMap::new(),
avg_adaptation_time: HashMap::new(),
},
};
Ok(Self {
config,
maml,
task_sampler,
meta_history,
performance_metrics: MetaPerformanceMetrics::default(),
})
}
pub async fn meta_train(&mut self, num_episodes: usize) -> Result<()> {
info!("Starting meta-learning training for {} episodes", num_episodes);
for episode in 0..num_episodes {
info!("Meta-learning episode {}/{}", episode + 1, num_episodes);
let tasks = self.task_sampler.sample_task_batch(8)?;
let maml_result = self.maml.meta_learn_episode(&tasks).await?;
let episode_result = EpisodeResult {
episode,
avg_loss: maml_result.average_loss,
avg_accuracy: 1.0 - maml_result.average_loss, task_results: maml_result.adaptation_results.into_iter().map(|ar| TaskResult {
task_id: ar.task_id,
loss: ar.final_loss,
accuracy: 1.0 - ar.final_loss, adaptation_steps: ar.adaptation_steps,
metadata: ar.task_metadata,
}).collect(),
duration: std::time::Duration::from_millis(100), };
self.meta_history.episodes.push(episode_result);
self.update_performance_trends(&maml_result);
if episode % self.config.global_settings.eval_frequency == 0 {
let eval_result = self.evaluate_meta_learning().await?;
info!("Episode {} evaluation: few-shot accuracy = {:.4}", episode, eval_result.few_shot_accuracy);
}
if self.check_early_stopping() {
info!("Early stopping triggered at episode {}", episode);
break;
}
}
info!("Meta-learning training completed");
Ok(())
}
pub async fn evaluate_meta_learning(&mut self) -> Result<MetaPerformanceMetrics> {
let eval_tasks = self.task_sampler.sample_task_batch(20)?;
let mut total_accuracy = 0.0;
let mut total_adaptation_time = 0.0;
for task in &eval_tasks {
let adaptation_result = self.maml.adapt_to_task(task).await?;
let accuracy = 1.0 - adaptation_result.final_loss; total_accuracy += accuracy;
total_adaptation_time += adaptation_result.duration.as_millis() as f32;
}
let avg_accuracy = total_accuracy / eval_tasks.len() as f32;
let avg_adaptation_time = total_adaptation_time / eval_tasks.len() as f32;
self.performance_metrics.few_shot_accuracy = avg_accuracy;
self.performance_metrics.avg_adaptation_time_ms = avg_adaptation_time;
Ok(self.performance_metrics.clone())
}
fn update_performance_trends(&mut self, result: &MetaLearningResult) {
use std::time::Instant;
let snapshot = PerformanceSnapshot {
timestamp: Instant::now(),
avg_loss: result.average_loss,
avg_accuracy: 1.0 - result.average_loss, learning_rate: self.config.global_settings.meta_learning_rate,
memory_usage: 0, };
self.meta_history.performance_history.push(snapshot);
}
fn check_early_stopping(&self) -> bool {
if !self.config.global_settings.enable_early_stopping {
return false;
}
if self.meta_history.episodes.len() < self.config.global_settings.early_stopping_patience {
return false;
}
let recent_episodes = &self.meta_history.episodes[
self.meta_history.episodes.len() - self.config.global_settings.early_stopping_patience..
];
let recent_losses: Vec<f32> = recent_episodes.iter().map(|e| e.avg_loss).collect();
let min_loss = recent_losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_loss = recent_losses.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
(max_loss - min_loss) < 0.001
}
pub fn get_statistics(&self) -> &MetaLearningHistory {
&self.meta_history
}
pub fn get_performance_metrics(&self) -> &MetaPerformanceMetrics {
&self.performance_metrics
}
pub async fn adapt_to_task(&mut self, task: &Task) -> Result<AdaptationResult> {
self.maml.adapt_to_task(task).await
}
}
pub struct TaskSampler {
config: TaskSamplingConfig,
sampling_statistics: SamplingStatistics,
}
impl TaskSampler {
pub fn new(config: TaskSamplingConfig) -> Self {
Self {
config,
sampling_statistics: SamplingStatistics {
total_tasks: 0,
domain_counts: HashMap::new(),
difficulty_histogram: vec![0; 10], avg_generation_time_ms: 0.0,
},
}
}
pub fn sample_task_batch(&mut self, batch_size: usize) -> Result<Vec<Task>> {
let mut tasks = Vec::new();
for _ in 0..batch_size {
let task = self.sample_single_task()?;
tasks.push(task);
}
Ok(tasks)
}
pub fn sample_single_task(&mut self) -> Result<Task> {
use std::time::Instant;
use scirs2_core::random::{seq::SliceRandom, thread_rng, RngExt};
use uuid::Uuid;
use scirs2_core::ndarray_ext::Array1;
let start_time = Instant::now();
let domain = self.select_domain()?;
let difficulty = self.sample_difficulty();
let task = self.generate_task(&domain, difficulty)?;
let generation_time = start_time.elapsed().as_millis() as f32;
self.update_statistics(&domain, difficulty, generation_time);
Ok(task)
}
fn select_domain(&self) -> Result<String> {
if self.config.domains.is_empty() {
return Err(anyhow!("No domains configured"));
}
use scirs2_core::random::{seq::SliceRandom, thread_rng};
let mut rng = rand::rng();
Ok(self.config.domains.choose(&mut rng).expect("domains should not be empty").clone())
}
fn sample_difficulty(&self) -> f32 {
use scirs2_core::random::Random;
let mut rng = Random::default();
match &self.config.difficulty_sampling {
s if s == "uniform" => rng.uniform(0.0, 1.0),
s if s == "normal" => {
let val: f32 = rng.random();
(val - 0.5) * 2.0 + 0.5 }
_ => rng.uniform(0.0, 1.0),
}
}
fn generate_task(&self, domain: &str, difficulty: f32) -> Result<Task> {
use scirs2_core::random::{thread_rng, RngExt};
use uuid::Uuid;
use scirs2_core::ndarray_ext::Array1;
use std::time::Instant;
let mut rng = rand::rng();
let mut support_set = Vec::new();
let mut query_set = Vec::new();
let n_way = 5; let k_shot = 1; let n_query = 15;
for class_idx in 0..n_way {
for _ in 0..k_shot {
let data_point = self.generate_data_point(class_idx, difficulty, n_way, &mut rng)?;
support_set.push(data_point);
}
for _ in 0..n_query {
let data_point = self.generate_data_point(class_idx, difficulty, n_way, &mut rng)?;
query_set.push(data_point);
}
}
Ok(Task {
id: Uuid::new_v4(),
task_type: format!("{}_way_{}_shot", n_way, k_shot),
support_set,
query_set,
metadata: TaskMetadata {
domain: domain.to_string(),
difficulty,
support_size: support_set.len(),
query_size: query_set.len(),
created_at: Instant::now(),
},
})
}
fn generate_data_point(&self, class_idx: usize, difficulty: f32, n_way: usize, rng: &mut impl Rng) -> Result<DataPoint> {
use scirs2_core::ndarray_ext::Array1;
let input_dim = 128;
let mut input = Array1::zeros(input_dim);
for i in 0..input_dim {
let base_value = (class_idx as f32 * 2.0 + difficulty) * (i as f32 / input_dim as f32);
let noise = rng.uniform(-0.1, 0.1) * difficulty;
input[i] = base_value + noise;
}
let mut target = Array1::zeros(n_way);
target[class_idx] = 1.0;
Ok(DataPoint {
input,
target,
metadata: None,
})
}
fn update_statistics(&mut self, domain: &str, difficulty: f32, generation_time: f32) {
self.sampling_statistics.total_tasks += 1;
*self.sampling_statistics.domain_counts.entry(domain.to_string()).or_insert(0) += 1;
let difficulty_bin = (difficulty * 10.0) as usize;
if difficulty_bin < self.sampling_statistics.difficulty_histogram.len() {
self.sampling_statistics.difficulty_histogram[difficulty_bin] += 1;
}
let alpha = 0.1; self.sampling_statistics.avg_generation_time_ms =
alpha * generation_time + (1.0 - alpha) * self.sampling_statistics.avg_generation_time_ms;
}
pub fn get_statistics(&self) -> &SamplingStatistics {
&self.sampling_statistics
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SamplingStatistics {
pub total_tasks: usize,
pub domain_counts: HashMap<String, usize>,
pub difficulty_histogram: Vec<usize>,
pub avg_generation_time_ms: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_meta_learning_engine_creation() {
let config = MetaLearningConfig::default();
let engine = MetaLearningEngine::new(config);
assert!(engine.is_ok());
}
#[test]
fn test_task_sampler() {
let config = TaskSamplingConfig::default();
let mut sampler = TaskSampler::new(config);
let task = sampler.sample_single_task();
assert!(task.is_ok());
let task = task.expect("should succeed");
assert_eq!(task.support_set.len(), 5); assert_eq!(task.query_set.len(), 75); }
}