use std::collections::HashMap;
use anyhow::Result;
use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::embedding::{EmbeddingProvider, EmbeddingVector};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuidancePattern {
pub id: String,
pub strategy: String,
pub domain: String,
#[serde(skip)]
pub embedding: Option<EmbeddingVector>,
pub quality: f32,
pub usage_count: u32,
pub success_count: u32,
pub created_at: DateTime<Utc>,
pub is_long_term: bool,
}
impl GuidancePattern {
pub fn success_rate(&self) -> f32 {
if self.usage_count == 0 {
0.0
} else {
self.success_count as f32 / self.usage_count as f32
}
}
pub fn combined_score(&self) -> f32 {
let recency = {
let age_hours = (Utc::now() - self.created_at).num_hours().max(1) as f32;
1.0 / (1.0 + age_hours * 0.01) };
self.quality * self.success_rate().max(0.1) * recency
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatternMatch {
pub pattern: GuidancePattern,
pub similarity: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingResult {
pub agent: String,
pub confidence: f32,
pub reasoning: String,
}
struct RoutingEntry {
keywords: &'static [&'static str],
agent: &'static str,
}
impl RoutingEntry {
fn default_table() -> &'static [RoutingEntry] {
&[
RoutingEntry {
keywords: &["security", "auth", "password", "token", "vulnerability", "csrf", "xss", "injection"],
agent: "security-auditor",
},
RoutingEntry {
keywords: &["test", "spec", "mock", "coverage", "unit test", "integration test"],
agent: "tester",
},
RoutingEntry {
keywords: &["perf", "optimize", "slow", "memory leak", "latency", "throughput", "benchmark"],
agent: "performance-engineer",
},
RoutingEntry {
keywords: &["fix", "bug", "error", "debug", "crash", "traceback", "panic"],
agent: "researcher",
},
RoutingEntry {
keywords: &["refactor", "architect", "design", "restructure", "reorganize"],
agent: "system-architect",
},
RoutingEntry {
keywords: &["deploy", "ci", "cd", "pipeline", "release", "build"],
agent: "devops",
},
RoutingEntry {
keywords: &["document", "docs", "readme", "comment", "explain"],
agent: "documenter",
},
]
}
}
pub struct ReasoningBank {
short_term: RwLock<HashMap<String, GuidancePattern>>,
long_term: RwLock<HashMap<String, GuidancePattern>>,
embedding: std::sync::Arc<dyn EmbeddingProvider>,
}
impl std::fmt::Debug for ReasoningBank {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReasoningBank")
.field("short_term_count", &self.short_term.read().len())
.field("long_term_count", &self.long_term.read().len())
.finish()
}
}
impl ReasoningBank {
pub fn new(embedding: std::sync::Arc<dyn EmbeddingProvider>) -> Self {
Self {
short_term: RwLock::new(HashMap::new()),
long_term: RwLock::new(HashMap::new()),
embedding,
}
}
pub async fn store_pattern(&self, mut pattern: GuidancePattern) -> Result<String> {
if pattern.id.is_empty() {
pattern.id = Uuid::new_v4().to_string();
}
let embedding = self.embedding.embed(&pattern.strategy).await?;
pattern.embedding = Some(embedding);
let id = pattern.id.clone();
self.short_term.write().insert(id.clone(), pattern);
tracing::debug!(id = %id, domain = %self.short_term.read().get(&id).map(|p| p.domain.clone()).unwrap_or_default(), "Pattern stored");
Ok(id)
}
pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<PatternMatch>> {
let query_embedding = self.embedding.embed(query).await?;
let mut matches = Vec::new();
{
let bank = self.short_term.read();
for pattern in bank.values() {
if let Some(ref emb) = pattern.embedding {
let sim = query_embedding.cosine_similarity(emb);
if sim > 0.1 {
matches.push(PatternMatch {
pattern: pattern.clone(),
similarity: sim,
});
}
}
}
}
{
let bank = self.long_term.read();
for pattern in bank.values() {
if let Some(ref emb) = pattern.embedding {
let sim = query_embedding.cosine_similarity(emb);
if sim > 0.1 {
matches.push(PatternMatch {
pattern: pattern.clone(),
similarity: sim,
});
}
}
}
}
matches.sort_by(|a, b| {
let score_a = a.similarity * (a.pattern.combined_score() as f64);
let score_b = b.similarity * (b.pattern.combined_score() as f64);
score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
});
matches.truncate(limit);
Ok(matches)
}
pub async fn search_by_domain(
&self,
query: &str,
domain: &str,
limit: usize,
) -> Result<Vec<PatternMatch>> {
let all = self.search(query, limit * 3).await?;
let filtered: Vec<_> = all
.into_iter()
.filter(|m| m.pattern.domain == domain)
.take(limit)
.collect();
Ok(filtered)
}
pub async fn route_task(&self, task: &str) -> Result<RoutingResult> {
let task_lower = task.to_lowercase();
let mut best_agent = "coder";
let mut best_keyword_count = 0usize;
for entry in RoutingEntry::default_table() {
let count = entry
.keywords
.iter()
.filter(|kw| task_lower.contains(&kw.to_lowercase()))
.count();
if count > best_keyword_count {
best_keyword_count = count;
best_agent = entry.agent;
}
}
let pattern_matches = self.search(task, 5).await.unwrap_or_default();
let (agent, confidence, reasoning) = if !pattern_matches.is_empty() {
let best_match = &pattern_matches[0];
let pattern_confidence = (best_match.similarity * best_match.pattern.quality as f64) as f32;
if pattern_confidence > 0.7 && best_match.pattern.domain != best_agent {
let pattern_agent = domain_to_agent(&best_match.pattern.domain);
if pattern_confidence > (best_keyword_count as f32 * 0.2).min(1.0) {
(
pattern_agent.to_string(),
pattern_confidence,
format!(
"Pattern '{}' (domain: {}, quality: {:.2}) suggests {} agent",
best_match.pattern.strategy,
best_match.pattern.domain,
best_match.pattern.quality,
pattern_agent,
),
)
} else {
(
best_agent.to_string(),
(best_keyword_count as f32 * 0.25).min(1.0),
format!(
"Keyword routing to {} ({} keyword matches), pattern alternative: {}",
best_agent,
best_keyword_count,
best_match.pattern.strategy,
),
)
}
} else {
(
best_agent.to_string(),
(best_keyword_count as f32 * 0.25).min(0.9),
format!(
"Keyword routing to {} ({} matches)",
best_agent, best_keyword_count,
),
)
}
} else {
(
best_agent.to_string(),
if best_keyword_count > 0 {
(best_keyword_count as f32 * 0.25).min(0.9)
} else {
0.5 },
format!(
"Default keyword routing to {} ({} matches, no patterns)",
best_agent, best_keyword_count,
),
)
};
Ok(RoutingResult {
agent,
confidence,
reasoning,
})
}
pub fn promote(&self, pattern_id: &str, min_quality: f32) -> Result<bool> {
let mut short = self.short_term.write();
if let Some(pattern) = short.remove(pattern_id) {
if pattern.quality >= min_quality {
let mut long = self.long_term.write();
long.insert(pattern_id.to_string(), pattern);
tracing::info!(id = %pattern_id, "Pattern promoted to long-term");
Ok(true)
} else {
short.insert(pattern_id.to_string(), pattern);
tracing::debug!(id = %pattern_id, quality = %short.get(pattern_id).map(|p| p.quality).unwrap_or(0.0), "Pattern not promoted (quality too low)");
Ok(false)
}
} else {
let long = self.long_term.read();
Ok(long.contains_key(pattern_id))
}
}
pub fn auto_promote(&self, min_quality: f32, min_usage: u32) -> usize {
let mut short = self.short_term.write();
let mut long = self.long_term.write();
let candidates: Vec<String> = short
.iter()
.filter(|(_, p)| p.quality >= min_quality && p.usage_count >= min_usage)
.map(|(id, _)| id.clone())
.collect();
let count = candidates.len();
for id in candidates {
if let Some(pattern) = short.remove(&id) {
long.insert(id, pattern);
}
}
if count > 0 {
tracing::info!(promoted = count, "Auto-promoted patterns");
}
count
}
pub fn record_usage(&self, pattern_id: &str, success: bool) {
{
let mut short = self.short_term.write();
if let Some(pattern) = short.get_mut(pattern_id) {
pattern.usage_count += 1;
if success {
pattern.success_count += 1;
}
return;
}
}
let mut long = self.long_term.write();
if let Some(pattern) = long.get_mut(pattern_id) {
pattern.usage_count += 1;
if success {
pattern.success_count += 1;
}
}
}
pub fn counts(&self) -> (usize, usize) {
let short = self.short_term.read().len();
let long = self.long_term.read().len();
(short, long)
}
pub fn all_patterns(&self) -> Vec<GuidancePattern> {
let short = self.short_term.read();
let long = self.long_term.read();
short.values().chain(long.values()).cloned().collect()
}
pub fn load_patterns(&self, patterns: Vec<GuidancePattern>) {
let mut short = self.short_term.write();
let mut long = self.long_term.write();
for pattern in patterns {
if pattern.is_long_term {
long.insert(pattern.id.clone(), pattern);
} else {
short.insert(pattern.id.clone(), pattern);
}
}
}
}
fn domain_to_agent(domain: &str) -> &'static str {
match domain {
"security" => "security-auditor",
"testing" => "tester",
"performance" => "performance-engineer",
"debugging" => "researcher",
"architecture" => "system-architect",
"devops" => "devops",
"documentation" => "documenter",
_ => "coder",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::TfIdfEmbeddingProvider;
fn make_pattern(domain: &str, strategy: &str, quality: f32) -> GuidancePattern {
GuidancePattern {
id: Uuid::new_v4().to_string(),
strategy: strategy.to_string(),
domain: domain.to_string(),
embedding: None,
quality,
usage_count: 1,
success_count: 1,
created_at: Utc::now(),
is_long_term: false,
}
}
#[tokio::test]
async fn test_store_and_search() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let pattern = make_pattern("security", "Use parameterized queries to prevent SQL injection", 0.95);
let id = bank.store_pattern(pattern).await.unwrap();
let results = bank.search("SQL injection prevention", 10).await.unwrap();
assert!(!results.is_empty(), "Should find the security pattern");
assert_eq!(results[0].pattern.id, id);
}
#[tokio::test]
async fn test_route_task_security() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let result = bank.route_task("Fix the authentication security vulnerability").await.unwrap();
assert_eq!(result.agent, "security-auditor");
assert!(result.confidence > 0.0);
}
#[tokio::test]
async fn test_route_task_testing() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let result = bank.route_task("Write unit test coverage for the module").await.unwrap();
assert_eq!(result.agent, "tester");
}
#[tokio::test]
async fn test_route_task_default() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let result = bank.route_task("implement a new feature").await.unwrap();
assert_eq!(result.agent, "coder");
}
#[tokio::test]
async fn test_promote_pattern() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let pattern = make_pattern("testing", "Use property-based testing for parsers", 0.9);
let id = bank.store_pattern(pattern).await.unwrap();
let promoted = bank.promote(&id, 0.8).unwrap();
assert!(promoted, "High-quality pattern should be promoted");
let (short, long) = bank.counts();
assert_eq!(short, 0);
assert_eq!(long, 1);
}
#[tokio::test]
async fn test_promote_below_threshold() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let pattern = make_pattern("testing", "Try running tests twice", 0.3);
let id = bank.store_pattern(pattern).await.unwrap();
let promoted = bank.promote(&id, 0.8).unwrap();
assert!(!promoted, "Low-quality pattern should not be promoted");
let (short, long) = bank.counts();
assert_eq!(short, 1);
assert_eq!(long, 0);
}
#[tokio::test]
async fn test_auto_promote() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let mut p1 = make_pattern("security", "Scan dependencies for CVEs", 0.9);
p1.usage_count = 5;
bank.store_pattern(p1).await.unwrap();
let p2 = make_pattern("testing", "Run tests once", 0.3);
bank.store_pattern(p2).await.unwrap();
let count = bank.auto_promote(0.8, 3);
assert_eq!(count, 1);
let (short, long) = bank.counts();
assert_eq!(short, 1);
assert_eq!(long, 1);
}
#[tokio::test]
async fn test_record_usage() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
let pattern = make_pattern("security", "Use prepared statements", 0.8);
let id = bank.store_pattern(pattern).await.unwrap();
bank.record_usage(&id, true);
bank.record_usage(&id, false);
bank.record_usage(&id, true);
let patterns = bank.all_patterns();
let found = patterns.iter().find(|p| p.id == id).unwrap();
assert_eq!(found.usage_count, 4); assert_eq!(found.success_count, 3); }
#[tokio::test]
async fn test_search_by_domain() {
let bank = ReasoningBank::new(std::sync::Arc::new(TfIdfEmbeddingProvider));
bank.store_pattern(make_pattern("security", "Use HTTPS for all connections and encrypt data in transit", 0.9)).await.unwrap();
bank.store_pattern(make_pattern("testing", "Write integration tests for API endpoints", 0.85)).await.unwrap();
bank.store_pattern(make_pattern("security", "Validate all user inputs to prevent injection attacks", 0.88)).await.unwrap();
let all_results = bank.search("security", 10).await.unwrap();
let results = bank.search_by_domain("security", "security", 10).await.unwrap();
for m in &results {
assert_eq!(m.pattern.domain, "security");
}
}
#[test]
fn test_domain_to_agent() {
assert_eq!(domain_to_agent("security"), "security-auditor");
assert_eq!(domain_to_agent("testing"), "tester");
assert_eq!(domain_to_agent("performance"), "performance-engineer");
assert_eq!(domain_to_agent("unknown"), "coder");
}
#[test]
fn test_guidance_pattern_combined_score() {
let pattern = make_pattern("security", "test", 0.9);
let score = pattern.combined_score();
assert!(score > 0.0 && score <= 1.0);
}
#[test]
fn test_guidance_pattern_success_rate() {
let mut pattern = make_pattern("security", "test", 0.9);
pattern.usage_count = 10;
pattern.success_count = 8;
assert!((pattern.success_rate() - 0.8).abs() < 0.01);
}
}