use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DagPattern {
pub id: u64,
pub vector: Vec<f32>,
pub quality_score: f32,
pub usage_count: usize,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct ReasoningBankConfig {
pub num_clusters: usize,
pub pattern_dim: usize,
pub max_patterns: usize,
pub similarity_threshold: f32,
}
impl Default for ReasoningBankConfig {
fn default() -> Self {
Self {
num_clusters: 100,
pattern_dim: 256,
max_patterns: 10000,
similarity_threshold: 0.7,
}
}
}
pub struct DagReasoningBank {
config: ReasoningBankConfig,
patterns: Vec<DagPattern>,
centroids: Vec<Vec<f32>>,
cluster_assignments: Vec<usize>,
next_id: u64,
}
impl DagReasoningBank {
pub fn new(config: ReasoningBankConfig) -> Self {
Self {
config,
patterns: Vec::new(),
centroids: Vec::new(),
cluster_assignments: Vec::new(),
next_id: 0,
}
}
pub fn store_pattern(&mut self, vector: Vec<f32>, quality: f32) -> u64 {
let id = self.next_id;
self.next_id += 1;
let pattern = DagPattern {
id,
vector,
quality_score: quality,
usage_count: 0,
metadata: HashMap::new(),
};
self.patterns.push(pattern);
if self.patterns.len() > self.config.max_patterns {
self.evict_lowest_quality();
}
id
}
pub fn query_similar(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
let mut similarities: Vec<(u64, f32)> = self
.patterns
.iter()
.map(|p| (p.id, cosine_similarity(&p.vector, query)))
.filter(|(_, sim)| *sim >= self.config.similarity_threshold)
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
similarities.truncate(k);
similarities
}
pub fn recompute_clusters(&mut self) {
if self.patterns.is_empty() {
return;
}
let k = self.config.num_clusters.min(self.patterns.len());
self.centroids = kmeans_pp_init(&self.patterns, k);
for _ in 0..10 {
self.cluster_assignments = self
.patterns
.iter()
.map(|p| self.nearest_centroid(&p.vector))
.collect();
self.update_centroids();
}
}
fn nearest_centroid(&self, point: &[f32]) -> usize {
self.centroids
.iter()
.enumerate()
.map(|(i, c)| (i, euclidean_distance(point, c)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0)
}
fn update_centroids(&mut self) {
let k = self.centroids.len();
let dim = if !self.centroids.is_empty() {
self.centroids[0].len()
} else {
return;
};
let mut new_centroids = vec![vec![0.0; dim]; k];
let mut counts = vec![0usize; k];
for (pattern, &cluster) in self.patterns.iter().zip(self.cluster_assignments.iter()) {
if cluster < k {
for (i, &val) in pattern.vector.iter().enumerate() {
new_centroids[cluster][i] += val;
}
counts[cluster] += 1;
}
}
for (centroid, count) in new_centroids.iter_mut().zip(counts.iter()) {
if *count > 0 {
for val in centroid.iter_mut() {
*val /= *count as f32;
}
}
}
self.centroids = new_centroids;
}
fn evict_lowest_quality(&mut self) {
if let Some(min_idx) = self
.patterns
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let score_a = a.quality_score * (a.usage_count as f32 + 1.0).ln();
let score_b = b.quality_score * (b.usage_count as f32 + 1.0).ln();
score_a.partial_cmp(&score_b).unwrap()
})
.map(|(i, _)| i)
{
self.patterns.remove(min_idx);
}
}
pub fn pattern_count(&self) -> usize {
self.patterns.len()
}
pub fn cluster_count(&self) -> usize {
self.centroids.len()
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn kmeans_pp_init(patterns: &[DagPattern], k: usize) -> Vec<Vec<f32>> {
use rand::Rng;
if patterns.is_empty() || k == 0 {
return Vec::new();
}
let mut rng = rand::thread_rng();
let mut centroids = Vec::with_capacity(k);
let _dim = patterns[0].vector.len();
let first_idx = rng.gen_range(0..patterns.len());
centroids.push(patterns[first_idx].vector.clone());
for _ in 1..k {
let mut distances = Vec::with_capacity(patterns.len());
let mut total_distance = 0.0f32;
for pattern in patterns {
let min_dist = centroids
.iter()
.map(|c| euclidean_distance(&pattern.vector, c))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0);
let squared = min_dist * min_dist;
distances.push(squared);
total_distance += squared;
}
if total_distance > 0.0 {
let mut threshold = rng.gen::<f32>() * total_distance;
for (idx, &dist) in distances.iter().enumerate() {
threshold -= dist;
if threshold <= 0.0 {
centroids.push(patterns[idx].vector.clone());
break;
}
}
} else {
let idx = rng.gen_range(0..patterns.len());
centroids.push(patterns[idx].vector.clone());
}
if centroids.len() >= k {
break;
}
}
centroids
}