1use std::collections::HashMap;
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use storage::VectorStorage;
14
15use crate::distance::calculate_distance;
16use common::DistanceMetric;
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
20pub struct RouteMatch {
21 pub namespace: String,
22 pub similarity: f32,
23 pub memory_count: usize,
24}
25
26pub struct SemanticRouterConfig {
28 pub sample_size: usize,
30 pub refresh_interval_secs: u64,
32}
33
34impl Default for SemanticRouterConfig {
35 fn default() -> Self {
36 Self {
37 sample_size: 20,
38 refresh_interval_secs: 1800, }
40 }
41}
42
43impl SemanticRouterConfig {
44 pub fn from_env() -> Self {
45 let sample_size: usize = std::env::var("DAKERA_ROUTE_SAMPLE_SIZE")
46 .ok()
47 .and_then(|v| v.parse().ok())
48 .unwrap_or(20);
49
50 let refresh_interval_secs: u64 = std::env::var("DAKERA_ROUTE_REFRESH_SECS")
51 .ok()
52 .and_then(|v| v.parse().ok())
53 .unwrap_or(1800);
54
55 Self {
56 sample_size,
57 refresh_interval_secs,
58 }
59 }
60}
61
62#[derive(Clone)]
64struct CentroidEntry {
65 centroid: Vec<f32>,
66 count: usize,
67}
68
69pub struct SemanticRouter {
71 config: SemanticRouterConfig,
72 cache: RwLock<HashMap<String, CentroidEntry>>,
74}
75
76impl SemanticRouter {
77 pub fn new(config: SemanticRouterConfig) -> Self {
78 Self {
79 config,
80 cache: RwLock::new(HashMap::new()),
81 }
82 }
83
84 pub fn route(&self, query: &[f32], top_k: usize, min_similarity: f32) -> Vec<RouteMatch> {
89 let cache = self.cache.read();
90 let mut matches: Vec<RouteMatch> = cache
91 .iter()
92 .filter_map(|(ns, entry)| {
93 if entry.centroid.len() != query.len() {
94 return None; }
96 let sim = calculate_distance(query, &entry.centroid, DistanceMetric::Cosine);
97 if sim >= min_similarity {
98 Some(RouteMatch {
99 namespace: ns.clone(),
100 similarity: sim,
101 memory_count: entry.count,
102 })
103 } else {
104 None
105 }
106 })
107 .collect();
108
109 matches.sort_by(|a, b| {
110 b.similarity
111 .partial_cmp(&a.similarity)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 matches.truncate(top_k);
115 matches
116 }
117
118 pub async fn refresh_centroids(&self, storage: &Arc<dyn VectorStorage>) {
123 let namespaces = match storage.list_namespaces().await {
124 Ok(ns) => ns,
125 Err(e) => {
126 tracing::warn!(error = %e, "Failed to list namespaces for centroid refresh");
127 return;
128 }
129 };
130
131 let mut new_cache: HashMap<String, CentroidEntry> = HashMap::new();
132
133 for namespace in &namespaces {
134 if !namespace.starts_with("_dakera_agent_") {
135 continue;
136 }
137
138 let vectors = match storage.get_all(namespace).await {
139 Ok(v) => v,
140 Err(_) => continue,
141 };
142
143 if vectors.is_empty() {
144 continue;
145 }
146
147 let count = vectors.len();
148
149 let sample: Vec<&Vec<f32>> = vectors
151 .iter()
152 .filter(|v| !v.values.is_empty())
153 .take(self.config.sample_size)
154 .map(|v| &v.values)
155 .collect();
156
157 if sample.is_empty() {
158 continue;
159 }
160
161 let dim = sample[0].len();
163 let mut centroid = vec![0.0f32; dim];
164 let mut valid = 0usize;
165 for embedding in &sample {
166 if embedding.len() == dim {
167 for (i, val) in embedding.iter().enumerate() {
168 centroid[i] += val;
169 }
170 valid += 1;
171 }
172 }
173
174 if valid > 0 {
175 for val in &mut centroid {
176 *val /= valid as f32;
177 }
178 let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
180 if norm > 1e-8 {
181 for val in &mut centroid {
182 *val /= norm;
183 }
184 }
185 new_cache.insert(namespace.clone(), CentroidEntry { centroid, count });
186 }
187 }
188
189 let refreshed_count = new_cache.len();
190 *self.cache.write() = new_cache;
191
192 tracing::info!(
193 namespaces_cached = refreshed_count,
194 "Semantic router centroid cache refreshed"
195 );
196 }
197
198 pub fn spawn_refresh(
200 router: Arc<SemanticRouter>,
201 storage: Arc<dyn VectorStorage>,
202 ) -> tokio::task::JoinHandle<()> {
203 let interval_secs = router.config.refresh_interval_secs;
204 tokio::spawn(async move {
205 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
207 router.refresh_centroids(&storage).await;
208
209 let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
210 loop {
211 interval.tick().await;
212 router.refresh_centroids(&storage).await;
213 }
214 })
215 }
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum QueryKind {
225 Keyword,
227 Semantic,
229 Hybrid,
231 Temporal,
238 MultiHop,
249}
250
251pub struct QueryClassifier;
254
255impl QueryClassifier {
256 pub fn classify(query: &str) -> QueryKind {
269 let trimmed = query.trim();
270 let word_count = trimmed.split_whitespace().count();
271 let lower = trimmed.to_lowercase();
272
273 let is_temporal = lower.starts_with("when ")
293 || lower.starts_with("when did")
294 || lower.starts_with("when was")
295 || lower.starts_with("when were")
296 || lower.starts_with("when is")
297 || lower.contains("what year")
298 || lower.contains("what date")
299 || lower.contains("what time did")
300 || lower.contains("what time was")
301 || lower.contains("how long ago")
302 || lower.contains("how long after ") || lower.contains("how long before ") || lower.contains("how long since ") || lower.contains("how soon after ") || lower.contains("how soon before ") || lower.contains("how many years")
308 || lower.contains("how many months")
309 || lower.contains("how many weeks") || lower.contains("how many days")
311 || lower.contains("how many hours") || lower.contains("how many minutes") || lower.contains("since when")
314 || lower.contains("at what age")
315 || lower.contains("how old was")
316 || lower.contains("how old were");
317
318 if is_temporal {
319 return QueryKind::Temporal;
320 }
321
322 let is_multi_hop = lower.contains("as a result of")
331 || lower.contains("as a consequence of")
332 || lower.contains("as a consequence")
333 || lower.contains("after she ")
335 || lower.contains("after he ")
336 || lower.contains("after they ")
337 || lower.contains("after it ")
338 || lower.contains("after we ")
339 || lower.contains("after her ")
340 || lower.contains("after his ")
341 || lower.contains("after their ")
342 || lower.contains("once she ")
344 || lower.contains("once he ")
345 || lower.contains("once they ")
346 || lower.contains("following the ")
348 || lower.contains("following her ")
349 || lower.contains("following his ")
350 || lower.contains("following their ")
351 || lower.contains("following a ");
352
353 if is_multi_hop {
354 return QueryKind::MultiHop;
355 }
356
357 let is_question = trimmed.contains('?')
360 || lower.starts_with("what ")
361 || lower.starts_with("how ")
362 || lower.starts_with("why ")
363 || lower.starts_with("when ")
364 || lower.starts_with("where ")
365 || lower.starts_with("who ")
366 || lower.starts_with("tell me")
367 || lower.starts_with("explain")
368 || lower.starts_with("describe");
369
370 if is_question {
371 QueryKind::Hybrid
372 } else if word_count >= 8 || trimmed.contains('.') {
373 QueryKind::Semantic
374 } else if word_count <= 3 {
375 QueryKind::Keyword
376 } else {
377 QueryKind::Hybrid
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_route_empty_cache() {
388 let router = SemanticRouter::new(SemanticRouterConfig::default());
389 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.5);
390 assert!(results.is_empty());
391 }
392
393 #[test]
394 fn test_route_with_cached_centroids() {
395 let router = SemanticRouter::new(SemanticRouterConfig::default());
396
397 {
399 let mut cache = router.cache.write();
400 cache.insert(
401 "_dakera_agent_dev".to_string(),
402 CentroidEntry {
403 centroid: vec![1.0, 0.0, 0.0],
404 count: 100,
405 },
406 );
407 cache.insert(
408 "_dakera_agent_ops".to_string(),
409 CentroidEntry {
410 centroid: vec![0.0, 1.0, 0.0],
411 count: 50,
412 },
413 );
414 cache.insert(
415 "_dakera_agent_sec".to_string(),
416 CentroidEntry {
417 centroid: vec![0.707, 0.707, 0.0],
418 count: 30,
419 },
420 );
421 }
422
423 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
425 assert_eq!(results.len(), 3);
426 assert_eq!(results[0].namespace, "_dakera_agent_dev");
427 assert!(results[0].similarity > results[1].similarity);
428 }
429
430 #[test]
431 fn test_route_min_similarity_filter() {
432 let router = SemanticRouter::new(SemanticRouterConfig::default());
433
434 {
435 let mut cache = router.cache.write();
436 cache.insert(
437 "_dakera_agent_a".to_string(),
438 CentroidEntry {
439 centroid: vec![1.0, 0.0, 0.0],
440 count: 10,
441 },
442 );
443 cache.insert(
444 "_dakera_agent_b".to_string(),
445 CentroidEntry {
446 centroid: vec![0.0, 1.0, 0.0],
447 count: 10,
448 },
449 );
450 }
451
452 let results = router.route(&[1.0, 0.0, 0.0], 5, 0.9);
454 assert_eq!(results.len(), 1);
455 assert_eq!(results[0].namespace, "_dakera_agent_a");
456 }
457
458 #[test]
459 fn test_route_top_k_truncation() {
460 let router = SemanticRouter::new(SemanticRouterConfig::default());
461
462 {
463 let mut cache = router.cache.write();
464 for i in 0..10 {
465 let mut centroid = vec![0.0f32; 3];
466 centroid[0] = 1.0 - (i as f32 * 0.05);
467 centroid[1] = i as f32 * 0.05;
468 let norm = (centroid[0] * centroid[0] + centroid[1] * centroid[1]).sqrt();
469 centroid[0] /= norm;
470 centroid[1] /= norm;
471 cache.insert(
472 format!("_dakera_agent_{}", i),
473 CentroidEntry {
474 centroid,
475 count: 10,
476 },
477 );
478 }
479 }
480
481 let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
482 assert_eq!(results.len(), 3);
483 }
484
485 #[test]
486 fn test_route_dimension_mismatch_skipped() {
487 let router = SemanticRouter::new(SemanticRouterConfig::default());
488
489 {
490 let mut cache = router.cache.write();
491 cache.insert(
492 "_dakera_agent_3d".to_string(),
493 CentroidEntry {
494 centroid: vec![1.0, 0.0, 0.0],
495 count: 10,
496 },
497 );
498 cache.insert(
499 "_dakera_agent_5d".to_string(),
500 CentroidEntry {
501 centroid: vec![1.0, 0.0, 0.0, 0.0, 0.0],
502 count: 10,
503 },
504 );
505 }
506
507 let results = router.route(&[1.0, 0.0, 0.0], 5, 0.0);
509 assert_eq!(results.len(), 1);
510 assert_eq!(results[0].namespace, "_dakera_agent_3d");
511 }
512
513 #[test]
514 fn test_config_defaults() {
515 let config = SemanticRouterConfig::default();
516 assert_eq!(config.sample_size, 20);
517 assert_eq!(config.refresh_interval_secs, 1800);
518 }
519
520 #[test]
523 fn test_classify_keyword_short() {
524 assert_eq!(QueryClassifier::classify("rust async"), QueryKind::Keyword);
525 assert_eq!(QueryClassifier::classify("HNSW"), QueryKind::Keyword);
526 assert_eq!(
527 QueryClassifier::classify("memory importance"),
528 QueryKind::Keyword
529 );
530 }
531
532 #[test]
533 fn test_classify_question_routes_hybrid() {
534 assert_eq!(
536 QueryClassifier::classify(
537 "what is the best way to store long term memories in an AI system"
538 ),
539 QueryKind::Hybrid
540 );
541 assert_eq!(
542 QueryClassifier::classify("tell me about the agent memory architecture"),
543 QueryKind::Hybrid
544 );
545 assert_eq!(
546 QueryClassifier::classify("how does HNSW work?"),
547 QueryKind::Hybrid
548 );
549 assert_eq!(
550 QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
551 QueryKind::Hybrid
552 );
553 }
554
555 #[test]
556 fn test_classify_semantic_long_prose() {
557 assert_eq!(
559 QueryClassifier::classify(
560 "the agent memory platform stores embeddings with adaptive decay weighting"
561 ),
562 QueryKind::Semantic
563 );
564 }
565
566 #[test]
567 fn test_classify_hybrid_middle() {
568 assert_eq!(
569 QueryClassifier::classify("vector search memory agent"),
570 QueryKind::Hybrid
571 );
572 }
573
574 #[test]
577 fn test_classify_temporal_when_prefix() {
578 assert_eq!(
580 QueryClassifier::classify("when did Caroline go to the store?"),
581 QueryKind::Temporal
582 );
583 assert_eq!(
584 QueryClassifier::classify("When was the last time they spoke?"),
585 QueryKind::Temporal
586 );
587 assert_eq!(
588 QueryClassifier::classify("When were the siblings born?"),
589 QueryKind::Temporal
590 );
591 }
592
593 #[test]
594 fn test_classify_temporal_date_year_patterns() {
595 assert_eq!(
596 QueryClassifier::classify("What year did they get married?"),
597 QueryKind::Temporal
598 );
599 assert_eq!(
600 QueryClassifier::classify("what date did the conference take place?"),
601 QueryKind::Temporal
602 );
603 assert_eq!(
604 QueryClassifier::classify("What time did the meeting start?"),
605 QueryKind::Temporal
606 );
607 assert_eq!(
608 QueryClassifier::classify("How long ago did this happen?"),
609 QueryKind::Temporal
610 );
611 assert_eq!(
612 QueryClassifier::classify("How many years have they been friends?"),
613 QueryKind::Temporal
614 );
615 assert_eq!(
616 QueryClassifier::classify("How old was Sarah when she graduated?"),
617 QueryKind::Temporal
618 );
619 }
620
621 #[test]
622 fn test_classify_temporal_does_not_capture_non_temporal_what() {
623 assert_eq!(
625 QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
626 QueryKind::Hybrid
627 );
628 assert_eq!(
629 QueryClassifier::classify("what is the best way to find old memories"),
630 QueryKind::Hybrid
631 );
632 }
633
634 #[test]
637 fn test_classify_multihop_pronoun_after_marker() {
638 assert_eq!(
640 QueryClassifier::classify("What did Sarah do after she got married?"),
641 QueryKind::MultiHop
642 );
643 assert_eq!(
644 QueryClassifier::classify("Where did they move after they sold the house?"),
645 QueryKind::MultiHop
646 );
647 assert_eq!(
648 QueryClassifier::classify("What happened after he graduated from college?"),
649 QueryKind::MultiHop
650 );
651 assert_eq!(
652 QueryClassifier::classify("What did Alice do once she moved to the new city?"),
653 QueryKind::MultiHop
654 );
655 }
656
657 #[test]
658 fn test_classify_multihop_causative_phrases() {
659 assert_eq!(
661 QueryClassifier::classify("What changed as a result of their decision?"),
662 QueryKind::MultiHop
663 );
664 assert_eq!(
665 QueryClassifier::classify("What happened as a consequence of the accident?"),
666 QueryKind::MultiHop
667 );
668 }
669
670 #[test]
671 fn test_classify_multihop_following_structural() {
672 assert_eq!(
674 QueryClassifier::classify("What did Bob do following the promotion?"),
675 QueryKind::MultiHop
676 );
677 assert_eq!(
678 QueryClassifier::classify("Where did they live following her diagnosis?"),
679 QueryKind::MultiHop
680 );
681 }
682
683 #[test]
684 fn test_classify_multihop_bare_markers_do_not_fire() {
685 assert_eq!(
688 QueryClassifier::classify("What did Sarah do after school?"),
689 QueryKind::Hybrid );
691 assert_eq!(
692 QueryClassifier::classify("What happened before the wedding?"),
693 QueryKind::Hybrid );
695 assert_eq!(
696 QueryClassifier::classify("Since when did they live there?"),
697 QueryKind::Temporal );
699 assert_eq!(
700 QueryClassifier::classify("How did they feel once settled?"),
701 QueryKind::Hybrid );
703 }
704
705 #[test]
706 fn test_classify_multihop_does_not_interfere_with_temporal() {
707 assert_eq!(
709 QueryClassifier::classify("when did she move after he graduated?"),
710 QueryKind::Temporal );
712 }
713
714 #[test]
717 fn test_classify_temporal_how_long_patterns() {
718 assert_eq!(
723 QueryClassifier::classify("How long after she moved did they get married?"),
724 QueryKind::Temporal );
726 assert_eq!(
727 QueryClassifier::classify("How long after he graduated did she find a job?"),
728 QueryKind::Temporal );
730 assert_eq!(
731 QueryClassifier::classify("How long before the wedding did they meet?"),
732 QueryKind::Temporal );
734 assert_eq!(
735 QueryClassifier::classify("How long since they got married has she been working?"),
736 QueryKind::Temporal );
738 assert_eq!(
740 QueryClassifier::classify("How long did the relationship last?"),
741 QueryKind::Hybrid );
743 assert_eq!(
744 QueryClassifier::classify("How long have they been friends?"),
745 QueryKind::Hybrid );
747 assert_eq!(
748 QueryClassifier::classify("How long has she been working at the company?"),
749 QueryKind::Hybrid );
751 }
752
753 #[test]
754 fn test_classify_temporal_how_soon_patterns() {
755 assert_eq!(
757 QueryClassifier::classify("How soon after she started the new job did they move?"),
758 QueryKind::Temporal );
760 assert_eq!(
761 QueryClassifier::classify("How soon before the trip did they pack?"),
762 QueryKind::Temporal );
764 assert_eq!(
766 QueryClassifier::classify("How soon did they get back together?"),
767 QueryKind::Hybrid );
769 }
770
771 #[test]
772 fn test_classify_temporal_new_time_units() {
773 assert_eq!(
775 QueryClassifier::classify("How many weeks after the move did they settle in?"),
776 QueryKind::Temporal );
778 assert_eq!(
779 QueryClassifier::classify("How many hours did the procedure take?"),
780 QueryKind::Temporal
781 );
782 assert_eq!(
783 QueryClassifier::classify("How many minutes before the event did she arrive?"),
784 QueryKind::Temporal
785 );
786 }
787
788 #[test]
789 fn test_classify_temporal_how_long_beats_multihop() {
790 assert_ne!(
795 QueryClassifier::classify("How long after she started did he propose?"),
796 QueryKind::MultiHop
797 );
798 assert_ne!(
799 QueryClassifier::classify(
800 "How long after they moved following the promotion did he get promoted again?"
801 ),
802 QueryKind::MultiHop
803 );
804 assert_ne!(
806 QueryClassifier::classify("How soon after she started the job did they move?"),
807 QueryKind::MultiHop
808 );
809 }
810
811 #[test]
814 fn test_classify_ce36_cat1_cat2_not_over_captured_by_temporal() {
815 assert_eq!(
822 QueryClassifier::classify("How long have they been dating?"),
823 QueryKind::Hybrid
824 );
825 assert_eq!(
826 QueryClassifier::classify("How long did he live in New York?"),
827 QueryKind::Hybrid
828 );
829 assert_eq!(
830 QueryClassifier::classify("How long was the trip they took together?"),
831 QueryKind::Hybrid
832 );
833 assert_eq!(
836 QueryClassifier::classify("How long did they stay after she got the promotion?"),
837 QueryKind::MultiHop );
839 assert_eq!(
841 QueryClassifier::classify("How soon will they be ready?"),
842 QueryKind::Hybrid
843 );
844 }
845}