use super::{AgentType, ClaudeFlowTask, RoutingDecision};
use crate::error::{Result, RuvLLMError};
use crate::sona::{SonaIntegration, Trajectory};
use dashmap::DashMap;
use parking_lot::RwLock;
use ruvector_core::index::hnsw::HnswIndex;
use ruvector_core::index::VectorIndex;
use ruvector_core::types::{DistanceMetric, HnswConfig, SearchResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswRouterConfig {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub max_patterns: usize,
pub distance_metric: HnswDistanceMetric,
pub embedding_dim: usize,
pub min_confidence: f32,
pub top_k: usize,
pub success_rate_decay: f32,
pub min_usage_for_trust: u32,
pub enable_online_learning: bool,
}
impl Default for HnswRouterConfig {
fn default() -> Self {
Self {
m: 32,
ef_construction: 200,
ef_search: 100,
max_patterns: 100_000,
distance_metric: HnswDistanceMetric::Cosine,
embedding_dim: 384,
min_confidence: 0.5,
top_k: 10,
success_rate_decay: 0.01,
min_usage_for_trust: 5,
enable_online_learning: true,
}
}
}
impl HnswRouterConfig {
pub fn high_recall() -> Self {
Self {
m: 48,
ef_construction: 400,
ef_search: 200,
top_k: 20,
..Default::default()
}
}
pub fn fast() -> Self {
Self {
m: 16,
ef_construction: 100,
ef_search: 50,
top_k: 5,
..Default::default()
}
}
pub fn for_small_model() -> Self {
Self {
embedding_dim: 384,
..Default::default()
}
}
pub fn for_large_model() -> Self {
Self {
embedding_dim: 768,
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HnswDistanceMetric {
Cosine,
Euclidean,
DotProduct,
}
impl From<HnswDistanceMetric> for DistanceMetric {
#[inline]
fn from(metric: HnswDistanceMetric) -> Self {
match metric {
HnswDistanceMetric::Cosine => DistanceMetric::Cosine,
HnswDistanceMetric::Euclidean => DistanceMetric::Euclidean,
HnswDistanceMetric::DotProduct => DistanceMetric::DotProduct,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskPattern {
pub id: String,
pub embedding: Vec<f32>,
pub agent_type: AgentType,
pub task_type: ClaudeFlowTask,
pub success_rate: f32,
pub usage_count: u32,
pub success_count: u32,
pub task_description: String,
pub created_at: i64,
pub last_used_at: i64,
pub metadata: HashMap<String, String>,
}
impl TaskPattern {
pub fn new(
embedding: Vec<f32>,
agent_type: AgentType,
task_type: ClaudeFlowTask,
task_description: String,
) -> Self {
let now = chrono::Utc::now().timestamp();
Self {
id: uuid::Uuid::new_v4().to_string(),
embedding,
agent_type,
task_type,
success_rate: 0.5, usage_count: 0,
success_count: 0,
task_description,
created_at: now,
last_used_at: now,
metadata: HashMap::new(),
}
}
#[inline]
pub fn update_success(&mut self, success: bool, decay: f32) {
self.usage_count += 1;
if success {
self.success_count += 1;
}
self.last_used_at = chrono::Utc::now().timestamp();
let outcome = if success { 1.0 } else { 0.0 };
self.success_rate = (1.0 - decay) * self.success_rate + decay * outcome;
}
#[inline]
pub fn confidence(&self, min_usage: u32) -> f32 {
if self.usage_count < min_usage {
0.5 * (self.usage_count as f32 / min_usage as f32)
} else {
self.success_rate
}
}
#[inline]
pub fn is_stale(&self, max_age_secs: i64) -> bool {
let now = chrono::Utc::now().timestamp();
now - self.last_used_at > max_age_secs
}
}
#[derive(Debug, Clone)]
pub struct HnswRoutingResult {
pub primary_agent: AgentType,
pub confidence: f32,
pub task_type: ClaudeFlowTask,
pub patterns_considered: usize,
pub alternatives: Vec<(AgentType, f32)>,
pub neighbor_distances: Vec<f32>,
pub search_latency_us: u64,
pub reasoning: String,
}
impl From<HnswRoutingResult> for RoutingDecision {
fn from(result: HnswRoutingResult) -> Self {
RoutingDecision {
primary_agent: result.primary_agent,
confidence: result.confidence,
alternatives: result.alternatives,
task_type: result.task_type,
reasoning: result.reasoning,
learned_patterns: result.patterns_considered,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HnswRouterState {
config: HnswRouterConfig,
patterns: Vec<TaskPattern>,
total_queries: u64,
total_hits: u64,
}
pub struct HnswRouter {
config: HnswRouterConfig,
index: Arc<RwLock<HnswIndex>>,
patterns: DashMap<String, TaskPattern>,
index_to_pattern: DashMap<String, String>,
total_queries: AtomicU64,
total_hits: AtomicU64,
total_patterns_added: AtomicU64,
sona: Option<Arc<RwLock<SonaIntegration>>>,
}
impl HnswRouter {
pub fn new(config: HnswRouterConfig) -> Result<Self> {
let hnsw_config = HnswConfig {
m: config.m,
ef_construction: config.ef_construction,
ef_search: config.ef_search,
max_elements: config.max_patterns,
};
let index = HnswIndex::new(
config.embedding_dim,
config.distance_metric.into(),
hnsw_config,
)
.map_err(|e| RuvLLMError::Ruvector(e.to_string()))?;
Ok(Self {
config,
index: Arc::new(RwLock::new(index)),
patterns: DashMap::new(),
index_to_pattern: DashMap::new(),
total_queries: AtomicU64::new(0),
total_hits: AtomicU64::new(0),
total_patterns_added: AtomicU64::new(0),
sona: None,
})
}
pub fn config(&self) -> &HnswRouterConfig {
&self.config
}
pub fn with_sona(config: HnswRouterConfig, sona: Arc<RwLock<SonaIntegration>>) -> Result<Self> {
let mut router = Self::new(config)?;
router.sona = Some(sona);
Ok(router)
}
pub fn add_pattern(&self, pattern: TaskPattern) -> Result<()> {
if pattern.embedding.len() != self.config.embedding_dim {
return Err(RuvLLMError::Config(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.config.embedding_dim,
pattern.embedding.len()
)));
}
let embedding = self.normalize_embedding(&pattern.embedding);
{
let mut index = self.index.write();
index
.add(pattern.id.clone(), embedding)
.map_err(|e| RuvLLMError::Ruvector(e.to_string()))?;
}
self.index_to_pattern
.insert(pattern.id.clone(), pattern.id.clone());
self.patterns.insert(pattern.id.clone(), pattern);
self.total_patterns_added.fetch_add(1, Ordering::SeqCst);
Ok(())
}
pub fn add_patterns(&self, patterns: Vec<TaskPattern>) -> Result<usize> {
let mut added = 0;
let mut entries = Vec::with_capacity(patterns.len());
for pattern in patterns {
if pattern.embedding.len() != self.config.embedding_dim {
continue; }
let embedding = self.normalize_embedding(&pattern.embedding);
entries.push((pattern.id.clone(), embedding));
self.index_to_pattern
.insert(pattern.id.clone(), pattern.id.clone());
self.patterns.insert(pattern.id.clone(), pattern);
added += 1;
}
if !entries.is_empty() {
let mut index = self.index.write();
index
.add_batch(entries)
.map_err(|e| RuvLLMError::Ruvector(e.to_string()))?;
}
self.total_patterns_added
.fetch_add(added as u64, Ordering::SeqCst);
Ok(added)
}
pub fn search_similar(&self, query: &[f32], k: usize) -> Result<Vec<(TaskPattern, f32)>> {
let start = std::time::Instant::now();
if query.len() != self.config.embedding_dim {
return Err(RuvLLMError::Config(format!(
"Query dimension mismatch: expected {}, got {}",
self.config.embedding_dim,
query.len()
)));
}
let normalized_query = self.normalize_embedding(query);
let results: Vec<SearchResult> = {
let index = self.index.read();
index
.search(&normalized_query, k)
.map_err(|e| RuvLLMError::Ruvector(e.to_string()))?
};
self.total_queries.fetch_add(1, Ordering::SeqCst);
let mut pattern_results = Vec::with_capacity(results.len());
for result in results {
if let Some(pattern) = self.patterns.get(&result.id) {
let similarity: f32 = 1.0 - result.score.max(0.0_f32).min(2.0_f32);
pattern_results.push((pattern.clone(), similarity));
}
}
if !pattern_results.is_empty() {
self.total_hits.fetch_add(1, Ordering::SeqCst);
}
let _latency = start.elapsed();
Ok(pattern_results)
}
pub fn route_by_similarity(&self, query_embedding: &[f32]) -> Result<HnswRoutingResult> {
let start = std::time::Instant::now();
let similar_patterns = self.search_similar(query_embedding, self.config.top_k)?;
if similar_patterns.is_empty() {
return Ok(HnswRoutingResult {
primary_agent: AgentType::Coder, confidence: self.config.min_confidence,
task_type: ClaudeFlowTask::CodeGeneration,
patterns_considered: 0,
alternatives: Vec::new(),
neighbor_distances: Vec::new(),
search_latency_us: start.elapsed().as_micros() as u64,
reasoning: "No similar patterns found, using default".to_string(),
});
}
let patterns_len = similar_patterns.len();
let mut agent_scores: HashMap<AgentType, f32> = HashMap::with_capacity(8);
let mut task_type_scores: HashMap<ClaudeFlowTask, f32> = HashMap::with_capacity(8);
let mut neighbor_distances = Vec::with_capacity(patterns_len);
let min_usage = self.config.min_usage_for_trust;
for (pattern, similarity) in &similar_patterns {
let pattern_confidence = pattern.confidence(min_usage);
let weight = similarity * pattern_confidence;
*agent_scores.entry(pattern.agent_type).or_insert(0.0) += weight;
*task_type_scores.entry(pattern.task_type).or_insert(0.0) += weight;
neighbor_distances.push(*similarity);
}
let (primary_agent, primary_score) = agent_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(a, s)| (*a, *s))
.unwrap_or((AgentType::Coder, 0.0));
let total_score: f32 = agent_scores.values().sum();
let confidence = if total_score > 0.0 {
(primary_score / total_score).min(0.99)
} else {
self.config.min_confidence
};
let mut alternatives: Vec<(AgentType, f32)> = agent_scores
.into_iter()
.filter(|(a, _)| *a != primary_agent)
.map(|(a, s)| (a, s / total_score.max(0.01)))
.collect();
alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
alternatives.truncate(3);
let task_type = task_type_scores
.into_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(t, _)| t)
.unwrap_or(ClaudeFlowTask::CodeGeneration);
let latency_us = start.elapsed().as_micros() as u64;
Ok(HnswRoutingResult {
primary_agent,
confidence,
task_type,
patterns_considered: similar_patterns.len(),
alternatives,
neighbor_distances,
search_latency_us: latency_us,
reasoning: format!(
"HNSW semantic match: {} patterns, confidence {:.2}, latency {}us",
similar_patterns.len(),
confidence,
latency_us
),
})
}
pub fn update_success_rate(&self, pattern_id: &str, success: bool) -> Result<bool> {
if let Some(mut pattern) = self.patterns.get_mut(pattern_id) {
pattern.update_success(success, self.config.success_rate_decay);
if let Some(sona) = &self.sona {
let trajectory = Trajectory {
request_id: uuid::Uuid::new_v4().to_string(),
session_id: "hnsw-router".to_string(),
query_embedding: pattern.embedding.clone(),
response_embedding: pattern.embedding.clone(),
quality_score: if success { 0.9 } else { 0.3 },
routing_features: vec![
pattern.agent_type as u8 as f32 / 10.0,
pattern.success_rate,
],
model_index: pattern.agent_type as usize,
timestamp: chrono::Utc::now(),
};
let sona_guard = sona.read();
let _ = sona_guard.record_trajectory(trajectory);
}
Ok(true)
} else {
Ok(false)
}
}
pub fn update_nearest_success(&self, query_embedding: &[f32], success: bool) -> Result<bool> {
let similar = self.search_similar(query_embedding, 1)?;
if let Some((pattern, similarity)) = similar.first() {
if *similarity > 0.8 {
return self.update_success_rate(&pattern.id, success);
}
}
Ok(false)
}
pub fn learn_pattern(
&self,
embedding: Vec<f32>,
agent_type: AgentType,
task_type: ClaudeFlowTask,
task_description: String,
success: bool,
) -> Result<Option<String>> {
if !self.config.enable_online_learning {
return Ok(None);
}
let similar = self.search_similar(&embedding, 1)?;
if let Some((existing, similarity)) = similar.first() {
if *similarity > 0.95 {
self.update_success_rate(&existing.id, success)?;
return Ok(Some(existing.id.clone()));
}
}
let mut pattern = TaskPattern::new(embedding, agent_type, task_type, task_description);
if success {
pattern.success_count = 1;
pattern.usage_count = 1;
pattern.success_rate = 0.75; } else {
pattern.usage_count = 1;
pattern.success_rate = 0.25;
}
let pattern_id = pattern.id.clone();
self.add_pattern(pattern)?;
Ok(Some(pattern_id))
}
pub fn remove_pattern(&self, pattern_id: &str) -> Result<bool> {
if self.patterns.remove(pattern_id).is_some() {
self.index_to_pattern.remove(pattern_id);
let mut index = self.index.write();
let _ = index.remove(&pattern_id.to_string());
Ok(true)
} else {
Ok(false)
}
}
pub fn prune_patterns(
&self,
min_success_rate: f32,
min_usage: u32,
max_age_secs: i64,
) -> Result<usize> {
let mut to_remove = Vec::new();
for entry in self.patterns.iter() {
let pattern = entry.value();
let should_remove = (pattern.usage_count >= min_usage
&& pattern.success_rate < min_success_rate)
|| (pattern.is_stale(max_age_secs) && pattern.usage_count == 0);
if should_remove {
to_remove.push(entry.key().clone());
}
}
let removed_count = to_remove.len();
for id in to_remove {
self.remove_pattern(&id)?;
}
Ok(removed_count)
}
pub fn consolidate_patterns(&self, similarity_threshold: f32) -> Result<usize> {
let mut consolidated = 0;
let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
let pattern_ids: Vec<String> = self.patterns.iter().map(|e| e.key().clone()).collect();
for id in pattern_ids {
if processed.contains(&id) {
continue;
}
if let Some(pattern) = self.patterns.get(&id) {
let similar = self.search_similar(&pattern.embedding, 5)?;
for (other, similarity) in similar {
if other.id != id
&& similarity > similarity_threshold
&& !processed.contains(&other.id)
&& other.agent_type == pattern.agent_type
{
if other.usage_count > pattern.usage_count {
if let Some(mut other_mut) = self.patterns.get_mut(&other.id) {
other_mut.usage_count += pattern.usage_count;
other_mut.success_count += pattern.success_count;
if other_mut.usage_count > 0 {
other_mut.success_rate = other_mut.success_count as f32
/ other_mut.usage_count as f32;
}
}
processed.insert(id.clone());
self.remove_pattern(&id)?;
consolidated += 1;
break;
} else {
if let Some(mut current) = self.patterns.get_mut(&id) {
current.usage_count += other.usage_count;
current.success_count += other.success_count;
if current.usage_count > 0 {
current.success_rate =
current.success_count as f32 / current.usage_count as f32;
}
}
processed.insert(other.id.clone());
self.remove_pattern(&other.id)?;
consolidated += 1;
}
}
}
}
processed.insert(id);
}
Ok(consolidated)
}
pub fn stats(&self) -> HnswRouterStats {
HnswRouterStats {
total_patterns: self.patterns.len(),
total_queries: self.total_queries.load(Ordering::SeqCst),
total_hits: self.total_hits.load(Ordering::SeqCst),
hit_rate: {
let queries = self.total_queries.load(Ordering::SeqCst);
let hits = self.total_hits.load(Ordering::SeqCst);
if queries > 0 {
hits as f32 / queries as f32
} else {
0.0
}
},
patterns_by_agent: self.count_patterns_by_agent(),
avg_success_rate: self.calculate_avg_success_rate(),
config: self.config.clone(),
}
}
pub fn get_all_patterns(&self) -> Vec<TaskPattern> {
self.patterns
.iter()
.map(|entry| entry.value().clone())
.collect()
}
pub fn get_pattern(&self, id: &str) -> Option<TaskPattern> {
self.patterns.get(id).map(|p| p.clone())
}
pub fn serialize(&self) -> Result<Vec<u8>> {
let patterns: Vec<TaskPattern> = self.get_all_patterns();
let state = HnswRouterState {
config: self.config.clone(),
patterns,
total_queries: self.total_queries.load(Ordering::SeqCst),
total_hits: self.total_hits.load(Ordering::SeqCst),
};
bincode::serde::encode_to_vec(&state, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))
}
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
let (state, _): (HnswRouterState, usize) =
bincode::serde::decode_from_slice(bytes, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
let mut router = Self::new(state.config)?;
router.add_patterns(state.patterns)?;
router
.total_queries
.store(state.total_queries, Ordering::SeqCst);
router.total_hits.store(state.total_hits, Ordering::SeqCst);
Ok(router)
}
#[inline]
fn normalize_embedding(&self, embedding: &[f32]) -> Vec<f32> {
if self.config.distance_metric != HnswDistanceMetric::Cosine {
return embedding.to_vec();
}
let mut norm_sq: f32 = 0.0;
for &x in embedding {
norm_sq += x * x;
}
let norm = norm_sq.sqrt();
if norm > 1e-8 {
let inv_norm = 1.0 / norm;
embedding.iter().map(|&x| x * inv_norm).collect()
} else {
embedding.to_vec()
}
}
#[inline]
fn count_patterns_by_agent(&self) -> HashMap<AgentType, usize> {
let mut counts = HashMap::with_capacity(16); for entry in self.patterns.iter() {
*counts.entry(entry.value().agent_type).or_insert(0) += 1;
}
counts
}
#[inline]
fn calculate_avg_success_rate(&self) -> f32 {
let mut total = 0.0;
let mut count = 0;
for entry in self.patterns.iter() {
if entry.value().usage_count >= self.config.min_usage_for_trust {
total += entry.value().success_rate;
count += 1;
}
}
if count > 0 {
total / count as f32
} else {
0.0
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswRouterStats {
pub total_patterns: usize,
pub total_queries: u64,
pub total_hits: u64,
pub hit_rate: f32,
pub patterns_by_agent: HashMap<AgentType, usize>,
pub avg_success_rate: f32,
pub config: HnswRouterConfig,
}
pub struct HybridRouter {
hnsw: HnswRouter,
keyword_weight: f32,
min_hnsw_confidence: f32,
}
impl HybridRouter {
pub fn new(config: HnswRouterConfig) -> Result<Self> {
Ok(Self {
hnsw: HnswRouter::new(config)?,
keyword_weight: 0.3,
min_hnsw_confidence: 0.6,
})
}
pub fn route(
&self,
task_description: &str,
embedding: &[f32],
keyword_decision: Option<RoutingDecision>,
) -> Result<RoutingDecision> {
let hnsw_result = self.hnsw.route_by_similarity(embedding)?;
let keyword = match keyword_decision {
Some(kw) => kw,
None => return Ok(hnsw_result.into()),
};
if hnsw_result.confidence > self.min_hnsw_confidence && hnsw_result.patterns_considered >= 3
{
return Ok(hnsw_result.into());
}
let hnsw_weight = 1.0 - self.keyword_weight;
if hnsw_result.primary_agent == keyword.primary_agent {
return Ok(RoutingDecision {
primary_agent: hnsw_result.primary_agent,
confidence: (hnsw_result.confidence * hnsw_weight
+ keyword.confidence * self.keyword_weight)
.min(0.99),
task_type: hnsw_result.task_type,
alternatives: hnsw_result.alternatives,
reasoning: format!(
"Hybrid: keyword + HNSW agree on {:?}",
hnsw_result.primary_agent
),
learned_patterns: hnsw_result.patterns_considered,
});
}
let hnsw_score = hnsw_result.confidence * hnsw_weight;
let keyword_score = keyword.confidence * self.keyword_weight;
if hnsw_score > keyword_score {
Ok(hnsw_result.into())
} else {
Ok(keyword)
}
}
pub fn hnsw(&self) -> &HnswRouter {
&self.hnsw
}
pub fn set_keyword_weight(&mut self, weight: f32) {
self.keyword_weight = weight.clamp(0.0, 1.0);
}
pub fn set_min_hnsw_confidence(&mut self, confidence: f32) {
self.min_hnsw_confidence = confidence.clamp(0.0, 1.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_embedding(seed: usize, dim: usize) -> Vec<f32> {
(0..dim)
.map(|i| ((i + seed) as f32 / dim as f32).sin())
.collect()
}
#[test]
fn test_hnsw_router_creation() {
let config = HnswRouterConfig::default();
let router = HnswRouter::new(config).unwrap();
let stats = router.stats();
assert_eq!(stats.total_patterns, 0);
assert_eq!(stats.total_queries, 0);
}
#[test]
fn test_add_and_search_pattern() {
let config = HnswRouterConfig {
embedding_dim: 128,
..Default::default()
};
let router = HnswRouter::new(config).unwrap();
let embedding = create_test_embedding(42, 128);
let pattern = TaskPattern::new(
embedding.clone(),
AgentType::Coder,
ClaudeFlowTask::CodeGeneration,
"implement a function".to_string(),
);
router.add_pattern(pattern).unwrap();
let results = router.search_similar(&embedding, 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0.agent_type, AgentType::Coder);
assert!(results[0].1 > 0.99); }
#[test]
fn test_route_by_similarity() {
let config = HnswRouterConfig {
embedding_dim: 128,
min_usage_for_trust: 1,
..Default::default()
};
let router = HnswRouter::new(config).unwrap();
for i in 0..10 {
let embedding = create_test_embedding(i * 100, 128);
let agent_type = if i < 5 {
AgentType::Coder
} else {
AgentType::Tester
};
let task_type = if i < 5 {
ClaudeFlowTask::CodeGeneration
} else {
ClaudeFlowTask::Testing
};
let mut pattern =
TaskPattern::new(embedding, agent_type, task_type, format!("task {}", i));
pattern.usage_count = 10;
pattern.success_count = 8;
pattern.success_rate = 0.8;
router.add_pattern(pattern).unwrap();
}
let query = create_test_embedding(150, 128); let result = router.route_by_similarity(&query).unwrap();
assert!(result.confidence > 0.0);
assert!(result.search_latency_us < 10_000); }
#[test]
fn test_update_success_rate() {
let config = HnswRouterConfig {
embedding_dim: 128,
success_rate_decay: 0.1,
..Default::default()
};
let router = HnswRouter::new(config).unwrap();
let embedding = create_test_embedding(42, 128);
let pattern = TaskPattern::new(
embedding,
AgentType::Coder,
ClaudeFlowTask::CodeGeneration,
"test task".to_string(),
);
let pattern_id = pattern.id.clone();
router.add_pattern(pattern).unwrap();
router.update_success_rate(&pattern_id, true).unwrap();
router.update_success_rate(&pattern_id, true).unwrap();
router.update_success_rate(&pattern_id, false).unwrap();
let updated_pattern = router.get_pattern(&pattern_id).unwrap();
assert_eq!(updated_pattern.usage_count, 3);
assert_eq!(updated_pattern.success_count, 2);
}
#[test]
fn test_learn_pattern() {
let config = HnswRouterConfig {
embedding_dim: 128,
enable_online_learning: true,
..Default::default()
};
let router = HnswRouter::new(config).unwrap();
let embedding = create_test_embedding(42, 128);
let pattern_id = router
.learn_pattern(
embedding.clone(),
AgentType::Researcher,
ClaudeFlowTask::Research,
"research best practices".to_string(),
true,
)
.unwrap();
assert!(pattern_id.is_some());
let stats = router.stats();
assert_eq!(stats.total_patterns, 1);
assert_eq!(
*stats.patterns_by_agent.get(&AgentType::Researcher).unwrap(),
1
);
}
#[test]
fn test_prune_patterns() {
let config = HnswRouterConfig {
embedding_dim: 128,
..Default::default()
};
let router = HnswRouter::new(config).unwrap();
let embedding = create_test_embedding(42, 128);
let mut pattern = TaskPattern::new(
embedding,
AgentType::Coder,
ClaudeFlowTask::CodeGeneration,
"bad task".to_string(),
);
pattern.usage_count = 100;
pattern.success_count = 10;
pattern.success_rate = 0.1;
router.add_pattern(pattern).unwrap();
let embedding2 = create_test_embedding(100, 128);
let mut pattern2 = TaskPattern::new(
embedding2,
AgentType::Coder,
ClaudeFlowTask::CodeGeneration,
"good task".to_string(),
);
pattern2.usage_count = 100;
pattern2.success_count = 90;
pattern2.success_rate = 0.9;
router.add_pattern(pattern2).unwrap();
let pruned = router.prune_patterns(0.3, 50, 86400).unwrap();
assert_eq!(pruned, 1);
assert_eq!(router.stats().total_patterns, 1);
}
#[test]
fn test_serialization() {
let config = HnswRouterConfig {
embedding_dim: 128,
..Default::default()
};
let router = HnswRouter::new(config).unwrap();
for i in 0..5 {
let embedding = create_test_embedding(i * 10, 128);
let pattern = TaskPattern::new(
embedding,
AgentType::Coder,
ClaudeFlowTask::CodeGeneration,
format!("task {}", i),
);
router.add_pattern(pattern).unwrap();
}
let bytes = router.serialize().unwrap();
let restored = HnswRouter::deserialize(&bytes).unwrap();
assert_eq!(restored.stats().total_patterns, 5);
}
#[test]
fn test_config_presets() {
let fast = HnswRouterConfig::fast();
assert_eq!(fast.m, 16);
assert_eq!(fast.ef_search, 50);
let high_recall = HnswRouterConfig::high_recall();
assert_eq!(high_recall.m, 48);
assert_eq!(high_recall.ef_search, 200);
}
#[test]
fn test_hybrid_router() {
let config = HnswRouterConfig {
embedding_dim: 128,
..Default::default()
};
let mut router = HybridRouter::new(config).unwrap();
for i in 0..5 {
let embedding = create_test_embedding(i * 10, 128);
let pattern = TaskPattern::new(
embedding,
AgentType::Coder,
ClaudeFlowTask::CodeGeneration,
format!("coding task {}", i),
);
router.hnsw.add_pattern(pattern).unwrap();
}
let query = create_test_embedding(25, 128);
let keyword_decision = RoutingDecision {
primary_agent: AgentType::Coder,
confidence: 0.8,
alternatives: vec![],
task_type: ClaudeFlowTask::CodeGeneration,
reasoning: "keyword match".to_string(),
learned_patterns: 0,
};
let result = router
.route("implement a function", &query, Some(keyword_decision))
.unwrap();
assert_eq!(result.primary_agent, AgentType::Coder);
router.set_keyword_weight(0.9);
router.set_min_hnsw_confidence(0.9);
}
}