Skip to main content

devsper_memory/
router.rs

1use crate::{index::EmbeddingIndex, store::LocalMemoryStore};
2use devsper_core::{MemoryHit, MemoryStore};
3use anyhow::Result;
4use std::sync::Arc;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum RetrievalStrategy {
8    /// Keyword/BM25 matching (fast, no embeddings)
9    Bm25,
10    /// TF-IDF embedding similarity (slightly slower)
11    Semantic,
12    /// Both, merge and re-rank by score
13    Hybrid,
14}
15
16/// Routes memory retrieval to appropriate strategy
17pub struct MemoryRouter {
18    store: Arc<LocalMemoryStore>,
19    index: Arc<EmbeddingIndex>,
20    strategy: RetrievalStrategy,
21}
22
23impl MemoryRouter {
24    pub fn new(strategy: RetrievalStrategy) -> Self {
25        Self {
26            store: Arc::new(LocalMemoryStore::new()),
27            index: Arc::new(EmbeddingIndex::new()),
28            strategy,
29        }
30    }
31
32    pub fn store(&self) -> &Arc<LocalMemoryStore> {
33        &self.store
34    }
35
36    /// Store and index a memory fact
37    pub async fn remember(&self, namespace: &str, key: &str, value: serde_json::Value) -> Result<()> {
38        let text = value.to_string();
39        self.store.store(namespace, key, value).await?;
40        self.index.index(format!("{namespace}/{key}"), &text).await;
41        Ok(())
42    }
43
44    /// Retrieve relevant memories for a query
45    pub async fn recall(&self, namespace: &str, query: &str, top_k: usize) -> Result<Vec<MemoryHit>> {
46        match &self.strategy {
47            RetrievalStrategy::Bm25 => {
48                self.store.search(namespace, query, top_k).await
49            }
50            RetrievalStrategy::Semantic => {
51                let results = self.index.search(query, top_k * 2).await;
52                let ns_prefix = format!("{namespace}/");
53                let mut hits = Vec::new();
54                for (doc_id, score) in results {
55                    if let Some(key) = doc_id.strip_prefix(&ns_prefix) {
56                        if let Ok(Some(value)) = self.store.retrieve(namespace, key).await {
57                            hits.push(MemoryHit {
58                                key: key.to_string(),
59                                value,
60                                score,
61                            });
62                        }
63                    }
64                }
65                hits.truncate(top_k);
66                Ok(hits)
67            }
68            RetrievalStrategy::Hybrid => {
69                let mut bm25 = self.store.search(namespace, query, top_k).await?;
70                let sem_results = self.index.search(query, top_k).await;
71                let ns_prefix = format!("{namespace}/");
72                for (doc_id, score) in sem_results {
73                    if let Some(key) = doc_id.strip_prefix(&ns_prefix) {
74                        let already = bm25.iter().any(|h| h.key == key);
75                        if !already {
76                            if let Ok(Some(value)) = self.store.retrieve(namespace, key).await {
77                                bm25.push(MemoryHit {
78                                    key: key.to_string(),
79                                    value,
80                                    score,
81                                });
82                            }
83                        }
84                    }
85                }
86                bm25.sort_by(|a, b| {
87                    b.score
88                        .partial_cmp(&a.score)
89                        .unwrap_or(std::cmp::Ordering::Equal)
90                });
91                bm25.truncate(top_k);
92                Ok(bm25)
93            }
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[tokio::test]
103    async fn bm25_recall() {
104        let router = MemoryRouter::new(RetrievalStrategy::Bm25);
105        router
106            .remember("ns", "k1", serde_json::json!("cats are fluffy"))
107            .await
108            .unwrap();
109        router
110            .remember("ns", "k2", serde_json::json!("dogs bark"))
111            .await
112            .unwrap();
113
114        let hits = router.recall("ns", "fluffy cats", 5).await.unwrap();
115        assert!(!hits.is_empty());
116        assert_eq!(hits[0].key, "k1");
117    }
118
119    #[tokio::test]
120    async fn semantic_recall() {
121        let router = MemoryRouter::new(RetrievalStrategy::Semantic);
122        router
123            .remember(
124                "ns",
125                "k1",
126                serde_json::json!("machine learning model training"),
127            )
128            .await
129            .unwrap();
130        router
131            .remember(
132                "ns",
133                "k2",
134                serde_json::json!("database query optimization"),
135            )
136            .await
137            .unwrap();
138
139        let hits = router.recall("ns", "machine learning", 5).await.unwrap();
140        assert!(!hits.is_empty());
141        assert_eq!(hits[0].key, "k1");
142    }
143
144    #[tokio::test]
145    async fn hybrid_recall() {
146        let router = MemoryRouter::new(RetrievalStrategy::Hybrid);
147        router
148            .remember("ns", "k1", serde_json::json!("rust programming language"))
149            .await
150            .unwrap();
151        router
152            .remember("ns", "k2", serde_json::json!("python scripting language"))
153            .await
154            .unwrap();
155
156        let hits = router.recall("ns", "rust language", 5).await.unwrap();
157        assert!(!hits.is_empty());
158    }
159}