1use crate::{index::EmbeddingIndex, store::LocalMemoryStore};
2use devsper_core::{MemoryHit, MemoryStore};
3use anyhow::Result;
4use std::sync::Arc;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum RetrievalStrategy {
8 Bm25,
10 Semantic,
12 Hybrid,
14}
15
16pub struct MemoryRouter {
18 store: Arc<LocalMemoryStore>,
19 index: Arc<EmbeddingIndex>,
20 strategy: RetrievalStrategy,
21}
22
23impl MemoryRouter {
24 pub fn new(strategy: RetrievalStrategy) -> Self {
25 Self {
26 store: Arc::new(LocalMemoryStore::new()),
27 index: Arc::new(EmbeddingIndex::new()),
28 strategy,
29 }
30 }
31
32 pub fn store(&self) -> &Arc<LocalMemoryStore> {
33 &self.store
34 }
35
36 pub async fn remember(&self, namespace: &str, key: &str, value: serde_json::Value) -> Result<()> {
38 let text = value.to_string();
39 self.store.store(namespace, key, value).await?;
40 self.index.index(format!("{namespace}/{key}"), &text).await;
41 Ok(())
42 }
43
44 pub async fn recall(&self, namespace: &str, query: &str, top_k: usize) -> Result<Vec<MemoryHit>> {
46 match &self.strategy {
47 RetrievalStrategy::Bm25 => {
48 self.store.search(namespace, query, top_k).await
49 }
50 RetrievalStrategy::Semantic => {
51 let results = self.index.search(query, top_k * 2).await;
52 let ns_prefix = format!("{namespace}/");
53 let mut hits = Vec::new();
54 for (doc_id, score) in results {
55 if let Some(key) = doc_id.strip_prefix(&ns_prefix) {
56 if let Ok(Some(value)) = self.store.retrieve(namespace, key).await {
57 hits.push(MemoryHit {
58 key: key.to_string(),
59 value,
60 score,
61 });
62 }
63 }
64 }
65 hits.truncate(top_k);
66 Ok(hits)
67 }
68 RetrievalStrategy::Hybrid => {
69 let mut bm25 = self.store.search(namespace, query, top_k).await?;
70 let sem_results = self.index.search(query, top_k).await;
71 let ns_prefix = format!("{namespace}/");
72 for (doc_id, score) in sem_results {
73 if let Some(key) = doc_id.strip_prefix(&ns_prefix) {
74 let already = bm25.iter().any(|h| h.key == key);
75 if !already {
76 if let Ok(Some(value)) = self.store.retrieve(namespace, key).await {
77 bm25.push(MemoryHit {
78 key: key.to_string(),
79 value,
80 score,
81 });
82 }
83 }
84 }
85 }
86 bm25.sort_by(|a, b| {
87 b.score
88 .partial_cmp(&a.score)
89 .unwrap_or(std::cmp::Ordering::Equal)
90 });
91 bm25.truncate(top_k);
92 Ok(bm25)
93 }
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[tokio::test]
103 async fn bm25_recall() {
104 let router = MemoryRouter::new(RetrievalStrategy::Bm25);
105 router
106 .remember("ns", "k1", serde_json::json!("cats are fluffy"))
107 .await
108 .unwrap();
109 router
110 .remember("ns", "k2", serde_json::json!("dogs bark"))
111 .await
112 .unwrap();
113
114 let hits = router.recall("ns", "fluffy cats", 5).await.unwrap();
115 assert!(!hits.is_empty());
116 assert_eq!(hits[0].key, "k1");
117 }
118
119 #[tokio::test]
120 async fn semantic_recall() {
121 let router = MemoryRouter::new(RetrievalStrategy::Semantic);
122 router
123 .remember(
124 "ns",
125 "k1",
126 serde_json::json!("machine learning model training"),
127 )
128 .await
129 .unwrap();
130 router
131 .remember(
132 "ns",
133 "k2",
134 serde_json::json!("database query optimization"),
135 )
136 .await
137 .unwrap();
138
139 let hits = router.recall("ns", "machine learning", 5).await.unwrap();
140 assert!(!hits.is_empty());
141 assert_eq!(hits[0].key, "k1");
142 }
143
144 #[tokio::test]
145 async fn hybrid_recall() {
146 let router = MemoryRouter::new(RetrievalStrategy::Hybrid);
147 router
148 .remember("ns", "k1", serde_json::json!("rust programming language"))
149 .await
150 .unwrap();
151 router
152 .remember("ns", "k2", serde_json::json!("python scripting language"))
153 .await
154 .unwrap();
155
156 let hits = router.recall("ns", "rust language", 5).await.unwrap();
157 assert!(!hits.is_empty());
158 }
159}