Skip to main content

lmm_agent/cognition/learning/
meta.rs

1// Copyright 2026 Mahmoud Harmouch.
2//
3// Licensed under the MIT license
4// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
5// option. This file may not be copied, modified, or distributed
6// except according to those terms.
7
8//! # `MetaAdapter` - prototype-based meta-learning.
9//!
10//! Enables few-shot and zero-shot adaptation by matching new tasks to the
11//! most similar previously seen task prototypes and applying their learned
12//! Q-value offsets as a warm-start - without any gradient descent or matrices.
13//!
14//! ## Algorithm
15//!
16//! 1. After each `ThinkLoop` run, `record_episode` stores a `TaskPrototype`
17//!    consisting of the goal's token set (centroid) and the final Q-value
18//!    offset for each action (best Q minus baseline 0.0).
19//!
20//! 2. On a new task, `adapt` computes the Jaccard similarity between the new
21//!    goal and every stored prototype, takes the top-K matches, and produces
22//!    a weighted average of their offsets as the warm-start Q-adjustment.
23//!
24//! 3. These offsets are returned as a `HashMap<ActionKey, f64>` that the
25//!    `LearningEngine` can apply as additive priors to the new task's Q-table
26//!    rows before any TD updates.
27//!
28//! ## Complexity
29//!
30//! - `record_episode`: O(P) - P = number of stored prototypes.
31//! - `adapt`:          O(P · T) - T = average goal token count.
32//!
33//! Both are entirely CPU-bound hash-set operations.
34//!
35//! ## Examples
36//!
37//! ```rust
38//! use std::collections::HashMap;
39//! use lmm_agent::cognition::learning::meta::MetaAdapter;
40//! use lmm_agent::cognition::learning::q_table::ActionKey;
41//!
42//! let mut adapter = MetaAdapter::new(3);
43//!
44//! let mut offsets = HashMap::new();
45//! offsets.insert(ActionKey::Narrow, 0.5);
46//! adapter.record_episode("rust ownership borrow", offsets, 0.8);
47//!
48//! let adapt = adapter.adapt("rust memory borrow checker");
49//! assert!(!adapt.is_empty());
50//! ```
51
52use crate::cognition::learning::q_table::ActionKey;
53use serde::{Deserialize, Serialize};
54use std::cmp::Ordering;
55use std::collections::{HashMap, HashSet};
56
57/// A stored snapshot of a completed task episode used as a meta-learning prototype.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct TaskPrototype {
60    /// Lowercase token set of the task goal.
61    pub tokens: HashSet<String>,
62
63    /// Per-action Q-value offsets learned during this episode.
64    pub offsets: HashMap<ActionKey, f64>,
65
66    /// Average reward achieved during the episode.
67    pub avg_reward: f64,
68
69    /// Number of times this prototype has been matched and applied.
70    pub match_count: usize,
71}
72
73impl TaskPrototype {
74    /// Computes Jaccard similarity between this prototype and a query token set.
75    ///
76    /// Returns a value in [0, 1] where 1.0 = exact token-set match.
77    pub fn similarity(&self, other: &HashSet<String>) -> f64 {
78        let intersection = self.tokens.intersection(other).count();
79        let union = self.tokens.len() + other.len() - intersection;
80        if union == 0 {
81            1.0
82        } else {
83            intersection as f64 / union as f64
84        }
85    }
86}
87
88/// Prototype store for task-level meta-adaptation.
89///
90/// # Examples
91///
92/// ```rust
93/// use std::collections::HashMap;
94/// use lmm_agent::cognition::learning::meta::MetaAdapter;
95/// use lmm_agent::cognition::learning::q_table::ActionKey;
96///
97/// let mut a = MetaAdapter::new(5);
98/// let offsets = HashMap::from([(ActionKey::Expand, 0.7)]);
99/// a.record_episode("rust async await", offsets, 1.0);
100/// let r = a.adapt("async rust futures");
101/// assert!(r.contains_key(&ActionKey::Expand));
102/// ```
103#[derive(Debug, Clone, Default, Serialize, Deserialize)]
104pub struct MetaAdapter {
105    /// Stored episode prototypes.
106    prototypes: Vec<TaskPrototype>,
107
108    /// Maximum number of prototypes considered per `adapt` call.
109    pub top_k: usize,
110}
111
112impl MetaAdapter {
113    /// Constructs a new `MetaAdapter` with the given top-k lookup limit.
114    pub fn new(top_k: usize) -> Self {
115        Self {
116            prototypes: Vec::new(),
117            top_k: top_k.max(1),
118        }
119    }
120
121    /// Returns the number of stored prototypes.
122    pub fn len(&self) -> usize {
123        self.prototypes.len()
124    }
125
126    /// Returns `true` when no prototypes have been stored.
127    pub fn is_empty(&self) -> bool {
128        self.prototypes.is_empty()
129    }
130
131    /// Stores a new episode as a task prototype.
132    ///
133    /// `goal` is tokenised into a `HashSet<String>`. If a nearly identical
134    /// prototype already exists (Jaccard ≥ 0.9), its offsets are blended
135    /// rather than creating a duplicate.
136    ///
137    /// # Arguments
138    ///
139    /// * `goal`       - Natural-language task goal.
140    /// * `offsets`    - Per-action Q-value offsets from the completed episode.
141    /// * `avg_reward` - Mean reward across the episode steps.
142    pub fn record_episode(
143        &mut self,
144        goal: &str,
145        offsets: HashMap<ActionKey, f64>,
146        avg_reward: f64,
147    ) {
148        let tokens = tokenise(goal);
149
150        if let Some(existing) = self
151            .prototypes
152            .iter_mut()
153            .find(|p| p.similarity(&tokens) >= 0.9)
154        {
155            for (action, val) in &offsets {
156                let e = existing.offsets.entry(*action).or_insert(0.0);
157                *e = (*e + val) / 2.0;
158            }
159            existing.avg_reward = (existing.avg_reward + avg_reward) / 2.0;
160            return;
161        }
162
163        self.prototypes.push(TaskPrototype {
164            tokens,
165            offsets,
166            avg_reward,
167            match_count: 0,
168        });
169    }
170
171    /// Returns weighted Q-offset priors for a new `goal`.
172    ///
173    /// Finds the top-K most similar prototypes, weights their offsets by
174    /// `similarity × avg_reward`, and returns the normalised blend.
175    ///
176    /// Returns an empty map when no prototypes exist.
177    ///
178    /// # Examples
179    ///
180    /// ```rust
181    /// use std::collections::HashMap;
182    /// use lmm_agent::cognition::learning::meta::MetaAdapter;
183    /// use lmm_agent::cognition::learning::q_table::ActionKey;
184    ///
185    /// let mut a = MetaAdapter::new(3);
186    /// a.record_episode("machine learning", HashMap::from([(ActionKey::Expand, 0.6)]), 0.7);
187    /// let out = a.adapt("deep learning models");
188    /// assert!(!out.is_empty());
189    /// ```
190    pub fn adapt(&mut self, goal: &str) -> HashMap<ActionKey, f64> {
191        if self.prototypes.is_empty() {
192            return HashMap::new();
193        }
194
195        let tokens = tokenise(goal);
196
197        let mut scored: Vec<(usize, f64)> = self
198            .prototypes
199            .iter()
200            .enumerate()
201            .map(|(i, p)| (i, p.similarity(&tokens) * p.avg_reward))
202            .collect();
203
204        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
205        scored.truncate(self.top_k);
206
207        let total_weight: f64 = scored.iter().map(|(_, w)| w).sum();
208        if total_weight <= 0.0 {
209            return HashMap::new();
210        }
211
212        let mut blended: HashMap<ActionKey, f64> = HashMap::new();
213        for (idx, weight) in &scored {
214            self.prototypes[*idx].match_count += 1;
215            for (&action, &val) in &self.prototypes[*idx].offsets {
216                *blended.entry(action).or_insert(0.0) += (weight / total_weight) * val;
217            }
218        }
219
220        blended
221    }
222
223    /// Returns a slice over all stored prototypes.
224    pub fn prototypes(&self) -> &[TaskPrototype] {
225        &self.prototypes
226    }
227}
228
229fn tokenise(text: &str) -> HashSet<String> {
230    text.split_whitespace()
231        .map(|w| {
232            w.chars()
233                .filter(|c| c.is_alphanumeric())
234                .collect::<String>()
235                .to_ascii_lowercase()
236        })
237        .filter(|s| s.len() >= 3)
238        .collect()
239}
240
241// Copyright 2026 Mahmoud Harmouch.
242//
243// Licensed under the MIT license
244// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
245// option. This file may not be copied, modified, or distributed
246// except according to those terms.