use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaLearner {
pub id: Uuid,
pub name: String,
pub meta_params: Vec<f32>,
pub strategies: Vec<LearningStrategy>,
pub task_embeddings: HashMap<String, Vec<f32>>,
pub meta_lr: f32,
pub inner_lr: f32,
pub task_count: u64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl MetaLearner {
pub fn new(name: impl Into<String>, num_params: usize) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4(),
name: name.into(),
meta_params: vec![0.0; num_params],
strategies: Vec::new(),
task_embeddings: HashMap::new(),
meta_lr: 0.1,
inner_lr: 0.01,
task_count: 0,
created_at: now,
updated_at: now,
}
}
pub fn with_meta_lr(mut self, lr: f32) -> Self {
self.meta_lr = lr;
self
}
pub fn with_inner_lr(mut self, lr: f32) -> Self {
self.inner_lr = lr;
self
}
pub fn initialize_for_task(&self, task_embedding: Option<&[f32]>) -> Vec<f32> {
let mut params = self.meta_params.clone();
if let Some(emb) = task_embedding
&& let Some((_, similar_params)) = self.find_similar_task(emb)
{
for i in 0..params.len().min(similar_params.len()) {
params[i] = 0.7 * params[i] + 0.3 * similar_params[i];
}
}
params
}
fn find_similar_task(&self, embedding: &[f32]) -> Option<(&str, Vec<f32>)> {
let mut best_sim = -1.0f32;
let mut best_task: Option<&str> = None;
for (task_id, task_emb) in &self.task_embeddings {
let sim = cosine_similarity(embedding, task_emb);
if sim > best_sim {
best_sim = sim;
best_task = Some(task_id);
}
}
if best_sim > 0.5 {
best_task.map(|t| (t, self.meta_params.clone()))
} else {
None
}
}
pub fn meta_update(
&mut self,
task_id: &str,
final_params: &[f32],
task_embedding: Option<Vec<f32>>,
) {
if final_params.len() != self.meta_params.len() {
return;
}
for (meta, &fin) in self.meta_params.iter_mut().zip(final_params.iter()) {
let delta = fin - *meta;
*meta += self.meta_lr * delta;
}
if let Some(emb) = task_embedding {
self.task_embeddings.insert(task_id.to_string(), emb);
}
self.task_count += 1;
self.updated_at = Utc::now();
}
pub fn register_strategy(&mut self, strategy: LearningStrategy) {
let exists = self.strategies.iter().any(|s| s.name == strategy.name);
if !exists {
self.strategies.push(strategy);
}
}
pub fn select_strategy(&self, task_features: &TaskFeatures) -> Option<&LearningStrategy> {
let mut best_score = 0.0f32;
let mut best_strategy: Option<&LearningStrategy> = None;
for strategy in &self.strategies {
let score = strategy.score_for_task(task_features);
if score > best_score {
best_score = score;
best_strategy = Some(strategy);
}
}
if best_score > 0.5 {
best_strategy
} else {
None
}
}
pub fn num_strategies(&self) -> usize {
self.strategies.len()
}
pub fn num_tasks(&self) -> u64 {
self.task_count
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningStrategy {
pub name: String,
pub description: String,
pub hyperparams: HashMap<String, f32>,
pub preferred_features: TaskFeatures,
pub success_rate: f32,
pub usage_count: u64,
}
impl LearningStrategy {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: String::new(),
hyperparams: HashMap::new(),
preferred_features: TaskFeatures::default(),
success_rate: 0.5,
usage_count: 0,
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
pub fn with_hyperparam(mut self, name: impl Into<String>, value: f32) -> Self {
self.hyperparams.insert(name.into(), value);
self
}
pub fn with_preferred_features(mut self, features: TaskFeatures) -> Self {
self.preferred_features = features;
self
}
pub fn score_for_task(&self, task: &TaskFeatures) -> f32 {
let mut score = 0.0f32;
let mut count = 0;
if let (Some(a), Some(b)) = (self.preferred_features.data_size, task.data_size) {
score += 1.0 - (a as f32 - b as f32).abs() / (a.max(b) as f32 + 1.0);
count += 1;
}
if let (Some(a), Some(b)) = (self.preferred_features.noise_level, task.noise_level) {
score += 1.0 - (a - b).abs();
count += 1;
}
if let (Some(a), Some(b)) = (self.preferred_features.complexity, task.complexity) {
score += 1.0 - (a - b).abs();
count += 1;
}
if self.preferred_features.is_classification == task.is_classification {
score += 1.0;
count += 1;
}
let feature_score = if count > 0 { score / count as f32 } else { 0.5 };
feature_score * self.success_rate
}
pub fn record_usage(&mut self, succeeded: bool) {
self.usage_count += 1;
let outcome = if succeeded { 1.0 } else { 0.0 };
self.success_rate = 0.9 * self.success_rate + 0.1 * outcome;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TaskFeatures {
pub data_size: Option<usize>,
pub noise_level: Option<f32>,
pub complexity: Option<f32>,
pub is_classification: bool,
pub input_dim: Option<usize>,
pub output_dim: Option<usize>,
pub domain: Option<String>,
}
impl TaskFeatures {
pub fn new() -> Self {
Self::default()
}
pub fn with_data_size(mut self, size: usize) -> Self {
self.data_size = Some(size);
self
}
pub fn with_noise(mut self, noise: f32) -> Self {
self.noise_level = Some(noise.clamp(0.0, 1.0));
self
}
pub fn with_complexity(mut self, complexity: f32) -> Self {
self.complexity = Some(complexity.clamp(0.0, 1.0));
self
}
pub fn classification(mut self) -> Self {
self.is_classification = true;
self
}
pub fn regression(mut self) -> Self {
self.is_classification = false;
self
}
pub fn with_domain(mut self, domain: impl Into<String>) -> Self {
self.domain = Some(domain.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FewShotLearner {
base_params: Vec<f32>,
adapted_params: Vec<f32>,
support_set: Vec<(Vec<f32>, f32)>,
adapt_lr: f32,
adapt_steps: usize,
}
impl FewShotLearner {
pub fn from_meta(meta: &MetaLearner, task_embedding: Option<&[f32]>) -> Self {
let params = meta.initialize_for_task(task_embedding);
Self {
base_params: params.clone(),
adapted_params: params,
support_set: Vec::new(),
adapt_lr: meta.inner_lr,
adapt_steps: 5,
}
}
pub fn with_adapt_lr(mut self, lr: f32) -> Self {
self.adapt_lr = lr;
self
}
pub fn with_adapt_steps(mut self, steps: usize) -> Self {
self.adapt_steps = steps;
self
}
pub fn add_example(&mut self, features: Vec<f32>, target: f32) {
self.support_set.push((features, target));
}
pub fn adapt(&mut self) {
self.adapted_params = self.base_params.clone();
for _ in 0..self.adapt_steps {
for (features, target) in &self.support_set {
if features.len() != self.adapted_params.len() {
continue;
}
let pred: f32 = features
.iter()
.zip(self.adapted_params.iter())
.map(|(f, p)| f * p)
.sum();
let error = pred - target;
for (param, &feat) in self.adapted_params.iter_mut().zip(features.iter()) {
let grad = 2.0 * error * feat;
*param -= self.adapt_lr * grad;
}
}
}
}
pub fn predict(&self, features: &[f32]) -> f32 {
if features.len() != self.adapted_params.len() {
return 0.0;
}
features
.iter()
.zip(self.adapted_params.iter())
.map(|(f, p)| f * p)
.sum()
}
pub fn get_adapted_params(&self) -> &[f32] {
&self.adapted_params
}
pub fn support_size(&self) -> usize {
self.support_set.len()
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_meta_learning() {
let mut meta = MetaLearner::new("test_meta", 2)
.with_meta_lr(0.3)
.with_inner_lr(0.1);
for task_idx in 0..5 {
let task_id = format!("task_{}", task_idx);
let noise = (task_idx as f32 - 2.0) * 0.1;
let final_params = vec![2.0 + noise, 3.0 - noise];
meta.meta_update(&task_id, &final_params, None);
}
assert!(
(meta.meta_params[0] - 2.0).abs() < 1.5,
"param[0] = {}",
meta.meta_params[0]
);
assert!(
(meta.meta_params[1] - 3.0).abs() < 1.5,
"param[1] = {}",
meta.meta_params[1]
);
assert_eq!(meta.num_tasks(), 5);
}
#[test]
fn test_few_shot_learning() {
let mut meta = MetaLearner::new("few_shot_meta", 1);
meta.meta_params = vec![1.5];
let mut few_shot = FewShotLearner::from_meta(&meta, None)
.with_adapt_lr(0.5)
.with_adapt_steps(10);
few_shot.add_example(vec![1.0], 2.0);
few_shot.add_example(vec![2.0], 4.0);
few_shot.add_example(vec![0.5], 1.0);
few_shot.adapt();
let pred = few_shot.predict(&[3.0]);
assert!((pred - 6.0).abs() < 1.0, "Expected ~6.0, got {}", pred);
}
#[test]
fn test_strategy_selection() {
let mut meta = MetaLearner::new("strategy_meta", 1);
let mut small_data_strategy = LearningStrategy::new("few_shot")
.with_description("For small datasets")
.with_hyperparam("lr", 0.1)
.with_preferred_features(TaskFeatures {
data_size: Some(10),
noise_level: Some(0.1),
..Default::default()
});
small_data_strategy.success_rate = 0.9;
let mut large_data_strategy = LearningStrategy::new("batch_gd")
.with_description("For large datasets")
.with_hyperparam("lr", 0.01)
.with_preferred_features(TaskFeatures {
data_size: Some(10000),
noise_level: Some(0.0),
..Default::default()
});
large_data_strategy.success_rate = 0.9;
meta.register_strategy(small_data_strategy);
meta.register_strategy(large_data_strategy);
assert_eq!(meta.num_strategies(), 2);
let small_task = TaskFeatures::new().with_data_size(15).with_noise(0.1);
let selected = meta.select_strategy(&small_task);
assert!(selected.is_some(), "Should select a strategy for the task");
}
#[test]
fn test_task_features() {
let classification_task = TaskFeatures::new()
.with_data_size(1000)
.with_noise(0.05)
.with_complexity(0.7)
.classification()
.with_domain("nlp");
assert!(classification_task.is_classification);
assert_eq!(classification_task.data_size, Some(1000));
assert!(classification_task.noise_level.unwrap() < 0.1);
let regression_task = TaskFeatures::new()
.with_data_size(500)
.with_noise(0.2)
.with_complexity(0.3)
.regression()
.with_domain("timeseries");
assert!(!regression_task.is_classification);
assert_eq!(regression_task.domain.as_deref(), Some("timeseries"));
}
}