use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct Trajectory {
pub query: Vec<f32>,
pub retrieved_ids: Vec<u64>,
pub reward: f32,
pub phi_at_decision: f64,
pub timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub rank: usize,
pub a: Vec<f32>, pub b: Vec<f32>, pub m: usize,
pub n: usize,
pub scale: f32,
}
impl LoraAdapter {
pub fn new(m: usize, n: usize, rank: usize) -> Self {
let scale = 1.0 / rank as f32;
Self {
rank,
a: vec![0.0f32; m * rank],
b: vec![0.0f32; rank * n],
m, n, scale,
}
}
pub fn apply(&self, output: &mut [f32]) {
let r = self.rank;
let m = self.m.min(output.len());
for i in 0..m {
let mut delta = 0.0f32;
for k in 0..r {
let a_ik = self.a.get(i * r + k).copied().unwrap_or(0.0);
for j in 0..self.n.min(output.len()) {
let b_kj = self.b.get(k * self.n + j).copied().unwrap_or(0.0);
delta += a_ik * b_kj;
}
}
output[i] += delta * self.scale;
}
}
pub fn gradient_step(&mut self, query: &[f32], reward: f32, lr: f32) {
let n = query.len().min(self.n);
for k in 0..self.rank {
for i in 0..self.m {
if i * self.rank + k < self.a.len() {
self.a[i * self.rank + k] += lr * reward * 0.01;
}
}
for j in 0..n {
if k * self.n + j < self.b.len() {
self.b[k * self.n + j] += lr * reward * query[j];
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct PhiWeightedFisher {
pub fisher: Vec<f32>,
pub theta_star: Vec<f32>,
pub phi: f64,
}
impl PhiWeightedFisher {
pub fn new(dim: usize, phi: f64) -> Self {
Self {
fisher: vec![1.0f32; dim],
theta_star: vec![0.0f32; dim],
phi,
}
}
pub fn penalty(&self, current: &[f32], lambda: f32) -> f32 {
let phi_scale = (self.phi as f32).max(0.1);
self.fisher.iter().zip(self.theta_star.iter()).zip(current.iter())
.map(|((fi, ti), ci)| fi * (ci - ti).powi(2))
.sum::<f32>() * lambda * phi_scale
}
}
pub struct ReasoningBank {
trajectories: VecDeque<Trajectory>,
max_size: usize,
next_timestamp: u64,
}
impl ReasoningBank {
pub fn new(max_size: usize) -> Self {
Self { trajectories: VecDeque::with_capacity(max_size), max_size, next_timestamp: 0 }
}
pub fn record(&mut self, query: Vec<f32>, retrieved_ids: Vec<u64>, reward: f32, phi: f64) {
if self.trajectories.len() >= self.max_size {
self.trajectories.pop_front();
}
self.trajectories.push_back(Trajectory {
query, retrieved_ids, reward, phi_at_decision: phi,
timestamp: self.next_timestamp,
});
self.next_timestamp += 1;
}
pub fn recall(&self, query: &[f32], k: usize) -> Vec<&Trajectory> {
let mut scored: Vec<(&Trajectory, f32)> = self.trajectories.iter()
.map(|t| {
let sim = cosine_sim(&t.query, query);
(t, sim)
})
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored.into_iter().map(|(t, _)| t).collect()
}
pub fn len(&self) -> usize { self.trajectories.len() }
pub fn high_phi_trajectories(&self, threshold: f64) -> Vec<&Trajectory> {
self.trajectories.iter().filter(|t| t.phi_at_decision >= threshold).collect()
}
}
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let dot: f32 = a[..n].iter().zip(b[..n].iter()).map(|(x, y)| x * y).sum();
let na: f32 = a[..n].iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let nb: f32 = b[..n].iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
dot / (na * nb)
}
pub struct LearnerConfig {
pub lora_rank: usize,
pub embedding_dim: usize,
pub ewc_lambda: f32,
pub reasoning_bank_size: usize,
pub high_phi_threshold: f64,
pub lr_instant: f32,
}
impl Default for LearnerConfig {
fn default() -> Self {
Self {
lora_rank: 2,
embedding_dim: 512,
ewc_lambda: 5.0,
reasoning_bank_size: 10_000,
high_phi_threshold: 2.0,
lr_instant: 0.001,
}
}
}
pub struct ExoLearner {
pub config: LearnerConfig,
lora: LoraAdapter,
protected_patterns: Vec<PhiWeightedFisher>,
pub bank: ReasoningBank,
total_updates: u64,
avg_reward: f32,
}
#[derive(Debug, Clone)]
pub struct LearnerUpdate {
pub lora_delta_norm: f32,
pub ewc_penalty: f32,
pub bank_size: usize,
pub avg_reward: f32,
pub phi_protection_applied: bool,
}
impl ExoLearner {
pub fn new(config: LearnerConfig) -> Self {
let dim = config.embedding_dim;
let rank = config.lora_rank;
let bank_size = config.reasoning_bank_size;
Self {
lora: LoraAdapter::new(dim, dim, rank),
protected_patterns: Vec::new(),
bank: ReasoningBank::new(bank_size),
total_updates: 0,
avg_reward: 0.5,
config,
}
}
pub fn adapt(
&mut self,
query: &[f32],
retrieved_ids: Vec<u64>,
reward: f32,
phi: f64,
) -> LearnerUpdate {
self.lora.gradient_step(query, reward - self.avg_reward, self.config.lr_instant);
let ewc_penalty: f32 = self.protected_patterns.iter()
.filter(|p| p.phi >= self.config.high_phi_threshold)
.map(|p| {
let padded: Vec<f32> = query.iter().chain(std::iter::repeat(&0.0))
.take(p.fisher.len()).copied().collect();
p.penalty(&padded, self.config.ewc_lambda)
})
.sum::<f32>() / self.protected_patterns.len().max(1) as f32;
self.avg_reward = 0.99 * self.avg_reward + 0.01 * reward;
self.total_updates += 1;
self.bank.record(query.to_vec(), retrieved_ids, reward, phi);
let phi_protection = !self.protected_patterns.is_empty() &&
self.protected_patterns.iter().any(|p| p.phi >= self.config.high_phi_threshold);
let delta_norm = self.lora.a.iter().map(|x| x * x).sum::<f32>().sqrt();
LearnerUpdate {
lora_delta_norm: delta_norm,
ewc_penalty,
bank_size: self.bank.len(),
avg_reward: self.avg_reward,
phi_protection_applied: phi_protection,
}
}
pub fn consolidate_high_phi(&mut self, weights: Vec<f32>, phi: f64) {
let mut entry = PhiWeightedFisher::new(weights.len(), phi);
entry.theta_star = weights;
let high_phi_trajs = self.bank.high_phi_trajectories(phi * 0.5);
for traj in high_phi_trajs.iter().take(100) {
for (i, f) in entry.fisher.iter_mut().enumerate() {
let g = traj.query.get(i).copied().unwrap_or(0.0);
*f = 0.9 * *f + 0.1 * g * g;
}
}
self.protected_patterns.push(entry);
}
pub fn apply_adapter(&self, embedding: &[f32]) -> Vec<f32> {
let mut output = embedding.to_vec();
self.lora.apply(&mut output);
output
}
pub fn n_protected(&self) -> usize { self.protected_patterns.len() }
pub fn total_updates(&self) -> u64 { self.total_updates }
}
impl Default for ExoLearner {
fn default() -> Self { Self::new(LearnerConfig::default()) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exo_learner_instant_update() {
let mut learner = ExoLearner::new(LearnerConfig { embedding_dim: 64, lora_rank: 2, ..Default::default() });
let query = vec![0.5f32; 64];
let update = learner.adapt(&query, vec![1, 2], 0.8, 2.5);
assert!(update.bank_size > 0);
assert!(update.avg_reward > 0.0);
}
#[test]
fn test_lora_adapter_applies() {
let mut adapter = LoraAdapter::new(8, 8, 2);
adapter.gradient_step(&[0.5f32; 8], 0.9, 0.01);
let mut output = vec![1.0f32; 8];
adapter.apply(&mut output);
let changed = output.iter().any(|&v| (v - 1.0).abs() > 1e-8);
assert!(changed, "LoRA should modify output");
}
#[test]
fn test_reasoning_bank_recall() {
let mut bank = ReasoningBank::new(100);
let q1 = vec![1.0f32, 0.0, 0.0];
let q2 = vec![0.0f32, 1.0, 0.0];
bank.record(q1.clone(), vec![1], 0.9, 3.0);
bank.record(q2.clone(), vec![2], 0.5, 1.0);
let recalled = bank.recall(&q1, 1);
assert_eq!(recalled.len(), 1);
assert_eq!(recalled[0].retrieved_ids, vec![1]);
}
#[test]
fn test_phi_weighted_ewc_penalty() {
let mut fisher = PhiWeightedFisher::new(8, 5.0); fisher.theta_star = vec![0.0f32; 8];
let drifted = vec![2.0f32; 8]; let penalty = fisher.penalty(&drifted, 1.0);
assert!(penalty > 0.0, "High-Phi pattern far from optimal should have penalty");
let mut low_phi = PhiWeightedFisher::new(8, 0.1); low_phi.theta_star = vec![0.0f32; 8];
let low_penalty = low_phi.penalty(&drifted, 1.0);
assert!(penalty > low_penalty, "High Phi should incur larger penalty");
}
#[test]
fn test_consolidate_protects_pattern() {
let mut learner = ExoLearner::new(LearnerConfig { embedding_dim: 32, lora_rank: 1, ..Default::default() });
learner.consolidate_high_phi(vec![0.5f32; 32], 4.0);
assert_eq!(learner.n_protected(), 1);
let query = vec![2.0f32; 32]; let update = learner.adapt(&query, vec![], 0.5, 4.0);
assert!(update.phi_protection_applied || learner.n_protected() > 0);
}
}