Skip to main content

engine/
routing.rs

1//! Semantic Routing Engine for Dakera AI Agent Memory Platform.
2//!
3//! Agents query Dakera without knowing which namespace holds the answer.
4//! Dakera figures it out by comparing the query embedding against cached
5//! namespace centroids (averaged embeddings sampled from each namespace).
6//!
7//! The centroid cache is refreshed periodically in the background.
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use storage::VectorStorage;
14
15use crate::distance::calculate_distance;
16use common::DistanceMetric;
17
18/// A route result: which namespace matched and how strongly.
19#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
20pub struct RouteMatch {
21    pub namespace: String,
22    pub similarity: f32,
23    pub memory_count: usize,
24}
25
26/// Configuration for the semantic router.
27pub struct SemanticRouterConfig {
28    /// Maximum number of memories to sample per namespace for centroid calculation
29    pub sample_size: usize,
30    /// How often to refresh centroids (seconds)
31    pub refresh_interval_secs: u64,
32}
33
34impl Default for SemanticRouterConfig {
35    fn default() -> Self {
36        Self {
37            sample_size: 20,
38            refresh_interval_secs: 1800, // 30 minutes
39        }
40    }
41}
42
43impl SemanticRouterConfig {
44    pub fn from_env() -> Self {
45        let sample_size: usize = std::env::var("DAKERA_ROUTE_SAMPLE_SIZE")
46            .ok()
47            .and_then(|v| v.parse().ok())
48            .unwrap_or(20);
49
50        let refresh_interval_secs: u64 = std::env::var("DAKERA_ROUTE_REFRESH_SECS")
51            .ok()
52            .and_then(|v| v.parse().ok())
53            .unwrap_or(1800);
54
55        Self {
56            sample_size,
57            refresh_interval_secs,
58        }
59    }
60}
61
62/// Cached centroid for a namespace: average embedding + vector count.
63#[derive(Clone)]
64struct CentroidEntry {
65    centroid: Vec<f32>,
66    count: usize,
67}
68
69/// Semantic router that maintains a centroid cache per namespace.
70pub struct SemanticRouter {
71    config: SemanticRouterConfig,
72    /// Namespace → averaged centroid embedding + count
73    cache: RwLock<HashMap<String, CentroidEntry>>,
74}
75
76impl SemanticRouter {
77    pub fn new(config: SemanticRouterConfig) -> Self {
78        Self {
79            config,
80            cache: RwLock::new(HashMap::new()),
81        }
82    }
83
84    /// Route a query embedding to the most relevant namespaces.
85    ///
86    /// Returns namespaces sorted by similarity (descending), filtered
87    /// by `min_similarity`.
88    pub fn route(&self, query: &[f32], top_k: usize, min_similarity: f32) -> Vec<RouteMatch> {
89        let cache = self.cache.read();
90        let mut matches: Vec<RouteMatch> = cache
91            .iter()
92            .filter_map(|(ns, entry)| {
93                if entry.centroid.len() != query.len() {
94                    return None; // dimension mismatch, skip
95                }
96                let sim = calculate_distance(query, &entry.centroid, DistanceMetric::Cosine);
97                if sim >= min_similarity {
98                    Some(RouteMatch {
99                        namespace: ns.clone(),
100                        similarity: sim,
101                        memory_count: entry.count,
102                    })
103                } else {
104                    None
105                }
106            })
107            .collect();
108
109        matches.sort_by(|a, b| {
110            b.similarity
111                .partial_cmp(&a.similarity)
112                .unwrap_or(std::cmp::Ordering::Equal)
113        });
114        matches.truncate(top_k);
115        matches
116    }
117
118    /// Refresh the centroid cache by sampling memories from each agent namespace.
119    ///
120    /// For each `_dakera_agent_*` namespace, sample up to `sample_size` vectors,
121    /// average their embeddings into a single centroid.
122    pub async fn refresh_centroids(&self, storage: &Arc<dyn VectorStorage>) {
123        let namespaces = match storage.list_namespaces().await {
124            Ok(ns) => ns,
125            Err(e) => {
126                tracing::warn!(error = %e, "Failed to list namespaces for centroid refresh");
127                return;
128            }
129        };
130
131        let mut new_cache: HashMap<String, CentroidEntry> = HashMap::new();
132
133        for namespace in &namespaces {
134            if !namespace.starts_with("_dakera_agent_") {
135                continue;
136            }
137
138            let vectors = match storage.get_all(namespace).await {
139                Ok(v) => v,
140                Err(_) => continue,
141            };
142
143            if vectors.is_empty() {
144                continue;
145            }
146
147            let count = vectors.len();
148
149            // Sample up to sample_size vectors (take first N — they're stored in insertion order)
150            let sample: Vec<&Vec<f32>> = vectors
151                .iter()
152                .filter(|v| !v.values.is_empty())
153                .take(self.config.sample_size)
154                .map(|v| &v.values)
155                .collect();
156
157            if sample.is_empty() {
158                continue;
159            }
160
161            // Compute centroid (average embedding)
162            let dim = sample[0].len();
163            let mut centroid = vec![0.0f32; dim];
164            let mut valid = 0usize;
165            for embedding in &sample {
166                if embedding.len() == dim {
167                    for (i, val) in embedding.iter().enumerate() {
168                        centroid[i] += val;
169                    }
170                    valid += 1;
171                }
172            }
173
174            if valid > 0 {
175                for val in &mut centroid {
176                    *val /= valid as f32;
177                }
178                // Normalize centroid for cosine comparison
179                let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
180                if norm > 1e-8 {
181                    for val in &mut centroid {
182                        *val /= norm;
183                    }
184                }
185                new_cache.insert(namespace.clone(), CentroidEntry { centroid, count });
186            }
187        }
188
189        let refreshed_count = new_cache.len();
190        *self.cache.write() = new_cache;
191
192        tracing::info!(
193            namespaces_cached = refreshed_count,
194            "Semantic router centroid cache refreshed"
195        );
196    }
197
198    /// Spawn the centroid refresh as a background tokio task.
199    pub fn spawn_refresh(
200        router: Arc<SemanticRouter>,
201        storage: Arc<dyn VectorStorage>,
202    ) -> tokio::task::JoinHandle<()> {
203        let interval_secs = router.config.refresh_interval_secs;
204        tokio::spawn(async move {
205            // Initial refresh on startup (small delay to let storage warm up)
206            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
207            router.refresh_centroids(&storage).await;
208
209            let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
210            loop {
211                interval.tick().await;
212                router.refresh_centroids(&storage).await;
213            }
214        })
215    }
216}
217
218// ============================================================================
219// CE-12a: Query Classifier for smart routing
220// ============================================================================
221
222/// Inferred query kind used for smart routing decisions.
223#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum QueryKind {
225    /// Short / keyword-based query → prefer BM25 full-text search
226    Keyword,
227    /// Long / natural-language query → prefer vector similarity search
228    Semantic,
229    /// Mixed signal → hybrid (vector + BM25)
230    Hybrid,
231    /// Temporal question (when/what year/what date/how long ago) → hybrid
232    /// with heavily BM25-weighted fusion (vector_weight=0.20).
233    ///
234    /// Rationale: date-prefixed memories ("2022-03-15: …") rank BM25 rank-1
235    /// via exact token match but score near-zero in vector space. A balanced
236    /// 0.50/0.50 split starves the BM25 signal. 0.20 vector / 0.80 BM25
237    /// restores temporal rank-1 precision while retaining vector for entity
238    /// disambiguation.
239    Temporal,
240}
241
242/// Heuristic classifier that determines the best retrieval strategy for a
243/// free-text query without any model inference.
244pub struct QueryClassifier;
245
246impl QueryClassifier {
247    /// Classify a raw query string into a [`QueryKind`].
248    ///
249    /// Heuristics (in priority order):
250    /// 1. Temporal question (when/what year/what date/how long ago/…) → [`QueryKind::Temporal`]
251    ///    Routed to hybrid with BM25-heavy weighting (vector_weight=0.20) because
252    ///    date-prefixed memories rank BM25 rank-1 but score near-zero in vector space.
253    /// 2. Natural language question (has `?` or starts with a question word) → [`QueryKind::Hybrid`]
254    ///    BM25 finds exact names/dates while vector captures semantic intent.
255    ///    This covers all LoCoMo benchmark queries which are long NL questions.
256    /// 3. ≥ 8 words **or** sentence ends in `.` (prose, no question) → [`QueryKind::Semantic`]
257    /// 4. ≤ 3 words with no sentence structure → [`QueryKind::Keyword`]
258    /// 5. Everything else → [`QueryKind::Hybrid`]
259    pub fn classify(query: &str) -> QueryKind {
260        let trimmed = query.trim();
261        let word_count = trimmed.split_whitespace().count();
262        let lower = trimmed.to_lowercase();
263
264        // Temporal queries need BM25 to dominate: date-prefixed memories rank
265        // BM25 rank-1 but score near-zero in vector space. Detect before the
266        // general is_question check because "when " would otherwise fall through
267        // to Hybrid with a balanced 0.50/0.50 weight.
268        let is_temporal = lower.starts_with("when ")
269            || lower.starts_with("when did")
270            || lower.starts_with("when was")
271            || lower.starts_with("when were")
272            || lower.starts_with("when is")
273            || lower.contains("what year")
274            || lower.contains("what date")
275            || lower.contains("what time did")
276            || lower.contains("what time was")
277            || lower.contains("how long ago")
278            || lower.contains("how many years")
279            || lower.contains("how many months")
280            || lower.contains("how many days")
281            || lower.contains("since when")
282            || lower.contains("at what age")
283            || lower.contains("how old was")
284            || lower.contains("how old were");
285
286        if is_temporal {
287            return QueryKind::Temporal;
288        }
289
290        // Natural language questions benefit from both BM25 (named entities, dates)
291        // and vector search (semantic meaning).
292        let is_question = trimmed.contains('?')
293            || lower.starts_with("what ")
294            || lower.starts_with("how ")
295            || lower.starts_with("why ")
296            || lower.starts_with("when ")
297            || lower.starts_with("where ")
298            || lower.starts_with("who ")
299            || lower.starts_with("tell me")
300            || lower.starts_with("explain")
301            || lower.starts_with("describe");
302
303        if is_question {
304            QueryKind::Hybrid
305        } else if word_count >= 8 || trimmed.contains('.') {
306            QueryKind::Semantic
307        } else if word_count <= 3 {
308            QueryKind::Keyword
309        } else {
310            QueryKind::Hybrid
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_route_empty_cache() {
321        let router = SemanticRouter::new(SemanticRouterConfig::default());
322        let results = router.route(&[1.0, 0.0, 0.0], 3, 0.5);
323        assert!(results.is_empty());
324    }
325
326    #[test]
327    fn test_route_with_cached_centroids() {
328        let router = SemanticRouter::new(SemanticRouterConfig::default());
329
330        // Manually populate cache
331        {
332            let mut cache = router.cache.write();
333            cache.insert(
334                "_dakera_agent_dev".to_string(),
335                CentroidEntry {
336                    centroid: vec![1.0, 0.0, 0.0],
337                    count: 100,
338                },
339            );
340            cache.insert(
341                "_dakera_agent_ops".to_string(),
342                CentroidEntry {
343                    centroid: vec![0.0, 1.0, 0.0],
344                    count: 50,
345                },
346            );
347            cache.insert(
348                "_dakera_agent_sec".to_string(),
349                CentroidEntry {
350                    centroid: vec![0.707, 0.707, 0.0],
351                    count: 30,
352                },
353            );
354        }
355
356        // Query aligned with "dev" namespace
357        let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
358        assert_eq!(results.len(), 3);
359        assert_eq!(results[0].namespace, "_dakera_agent_dev");
360        assert!(results[0].similarity > results[1].similarity);
361    }
362
363    #[test]
364    fn test_route_min_similarity_filter() {
365        let router = SemanticRouter::new(SemanticRouterConfig::default());
366
367        {
368            let mut cache = router.cache.write();
369            cache.insert(
370                "_dakera_agent_a".to_string(),
371                CentroidEntry {
372                    centroid: vec![1.0, 0.0, 0.0],
373                    count: 10,
374                },
375            );
376            cache.insert(
377                "_dakera_agent_b".to_string(),
378                CentroidEntry {
379                    centroid: vec![0.0, 1.0, 0.0],
380                    count: 10,
381                },
382            );
383        }
384
385        // High min_similarity should filter out the orthogonal namespace
386        let results = router.route(&[1.0, 0.0, 0.0], 5, 0.9);
387        assert_eq!(results.len(), 1);
388        assert_eq!(results[0].namespace, "_dakera_agent_a");
389    }
390
391    #[test]
392    fn test_route_top_k_truncation() {
393        let router = SemanticRouter::new(SemanticRouterConfig::default());
394
395        {
396            let mut cache = router.cache.write();
397            for i in 0..10 {
398                let mut centroid = vec![0.0f32; 3];
399                centroid[0] = 1.0 - (i as f32 * 0.05);
400                centroid[1] = i as f32 * 0.05;
401                let norm = (centroid[0] * centroid[0] + centroid[1] * centroid[1]).sqrt();
402                centroid[0] /= norm;
403                centroid[1] /= norm;
404                cache.insert(
405                    format!("_dakera_agent_{}", i),
406                    CentroidEntry {
407                        centroid,
408                        count: 10,
409                    },
410                );
411            }
412        }
413
414        let results = router.route(&[1.0, 0.0, 0.0], 3, 0.0);
415        assert_eq!(results.len(), 3);
416    }
417
418    #[test]
419    fn test_route_dimension_mismatch_skipped() {
420        let router = SemanticRouter::new(SemanticRouterConfig::default());
421
422        {
423            let mut cache = router.cache.write();
424            cache.insert(
425                "_dakera_agent_3d".to_string(),
426                CentroidEntry {
427                    centroid: vec![1.0, 0.0, 0.0],
428                    count: 10,
429                },
430            );
431            cache.insert(
432                "_dakera_agent_5d".to_string(),
433                CentroidEntry {
434                    centroid: vec![1.0, 0.0, 0.0, 0.0, 0.0],
435                    count: 10,
436                },
437            );
438        }
439
440        // Query is 3D, should only match the 3D centroid
441        let results = router.route(&[1.0, 0.0, 0.0], 5, 0.0);
442        assert_eq!(results.len(), 1);
443        assert_eq!(results[0].namespace, "_dakera_agent_3d");
444    }
445
446    #[test]
447    fn test_config_defaults() {
448        let config = SemanticRouterConfig::default();
449        assert_eq!(config.sample_size, 20);
450        assert_eq!(config.refresh_interval_secs, 1800);
451    }
452
453    // --- QueryClassifier tests ---
454
455    #[test]
456    fn test_classify_keyword_short() {
457        assert_eq!(QueryClassifier::classify("rust async"), QueryKind::Keyword);
458        assert_eq!(QueryClassifier::classify("HNSW"), QueryKind::Keyword);
459        assert_eq!(
460            QueryClassifier::classify("memory importance"),
461            QueryKind::Keyword
462        );
463    }
464
465    #[test]
466    fn test_classify_question_routes_hybrid() {
467        // Natural language questions → Hybrid (BM25 + vector, covers LoCoMo benchmark queries)
468        assert_eq!(
469            QueryClassifier::classify(
470                "what is the best way to store long term memories in an AI system"
471            ),
472            QueryKind::Hybrid
473        );
474        assert_eq!(
475            QueryClassifier::classify("tell me about the agent memory architecture"),
476            QueryKind::Hybrid
477        );
478        assert_eq!(
479            QueryClassifier::classify("how does HNSW work?"),
480            QueryKind::Hybrid
481        );
482        assert_eq!(
483            QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
484            QueryKind::Hybrid
485        );
486    }
487
488    #[test]
489    fn test_classify_semantic_long_prose() {
490        // Long prose without question structure → Semantic
491        assert_eq!(
492            QueryClassifier::classify(
493                "the agent memory platform stores embeddings with adaptive decay weighting"
494            ),
495            QueryKind::Semantic
496        );
497    }
498
499    #[test]
500    fn test_classify_hybrid_middle() {
501        assert_eq!(
502            QueryClassifier::classify("vector search memory agent"),
503            QueryKind::Hybrid
504        );
505    }
506
507    // --- CE-15: Temporal classifier tests ---
508
509    #[test]
510    fn test_classify_temporal_when_prefix() {
511        // "when " prefix → Temporal (BM25-heavy, not balanced Hybrid)
512        assert_eq!(
513            QueryClassifier::classify("when did Caroline go to the store?"),
514            QueryKind::Temporal
515        );
516        assert_eq!(
517            QueryClassifier::classify("When was the last time they spoke?"),
518            QueryKind::Temporal
519        );
520        assert_eq!(
521            QueryClassifier::classify("When were the siblings born?"),
522            QueryKind::Temporal
523        );
524    }
525
526    #[test]
527    fn test_classify_temporal_date_year_patterns() {
528        assert_eq!(
529            QueryClassifier::classify("What year did they get married?"),
530            QueryKind::Temporal
531        );
532        assert_eq!(
533            QueryClassifier::classify("what date did the conference take place?"),
534            QueryKind::Temporal
535        );
536        assert_eq!(
537            QueryClassifier::classify("What time did the meeting start?"),
538            QueryKind::Temporal
539        );
540        assert_eq!(
541            QueryClassifier::classify("How long ago did this happen?"),
542            QueryKind::Temporal
543        );
544        assert_eq!(
545            QueryClassifier::classify("How many years have they been friends?"),
546            QueryKind::Temporal
547        );
548        assert_eq!(
549            QueryClassifier::classify("How old was Sarah when she graduated?"),
550            QueryKind::Temporal
551        );
552    }
553
554    #[test]
555    fn test_classify_temporal_does_not_capture_non_temporal_what() {
556        // "what sport" / "what color" / "what is" should NOT route to Temporal
557        assert_eq!(
558            QueryClassifier::classify("What sport did Sarah's brother play in high school?"),
559            QueryKind::Hybrid
560        );
561        assert_eq!(
562            QueryClassifier::classify("what is the best way to find old memories"),
563            QueryKind::Hybrid
564        );
565    }
566}