use super::AgentType;
use crate::error::{Result, RuvLLMError};
use crate::sona::{SonaConfig, SonaIntegration, Trajectory as SonaTrajectory};
use parking_lot::RwLock;
use ruvector_sona::{
EwcConfig, EwcPlusPlus, LearnedPattern, PatternConfig, PatternType, ReasoningBank,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Verdict {
Success {
reason: String,
},
Failure {
reason: String,
error_code: Option<String>,
},
Partial {
completion: f32,
reason: String,
},
RecoveredViaReflection {
original_error: String,
recovery_strategy: String,
attempts: u32,
},
}
impl Verdict {
#[inline]
pub fn quality_score(&self) -> f32 {
match self {
Verdict::Success { .. } => 1.0,
Verdict::Failure { .. } => 0.0,
Verdict::Partial { completion, .. } => *completion,
Verdict::RecoveredViaReflection { attempts, .. } => {
(1.0 - (*attempts as f32 - 1.0) * 0.05).clamp(0.7, 0.95)
}
}
}
#[inline]
pub fn is_successful(&self) -> bool {
self.quality_score() >= 0.5
}
#[inline]
pub fn reason(&self) -> &str {
match self {
Verdict::Success { reason } => reason,
Verdict::Failure { reason, .. } => reason,
Verdict::Partial { reason, .. } => reason,
Verdict::RecoveredViaReflection {
recovery_strategy, ..
} => recovery_strategy,
}
}
#[inline]
pub fn is_recovered(&self) -> bool {
matches!(self, Verdict::RecoveredViaReflection { .. })
}
#[inline]
pub fn original_error(&self) -> Option<&str> {
match self {
Verdict::RecoveredViaReflection { original_error, .. } => Some(original_error),
_ => None,
}
}
#[inline]
pub fn recovery_attempts(&self) -> Option<u32> {
match self {
Verdict::RecoveredViaReflection { attempts, .. } => Some(*attempts),
_ => None,
}
}
}
impl Default for Verdict {
fn default() -> Self {
Verdict::Partial {
completion: 0.5,
reason: "Unknown".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryStep {
pub name: String,
pub quality: f32,
pub agent: Option<AgentType>,
pub duration_ms: Option<u64>,
pub metadata: HashMap<String, String>,
}
impl TrajectoryStep {
pub fn new(name: impl Into<String>, quality: f32) -> Self {
Self {
name: name.into(),
quality: quality.clamp(0.0, 1.0),
agent: None,
duration_ms: None,
metadata: HashMap::new(),
}
}
pub fn with_agent(mut self, agent: AgentType) -> Self {
self.agent = Some(agent);
self
}
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = Some(duration_ms);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trajectory {
pub task_id: String,
pub embedding: Vec<f32>,
pub steps: Vec<TrajectoryStep>,
pub verdict: Verdict,
pub quality_score: f32,
pub primary_agent: Option<AgentType>,
pub task_type: Option<String>,
pub timestamp: u64,
pub total_duration_ms: Option<u64>,
}
impl Trajectory {
pub fn new(
task_id: impl Into<String>,
embedding: Vec<f32>,
steps: Vec<TrajectoryStep>,
verdict: Verdict,
) -> Self {
let quality_score = Self::compute_quality(&steps, &verdict);
let primary_agent = steps.iter().filter_map(|s| s.agent).next();
let total_duration_ms = steps.iter().filter_map(|s| s.duration_ms).sum::<u64>();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
task_id: task_id.into(),
embedding,
steps,
verdict,
quality_score,
primary_agent,
task_type: None,
timestamp: now,
total_duration_ms: if total_duration_ms > 0 {
Some(total_duration_ms)
} else {
None
},
}
}
fn compute_quality(steps: &[TrajectoryStep], verdict: &Verdict) -> f32 {
if steps.is_empty() {
return verdict.quality_score();
}
let step_avg = steps.iter().map(|s| s.quality).sum::<f32>() / steps.len() as f32;
step_avg * 0.7 + verdict.quality_score() * 0.3
}
pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
self.task_type = Some(task_type.into());
self
}
pub fn is_high_quality(&self, threshold: f32) -> bool {
self.quality_score >= threshold
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningBankConfig {
pub capacity: usize,
pub distillation_threshold: f32,
pub ewc_lambda: f32,
pub min_trajectories_for_distillation: usize,
pub consolidation_similarity: f32,
pub embedding_dim: usize,
pub num_clusters: usize,
pub min_pattern_quality: f32,
pub pattern_decay: f32,
pub max_pattern_age_secs: u64,
pub auto_distill: bool,
pub distill_interval: usize,
}
impl Default for ReasoningBankConfig {
fn default() -> Self {
Self {
capacity: 10000,
distillation_threshold: 0.6,
ewc_lambda: 2000.0,
min_trajectories_for_distillation: 50,
consolidation_similarity: 0.85,
embedding_dim: 384,
num_clusters: 100,
min_pattern_quality: 0.3,
pattern_decay: 0.99,
max_pattern_age_secs: 604800, auto_distill: true,
distill_interval: 100,
}
}
}
impl ReasoningBankConfig {
pub fn for_ruvltra_small() -> Self {
Self {
capacity: 5000,
distillation_threshold: 0.6,
ewc_lambda: 500.0,
min_trajectories_for_distillation: 30,
consolidation_similarity: 0.9,
embedding_dim: 384,
num_clusters: 50,
min_pattern_quality: 0.4,
pattern_decay: 0.995,
max_pattern_age_secs: 259200, auto_distill: true,
distill_interval: 50,
}
}
pub fn for_edge() -> Self {
Self {
capacity: 1000,
distillation_threshold: 0.7,
ewc_lambda: 1000.0,
min_trajectories_for_distillation: 20,
consolidation_similarity: 0.95,
embedding_dim: 256,
num_clusters: 20,
min_pattern_quality: 0.5,
pattern_decay: 0.99,
max_pattern_age_secs: 86400, auto_distill: true,
distill_interval: 30,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingRecommendation {
pub agent: AgentType,
pub confidence: f32,
pub patterns_used: usize,
pub avg_pattern_quality: f32,
pub alternatives: Vec<(AgentType, f32)>,
pub reasoning: String,
}
impl Default for RoutingRecommendation {
fn default() -> Self {
Self {
agent: AgentType::Coder,
confidence: 0.3,
patterns_used: 0,
avg_pattern_quality: 0.0,
alternatives: Vec::new(),
reasoning: "No patterns available, using default agent".to_string(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ReasoningBankStats {
pub total_trajectories: u64,
pub successful_trajectories: u64,
pub failed_trajectories: u64,
pub partial_trajectories: u64,
pub buffer_size: usize,
pub patterns_learned: usize,
pub distillation_runs: u64,
pub consolidation_runs: u64,
pub avg_quality: f32,
pub ewc_tasks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistilledPattern {
pub id: u64,
pub centroid: Vec<f32>,
pub primary_agent: AgentType,
pub agent_scores: HashMap<AgentType, f32>,
pub avg_quality: f32,
pub trajectory_count: usize,
pub task_type: Option<String>,
pub created_at: u64,
pub last_accessed: u64,
pub access_count: u32,
}
impl DistilledPattern {
#[inline]
pub fn similarity(&self, embedding: &[f32]) -> f32 {
let len = self.centroid.len();
if len != embedding.len() {
return 0.0;
}
let mut dot: f32 = 0.0;
let mut norm_a_sq: f32 = 0.0;
let mut norm_b_sq: f32 = 0.0;
for i in 0..len {
let a = self.centroid[i];
let b = embedding[i];
dot += a * b;
norm_a_sq += a * a;
norm_b_sq += b * b;
}
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[inline]
pub fn best_agent(&self) -> AgentType {
self.agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, _)| *agent)
.unwrap_or(self.primary_agent)
}
#[inline]
pub fn should_prune(&self, min_quality: f32, max_age_secs: u64) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let age = now.saturating_sub(self.last_accessed);
self.avg_quality < min_quality && age > max_age_secs && self.access_count < 5
}
}
pub struct ReasoningBankIntegration {
config: ReasoningBankConfig,
trajectory_buffer: Arc<RwLock<Vec<Trajectory>>>,
patterns: Arc<RwLock<HashMap<u64, DistilledPattern>>>,
ewc: Arc<RwLock<EwcPlusPlus>>,
core_bank: Arc<RwLock<ReasoningBank>>,
sona: Option<Arc<RwLock<SonaIntegration>>>,
next_pattern_id: AtomicU64,
stats: RwLock<ReasoningBankStats>,
trajectories_since_distill: AtomicU64,
}
impl ReasoningBankIntegration {
pub fn new(config: ReasoningBankConfig) -> Self {
let ewc_config = EwcConfig {
param_count: config.embedding_dim,
initial_lambda: config.ewc_lambda,
max_lambda: config.ewc_lambda * 5.0,
..Default::default()
};
let pattern_config = PatternConfig {
k_clusters: config.num_clusters,
embedding_dim: config.embedding_dim.min(256),
max_trajectories: config.capacity,
quality_threshold: config.min_pattern_quality,
..Default::default()
};
Self {
config,
trajectory_buffer: Arc::new(RwLock::new(Vec::new())),
patterns: Arc::new(RwLock::new(HashMap::new())),
ewc: Arc::new(RwLock::new(EwcPlusPlus::new(ewc_config))),
core_bank: Arc::new(RwLock::new(ReasoningBank::new(pattern_config))),
sona: None,
next_pattern_id: AtomicU64::new(0),
stats: RwLock::new(ReasoningBankStats::default()),
trajectories_since_distill: AtomicU64::new(0),
}
}
pub fn with_sona(config: ReasoningBankConfig, sona_config: SonaConfig) -> Self {
let mut bank = Self::new(config);
bank.sona = Some(Arc::new(RwLock::new(SonaIntegration::new(sona_config))));
bank
}
pub fn record_trajectory(
&self,
task_id: impl Into<String>,
embedding: &[f32],
steps: Vec<TrajectoryStep>,
verdict: Verdict,
) -> Result<()> {
let trajectory = Trajectory::new(task_id, embedding.to_vec(), steps, verdict.clone());
{
let mut stats = self.stats.write();
stats.total_trajectories += 1;
match &verdict {
Verdict::Success { .. } => stats.successful_trajectories += 1,
Verdict::Failure { .. } => stats.failed_trajectories += 1,
Verdict::Partial { .. } => stats.partial_trajectories += 1,
Verdict::RecoveredViaReflection { .. } => {
stats.successful_trajectories += 1;
}
}
let n = stats.total_trajectories as f32;
stats.avg_quality = stats.avg_quality * (n - 1.0) / n + trajectory.quality_score / n;
}
{
let mut buffer = self.trajectory_buffer.write();
if buffer.len() >= self.config.capacity {
buffer.remove(0);
}
buffer.push(trajectory.clone());
}
if let Some(ref sona) = self.sona {
let sona_trajectory = SonaTrajectory {
request_id: trajectory.task_id.clone(),
session_id: "reasoning-bank".to_string(),
query_embedding: embedding.to_vec(),
response_embedding: embedding.to_vec(),
quality_score: trajectory.quality_score,
routing_features: vec![
trajectory.quality_score,
verdict.quality_score(),
trajectory.steps.len() as f32 / 10.0,
],
model_index: trajectory.primary_agent.map(|a| a as usize).unwrap_or(0),
timestamp: chrono::Utc::now(),
};
let sona_guard = sona.read();
let _ = sona_guard.record_trajectory(sona_trajectory);
}
{
let mut core = self.core_bank.write();
let query_traj =
ruvector_sona::QueryTrajectory::new(trajectory.timestamp, embedding.to_vec());
core.add_trajectory(&query_traj);
}
let count = self
.trajectories_since_distill
.fetch_add(1, Ordering::SeqCst)
+ 1;
if self.config.auto_distill && count >= self.config.distill_interval as u64 {
self.distill_patterns()?;
self.trajectories_since_distill.store(0, Ordering::SeqCst);
}
Ok(())
}
pub fn distill_patterns(&self) -> Result<Vec<DistilledPattern>> {
let trajectories: Vec<Trajectory> = {
let buffer = self.trajectory_buffer.read();
buffer
.iter()
.filter(|t| t.quality_score >= self.config.distillation_threshold)
.cloned()
.collect()
};
if trajectories.len() < self.config.min_trajectories_for_distillation {
return Ok(Vec::new());
}
{
let mut core = self.core_bank.write();
core.extract_patterns();
}
let clusters = self.cluster_trajectories(&trajectories);
let mut new_patterns = Vec::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
for cluster in clusters {
if cluster.is_empty() {
continue;
}
let dim = cluster[0].embedding.len();
let mut centroid = vec![0.0f32; dim];
for traj in &cluster {
for (i, &e) in traj.embedding.iter().enumerate() {
if i < dim {
centroid[i] += e;
}
}
}
for c in &mut centroid {
*c /= cluster.len() as f32;
}
let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for c in &mut centroid {
*c /= norm;
}
}
let mut agent_scores: HashMap<AgentType, f32> = HashMap::new();
let mut total_quality = 0.0f32;
let mut task_type: Option<String> = None;
for traj in &cluster {
if let Some(agent) = traj.primary_agent {
*agent_scores.entry(agent).or_insert(0.0) += traj.quality_score;
}
total_quality += traj.quality_score;
if task_type.is_none() {
task_type = traj.task_type.clone();
}
}
let total_agent_score: f32 = agent_scores.values().sum();
if total_agent_score > 0.0 {
for score in agent_scores.values_mut() {
*score /= total_agent_score;
}
}
let primary_agent = agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, _)| *agent)
.unwrap_or(AgentType::Coder);
let pattern_id = self.next_pattern_id.fetch_add(1, Ordering::SeqCst);
let pattern = DistilledPattern {
id: pattern_id,
centroid,
primary_agent,
agent_scores,
avg_quality: total_quality / cluster.len() as f32,
trajectory_count: cluster.len(),
task_type,
created_at: now,
last_accessed: now,
access_count: 0,
};
{
let mut patterns = self.patterns.write();
patterns.insert(pattern_id, pattern.clone());
}
new_patterns.push(pattern);
}
self.update_ewc_from_patterns(&new_patterns);
{
let mut stats = self.stats.write();
stats.distillation_runs += 1;
stats.patterns_learned = self.patterns.read().len();
}
Ok(new_patterns)
}
fn cluster_trajectories(&self, trajectories: &[Trajectory]) -> Vec<Vec<Trajectory>> {
if trajectories.is_empty() {
return Vec::new();
}
let k = self.config.num_clusters.min(trajectories.len() / 3).max(1);
let dim = trajectories[0].embedding.len();
let mut centroids: Vec<Vec<f32>> = trajectories
.iter()
.take(k)
.map(|t| t.embedding.clone())
.collect();
let mut assignments = vec![0usize; trajectories.len()];
for _ in 0..10 {
let mut changed = false;
for (i, traj) in trajectories.iter().enumerate() {
let nearest = centroids
.iter()
.enumerate()
.map(|(j, c)| (j, self.cosine_similarity(&traj.embedding, c)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(j, _)| j)
.unwrap_or(0);
if assignments[i] != nearest {
assignments[i] = nearest;
changed = true;
}
}
if !changed {
break;
}
let mut new_centroids = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, traj) in trajectories.iter().enumerate() {
let cluster = assignments[i];
counts[cluster] += 1;
for (j, &e) in traj.embedding.iter().enumerate() {
if j < dim {
new_centroids[cluster][j] += e;
}
}
}
for (i, centroid) in new_centroids.iter_mut().enumerate() {
if counts[i] > 0 {
for c in centroid.iter_mut() {
*c /= counts[i] as f32;
}
}
}
centroids = new_centroids;
}
let mut clusters: Vec<Vec<Trajectory>> = vec![Vec::new(); k];
for (i, traj) in trajectories.iter().enumerate() {
clusters[assignments[i]].push(traj.clone());
}
clusters.into_iter().filter(|c| c.len() >= 2).collect()
}
#[inline]
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
if len != b.len() {
return 0.0;
}
let mut dot: f32 = 0.0;
let mut norm_a_sq: f32 = 0.0;
let mut norm_b_sq: f32 = 0.0;
for i in 0..len {
let x = a[i];
let y = b[i];
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
}
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
fn update_ewc_from_patterns(&self, patterns: &[DistilledPattern]) {
let mut ewc = self.ewc.write();
for pattern in patterns {
let gradients: Vec<f32> = pattern
.centroid
.iter()
.take(self.config.embedding_dim)
.copied()
.chain(std::iter::repeat(0.0))
.take(self.config.embedding_dim)
.collect();
ewc.update_fisher(&gradients);
}
if patterns.len() >= 10 {
ewc.start_new_task();
}
}
pub fn get_recommendation(&self, embedding: &[f32]) -> RoutingRecommendation {
let patterns = self.patterns.read();
if patterns.is_empty() {
return RoutingRecommendation::default();
}
let mut scored: Vec<(&DistilledPattern, f32)> = Vec::with_capacity(patterns.len());
for pattern in patterns.values() {
scored.push((pattern, pattern.similarity(embedding)));
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_patterns: Vec<_> = scored.into_iter().take(5).collect();
if top_patterns.is_empty() {
return RoutingRecommendation::default();
}
{
let mut patterns_mut = self.patterns.write();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
for (pattern, _) in &top_patterns {
if let Some(p) = patterns_mut.get_mut(&pattern.id) {
p.access_count += 1;
p.last_accessed = now;
}
}
}
let mut agent_votes: HashMap<AgentType, f32> = HashMap::with_capacity(16);
let mut total_weight = 0.0f32;
let mut total_quality = 0.0f32;
for (pattern, similarity) in &top_patterns {
let weight = similarity * pattern.avg_quality;
total_weight += weight;
total_quality += pattern.avg_quality;
for (agent, score) in &pattern.agent_scores {
*agent_votes.entry(*agent).or_insert(0.0) += weight * score;
}
*agent_votes.entry(pattern.primary_agent).or_insert(0.0) += weight * 0.5;
}
if total_weight > 0.0 {
for vote in agent_votes.values_mut() {
*vote /= total_weight;
}
}
let (best_agent, best_score) = agent_votes
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, score)| (*agent, *score))
.unwrap_or((AgentType::Coder, 0.0));
let mut alternatives: Vec<(AgentType, f32)> = agent_votes
.into_iter()
.filter(|(agent, _)| *agent != best_agent)
.collect();
alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
alternatives.truncate(3);
let confidence = if top_patterns.is_empty() {
0.0
} else {
let max_similarity = top_patterns[0].1;
(best_score * max_similarity).min(1.0)
};
let avg_pattern_quality = if top_patterns.is_empty() {
0.0
} else {
total_quality / top_patterns.len() as f32
};
let reasoning = format!(
"Based on {} similar patterns with avg quality {:.2}; best match similarity: {:.2}",
top_patterns.len(),
avg_pattern_quality,
top_patterns.first().map(|(_, s)| *s).unwrap_or(0.0)
);
RoutingRecommendation {
agent: best_agent,
confidence,
patterns_used: top_patterns.len(),
avg_pattern_quality,
alternatives,
reasoning,
}
}
pub fn consolidate(&self) -> Result<()> {
{
let mut patterns = self.patterns.write();
let to_remove: Vec<u64> = patterns
.iter()
.filter(|(_, p)| {
p.should_prune(
self.config.min_pattern_quality,
self.config.max_pattern_age_secs,
)
})
.map(|(id, _)| *id)
.collect();
for id in to_remove {
patterns.remove(&id);
}
}
{
let mut patterns = self.patterns.write();
let pattern_ids: Vec<u64> = patterns.keys().copied().collect();
let mut merged_ids = Vec::new();
for i in 0..pattern_ids.len() {
for j in i + 1..pattern_ids.len() {
let id1 = pattern_ids[i];
let id2 = pattern_ids[j];
if merged_ids.contains(&id1) || merged_ids.contains(&id2) {
continue;
}
if let (Some(p1), Some(p2)) = (patterns.get(&id1), patterns.get(&id2)) {
let similarity = p1.similarity(&p2.centroid);
if similarity > self.config.consolidation_similarity {
let merged = self.merge_patterns(p1, p2);
patterns.insert(id1, merged);
merged_ids.push(id2);
}
}
}
}
for id in merged_ids {
patterns.remove(&id);
}
}
{
let mut ewc = self.ewc.write();
ewc.consolidate_all_tasks();
}
{
let mut stats = self.stats.write();
stats.consolidation_runs += 1;
stats.patterns_learned = self.patterns.read().len();
stats.ewc_tasks = self.ewc.read().task_count();
}
Ok(())
}
fn merge_patterns(&self, p1: &DistilledPattern, p2: &DistilledPattern) -> DistilledPattern {
let total_count = p1.trajectory_count + p2.trajectory_count;
let w1 = p1.trajectory_count as f32 / total_count as f32;
let w2 = p2.trajectory_count as f32 / total_count as f32;
let centroid: Vec<f32> = p1
.centroid
.iter()
.zip(&p2.centroid)
.map(|(&a, &b)| a * w1 + b * w2)
.collect();
let mut agent_scores: HashMap<AgentType, f32> = p1.agent_scores.clone();
for (agent, score) in &p2.agent_scores {
*agent_scores.entry(*agent).or_insert(0.0) += score * w2;
}
let total: f32 = agent_scores.values().sum();
if total > 0.0 {
for score in agent_scores.values_mut() {
*score /= total;
}
}
let primary_agent = agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(agent, _)| *agent)
.unwrap_or(p1.primary_agent);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
DistilledPattern {
id: p1.id,
centroid,
primary_agent,
agent_scores,
avg_quality: p1.avg_quality * w1 + p2.avg_quality * w2,
trajectory_count: total_count,
task_type: p1.task_type.clone().or_else(|| p2.task_type.clone()),
created_at: p1.created_at.min(p2.created_at),
last_accessed: now,
access_count: p1.access_count + p2.access_count,
}
}
pub fn stats(&self) -> ReasoningBankStats {
let mut stats = self.stats.read().clone();
stats.buffer_size = self.trajectory_buffer.read().len();
stats.patterns_learned = self.patterns.read().len();
stats.ewc_tasks = self.ewc.read().task_count();
stats
}
pub fn get_patterns(&self) -> Vec<DistilledPattern> {
self.patterns.read().values().cloned().collect()
}
pub fn trajectory_count(&self) -> usize {
self.trajectory_buffer.read().len()
}
pub fn pattern_count(&self) -> usize {
self.patterns.read().len()
}
pub fn clear(&self) {
self.trajectory_buffer.write().clear();
self.patterns.write().clear();
*self.stats.write() = ReasoningBankStats::default();
self.trajectories_since_distill.store(0, Ordering::SeqCst);
}
pub fn export_patterns(&self) -> Vec<DistilledPattern> {
self.patterns.read().values().cloned().collect()
}
pub fn import_patterns(&self, patterns: Vec<DistilledPattern>) {
let mut pattern_map = self.patterns.write();
for pattern in patterns {
let id = pattern.id.max(self.next_pattern_id.load(Ordering::SeqCst));
self.next_pattern_id.fetch_max(id + 1, Ordering::SeqCst);
pattern_map.insert(pattern.id, pattern);
}
self.stats.write().patterns_learned = pattern_map.len();
}
}
impl std::fmt::Debug for ReasoningBankIntegration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReasoningBankIntegration")
.field("config", &self.config)
.field("trajectory_count", &self.trajectory_count())
.field("pattern_count", &self.pattern_count())
.field("stats", &self.stats())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_verdict_quality_scores() {
assert_eq!(
Verdict::Success {
reason: "ok".into()
}
.quality_score(),
1.0
);
assert_eq!(
Verdict::Failure {
reason: "err".into(),
error_code: None
}
.quality_score(),
0.0
);
assert_eq!(
Verdict::Partial {
completion: 0.7,
reason: "partial".into()
}
.quality_score(),
0.7
);
}
#[test]
fn test_trajectory_step_creation() {
let step = TrajectoryStep::new("test_step", 0.8)
.with_agent(AgentType::Coder)
.with_duration(100)
.with_metadata("key", "value");
assert_eq!(step.name, "test_step");
assert_eq!(step.quality, 0.8);
assert_eq!(step.agent, Some(AgentType::Coder));
assert_eq!(step.duration_ms, Some(100));
assert_eq!(step.metadata.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_trajectory_creation() {
let steps = vec![
TrajectoryStep::new("step1", 0.7).with_agent(AgentType::Researcher),
TrajectoryStep::new("step2", 0.9).with_agent(AgentType::Coder),
];
let traj = Trajectory::new(
"task-1",
vec![0.1, 0.2, 0.3],
steps,
Verdict::Success {
reason: "done".into(),
},
);
assert_eq!(traj.task_id, "task-1");
assert_eq!(traj.steps.len(), 2);
assert!((traj.quality_score - 0.86).abs() < 0.01);
}
#[test]
fn test_reasoning_bank_creation() {
let config = ReasoningBankConfig::default();
let bank = ReasoningBankIntegration::new(config);
assert_eq!(bank.trajectory_count(), 0);
assert_eq!(bank.pattern_count(), 0);
}
#[test]
fn test_record_trajectory() {
let config = ReasoningBankConfig {
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
let steps = vec![TrajectoryStep::new("step1", 0.8).with_agent(AgentType::Coder)];
bank.record_trajectory(
"task-1",
&vec![0.1; 384],
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
assert_eq!(bank.trajectory_count(), 1);
let stats = bank.stats();
assert_eq!(stats.total_trajectories, 1);
assert_eq!(stats.successful_trajectories, 1);
}
#[test]
fn test_distill_patterns() {
let config = ReasoningBankConfig {
min_trajectories_for_distillation: 5,
distillation_threshold: 0.0, num_clusters: 2,
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
for i in 0..10 {
let embedding: Vec<f32> = if i < 5 {
vec![1.0, 0.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect()
} else {
vec![0.0, 1.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect()
};
let steps = vec![TrajectoryStep::new("step", 0.8).with_agent(AgentType::Coder)];
bank.record_trajectory(
format!("task-{}", i),
&embedding,
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
}
let patterns = bank.distill_patterns().unwrap();
assert!(!patterns.is_empty());
}
#[test]
fn test_get_recommendation() {
let config = ReasoningBankConfig {
min_trajectories_for_distillation: 2,
distillation_threshold: 0.0,
num_clusters: 1,
auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
for i in 0..5 {
let embedding: Vec<f32> = vec![1.0, 0.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect();
let steps = vec![TrajectoryStep::new("step", 0.9).with_agent(AgentType::Tester)];
bank.record_trajectory(
format!("task-{}", i),
&embedding,
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
}
bank.distill_patterns().unwrap();
let query: Vec<f32> = vec![0.9, 0.1, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect();
let rec = bank.get_recommendation(&query);
assert!(rec.patterns_used > 0);
assert!(rec.confidence > 0.0);
}
#[test]
fn test_consolidate() {
let config = ReasoningBankConfig {
min_trajectories_for_distillation: 2,
distillation_threshold: 0.0,
num_clusters: 2,
consolidation_similarity: 0.99, auto_distill: false,
..Default::default()
};
let bank = ReasoningBankIntegration::new(config);
for i in 0..6 {
let embedding: Vec<f32> = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0]
.into_iter()
.chain(std::iter::repeat(0.0))
.take(384)
.collect();
let steps = vec![TrajectoryStep::new("step", 0.8).with_agent(AgentType::Coder)];
bank.record_trajectory(
format!("task-{}", i),
&embedding,
steps,
Verdict::Success {
reason: "done".into(),
},
)
.unwrap();
}
bank.distill_patterns().unwrap();
let before = bank.pattern_count();
bank.consolidate().unwrap();
let after = bank.pattern_count();
assert!(after <= before);
}
#[test]
fn test_distilled_pattern_similarity() {
let pattern = DistilledPattern {
id: 1,
centroid: vec![1.0, 0.0, 0.0, 0.0],
primary_agent: AgentType::Coder,
agent_scores: HashMap::new(),
avg_quality: 0.9,
trajectory_count: 10,
task_type: None,
created_at: 0,
last_accessed: 0,
access_count: 0,
};
let same = vec![1.0, 0.0, 0.0, 0.0];
let orthogonal = vec![0.0, 1.0, 0.0, 0.0];
assert!((pattern.similarity(&same) - 1.0).abs() < 0.01);
assert!(pattern.similarity(&orthogonal).abs() < 0.01);
}
#[test]
fn test_export_import_patterns() {
let config = ReasoningBankConfig::default();
let bank = ReasoningBankIntegration::new(config.clone());
let pattern = DistilledPattern {
id: 42,
centroid: vec![0.5; 384],
primary_agent: AgentType::Researcher,
agent_scores: HashMap::from([(AgentType::Researcher, 0.8), (AgentType::Coder, 0.2)]),
avg_quality: 0.85,
trajectory_count: 50,
task_type: Some("research".to_string()),
created_at: 1000,
last_accessed: 2000,
access_count: 10,
};
bank.import_patterns(vec![pattern.clone()]);
assert_eq!(bank.pattern_count(), 1);
let exported = bank.export_patterns();
assert_eq!(exported.len(), 1);
assert_eq!(exported[0].id, 42);
assert_eq!(exported[0].primary_agent, AgentType::Researcher);
}
#[test]
fn test_config_presets() {
let default = ReasoningBankConfig::default();
let small = ReasoningBankConfig::for_ruvltra_small();
let edge = ReasoningBankConfig::for_edge();
assert!(default.capacity > small.capacity);
assert!(small.capacity > edge.capacity);
assert!(edge.num_clusters < small.num_clusters);
}
}