use super::patterns::LearnedPattern;
use dashmap::DashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::SystemTime;
#[derive(Debug, Clone)]
struct PatternEntry {
pattern: LearnedPattern,
usage_count: usize,
last_used: SystemTime,
}
pub struct ReasoningBank {
patterns: DashMap<usize, PatternEntry>,
next_id: AtomicUsize,
}
impl ReasoningBank {
pub fn new() -> Self {
Self {
patterns: DashMap::new(),
next_id: AtomicUsize::new(0),
}
}
pub fn store(&self, pattern: LearnedPattern) -> usize {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let entry = PatternEntry {
pattern,
usage_count: 0,
last_used: SystemTime::now(),
};
self.patterns.insert(id, entry);
id
}
pub fn lookup(&self, query: &[f32], k: usize) -> Vec<(usize, LearnedPattern, f64)> {
let mut similarities: Vec<(usize, LearnedPattern, f64)> = self
.patterns
.iter()
.map(|entry| {
let id = *entry.key();
let pattern = &entry.value().pattern;
let similarity = pattern.similarity(query);
(id, pattern.clone(), similarity)
})
.collect();
similarities.sort_by(|a, b| {
let score_a = a.2 * a.1.confidence;
let score_b = b.2 * b.1.confidence;
score_b.partial_cmp(&score_a).unwrap()
});
similarities.truncate(k);
for (id, _, _) in &similarities {
if let Some(mut entry) = self.patterns.get_mut(id) {
entry.usage_count += 1;
entry.last_used = SystemTime::now();
}
}
similarities
}
pub fn get(&self, id: usize) -> Option<LearnedPattern> {
self.patterns.get_mut(&id).map(|mut entry| {
entry.usage_count += 1;
entry.last_used = SystemTime::now();
entry.pattern.clone()
})
}
pub fn consolidate(&self, similarity_threshold: f64) -> usize {
let patterns: Vec<(usize, LearnedPattern)> = self
.patterns
.iter()
.map(|entry| (*entry.key(), entry.value().pattern.clone()))
.collect();
if patterns.len() < 2 {
return 0;
}
let mut to_remove = Vec::new();
let mut merged = 0;
for i in 0..patterns.len() {
if to_remove.contains(&patterns[i].0) {
continue;
}
for j in (i + 1)..patterns.len() {
if to_remove.contains(&patterns[j].0) {
continue;
}
let sim = patterns[i].1.similarity(&patterns[j].1.centroid);
if sim >= similarity_threshold {
if let Some(mut entry_i) = self.patterns.get_mut(&patterns[i].0) {
if let Some(entry_j) = self.patterns.get(&patterns[j].0) {
let total_samples =
entry_i.pattern.sample_count + entry_j.pattern.sample_count;
let weight_i =
entry_i.pattern.sample_count as f64 / total_samples as f64;
let weight_j =
entry_j.pattern.sample_count as f64 / total_samples as f64;
for k in 0..entry_i.pattern.centroid.len() {
entry_i.pattern.centroid[k] = (entry_i.pattern.centroid[k] as f64
* weight_i
+ entry_j.pattern.centroid[k] as f64 * weight_j)
as f32;
}
entry_i.pattern.optimal_ef = (entry_i.pattern.optimal_ef as f64
* weight_i
+ entry_j.pattern.optimal_ef as f64 * weight_j)
as usize;
entry_i.pattern.optimal_probes = (entry_i.pattern.optimal_probes as f64
* weight_i
+ entry_j.pattern.optimal_probes as f64 * weight_j)
as usize;
entry_i.pattern.sample_count += entry_j.pattern.sample_count;
entry_i.pattern.avg_latency_us = entry_i.pattern.avg_latency_us
* weight_i
+ entry_j.pattern.avg_latency_us * weight_j;
entry_i.pattern.confidence = (entry_i.pattern.confidence * weight_i
+ entry_j.pattern.confidence * weight_j)
.min(1.0);
entry_i.usage_count += entry_j.usage_count;
}
}
to_remove.push(patterns[j].0);
merged += 1;
}
}
}
for id in to_remove {
self.patterns.remove(&id);
}
merged
}
pub fn prune(&self, min_usage: usize, min_confidence: f64) -> usize {
let to_remove: Vec<usize> = self
.patterns
.iter()
.filter(|entry| {
entry.value().usage_count < min_usage
|| entry.value().pattern.confidence < min_confidence
})
.map(|entry| *entry.key())
.collect();
let count = to_remove.len();
for id in to_remove {
self.patterns.remove(&id);
}
count
}
pub fn len(&self) -> usize {
self.patterns.len()
}
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
pub fn stats(&self) -> BankStats {
if self.patterns.is_empty() {
return BankStats::default();
}
let total = self.patterns.len();
let total_samples: usize = self
.patterns
.iter()
.map(|e| e.value().pattern.sample_count)
.sum();
let avg_confidence: f64 = self
.patterns
.iter()
.map(|e| e.value().pattern.confidence)
.sum::<f64>()
/ total as f64;
let total_usage: usize = self.patterns.iter().map(|e| e.value().usage_count).sum();
BankStats {
total_patterns: total,
total_samples,
avg_confidence,
total_usage,
}
}
pub fn clear(&self) {
self.patterns.clear();
self.next_id.store(0, Ordering::SeqCst);
}
}
impl Default for ReasoningBank {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct BankStats {
pub total_patterns: usize,
pub total_samples: usize,
pub avg_confidence: f64,
pub total_usage: usize,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_pattern(centroid: Vec<f32>, ef: usize) -> LearnedPattern {
LearnedPattern::new(centroid, ef, 10, 0.9, 100, 1000.0, Some(0.95))
}
#[test]
fn test_store_and_lookup() {
let bank = ReasoningBank::new();
let pattern1 = create_test_pattern(vec![1.0, 0.0, 0.0], 50);
let pattern2 = create_test_pattern(vec![0.0, 1.0, 0.0], 60);
bank.store(pattern1);
bank.store(pattern2);
assert_eq!(bank.len(), 2);
let query = vec![0.9, 0.1, 0.0];
let results = bank.lookup(&query, 2);
assert_eq!(results.len(), 2);
assert!(results[0].2 > results[1].2); }
#[test]
fn test_consolidate() {
let bank = ReasoningBank::new();
let pattern1 = create_test_pattern(vec![1.0, 0.0], 50);
let pattern2 = create_test_pattern(vec![0.99, 0.01], 50);
let pattern3 = create_test_pattern(vec![0.0, 1.0], 60);
bank.store(pattern1);
bank.store(pattern2);
bank.store(pattern3);
assert_eq!(bank.len(), 3);
let merged = bank.consolidate(0.95);
assert!(merged > 0);
assert!(bank.len() < 3);
}
#[test]
fn test_prune() {
let bank = ReasoningBank::new();
let mut pattern_low_conf = create_test_pattern(vec![1.0, 0.0], 50);
pattern_low_conf.confidence = 0.3;
bank.store(pattern_low_conf);
bank.store(create_test_pattern(vec![0.0, 1.0], 60));
assert_eq!(bank.len(), 2);
let pruned = bank.prune(0, 0.5);
assert_eq!(pruned, 1);
assert_eq!(bank.len(), 1);
}
#[test]
fn test_stats() {
let bank = ReasoningBank::new();
bank.store(create_test_pattern(vec![1.0], 50));
bank.store(create_test_pattern(vec![2.0], 60));
let stats = bank.stats();
assert_eq!(stats.total_patterns, 2);
assert_eq!(stats.total_samples, 200);
assert_eq!(stats.avg_confidence, 0.9);
}
}