converge_knowledge/agentic/
meta.rs1use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct MetaLearner {
23 pub id: Uuid,
25
26 pub name: String,
28
29 pub meta_params: Vec<f32>,
31
32 pub strategies: Vec<LearningStrategy>,
34
35 pub task_embeddings: HashMap<String, Vec<f32>>,
37
38 pub meta_lr: f32,
40
41 pub inner_lr: f32,
43
44 pub task_count: u64,
46
47 pub created_at: DateTime<Utc>,
49
50 pub updated_at: DateTime<Utc>,
52}
53
54impl MetaLearner {
55 pub fn new(name: impl Into<String>, num_params: usize) -> Self {
57 let now = Utc::now();
58 Self {
59 id: Uuid::new_v4(),
60 name: name.into(),
61 meta_params: vec![0.0; num_params],
62 strategies: Vec::new(),
63 task_embeddings: HashMap::new(),
64 meta_lr: 0.1,
65 inner_lr: 0.01,
66 task_count: 0,
67 created_at: now,
68 updated_at: now,
69 }
70 }
71
72 pub fn with_meta_lr(mut self, lr: f32) -> Self {
74 self.meta_lr = lr;
75 self
76 }
77
78 pub fn with_inner_lr(mut self, lr: f32) -> Self {
80 self.inner_lr = lr;
81 self
82 }
83
84 pub fn initialize_for_task(&self, task_embedding: Option<&[f32]>) -> Vec<f32> {
89 let mut params = self.meta_params.clone();
90
91 if let Some(emb) = task_embedding {
93 if let Some((_, similar_params)) = self.find_similar_task(emb) {
94 for i in 0..params.len().min(similar_params.len()) {
96 params[i] = 0.7 * params[i] + 0.3 * similar_params[i];
97 }
98 }
99 }
100
101 params
102 }
103
104 fn find_similar_task(&self, embedding: &[f32]) -> Option<(&str, Vec<f32>)> {
106 let mut best_sim = -1.0f32;
107 let mut best_task: Option<&str> = None;
108
109 for (task_id, task_emb) in &self.task_embeddings {
110 let sim = cosine_similarity(embedding, task_emb);
111 if sim > best_sim {
112 best_sim = sim;
113 best_task = Some(task_id);
114 }
115 }
116
117 if best_sim > 0.5 {
119 best_task.map(|t| (t, self.meta_params.clone()))
120 } else {
121 None
122 }
123 }
124
125 pub fn meta_update(
130 &mut self,
131 task_id: &str,
132 final_params: &[f32],
133 task_embedding: Option<Vec<f32>>,
134 ) {
135 if final_params.len() != self.meta_params.len() {
136 return;
137 }
138
139 for i in 0..self.meta_params.len() {
141 let delta = final_params[i] - self.meta_params[i];
142 self.meta_params[i] += self.meta_lr * delta;
143 }
144
145 if let Some(emb) = task_embedding {
147 self.task_embeddings.insert(task_id.to_string(), emb);
148 }
149
150 self.task_count += 1;
151 self.updated_at = Utc::now();
152 }
153
154 pub fn register_strategy(&mut self, strategy: LearningStrategy) {
156 let exists = self.strategies.iter().any(|s| s.name == strategy.name);
158 if !exists {
159 self.strategies.push(strategy);
160 }
161 }
162
163 pub fn select_strategy(&self, task_features: &TaskFeatures) -> Option<&LearningStrategy> {
165 let mut best_score = 0.0f32;
166 let mut best_strategy: Option<&LearningStrategy> = None;
167
168 for strategy in &self.strategies {
169 let score = strategy.score_for_task(task_features);
170 if score > best_score {
171 best_score = score;
172 best_strategy = Some(strategy);
173 }
174 }
175
176 if best_score > 0.5 {
178 best_strategy
179 } else {
180 None
181 }
182 }
183
184 pub fn num_strategies(&self) -> usize {
186 self.strategies.len()
187 }
188
189 pub fn num_tasks(&self) -> u64 {
191 self.task_count
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct LearningStrategy {
198 pub name: String,
200
201 pub description: String,
203
204 pub hyperparams: HashMap<String, f32>,
206
207 pub preferred_features: TaskFeatures,
209
210 pub success_rate: f32,
212
213 pub usage_count: u64,
215}
216
217impl LearningStrategy {
218 pub fn new(name: impl Into<String>) -> Self {
220 Self {
221 name: name.into(),
222 description: String::new(),
223 hyperparams: HashMap::new(),
224 preferred_features: TaskFeatures::default(),
225 success_rate: 0.5,
226 usage_count: 0,
227 }
228 }
229
230 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
232 self.description = desc.into();
233 self
234 }
235
236 pub fn with_hyperparam(mut self, name: impl Into<String>, value: f32) -> Self {
238 self.hyperparams.insert(name.into(), value);
239 self
240 }
241
242 pub fn with_preferred_features(mut self, features: TaskFeatures) -> Self {
244 self.preferred_features = features;
245 self
246 }
247
248 pub fn score_for_task(&self, task: &TaskFeatures) -> f32 {
250 let mut score = 0.0f32;
251 let mut count = 0;
252
253 if let (Some(a), Some(b)) = (self.preferred_features.data_size, task.data_size) {
255 score += 1.0 - (a as f32 - b as f32).abs() / (a.max(b) as f32 + 1.0);
256 count += 1;
257 }
258
259 if let (Some(a), Some(b)) = (self.preferred_features.noise_level, task.noise_level) {
260 score += 1.0 - (a - b).abs();
261 count += 1;
262 }
263
264 if let (Some(a), Some(b)) = (self.preferred_features.complexity, task.complexity) {
265 score += 1.0 - (a - b).abs();
266 count += 1;
267 }
268
269 if self.preferred_features.is_classification == task.is_classification {
270 score += 1.0;
271 count += 1;
272 }
273
274 let feature_score = if count > 0 { score / count as f32 } else { 0.5 };
276
277 feature_score * self.success_rate
278 }
279
280 pub fn record_usage(&mut self, succeeded: bool) {
282 self.usage_count += 1;
283 let outcome = if succeeded { 1.0 } else { 0.0 };
284 self.success_rate = 0.9 * self.success_rate + 0.1 * outcome;
286 }
287}
288
289#[derive(Debug, Clone, Default, Serialize, Deserialize)]
291pub struct TaskFeatures {
292 pub data_size: Option<usize>,
294
295 pub noise_level: Option<f32>,
297
298 pub complexity: Option<f32>,
300
301 pub is_classification: bool,
303
304 pub input_dim: Option<usize>,
306
307 pub output_dim: Option<usize>,
309
310 pub domain: Option<String>,
312}
313
314impl TaskFeatures {
315 pub fn new() -> Self {
317 Self::default()
318 }
319
320 pub fn with_data_size(mut self, size: usize) -> Self {
322 self.data_size = Some(size);
323 self
324 }
325
326 pub fn with_noise(mut self, noise: f32) -> Self {
328 self.noise_level = Some(noise.clamp(0.0, 1.0));
329 self
330 }
331
332 pub fn with_complexity(mut self, complexity: f32) -> Self {
334 self.complexity = Some(complexity.clamp(0.0, 1.0));
335 self
336 }
337
338 pub fn classification(mut self) -> Self {
340 self.is_classification = true;
341 self
342 }
343
344 pub fn regression(mut self) -> Self {
346 self.is_classification = false;
347 self
348 }
349
350 pub fn with_domain(mut self, domain: impl Into<String>) -> Self {
352 self.domain = Some(domain.into());
353 self
354 }
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct FewShotLearner {
363 base_params: Vec<f32>,
365
366 adapted_params: Vec<f32>,
368
369 support_set: Vec<(Vec<f32>, f32)>,
371
372 adapt_lr: f32,
374
375 adapt_steps: usize,
377}
378
379impl FewShotLearner {
380 pub fn from_meta(meta: &MetaLearner, task_embedding: Option<&[f32]>) -> Self {
382 let params = meta.initialize_for_task(task_embedding);
383 Self {
384 base_params: params.clone(),
385 adapted_params: params,
386 support_set: Vec::new(),
387 adapt_lr: meta.inner_lr,
388 adapt_steps: 5,
389 }
390 }
391
392 pub fn with_adapt_lr(mut self, lr: f32) -> Self {
394 self.adapt_lr = lr;
395 self
396 }
397
398 pub fn with_adapt_steps(mut self, steps: usize) -> Self {
400 self.adapt_steps = steps;
401 self
402 }
403
404 pub fn add_example(&mut self, features: Vec<f32>, target: f32) {
406 self.support_set.push((features, target));
407 }
408
409 pub fn adapt(&mut self) {
414 self.adapted_params = self.base_params.clone();
415
416 for _ in 0..self.adapt_steps {
417 for (features, target) in &self.support_set {
418 if features.len() != self.adapted_params.len() {
419 continue;
420 }
421
422 let pred: f32 = features
424 .iter()
425 .zip(self.adapted_params.iter())
426 .map(|(f, p)| f * p)
427 .sum();
428
429 let error = pred - target;
431 for i in 0..self.adapted_params.len() {
432 let grad = 2.0 * error * features[i];
433 self.adapted_params[i] -= self.adapt_lr * grad;
434 }
435 }
436 }
437 }
438
439 pub fn predict(&self, features: &[f32]) -> f32 {
441 if features.len() != self.adapted_params.len() {
442 return 0.0;
443 }
444
445 features
446 .iter()
447 .zip(self.adapted_params.iter())
448 .map(|(f, p)| f * p)
449 .sum()
450 }
451
452 pub fn get_adapted_params(&self) -> &[f32] {
454 &self.adapted_params
455 }
456
457 pub fn support_size(&self) -> usize {
459 self.support_set.len()
460 }
461}
462
463fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
465 if a.len() != b.len() || a.is_empty() {
466 return 0.0;
467 }
468
469 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
470 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
471 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
472
473 if norm_a == 0.0 || norm_b == 0.0 {
474 0.0
475 } else {
476 dot / (norm_a * norm_b)
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
491 fn test_meta_learning() {
492 let mut meta = MetaLearner::new("test_meta", 2)
493 .with_meta_lr(0.3)
494 .with_inner_lr(0.1);
495
496 for task_idx in 0..5 {
498 let task_id = format!("task_{}", task_idx);
500 let noise = (task_idx as f32 - 2.0) * 0.1; let final_params = vec![2.0 + noise, 3.0 - noise];
504
505 meta.meta_update(&task_id, &final_params, None);
506 }
507
508 assert!(
511 (meta.meta_params[0] - 2.0).abs() < 1.5,
512 "param[0] = {}",
513 meta.meta_params[0]
514 );
515 assert!(
516 (meta.meta_params[1] - 3.0).abs() < 1.5,
517 "param[1] = {}",
518 meta.meta_params[1]
519 );
520 assert_eq!(meta.num_tasks(), 5);
521 }
522
523 #[test]
530 fn test_few_shot_learning() {
531 let mut meta = MetaLearner::new("few_shot_meta", 1);
533 meta.meta_params = vec![1.5]; let mut few_shot = FewShotLearner::from_meta(&meta, None)
537 .with_adapt_lr(0.5)
538 .with_adapt_steps(10);
539
540 few_shot.add_example(vec![1.0], 2.0);
542 few_shot.add_example(vec![2.0], 4.0);
543 few_shot.add_example(vec![0.5], 1.0);
544
545 few_shot.adapt();
547
548 let pred = few_shot.predict(&[3.0]);
550
551 assert!((pred - 6.0).abs() < 1.0, "Expected ~6.0, got {}", pred);
553 }
554
555 #[test]
562 fn test_strategy_selection() {
563 let mut meta = MetaLearner::new("strategy_meta", 1);
564
565 let mut small_data_strategy = LearningStrategy::new("few_shot")
567 .with_description("For small datasets")
568 .with_hyperparam("lr", 0.1)
569 .with_preferred_features(TaskFeatures {
570 data_size: Some(10),
571 noise_level: Some(0.1),
572 ..Default::default()
573 });
574 small_data_strategy.success_rate = 0.9; let mut large_data_strategy = LearningStrategy::new("batch_gd")
577 .with_description("For large datasets")
578 .with_hyperparam("lr", 0.01)
579 .with_preferred_features(TaskFeatures {
580 data_size: Some(10000),
581 noise_level: Some(0.0),
582 ..Default::default()
583 });
584 large_data_strategy.success_rate = 0.9;
585
586 meta.register_strategy(small_data_strategy);
587 meta.register_strategy(large_data_strategy);
588
589 assert_eq!(meta.num_strategies(), 2);
590
591 let small_task = TaskFeatures::new().with_data_size(15).with_noise(0.1);
593 let selected = meta.select_strategy(&small_task);
594 assert!(selected.is_some(), "Should select a strategy for the task");
597 }
598
599 #[test]
605 fn test_task_features() {
606 let classification_task = TaskFeatures::new()
607 .with_data_size(1000)
608 .with_noise(0.05)
609 .with_complexity(0.7)
610 .classification()
611 .with_domain("nlp");
612
613 assert!(classification_task.is_classification);
614 assert_eq!(classification_task.data_size, Some(1000));
615 assert!(classification_task.noise_level.unwrap() < 0.1);
616
617 let regression_task = TaskFeatures::new()
618 .with_data_size(500)
619 .with_noise(0.2)
620 .with_complexity(0.3)
621 .regression()
622 .with_domain("timeseries");
623
624 assert!(!regression_task.is_classification);
625 assert_eq!(regression_task.domain.as_deref(), Some("timeseries"));
626 }
627}