1use std::collections::HashMap;
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use storage::VectorStorage;
14
15use crate::distance::calculate_distance;
16use common::DistanceMetric;
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
20pub struct RouteMatch {
21 pub namespace: String,
22 pub similarity: f32,
23 pub memory_count: usize,
24}
25
26pub struct SemanticRouterConfig {
28 pub sample_size: usize,
30 pub refresh_interval_secs: u64,
32}
33
34impl Default for SemanticRouterConfig {
35 fn default() -> Self {
36 Self {
37 sample_size: 20,
38 refresh_interval_secs: 1800, }
40 }
41}
42
43impl SemanticRouterConfig {
44 pub fn from_env() -> Self {
45 let sample_size: usize = std::env::var("DAKERA_ROUTE_SAMPLE_SIZE")
46 .ok()
47 .and_then(|v| v.parse().ok())
48 .unwrap_or(20);
49
50 let refresh_interval_secs: u64 = std::env::var("DAKERA_ROUTE_REFRESH_SECS")
51 .ok()
52 .and_then(|v| v.parse().ok())
53 .unwrap_or(1800);
54
55 Self {
56 sample_size,
57 refresh_interval_secs,
58 }
59 }
60}
61
62#[derive(Clone)]
64struct CentroidEntry {
65 centroid: Vec<f32>,
66 count: usize,
67}
68
69pub struct SemanticRouter {
71 config: SemanticRouterConfig,
72 cache: RwLock<HashMap<String, CentroidEntry>>,
74}
75
76impl SemanticRouter {
77 pub fn new(config: SemanticRouterConfig) -> Self {
78 Self {
79 config,
80 cache: RwLock::new(HashMap::new()),
81 }
82 }
83
84 pub fn route(&self, query: &[f32], top_k: usize, min_similarity: f32) -> Vec<RouteMatch> {
89 let cache = self.cache.read();
90 let mut matches: Vec<RouteMatch> = cache
91 .iter()
92 .filter_map(|(ns, entry)| {
93 if entry.centroid.len() != query.len() {
94 return None; }
96 let sim = calculate_distance(query, &entry.centroid, DistanceMetric::Cosine);
97 if sim >= min_similarity {
98 Some(RouteMatch {
99 namespace: ns.clone(),
100 similarity: sim,
101 memory_count: entry.count,
102 })
103 } else {
104 None
105 }
106 })
107 .collect();
108
109 matches.sort_by(|a, b| {
110 b.similarity
111 .partial_cmp(&a.similarity)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 matches.truncate(top_k);
115 matches
116 }
117
118 pub async fn refresh_centroids(&self, storage: &Arc<dyn VectorStorage>) {
123 let namespaces = match storage.list_namespaces().await {
124 Ok(ns) => ns,
125 Err(e) => {
126 tracing::warn!(error = %e, "Failed to list namespaces for centroid refresh");
127 return;
128 }
129 };
130
131 let mut new_cache: HashMap<String, CentroidEntry> = HashMap::new();
132
133 for namespace in &namespaces {
134 if !namespace.starts_with("_dakera_agent_") {
135 continue;
136 }
137
138 let vectors = match storage.get_all(namespace).await {
139 Ok(v) => v,
140 Err(_) => continue,
141 };
142
143 if vectors.is_empty() {
144 continue;
145 }
146
147 let count = vectors.len();
148
149 let sample: Vec<&Vec<f32>> = vectors
151 .iter()
152 .filter(|v| !v.values.is_empty())
153 .take(self.config.sample_size)
154 .map(|v| &v.values)
155 .collect();
156
157 if sample.is_empty() {
158 continue;
159 }
160
161 let dim = sample[0].len();
163 let mut centroid = vec![0.0f32; dim];
164 let mut valid = 0usize;
165 for embedding in &sample {
166 if embedding.len() == dim {
167 for (i, val) in embedding.iter().enumerate() {
168 centroid[i] += val;
169 }
170 valid += 1;
171 }
172 }
173
174 if valid > 0 {
175 for val in &mut centroid {
176 *val /= valid as f32;
177 }
178 let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
180 if norm > 1e-8 {
181 for val in &mut centroid {
182 *val /= norm;
183 }
184 }
185 new_cache.insert(namespace.clone(), CentroidEntry { centroid, count });
186 }
187 }
188
189 let refreshed_count = new_cache.len();
190 *self.cache.write() = new_cache;
191
192 tracing::info!(
193 namespaces_cached = refreshed_count,
194 "Semantic router centroid cache refreshed"
195 );
196 }
197
198 pub fn spawn_refresh(
200 router: Arc<SemanticRouter>,
201 storage: Arc<dyn VectorStorage>,
202 ) -> tokio::task::JoinHandle<()> {
203 let interval_secs = router.config.refresh_interval_secs;
204 tokio::spawn(async move {
205 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
207 router.refresh_centroids(&storage).await;
208
209 let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
210 loop {
211 interval.tick().await;
212 router.refresh_centroids(&storage).await;
213 }
214 })
215 }
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum QueryKind {
225 Keyword,
227 Semantic,
229 Hybrid,
231 Temporal,
238 MultiHop,
249}
250
251pub struct QueryClassifier;
254
255impl QueryClassifier {
256 pub fn classify(query: &str) -> QueryKind {
269 let trimmed = query.trim();
270 let word_count = trimmed.split_whitespace().count();
271 let lower = trimmed.to_lowercase();
272
273 let is_temporal = lower.starts_with("when ")
285 || lower.starts_with("when did")
286 || lower.starts_with("when was")
287 || lower.starts_with("when were")
288 || lower.starts_with("when is")
289 || lower.contains("what year")
290 || lower.contains("what date")
291 || lower.contains("what time did")
292 || lower.contains("what time was")
293 || lower.contains("how long ago")
294 || lower.contains("how long") || lower.contains("how soon") || lower.contains("how many years")
297 || lower.contains("how many months")
298 || lower.contains("how many weeks") || lower.contains("how many days")
300 || lower.contains("how many hours") || lower.contains("how many minutes") || lower.contains("since when")
303 || lower.contains("at what age")
304 || lower.contains("how old was")
305 || lower.contains("how old were");
306
307 if is_temporal {
308 return QueryKind::Temporal;
309 }
310
311 let is_multi_hop = lower.contains("as a result of")
320 || lower.contains("as a consequence of")
321 || lower.contains("as a consequence")
322 || lower.contains("after she ")
324 || lower.contains("after he ")
325 || lower.contains("after they ")
326 || lower.contains("after it ")
327 || lower.contains("after we ")
328 || lower.contains("after her ")
329 || lower.contains("after his ")
330 || lower.contains("after their ")
331 || lower.contains("once she ")
333 || lower.contains("once he ")
334 || lower.contains("once they ")
335 || lower.contains("following the ")
337 || lower.contains("following her ")
338 || lower.contains("following his ")
339 || lower.contains("following their ")
340 || lower.contains("following a ");
341
342 if is_multi_hop {
343 return QueryKind::MultiHop;
344 }
345
346 let is_question = trimmed.contains('?')
349 || lower.starts_with("what ")
350 || lower.starts_with("how ")
351 || lower.starts_with("why ")
352 || lower.starts_with("when ")
353 || lower.starts_with("where ")
354 || lower.starts_with("who ")
355 || lower.starts_with("tell me")
356 || lower.starts_with("explain")
357 || lower.starts_with("describe");
358
359 if is_question {
360 QueryKind::Hybrid
361 } else if word_count >= 8 || trimmed.contains('.') {
362 QueryKind::Semantic
363 } else if word_count <= 3 {
364 QueryKind::Keyword
365 } else {
366 QueryKind::Hybrid
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_route_empty_cache() {
377 let router = SemanticRouter::new(SemanticRouterConfig::default());
378 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.5);
379 assert!(results.is_empty());
380 }
381
382 #[test]
383 fn test_route_with_cached_centroids() {
384 let router = SemanticRouter::new(SemanticRouterConfig::default());
385
386 {
388 let mut cache = router.cache.write();
389 cache.insert(
390 "_dakera_agent_dev".to_string(),
391 CentroidEntry {
392 centroid: vec![1.0, 0.0, 0.0],
393 count: 100,
394 },
395 );
396 cache.insert(
397 "_dakera_agent_ops".to_string(),
398 CentroidEntry {
399 centroid: vec![0.0, 1.0, 0.0],
400 count: 50,
401 },
402 );
403 cache.insert(
404 "_dakera_agent_sec".to_string(),
405 CentroidEntry {
406 centroid: vec![0.707, 0.707, 0.0],
407 count: 30,
408 },
409 );
410 }
411
412 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
414 assert_eq!(results.len(), 3);
415 assert_eq!(results[0].namespace, "_dakera_agent_dev");
416 assert!(results[0].similarity > results[1].similarity);
417 }
418
419 #[test]
420 fn test_route_min_similarity_filter() {
421 let router = SemanticRouter::new(SemanticRouterConfig::default());
422
423 {
424 let mut cache = router.cache.write();
425 cache.insert(
426 "_dakera_agent_a".to_string(),
427 CentroidEntry {
428 centroid: vec![1.0, 0.0, 0.0],
429 count: 10,
430 },
431 );
432 cache.insert(
433 "_dakera_agent_b".to_string(),
434 CentroidEntry {
435 centroid: vec![0.0, 1.0, 0.0],
436 count: 10,
437 },
438 );
439 }
440
441 let results = router.route(&[1.0, 0.0, 0.0], 5, 0.9);
443 assert_eq!(results.len(), 1);
444 assert_eq!(results[0].namespace, "_dakera_agent_a");
445 }
446
447 #[test]
448 fn test_route_top_k_truncation() {
449 let router = SemanticRouter::new(SemanticRouterConfig::default());
450
451 {
452 let mut cache = router.cache.write();
453 for i in 0..10 {
454 let mut centroid = vec![0.0f32; 3];
455 centroid[0] = 1.0 - (i as f32 * 0.05);
456 centroid[1] = i as f32 * 0.05;
457 let norm = (centroid[0] * centroid[0] + centroid[1] * centroid[1]).sqrt();
458 centroid[0] /= norm;
459 centroid[1] /= norm;
460 cache.insert(
461 format!("_dakera_agent_{}", i),
462 CentroidEntry {
463 centroid,
464 count: 10,
465 },
466 );
467 }
468 }
469
470 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
471 assert_eq!(results.len(), 3);
472 }
473
474 #[test]
475 fn test_route_dimension_mismatch_skipped() {
476 let router = SemanticRouter::new(SemanticRouterConfig::default());
477
478 {
479 let mut cache = router.cache.write();
480 cache.insert(
481 "_dakera_agent_3d".to_string(),
482 CentroidEntry {
483 centroid: vec![1.0, 0.0, 0.0],
484 count: 10,
485 },
486 );
487 cache.insert(
488 "_dakera_agent_5d".to_string(),
489 CentroidEntry {
490 centroid: vec![1.0, 0.0, 0.0, 0.0, 0.0],
491 count: 10,
492 },
493 );
494 }
495
496 let results = router.route(&[1.0, 0.0, 0.0], 5, 0.0);
498 assert_eq!(results.len(), 1);
499 assert_eq!(results[0].namespace, "_dakera_agent_3d");
500 }
501
502 #[test]
503 fn test_config_defaults() {
504 let config = SemanticRouterConfig::default();
505 assert_eq!(config.sample_size, 20);
506 assert_eq!(config.refresh_interval_secs, 1800);
507 }
508
509 #[test]
512 fn test_classify_keyword_short() {
513 assert_eq!(QueryClassifier::classify("rust async"), QueryKind::Keyword);
514 assert_eq!(QueryClassifier::classify("HNSW"), QueryKind::Keyword);
515 assert_eq!(
516 QueryClassifier::classify("memory importance"),
517 QueryKind::Keyword
518 );
519 }
520
521 #[test]
522 fn test_classify_question_routes_hybrid() {
523 assert_eq!(
525 QueryClassifier::classify(
526 "what is the best way to store long term memories in an AI system"
527 ),
528 QueryKind::Hybrid
529 );
530 assert_eq!(
531 QueryClassifier::classify("tell me about the agent memory architecture"),
532 QueryKind::Hybrid
533 );
534 assert_eq!(
535 QueryClassifier::classify("how does HNSW work?"),
536 QueryKind::Hybrid
537 );
538 assert_eq!(
539 QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
540 QueryKind::Hybrid
541 );
542 }
543
544 #[test]
545 fn test_classify_semantic_long_prose() {
546 assert_eq!(
548 QueryClassifier::classify(
549 "the agent memory platform stores embeddings with adaptive decay weighting"
550 ),
551 QueryKind::Semantic
552 );
553 }
554
555 #[test]
556 fn test_classify_hybrid_middle() {
557 assert_eq!(
558 QueryClassifier::classify("vector search memory agent"),
559 QueryKind::Hybrid
560 );
561 }
562
563 #[test]
566 fn test_classify_temporal_when_prefix() {
567 assert_eq!(
569 QueryClassifier::classify("when did Caroline go to the store?"),
570 QueryKind::Temporal
571 );
572 assert_eq!(
573 QueryClassifier::classify("When was the last time they spoke?"),
574 QueryKind::Temporal
575 );
576 assert_eq!(
577 QueryClassifier::classify("When were the siblings born?"),
578 QueryKind::Temporal
579 );
580 }
581
582 #[test]
583 fn test_classify_temporal_date_year_patterns() {
584 assert_eq!(
585 QueryClassifier::classify("What year did they get married?"),
586 QueryKind::Temporal
587 );
588 assert_eq!(
589 QueryClassifier::classify("what date did the conference take place?"),
590 QueryKind::Temporal
591 );
592 assert_eq!(
593 QueryClassifier::classify("What time did the meeting start?"),
594 QueryKind::Temporal
595 );
596 assert_eq!(
597 QueryClassifier::classify("How long ago did this happen?"),
598 QueryKind::Temporal
599 );
600 assert_eq!(
601 QueryClassifier::classify("How many years have they been friends?"),
602 QueryKind::Temporal
603 );
604 assert_eq!(
605 QueryClassifier::classify("How old was Sarah when she graduated?"),
606 QueryKind::Temporal
607 );
608 }
609
610 #[test]
611 fn test_classify_temporal_does_not_capture_non_temporal_what() {
612 assert_eq!(
614 QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
615 QueryKind::Hybrid
616 );
617 assert_eq!(
618 QueryClassifier::classify("what is the best way to find old memories"),
619 QueryKind::Hybrid
620 );
621 }
622
623 #[test]
626 fn test_classify_multihop_pronoun_after_marker() {
627 assert_eq!(
629 QueryClassifier::classify("What did Sarah do after she got married?"),
630 QueryKind::MultiHop
631 );
632 assert_eq!(
633 QueryClassifier::classify("Where did they move after they sold the house?"),
634 QueryKind::MultiHop
635 );
636 assert_eq!(
637 QueryClassifier::classify("What happened after he graduated from college?"),
638 QueryKind::MultiHop
639 );
640 assert_eq!(
641 QueryClassifier::classify("What did Alice do once she moved to the new city?"),
642 QueryKind::MultiHop
643 );
644 }
645
646 #[test]
647 fn test_classify_multihop_causative_phrases() {
648 assert_eq!(
650 QueryClassifier::classify("What changed as a result of their decision?"),
651 QueryKind::MultiHop
652 );
653 assert_eq!(
654 QueryClassifier::classify("What happened as a consequence of the accident?"),
655 QueryKind::MultiHop
656 );
657 }
658
659 #[test]
660 fn test_classify_multihop_following_structural() {
661 assert_eq!(
663 QueryClassifier::classify("What did Bob do following the promotion?"),
664 QueryKind::MultiHop
665 );
666 assert_eq!(
667 QueryClassifier::classify("Where did they live following her diagnosis?"),
668 QueryKind::MultiHop
669 );
670 }
671
672 #[test]
673 fn test_classify_multihop_bare_markers_do_not_fire() {
674 assert_eq!(
677 QueryClassifier::classify("What did Sarah do after school?"),
678 QueryKind::Hybrid );
680 assert_eq!(
681 QueryClassifier::classify("What happened before the wedding?"),
682 QueryKind::Hybrid );
684 assert_eq!(
685 QueryClassifier::classify("Since when did they live there?"),
686 QueryKind::Temporal );
688 assert_eq!(
689 QueryClassifier::classify("How did they feel once settled?"),
690 QueryKind::Hybrid );
692 }
693
694 #[test]
695 fn test_classify_multihop_does_not_interfere_with_temporal() {
696 assert_eq!(
698 QueryClassifier::classify("when did she move after he graduated?"),
699 QueryKind::Temporal );
701 }
702
703 #[test]
706 fn test_classify_temporal_how_long_patterns() {
707 assert_eq!(
711 QueryClassifier::classify("How long after she moved did they get married?"),
712 QueryKind::Temporal );
714 assert_eq!(
715 QueryClassifier::classify("How long after he graduated did she find a job?"),
716 QueryKind::Temporal );
718 assert_eq!(
719 QueryClassifier::classify("How long before the wedding did they meet?"),
720 QueryKind::Temporal );
722 assert_eq!(
723 QueryClassifier::classify("How long did the relationship last?"),
724 QueryKind::Temporal );
726 }
727
728 #[test]
729 fn test_classify_temporal_how_soon_patterns() {
730 assert_eq!(
732 QueryClassifier::classify("How soon after she started the new job did they move?"),
733 QueryKind::Temporal );
735 assert_eq!(
736 QueryClassifier::classify("How soon did they get back together?"),
737 QueryKind::Temporal
738 );
739 }
740
741 #[test]
742 fn test_classify_temporal_new_time_units() {
743 assert_eq!(
745 QueryClassifier::classify("How many weeks after the move did they settle in?"),
746 QueryKind::Temporal );
748 assert_eq!(
749 QueryClassifier::classify("How many hours did the procedure take?"),
750 QueryKind::Temporal
751 );
752 assert_eq!(
753 QueryClassifier::classify("How many minutes before the event did she arrive?"),
754 QueryKind::Temporal
755 );
756 }
757
758 #[test]
759 fn test_classify_temporal_how_long_beats_multihop() {
760 assert_ne!(
765 QueryClassifier::classify("How long after she started did he propose?"),
766 QueryKind::MultiHop
767 );
768 assert_ne!(
769 QueryClassifier::classify(
770 "How long after they moved following the promotion did he get promoted again?"
771 ),
772 QueryKind::MultiHop
773 );
774 }
775}