use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use storage::VectorStorage;
use crate::distance::calculate_distance;
use common::DistanceMetric;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RouteMatch {
pub namespace: String,
pub similarity: f32,
pub memory_count: usize,
}
pub struct SemanticRouterConfig {
pub sample_size: usize,
pub refresh_interval_secs: u64,
}
impl Default for SemanticRouterConfig {
fn default() -> Self {
Self {
sample_size: 20,
refresh_interval_secs: 1800, }
}
}
impl SemanticRouterConfig {
pub fn from_env() -> Self {
let sample_size: usize = std::env::var("DAKERA_ROUTE_SAMPLE_SIZE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(20);
let refresh_interval_secs: u64 = std::env::var("DAKERA_ROUTE_REFRESH_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1800);
Self {
sample_size,
refresh_interval_secs,
}
}
}
#[derive(Clone)]
struct CentroidEntry {
centroid: Vec<f32>,
count: usize,
}
pub struct SemanticRouter {
config: SemanticRouterConfig,
cache: RwLock<HashMap<String, CentroidEntry>>,
}
impl SemanticRouter {
pub fn new(config: SemanticRouterConfig) -> Self {
Self {
config,
cache: RwLock::new(HashMap::new()),
}
}
pub fn route(&self, query: &[f32], top_k: usize, min_similarity: f32) -> Vec<RouteMatch> {
let cache = self.cache.read();
let mut matches: Vec<RouteMatch> = cache
.iter()
.filter_map(|(ns, entry)| {
if entry.centroid.len() != query.len() {
return None; }
let sim = calculate_distance(query, &entry.centroid, DistanceMetric::Cosine);
if sim >= min_similarity {
Some(RouteMatch {
namespace: ns.clone(),
similarity: sim,
memory_count: entry.count,
})
} else {
None
}
})
.collect();
matches.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
matches.truncate(top_k);
matches
}
pub async fn refresh_centroids(&self, storage: &Arc<dyn VectorStorage>) {
let namespaces = match storage.list_namespaces().await {
Ok(ns) => ns,
Err(e) => {
tracing::warn!(error = %e, "Failed to list namespaces for centroid refresh");
return;
}
};
let mut new_cache: HashMap<String, CentroidEntry> = HashMap::new();
for namespace in &namespaces {
if !namespace.starts_with("_dakera_agent_") {
continue;
}
let vectors = match storage.get_all(namespace).await {
Ok(v) => v,
Err(_) => continue,
};
if vectors.is_empty() {
continue;
}
let count = vectors.len();
let sample: Vec<&Vec<f32>> = vectors
.iter()
.filter(|v| !v.values.is_empty())
.take(self.config.sample_size)
.map(|v| &v.values)
.collect();
if sample.is_empty() {
continue;
}
let dim = sample[0].len();
let mut centroid = vec![0.0f32; dim];
let mut valid = 0usize;
for embedding in &sample {
if embedding.len() == dim {
for (i, val) in embedding.iter().enumerate() {
centroid[i] += val;
}
valid += 1;
}
}
if valid > 0 {
for val in &mut centroid {
*val /= valid as f32;
}
let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for val in &mut centroid {
*val /= norm;
}
}
new_cache.insert(namespace.clone(), CentroidEntry { centroid, count });
}
}
let refreshed_count = new_cache.len();
*self.cache.write() = new_cache;
tracing::info!(
namespaces_cached = refreshed_count,
"Semantic router centroid cache refreshed"
);
}
pub fn spawn_refresh(
router: Arc<SemanticRouter>,
storage: Arc<dyn VectorStorage>,
) -> tokio::task::JoinHandle<()> {
let interval_secs = router.config.refresh_interval_secs;
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
router.refresh_centroids(&storage).await;
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
router.refresh_centroids(&storage).await;
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryKind {
Keyword,
Semantic,
Hybrid,
}
pub struct QueryClassifier;
impl QueryClassifier {
pub fn classify(query: &str) -> QueryKind {
let trimmed = query.trim();
let word_count = trimmed.split_whitespace().count();
let has_sentence_structure = trimmed.contains('?') || trimmed.contains('.') || {
let lower = trimmed.to_lowercase();
lower.starts_with("what ")
|| lower.starts_with("how ")
|| lower.starts_with("why ")
|| lower.starts_with("when ")
|| lower.starts_with("where ")
|| lower.starts_with("who ")
|| lower.starts_with("tell me")
|| lower.starts_with("explain")
|| lower.starts_with("describe")
};
if word_count >= 8 || has_sentence_structure {
QueryKind::Semantic
} else if word_count <= 3 && !has_sentence_structure {
QueryKind::Keyword
} else {
QueryKind::Hybrid
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_route_empty_cache() {
let router = SemanticRouter::new(SemanticRouterConfig::default());
let results = router.route(&[1.0, 0.0, 0.0], 3, 0.5);
assert!(results.is_empty());
}
#[test]
fn test_route_with_cached_centroids() {
let router = SemanticRouter::new(SemanticRouterConfig::default());
{
let mut cache = router.cache.write();
cache.insert(
"_dakera_agent_dev".to_string(),
CentroidEntry {
centroid: vec![1.0, 0.0, 0.0],
count: 100,
},
);
cache.insert(
"_dakera_agent_ops".to_string(),
CentroidEntry {
centroid: vec![0.0, 1.0, 0.0],
count: 50,
},
);
cache.insert(
"_dakera_agent_sec".to_string(),
CentroidEntry {
centroid: vec![0.707, 0.707, 0.0],
count: 30,
},
);
}
let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
assert_eq!(results.len(), 3);
assert_eq!(results[0].namespace, "_dakera_agent_dev");
assert!(results[0].similarity > results[1].similarity);
}
#[test]
fn test_route_min_similarity_filter() {
let router = SemanticRouter::new(SemanticRouterConfig::default());
{
let mut cache = router.cache.write();
cache.insert(
"_dakera_agent_a".to_string(),
CentroidEntry {
centroid: vec![1.0, 0.0, 0.0],
count: 10,
},
);
cache.insert(
"_dakera_agent_b".to_string(),
CentroidEntry {
centroid: vec![0.0, 1.0, 0.0],
count: 10,
},
);
}
let results = router.route(&[1.0, 0.0, 0.0], 5, 0.9);
assert_eq!(results.len(), 1);
assert_eq!(results[0].namespace, "_dakera_agent_a");
}
#[test]
fn test_route_top_k_truncation() {
let router = SemanticRouter::new(SemanticRouterConfig::default());
{
let mut cache = router.cache.write();
for i in 0..10 {
let mut centroid = vec![0.0f32; 3];
centroid[0] = 1.0 - (i as f32 * 0.05);
centroid[1] = i as f32 * 0.05;
let norm = (centroid[0] * centroid[0] + centroid[1] * centroid[1]).sqrt();
centroid[0] /= norm;
centroid[1] /= norm;
cache.insert(
format!("_dakera_agent_{}", i),
CentroidEntry {
centroid,
count: 10,
},
);
}
}
let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
assert_eq!(results.len(), 3);
}
#[test]
fn test_route_dimension_mismatch_skipped() {
let router = SemanticRouter::new(SemanticRouterConfig::default());
{
let mut cache = router.cache.write();
cache.insert(
"_dakera_agent_3d".to_string(),
CentroidEntry {
centroid: vec![1.0, 0.0, 0.0],
count: 10,
},
);
cache.insert(
"_dakera_agent_5d".to_string(),
CentroidEntry {
centroid: vec![1.0, 0.0, 0.0, 0.0, 0.0],
count: 10,
},
);
}
let results = router.route(&[1.0, 0.0, 0.0], 5, 0.0);
assert_eq!(results.len(), 1);
assert_eq!(results[0].namespace, "_dakera_agent_3d");
}
#[test]
fn test_config_defaults() {
let config = SemanticRouterConfig::default();
assert_eq!(config.sample_size, 20);
assert_eq!(config.refresh_interval_secs, 1800);
}
#[test]
fn test_classify_keyword_short() {
assert_eq!(QueryClassifier::classify("rust async"), QueryKind::Keyword);
assert_eq!(QueryClassifier::classify("HNSW"), QueryKind::Keyword);
assert_eq!(
QueryClassifier::classify("memory importance"),
QueryKind::Keyword
);
}
#[test]
fn test_classify_semantic_long() {
assert_eq!(
QueryClassifier::classify(
"what is the best way to store long term memories in an AI system"
),
QueryKind::Semantic
);
assert_eq!(
QueryClassifier::classify("tell me about the agent memory architecture"),
QueryKind::Semantic
);
}
#[test]
fn test_classify_semantic_question_mark() {
assert_eq!(
QueryClassifier::classify("how does HNSW work?"),
QueryKind::Semantic
);
}
#[test]
fn test_classify_hybrid_middle() {
assert_eq!(
QueryClassifier::classify("vector search memory agent"),
QueryKind::Hybrid
);
}
}