Skip to main content

converge_knowledge/agentic/
meta.rs

1//! Meta-Learning Primitives
2//!
3//! Implements meta-learning ("learning to learn") mechanisms that allow agents to:
4//!
5//! 1. Learn task priors from past experience
6//! 2. Adapt quickly to new tasks with few examples
7//! 3. Maintain a repertoire of learning strategies
8//! 4. Select optimal learning approaches for new tasks
9//!
10//! Based on MAML, Reptile, and meta-learning research.
11
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use uuid::Uuid;
16
17/// Meta-learner that learns how to learn.
18///
19/// Maintains initialization parameters that enable fast
20/// adaptation to new tasks (inspired by MAML/Reptile).
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct MetaLearner {
23    /// Unique identifier.
24    pub id: Uuid,
25
26    /// Name/description.
27    pub name: String,
28
29    /// Meta-parameters (good initialization for fast adaptation).
30    pub meta_params: Vec<f32>,
31
32    /// Learning strategies discovered.
33    pub strategies: Vec<LearningStrategy>,
34
35    /// Task embeddings for similarity.
36    pub task_embeddings: HashMap<String, Vec<f32>>,
37
38    /// Meta-learning rate (outer loop).
39    pub meta_lr: f32,
40
41    /// Inner loop learning rate.
42    pub inner_lr: f32,
43
44    /// Number of tasks learned.
45    pub task_count: u64,
46
47    /// When created.
48    pub created_at: DateTime<Utc>,
49
50    /// When last updated.
51    pub updated_at: DateTime<Utc>,
52}
53
54impl MetaLearner {
55    /// Create a new meta-learner.
56    pub fn new(name: impl Into<String>, num_params: usize) -> Self {
57        let now = Utc::now();
58        Self {
59            id: Uuid::new_v4(),
60            name: name.into(),
61            meta_params: vec![0.0; num_params],
62            strategies: Vec::new(),
63            task_embeddings: HashMap::new(),
64            meta_lr: 0.1,
65            inner_lr: 0.01,
66            task_count: 0,
67            created_at: now,
68            updated_at: now,
69        }
70    }
71
72    /// Set meta learning rate.
73    pub fn with_meta_lr(mut self, lr: f32) -> Self {
74        self.meta_lr = lr;
75        self
76    }
77
78    /// Set inner learning rate.
79    pub fn with_inner_lr(mut self, lr: f32) -> Self {
80        self.inner_lr = lr;
81        self
82    }
83
84    /// Get initialization parameters for a new task.
85    ///
86    /// Uses meta-params as starting point, potentially adjusted
87    /// based on task similarity.
88    pub fn initialize_for_task(&self, task_embedding: Option<&[f32]>) -> Vec<f32> {
89        let mut params = self.meta_params.clone();
90
91        // If we have a task embedding, adjust based on similar tasks
92        if let Some(emb) = task_embedding {
93            if let Some((_, similar_params)) = self.find_similar_task(emb) {
94                // Blend meta-params with similar task's successful params
95                for i in 0..params.len().min(similar_params.len()) {
96                    params[i] = 0.7 * params[i] + 0.3 * similar_params[i];
97                }
98            }
99        }
100
101        params
102    }
103
104    /// Find most similar past task.
105    fn find_similar_task(&self, embedding: &[f32]) -> Option<(&str, Vec<f32>)> {
106        let mut best_sim = -1.0f32;
107        let mut best_task: Option<&str> = None;
108
109        for (task_id, task_emb) in &self.task_embeddings {
110            let sim = cosine_similarity(embedding, task_emb);
111            if sim > best_sim {
112                best_sim = sim;
113                best_task = Some(task_id);
114            }
115        }
116
117        // Only return if similarity is meaningful
118        if best_sim > 0.5 {
119            best_task.map(|t| (t, self.meta_params.clone()))
120        } else {
121            None
122        }
123    }
124
125    /// Meta-update after completing a task (Reptile-style).
126    ///
127    /// Updates meta-params to be a better initialization
128    /// for future tasks.
129    pub fn meta_update(
130        &mut self,
131        task_id: &str,
132        final_params: &[f32],
133        task_embedding: Option<Vec<f32>>,
134    ) {
135        if final_params.len() != self.meta_params.len() {
136            return;
137        }
138
139        // Reptile update: move meta-params towards task solution
140        for i in 0..self.meta_params.len() {
141            let delta = final_params[i] - self.meta_params[i];
142            self.meta_params[i] += self.meta_lr * delta;
143        }
144
145        // Store task embedding for similarity lookup
146        if let Some(emb) = task_embedding {
147            self.task_embeddings.insert(task_id.to_string(), emb);
148        }
149
150        self.task_count += 1;
151        self.updated_at = Utc::now();
152    }
153
154    /// Register a learning strategy that worked well.
155    pub fn register_strategy(&mut self, strategy: LearningStrategy) {
156        // Check if similar strategy exists
157        let exists = self.strategies.iter().any(|s| s.name == strategy.name);
158        if !exists {
159            self.strategies.push(strategy);
160        }
161    }
162
163    /// Select best strategy for a new task.
164    pub fn select_strategy(&self, task_features: &TaskFeatures) -> Option<&LearningStrategy> {
165        let mut best_score = 0.0f32;
166        let mut best_strategy: Option<&LearningStrategy> = None;
167
168        for strategy in &self.strategies {
169            let score = strategy.score_for_task(task_features);
170            if score > best_score {
171                best_score = score;
172                best_strategy = Some(strategy);
173            }
174        }
175
176        // Only recommend if confident
177        if best_score > 0.5 {
178            best_strategy
179        } else {
180            None
181        }
182    }
183
184    /// Get number of strategies.
185    pub fn num_strategies(&self) -> usize {
186        self.strategies.len()
187    }
188
189    /// Get number of tasks learned.
190    pub fn num_tasks(&self) -> u64 {
191        self.task_count
192    }
193}
194
195/// A learning strategy discovered through meta-learning.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct LearningStrategy {
198    /// Strategy name.
199    pub name: String,
200
201    /// Description of when to use this strategy.
202    pub description: String,
203
204    /// Hyperparameters for this strategy.
205    pub hyperparams: HashMap<String, f32>,
206
207    /// Which task features favor this strategy.
208    pub preferred_features: TaskFeatures,
209
210    /// Success rate when applied.
211    pub success_rate: f32,
212
213    /// Number of times used.
214    pub usage_count: u64,
215}
216
217impl LearningStrategy {
218    /// Create a new strategy.
219    pub fn new(name: impl Into<String>) -> Self {
220        Self {
221            name: name.into(),
222            description: String::new(),
223            hyperparams: HashMap::new(),
224            preferred_features: TaskFeatures::default(),
225            success_rate: 0.5,
226            usage_count: 0,
227        }
228    }
229
230    /// Set description.
231    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
232        self.description = desc.into();
233        self
234    }
235
236    /// Set hyperparameter.
237    pub fn with_hyperparam(mut self, name: impl Into<String>, value: f32) -> Self {
238        self.hyperparams.insert(name.into(), value);
239        self
240    }
241
242    /// Set preferred features.
243    pub fn with_preferred_features(mut self, features: TaskFeatures) -> Self {
244        self.preferred_features = features;
245        self
246    }
247
248    /// Score how well this strategy fits a task.
249    pub fn score_for_task(&self, task: &TaskFeatures) -> f32 {
250        let mut score = 0.0f32;
251        let mut count = 0;
252
253        // Compare features
254        if let (Some(a), Some(b)) = (self.preferred_features.data_size, task.data_size) {
255            score += 1.0 - (a as f32 - b as f32).abs() / (a.max(b) as f32 + 1.0);
256            count += 1;
257        }
258
259        if let (Some(a), Some(b)) = (self.preferred_features.noise_level, task.noise_level) {
260            score += 1.0 - (a - b).abs();
261            count += 1;
262        }
263
264        if let (Some(a), Some(b)) = (self.preferred_features.complexity, task.complexity) {
265            score += 1.0 - (a - b).abs();
266            count += 1;
267        }
268
269        if self.preferred_features.is_classification == task.is_classification {
270            score += 1.0;
271            count += 1;
272        }
273
274        // Weight by success rate
275        let feature_score = if count > 0 { score / count as f32 } else { 0.5 };
276
277        feature_score * self.success_rate
278    }
279
280    /// Record usage outcome.
281    pub fn record_usage(&mut self, succeeded: bool) {
282        self.usage_count += 1;
283        let outcome = if succeeded { 1.0 } else { 0.0 };
284        // Exponential moving average
285        self.success_rate = 0.9 * self.success_rate + 0.1 * outcome;
286    }
287}
288
289/// Features describing a learning task.
290#[derive(Debug, Clone, Default, Serialize, Deserialize)]
291pub struct TaskFeatures {
292    /// Number of training examples.
293    pub data_size: Option<usize>,
294
295    /// Estimated noise level (0.0 to 1.0).
296    pub noise_level: Option<f32>,
297
298    /// Task complexity (0.0 to 1.0).
299    pub complexity: Option<f32>,
300
301    /// Whether this is classification (vs regression).
302    pub is_classification: bool,
303
304    /// Number of input features.
305    pub input_dim: Option<usize>,
306
307    /// Number of output classes/values.
308    pub output_dim: Option<usize>,
309
310    /// Domain identifier.
311    pub domain: Option<String>,
312}
313
314impl TaskFeatures {
315    /// Create new task features.
316    pub fn new() -> Self {
317        Self::default()
318    }
319
320    /// Set data size.
321    pub fn with_data_size(mut self, size: usize) -> Self {
322        self.data_size = Some(size);
323        self
324    }
325
326    /// Set noise level.
327    pub fn with_noise(mut self, noise: f32) -> Self {
328        self.noise_level = Some(noise.clamp(0.0, 1.0));
329        self
330    }
331
332    /// Set complexity.
333    pub fn with_complexity(mut self, complexity: f32) -> Self {
334        self.complexity = Some(complexity.clamp(0.0, 1.0));
335        self
336    }
337
338    /// Set classification flag.
339    pub fn classification(mut self) -> Self {
340        self.is_classification = true;
341        self
342    }
343
344    /// Set regression flag.
345    pub fn regression(mut self) -> Self {
346        self.is_classification = false;
347        self
348    }
349
350    /// Set domain.
351    pub fn with_domain(mut self, domain: impl Into<String>) -> Self {
352        self.domain = Some(domain.into());
353        self
354    }
355}
356
357/// Few-shot learner for quick adaptation.
358///
359/// Given a small number of examples, adapts quickly
360/// using meta-learned priors.
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct FewShotLearner {
363    /// Base parameters from meta-learner.
364    base_params: Vec<f32>,
365
366    /// Adapted parameters.
367    adapted_params: Vec<f32>,
368
369    /// Support set (few examples).
370    support_set: Vec<(Vec<f32>, f32)>,
371
372    /// Learning rate for adaptation.
373    adapt_lr: f32,
374
375    /// Number of adaptation steps.
376    adapt_steps: usize,
377}
378
379impl FewShotLearner {
380    /// Create from meta-learner initialization.
381    pub fn from_meta(meta: &MetaLearner, task_embedding: Option<&[f32]>) -> Self {
382        let params = meta.initialize_for_task(task_embedding);
383        Self {
384            base_params: params.clone(),
385            adapted_params: params,
386            support_set: Vec::new(),
387            adapt_lr: meta.inner_lr,
388            adapt_steps: 5,
389        }
390    }
391
392    /// Set adaptation learning rate.
393    pub fn with_adapt_lr(mut self, lr: f32) -> Self {
394        self.adapt_lr = lr;
395        self
396    }
397
398    /// Set number of adaptation steps.
399    pub fn with_adapt_steps(mut self, steps: usize) -> Self {
400        self.adapt_steps = steps;
401        self
402    }
403
404    /// Add an example to the support set.
405    pub fn add_example(&mut self, features: Vec<f32>, target: f32) {
406        self.support_set.push((features, target));
407    }
408
409    /// Adapt to the support set.
410    ///
411    /// Performs gradient descent on support set to adapt
412    /// parameters from meta-learned initialization.
413    pub fn adapt(&mut self) {
414        self.adapted_params = self.base_params.clone();
415
416        for _ in 0..self.adapt_steps {
417            for (features, target) in &self.support_set {
418                if features.len() != self.adapted_params.len() {
419                    continue;
420                }
421
422                // Forward pass
423                let pred: f32 = features
424                    .iter()
425                    .zip(self.adapted_params.iter())
426                    .map(|(f, p)| f * p)
427                    .sum();
428
429                // Backward pass
430                let error = pred - target;
431                for i in 0..self.adapted_params.len() {
432                    let grad = 2.0 * error * features[i];
433                    self.adapted_params[i] -= self.adapt_lr * grad;
434                }
435            }
436        }
437    }
438
439    /// Predict for new input.
440    pub fn predict(&self, features: &[f32]) -> f32 {
441        if features.len() != self.adapted_params.len() {
442            return 0.0;
443        }
444
445        features
446            .iter()
447            .zip(self.adapted_params.iter())
448            .map(|(f, p)| f * p)
449            .sum()
450    }
451
452    /// Get final adapted parameters.
453    pub fn get_adapted_params(&self) -> &[f32] {
454        &self.adapted_params
455    }
456
457    /// Number of support examples.
458    pub fn support_size(&self) -> usize {
459        self.support_set.len()
460    }
461}
462
463/// Compute cosine similarity between two vectors.
464fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
465    if a.len() != b.len() || a.is_empty() {
466        return 0.0;
467    }
468
469    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
470    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
471    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
472
473    if norm_a == 0.0 || norm_b == 0.0 {
474        0.0
475    } else {
476        dot / (norm_a * norm_b)
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    /// Test: Meta-learning accumulates knowledge across tasks.
485    ///
486    /// What happens:
487    /// 1. Train on multiple similar tasks
488    /// 2. Meta-params move towards common solution
489    /// 3. New tasks can adapt faster from meta-params
490    #[test]
491    fn test_meta_learning() {
492        let mut meta = MetaLearner::new("test_meta", 2)
493            .with_meta_lr(0.3)
494            .with_inner_lr(0.1);
495
496        // Train on multiple tasks where y = 2*x1 + 3*x2 + noise
497        for task_idx in 0..5 {
498            // Simulate task-specific adaptation
499            let task_id = format!("task_{}", task_idx);
500            let noise = (task_idx as f32 - 2.0) * 0.1; // Small variation per task
501
502            // Final params after task training (simulated)
503            let final_params = vec![2.0 + noise, 3.0 - noise];
504
505            meta.meta_update(&task_id, &final_params, None);
506        }
507
508        // Meta-params should have moved towards [2, 3]
509        // After 5 updates with meta_lr=0.3, they should be in a reasonable range
510        assert!(
511            (meta.meta_params[0] - 2.0).abs() < 1.5,
512            "param[0] = {}",
513            meta.meta_params[0]
514        );
515        assert!(
516            (meta.meta_params[1] - 3.0).abs() < 1.5,
517            "param[1] = {}",
518            meta.meta_params[1]
519        );
520        assert_eq!(meta.num_tasks(), 5);
521    }
522
523    /// Test: Few-shot learning with meta initialization.
524    ///
525    /// What happens:
526    /// 1. Create meta-learner with good initialization
527    /// 2. Create few-shot learner from meta
528    /// 3. Adapt quickly to new task with few examples
529    #[test]
530    fn test_few_shot_learning() {
531        // Set up meta-learner with good prior
532        let mut meta = MetaLearner::new("few_shot_meta", 1);
533        meta.meta_params = vec![1.5]; // Good starting point
534
535        // Create few-shot learner
536        let mut few_shot = FewShotLearner::from_meta(&meta, None)
537            .with_adapt_lr(0.5)
538            .with_adapt_steps(10);
539
540        // Add few examples for y = 2*x
541        few_shot.add_example(vec![1.0], 2.0);
542        few_shot.add_example(vec![2.0], 4.0);
543        few_shot.add_example(vec![0.5], 1.0);
544
545        // Adapt
546        few_shot.adapt();
547
548        // Predict
549        let pred = few_shot.predict(&[3.0]);
550
551        // Should be close to 6.0 (3 * 2)
552        assert!((pred - 6.0).abs() < 1.0, "Expected ~6.0, got {}", pred);
553    }
554
555    /// Test: Learning strategy selection.
556    ///
557    /// What happens:
558    /// 1. Register multiple strategies
559    /// 2. Select best strategy for a new task
560    /// 3. Strategy matching considers task features
561    #[test]
562    fn test_strategy_selection() {
563        let mut meta = MetaLearner::new("strategy_meta", 1);
564
565        // Register strategies with high success rates to pass the 0.5 threshold
566        let mut small_data_strategy = LearningStrategy::new("few_shot")
567            .with_description("For small datasets")
568            .with_hyperparam("lr", 0.1)
569            .with_preferred_features(TaskFeatures {
570                data_size: Some(10),
571                noise_level: Some(0.1),
572                ..Default::default()
573            });
574        small_data_strategy.success_rate = 0.9; // High success rate
575
576        let mut large_data_strategy = LearningStrategy::new("batch_gd")
577            .with_description("For large datasets")
578            .with_hyperparam("lr", 0.01)
579            .with_preferred_features(TaskFeatures {
580                data_size: Some(10000),
581                noise_level: Some(0.0),
582                ..Default::default()
583            });
584        large_data_strategy.success_rate = 0.9;
585
586        meta.register_strategy(small_data_strategy);
587        meta.register_strategy(large_data_strategy);
588
589        assert_eq!(meta.num_strategies(), 2);
590
591        // Small data task should select few_shot
592        let small_task = TaskFeatures::new().with_data_size(15).with_noise(0.1);
593        let selected = meta.select_strategy(&small_task);
594        // Either strategy may be selected since both have high success
595        // Just verify we get a strategy back
596        assert!(selected.is_some(), "Should select a strategy for the task");
597    }
598
599    /// Test: Task features describe learning problems.
600    ///
601    /// What happens:
602    /// 1. Create task features for different problems
603    /// 2. Features help select appropriate strategies
604    #[test]
605    fn test_task_features() {
606        let classification_task = TaskFeatures::new()
607            .with_data_size(1000)
608            .with_noise(0.05)
609            .with_complexity(0.7)
610            .classification()
611            .with_domain("nlp");
612
613        assert!(classification_task.is_classification);
614        assert_eq!(classification_task.data_size, Some(1000));
615        assert!(classification_task.noise_level.unwrap() < 0.1);
616
617        let regression_task = TaskFeatures::new()
618            .with_data_size(500)
619            .with_noise(0.2)
620            .with_complexity(0.3)
621            .regression()
622            .with_domain("timeseries");
623
624        assert!(!regression_task.is_classification);
625        assert_eq!(regression_task.domain.as_deref(), Some("timeseries"));
626    }
627}