pub mod consolidation;
pub mod distillation;
pub mod pattern_store;
pub mod trajectory;
pub mod verdicts;
pub use consolidation::{
ConsolidationConfig, ConsolidationResult, FisherInformation, ImportanceScore,
PatternConsolidator,
};
pub use distillation::{
CompressedTrajectory, DistillationConfig, DistillationResult, KeyLesson, MemoryDistiller,
};
pub use pattern_store::{
Pattern, PatternCategory, PatternSearchResult, PatternStats, PatternStore, PatternStoreConfig,
};
pub use trajectory::{
StepOutcome, Trajectory, TrajectoryId, TrajectoryMetadata, TrajectoryRecorder, TrajectoryStep,
};
pub use verdicts::{
FailurePattern, RecoveryStrategy, RootCause, Verdict, VerdictAnalysis, VerdictAnalyzer,
};
use crate::error::Result;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningBankConfig {
pub storage_path: String,
pub embedding_dim: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub m: usize,
pub max_trajectories: usize,
pub min_quality_threshold: f32,
pub consolidation_interval_secs: u64,
pub auto_consolidate: bool,
pub pattern_config: PatternStoreConfig,
pub consolidation_config: ConsolidationConfig,
pub distillation_config: DistillationConfig,
}
impl Default for ReasoningBankConfig {
fn default() -> Self {
Self {
storage_path: ".ruvllm/reasoning_bank".to_string(),
embedding_dim: 768,
ef_construction: 200,
ef_search: 100,
m: 32,
max_trajectories: 100_000,
min_quality_threshold: 0.3,
consolidation_interval_secs: 3600, auto_consolidate: true,
pattern_config: PatternStoreConfig::default(),
consolidation_config: ConsolidationConfig::default(),
distillation_config: DistillationConfig::default(),
}
}
}
pub struct ReasoningBank {
config: ReasoningBankConfig,
pattern_store: Arc<RwLock<PatternStore>>,
verdict_analyzer: VerdictAnalyzer,
consolidator: PatternConsolidator,
distiller: MemoryDistiller,
trajectories: Arc<RwLock<Vec<Trajectory>>>,
stats: Arc<RwLock<ReasoningBankStats>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ReasoningBankStats {
pub total_trajectories: u64,
pub total_patterns: u64,
pub success_count: u64,
pub failure_count: u64,
pub recovered_count: u64,
pub consolidation_count: u64,
pub distillation_count: u64,
pub avg_quality: f32,
pub last_consolidation: u64,
pub last_distillation: u64,
}
impl ReasoningBank {
pub fn new(config: ReasoningBankConfig) -> Result<Self> {
let pattern_store = PatternStore::new(config.pattern_config.clone())?;
let verdict_analyzer = VerdictAnalyzer::new();
let consolidator = PatternConsolidator::new(config.consolidation_config.clone());
let distiller = MemoryDistiller::new(config.distillation_config.clone());
Ok(Self {
config,
pattern_store: Arc::new(RwLock::new(pattern_store)),
verdict_analyzer,
consolidator,
distiller,
trajectories: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(ReasoningBankStats::default())),
})
}
pub fn start_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryRecorder {
TrajectoryRecorder::new(query_embedding)
}
pub fn store_trajectory(&self, trajectory: Trajectory) -> Result<()> {
{
let mut stats = self.stats.write();
stats.total_trajectories += 1;
match &trajectory.verdict {
Verdict::Success => stats.success_count += 1,
Verdict::Failure(_) => stats.failure_count += 1,
Verdict::RecoveredViaReflection { .. } => stats.recovered_count += 1,
_ => {}
}
let n = stats.total_trajectories as f32;
stats.avg_quality = stats.avg_quality * ((n - 1.0) / n) + trajectory.quality / n;
}
{
let mut trajectories = self.trajectories.write();
trajectories.push(trajectory.clone());
if trajectories.len() > self.config.max_trajectories {
drop(trajectories);
self.distill()?;
}
}
if trajectory.quality >= self.config.min_quality_threshold {
let pattern = Pattern::from_trajectory(&trajectory);
let mut store = self.pattern_store.write();
store.store_pattern(pattern)?;
let mut stats = self.stats.write();
stats.total_patterns += 1;
}
Ok(())
}
pub fn analyze_verdict(&self, trajectory: &Trajectory) -> VerdictAnalysis {
self.verdict_analyzer.analyze(trajectory)
}
pub fn search_similar(
&self,
query_embedding: &[f32],
limit: usize,
) -> Result<Vec<PatternSearchResult>> {
let store = self.pattern_store.read();
store.search_similar(query_embedding, limit)
}
pub fn search_by_category(
&self,
category: PatternCategory,
limit: usize,
) -> Result<Vec<Pattern>> {
let store = self.pattern_store.read();
store.get_by_category(category, limit)
}
pub fn consolidate(&self) -> Result<ConsolidationResult> {
let mut store = self.pattern_store.write();
let patterns = store.get_all_patterns()?;
let result = self.consolidator.consolidate_patterns(&patterns)?;
for pattern_id in &result.merged_pattern_ids {
store.remove_pattern(*pattern_id)?;
}
for pattern_id in &result.pruned_pattern_ids {
store.remove_pattern(*pattern_id)?;
}
{
let mut stats = self.stats.write();
stats.consolidation_count += 1;
stats.last_consolidation = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
}
Ok(result)
}
pub fn distill(&self) -> Result<DistillationResult> {
let trajectories = {
let mut traj = self.trajectories.write();
std::mem::take(&mut *traj)
};
let result = self.distiller.extract_key_lessons(&trajectories)?;
{
let mut traj = self.trajectories.write();
for compressed in &result.compressed_trajectories {
let minimal = Trajectory::from_compressed(compressed);
traj.push(minimal);
}
}
{
let mut store = self.pattern_store.write();
for lesson in &result.key_lessons {
let pattern = Pattern::from_lesson(lesson);
store.store_pattern(pattern)?;
}
}
{
let mut stats = self.stats.write();
stats.distillation_count += 1;
stats.last_distillation = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
}
Ok(result)
}
pub fn prune_low_quality(&self, min_quality: f32) -> Result<usize> {
let mut store = self.pattern_store.write();
store.prune_low_quality(min_quality)
}
pub fn merge_similar_patterns(&self, similarity_threshold: f32) -> Result<usize> {
let mut store = self.pattern_store.write();
store.merge_similar(similarity_threshold)
}
pub fn stats(&self) -> ReasoningBankStats {
self.stats.read().clone()
}
pub fn pattern_stats(&self) -> PatternStats {
self.pattern_store.read().stats()
}
pub fn config(&self) -> &ReasoningBankConfig {
&self.config
}
pub fn export_patterns(&self) -> Result<Vec<Pattern>> {
let store = self.pattern_store.read();
store.get_all_patterns()
}
pub fn import_patterns(&self, patterns: Vec<Pattern>) -> Result<usize> {
let mut store = self.pattern_store.write();
let mut imported = 0;
for pattern in patterns {
if store.store_pattern(pattern).is_ok() {
imported += 1;
}
}
let mut stats = self.stats.write();
stats.total_patterns += imported as u64;
Ok(imported)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reasoning_bank_config_default() {
let config = ReasoningBankConfig::default();
assert_eq!(config.embedding_dim, 768);
assert_eq!(config.ef_construction, 200);
assert_eq!(config.ef_search, 100);
assert_eq!(config.m, 32);
}
#[test]
fn test_reasoning_bank_creation() {
let config = ReasoningBankConfig {
storage_path: "/tmp/test_reasoning_bank".to_string(),
..Default::default()
};
let bank = ReasoningBank::new(config);
assert!(bank.is_ok());
}
#[test]
fn test_trajectory_recording() {
let config = ReasoningBankConfig::default();
let bank = ReasoningBank::new(config).unwrap();
let mut recorder = bank.start_trajectory(vec![0.1; 768]);
recorder.add_step(
"analyze".to_string(),
"Need to understand the problem".to_string(),
StepOutcome::Success,
0.9,
);
let trajectory = recorder.complete(Verdict::Success);
assert!(!trajectory.steps.is_empty());
}
#[test]
fn test_stats_tracking() {
let config = ReasoningBankConfig::default();
let bank = ReasoningBank::new(config).unwrap();
let stats = bank.stats();
assert_eq!(stats.total_trajectories, 0);
}
}