1use std::collections::HashMap;
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use storage::VectorStorage;
14
15use crate::distance::calculate_distance;
16use common::DistanceMetric;
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
20pub struct RouteMatch {
21 pub namespace: String,
22 pub similarity: f32,
23 pub memory_count: usize,
24}
25
26pub struct SemanticRouterConfig {
28 pub sample_size: usize,
30 pub refresh_interval_secs: u64,
32}
33
34impl Default for SemanticRouterConfig {
35 fn default() -> Self {
36 Self {
37 sample_size: 20,
38 refresh_interval_secs: 1800, }
40 }
41}
42
43impl SemanticRouterConfig {
44 pub fn from_env() -> Self {
45 let sample_size: usize = std::env::var("DAKERA_ROUTE_SAMPLE_SIZE")
46 .ok()
47 .and_then(|v| v.parse().ok())
48 .unwrap_or(20);
49
50 let refresh_interval_secs: u64 = std::env::var("DAKERA_ROUTE_REFRESH_SECS")
51 .ok()
52 .and_then(|v| v.parse().ok())
53 .unwrap_or(1800);
54
55 Self {
56 sample_size,
57 refresh_interval_secs,
58 }
59 }
60}
61
62#[derive(Clone)]
64struct CentroidEntry {
65 centroid: Vec<f32>,
66 count: usize,
67}
68
69pub struct SemanticRouter {
71 config: SemanticRouterConfig,
72 cache: RwLock<HashMap<String, CentroidEntry>>,
74}
75
76impl SemanticRouter {
77 pub fn new(config: SemanticRouterConfig) -> Self {
78 Self {
79 config,
80 cache: RwLock::new(HashMap::new()),
81 }
82 }
83
84 pub fn route(&self, query: &[f32], top_k: usize, min_similarity: f32) -> Vec<RouteMatch> {
89 let cache = self.cache.read();
90 let mut matches: Vec<RouteMatch> = cache
91 .iter()
92 .filter_map(|(ns, entry)| {
93 if entry.centroid.len() != query.len() {
94 return None; }
96 let sim = calculate_distance(query, &entry.centroid, DistanceMetric::Cosine);
97 if sim >= min_similarity {
98 Some(RouteMatch {
99 namespace: ns.clone(),
100 similarity: sim,
101 memory_count: entry.count,
102 })
103 } else {
104 None
105 }
106 })
107 .collect();
108
109 matches.sort_by(|a, b| {
110 b.similarity
111 .partial_cmp(&a.similarity)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 matches.truncate(top_k);
115 matches
116 }
117
118 pub async fn refresh_centroids(&self, storage: &Arc<dyn VectorStorage>) {
123 let namespaces = match storage.list_namespaces().await {
124 Ok(ns) => ns,
125 Err(e) => {
126 tracing::warn!(error = %e, "Failed to list namespaces for centroid refresh");
127 return;
128 }
129 };
130
131 let mut new_cache: HashMap<String, CentroidEntry> = HashMap::new();
132
133 for namespace in &namespaces {
134 if !namespace.starts_with("_dakera_agent_") {
135 continue;
136 }
137
138 let vectors = match storage.get_all(namespace).await {
139 Ok(v) => v,
140 Err(_) => continue,
141 };
142
143 if vectors.is_empty() {
144 continue;
145 }
146
147 let count = vectors.len();
148
149 let sample: Vec<&Vec<f32>> = vectors
151 .iter()
152 .filter(|v| !v.values.is_empty())
153 .take(self.config.sample_size)
154 .map(|v| &v.values)
155 .collect();
156
157 if sample.is_empty() {
158 continue;
159 }
160
161 let dim = sample[0].len();
163 let mut centroid = vec![0.0f32; dim];
164 let mut valid = 0usize;
165 for embedding in &sample {
166 if embedding.len() == dim {
167 for (i, val) in embedding.iter().enumerate() {
168 centroid[i] += val;
169 }
170 valid += 1;
171 }
172 }
173
174 if valid > 0 {
175 for val in &mut centroid {
176 *val /= valid as f32;
177 }
178 let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
180 if norm > 1e-8 {
181 for val in &mut centroid {
182 *val /= norm;
183 }
184 }
185 new_cache.insert(namespace.clone(), CentroidEntry { centroid, count });
186 }
187 }
188
189 let refreshed_count = new_cache.len();
190 *self.cache.write() = new_cache;
191
192 tracing::info!(
193 namespaces_cached = refreshed_count,
194 "Semantic router centroid cache refreshed"
195 );
196 }
197
198 pub fn spawn_refresh(
200 router: Arc<SemanticRouter>,
201 storage: Arc<dyn VectorStorage>,
202 ) -> tokio::task::JoinHandle<()> {
203 let interval_secs = router.config.refresh_interval_secs;
204 tokio::spawn(async move {
205 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
207 router.refresh_centroids(&storage).await;
208
209 let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
210 loop {
211 interval.tick().await;
212 router.refresh_centroids(&storage).await;
213 }
214 })
215 }
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum QueryKind {
225 Keyword,
227 Semantic,
229 Hybrid,
231 Temporal,
240}
241
242pub struct QueryClassifier;
245
246impl QueryClassifier {
247 pub fn classify(query: &str) -> QueryKind {
260 let trimmed = query.trim();
261 let word_count = trimmed.split_whitespace().count();
262 let lower = trimmed.to_lowercase();
263
264 let is_temporal = lower.starts_with("when ")
269 || lower.starts_with("when did")
270 || lower.starts_with("when was")
271 || lower.starts_with("when were")
272 || lower.starts_with("when is")
273 || lower.contains("what year")
274 || lower.contains("what date")
275 || lower.contains("what time did")
276 || lower.contains("what time was")
277 || lower.contains("how long ago")
278 || lower.contains("how many years")
279 || lower.contains("how many months")
280 || lower.contains("how many days")
281 || lower.contains("since when")
282 || lower.contains("at what age")
283 || lower.contains("how old was")
284 || lower.contains("how old were");
285
286 if is_temporal {
287 return QueryKind::Temporal;
288 }
289
290 let is_question = trimmed.contains('?')
293 || lower.starts_with("what ")
294 || lower.starts_with("how ")
295 || lower.starts_with("why ")
296 || lower.starts_with("when ")
297 || lower.starts_with("where ")
298 || lower.starts_with("who ")
299 || lower.starts_with("tell me")
300 || lower.starts_with("explain")
301 || lower.starts_with("describe");
302
303 if is_question {
304 QueryKind::Hybrid
305 } else if word_count >= 8 || trimmed.contains('.') {
306 QueryKind::Semantic
307 } else if word_count <= 3 {
308 QueryKind::Keyword
309 } else {
310 QueryKind::Hybrid
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_route_empty_cache() {
321 let router = SemanticRouter::new(SemanticRouterConfig::default());
322 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.5);
323 assert!(results.is_empty());
324 }
325
326 #[test]
327 fn test_route_with_cached_centroids() {
328 let router = SemanticRouter::new(SemanticRouterConfig::default());
329
330 {
332 let mut cache = router.cache.write();
333 cache.insert(
334 "_dakera_agent_dev".to_string(),
335 CentroidEntry {
336 centroid: vec![1.0, 0.0, 0.0],
337 count: 100,
338 },
339 );
340 cache.insert(
341 "_dakera_agent_ops".to_string(),
342 CentroidEntry {
343 centroid: vec![0.0, 1.0, 0.0],
344 count: 50,
345 },
346 );
347 cache.insert(
348 "_dakera_agent_sec".to_string(),
349 CentroidEntry {
350 centroid: vec![0.707, 0.707, 0.0],
351 count: 30,
352 },
353 );
354 }
355
356 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
358 assert_eq!(results.len(), 3);
359 assert_eq!(results[0].namespace, "_dakera_agent_dev");
360 assert!(results[0].similarity > results[1].similarity);
361 }
362
363 #[test]
364 fn test_route_min_similarity_filter() {
365 let router = SemanticRouter::new(SemanticRouterConfig::default());
366
367 {
368 let mut cache = router.cache.write();
369 cache.insert(
370 "_dakera_agent_a".to_string(),
371 CentroidEntry {
372 centroid: vec![1.0, 0.0, 0.0],
373 count: 10,
374 },
375 );
376 cache.insert(
377 "_dakera_agent_b".to_string(),
378 CentroidEntry {
379 centroid: vec![0.0, 1.0, 0.0],
380 count: 10,
381 },
382 );
383 }
384
385 let results = router.route(&[1.0, 0.0, 0.0], 5, 0.9);
387 assert_eq!(results.len(), 1);
388 assert_eq!(results[0].namespace, "_dakera_agent_a");
389 }
390
391 #[test]
392 fn test_route_top_k_truncation() {
393 let router = SemanticRouter::new(SemanticRouterConfig::default());
394
395 {
396 let mut cache = router.cache.write();
397 for i in 0..10 {
398 let mut centroid = vec![0.0f32; 3];
399 centroid[0] = 1.0 - (i as f32 * 0.05);
400 centroid[1] = i as f32 * 0.05;
401 let norm = (centroid[0] * centroid[0] + centroid[1] * centroid[1]).sqrt();
402 centroid[0] /= norm;
403 centroid[1] /= norm;
404 cache.insert(
405 format!("_dakera_agent_{}", i),
406 CentroidEntry {
407 centroid,
408 count: 10,
409 },
410 );
411 }
412 }
413
414 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
415 assert_eq!(results.len(), 3);
416 }
417
418 #[test]
419 fn test_route_dimension_mismatch_skipped() {
420 let router = SemanticRouter::new(SemanticRouterConfig::default());
421
422 {
423 let mut cache = router.cache.write();
424 cache.insert(
425 "_dakera_agent_3d".to_string(),
426 CentroidEntry {
427 centroid: vec![1.0, 0.0, 0.0],
428 count: 10,
429 },
430 );
431 cache.insert(
432 "_dakera_agent_5d".to_string(),
433 CentroidEntry {
434 centroid: vec![1.0, 0.0, 0.0, 0.0, 0.0],
435 count: 10,
436 },
437 );
438 }
439
440 let results = router.route(&[1.0, 0.0, 0.0], 5, 0.0);
442 assert_eq!(results.len(), 1);
443 assert_eq!(results[0].namespace, "_dakera_agent_3d");
444 }
445
446 #[test]
447 fn test_config_defaults() {
448 let config = SemanticRouterConfig::default();
449 assert_eq!(config.sample_size, 20);
450 assert_eq!(config.refresh_interval_secs, 1800);
451 }
452
453 #[test]
456 fn test_classify_keyword_short() {
457 assert_eq!(QueryClassifier::classify("rust async"), QueryKind::Keyword);
458 assert_eq!(QueryClassifier::classify("HNSW"), QueryKind::Keyword);
459 assert_eq!(
460 QueryClassifier::classify("memory importance"),
461 QueryKind::Keyword
462 );
463 }
464
465 #[test]
466 fn test_classify_question_routes_hybrid() {
467 assert_eq!(
469 QueryClassifier::classify(
470 "what is the best way to store long term memories in an AI system"
471 ),
472 QueryKind::Hybrid
473 );
474 assert_eq!(
475 QueryClassifier::classify("tell me about the agent memory architecture"),
476 QueryKind::Hybrid
477 );
478 assert_eq!(
479 QueryClassifier::classify("how does HNSW work?"),
480 QueryKind::Hybrid
481 );
482 assert_eq!(
483 QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
484 QueryKind::Hybrid
485 );
486 }
487
488 #[test]
489 fn test_classify_semantic_long_prose() {
490 assert_eq!(
492 QueryClassifier::classify(
493 "the agent memory platform stores embeddings with adaptive decay weighting"
494 ),
495 QueryKind::Semantic
496 );
497 }
498
499 #[test]
500 fn test_classify_hybrid_middle() {
501 assert_eq!(
502 QueryClassifier::classify("vector search memory agent"),
503 QueryKind::Hybrid
504 );
505 }
506
507 #[test]
510 fn test_classify_temporal_when_prefix() {
511 assert_eq!(
513 QueryClassifier::classify("when did Caroline go to the store?"),
514 QueryKind::Temporal
515 );
516 assert_eq!(
517 QueryClassifier::classify("When was the last time they spoke?"),
518 QueryKind::Temporal
519 );
520 assert_eq!(
521 QueryClassifier::classify("When were the siblings born?"),
522 QueryKind::Temporal
523 );
524 }
525
526 #[test]
527 fn test_classify_temporal_date_year_patterns() {
528 assert_eq!(
529 QueryClassifier::classify("What year did they get married?"),
530 QueryKind::Temporal
531 );
532 assert_eq!(
533 QueryClassifier::classify("what date did the conference take place?"),
534 QueryKind::Temporal
535 );
536 assert_eq!(
537 QueryClassifier::classify("What time did the meeting start?"),
538 QueryKind::Temporal
539 );
540 assert_eq!(
541 QueryClassifier::classify("How long ago did this happen?"),
542 QueryKind::Temporal
543 );
544 assert_eq!(
545 QueryClassifier::classify("How many years have they been friends?"),
546 QueryKind::Temporal
547 );
548 assert_eq!(
549 QueryClassifier::classify("How old was Sarah when she graduated?"),
550 QueryKind::Temporal
551 );
552 }
553
554 #[test]
555 fn test_classify_temporal_does_not_capture_non_temporal_what() {
556 assert_eq!(
558 QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
559 QueryKind::Hybrid
560 );
561 assert_eq!(
562 QueryClassifier::classify("what is the best way to find old memories"),
563 QueryKind::Hybrid
564 );
565 }
566}