1use devsper_core::{MemoryStore, MemoryHit};
2use anyhow::Result;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::debug;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MemoryEntry {
13 pub key: String,
14 pub value: serde_json::Value,
15 pub namespace: String,
16 pub created_at: u64,
17 pub tags: Vec<String>,
18}
19
20pub struct LocalMemoryStore {
23 data: Arc<RwLock<HashMap<String, HashMap<String, MemoryEntry>>>>,
25}
26
27impl LocalMemoryStore {
28 pub fn new() -> Self {
29 Self {
30 data: Arc::new(RwLock::new(HashMap::new())),
31 }
32 }
33}
34
35impl Default for LocalMemoryStore {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41#[async_trait]
42impl MemoryStore for LocalMemoryStore {
43 async fn store(&self, namespace: &str, key: &str, value: serde_json::Value) -> Result<()> {
44 debug!(namespace = %namespace, key = %key, "Memory store");
45 let entry = MemoryEntry {
46 key: key.to_string(),
47 value,
48 namespace: namespace.to_string(),
49 created_at: devsper_core::now_ms(),
50 tags: vec![],
51 };
52 let mut data = self.data.write().await;
53 data.entry(namespace.to_string())
54 .or_insert_with(HashMap::new)
55 .insert(key.to_string(), entry);
56 Ok(())
57 }
58
59 async fn retrieve(&self, namespace: &str, key: &str) -> Result<Option<serde_json::Value>> {
60 let data = self.data.read().await;
61 Ok(data
62 .get(namespace)
63 .and_then(|ns| ns.get(key))
64 .map(|e| e.value.clone()))
65 }
66
67 async fn search(&self, namespace: &str, query: &str, top_k: usize) -> Result<Vec<MemoryHit>> {
68 let data = self.data.read().await;
70 let ns_data = match data.get(namespace) {
71 Some(d) => d,
72 None => return Ok(vec![]),
73 };
74
75 let query_terms: Vec<String> = query
76 .to_lowercase()
77 .split_whitespace()
78 .map(str::to_string)
79 .collect();
80
81 let mut hits: Vec<MemoryHit> = ns_data
82 .values()
83 .map(|entry| {
84 let text = entry.value.to_string().to_lowercase();
85 let score = query_terms.iter().filter(|t| text.contains(t.as_str())).count()
86 as f32
87 / query_terms.len().max(1) as f32;
88 MemoryHit {
89 key: entry.key.clone(),
90 value: entry.value.clone(),
91 score,
92 }
93 })
94 .filter(|h| h.score > 0.0)
95 .collect();
96
97 hits.sort_by(|a, b| {
98 b.score
99 .partial_cmp(&a.score)
100 .unwrap_or(std::cmp::Ordering::Equal)
101 });
102 hits.truncate(top_k);
103 Ok(hits)
104 }
105
106 async fn delete(&self, namespace: &str, key: &str) -> Result<()> {
107 let mut data = self.data.write().await;
108 if let Some(ns) = data.get_mut(namespace) {
109 ns.remove(key);
110 }
111 Ok(())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[tokio::test]
120 async fn store_and_retrieve() {
121 let store = LocalMemoryStore::new();
122 let ns = "run-1/agent-a";
123 store
124 .store(ns, "fact-1", serde_json::json!({"text": "The sky is blue"}))
125 .await
126 .unwrap();
127 let val = store.retrieve(ns, "fact-1").await.unwrap();
128 assert!(val.is_some());
129 assert_eq!(val.unwrap()["text"], "The sky is blue");
130 }
131
132 #[tokio::test]
133 async fn retrieve_missing_returns_none() {
134 let store = LocalMemoryStore::new();
135 let val = store.retrieve("ns", "missing").await.unwrap();
136 assert!(val.is_none());
137 }
138
139 #[tokio::test]
140 async fn search_returns_relevant_hits() {
141 let store = LocalMemoryStore::new();
142 let ns = "run-1/agent-a";
143 store
144 .store(ns, "k1", serde_json::json!({"text": "cats are fluffy animals"}))
145 .await
146 .unwrap();
147 store
148 .store(ns, "k2", serde_json::json!({"text": "dogs are loyal pets"}))
149 .await
150 .unwrap();
151 store
152 .store(ns, "k3", serde_json::json!({"text": "the weather is nice today"}))
153 .await
154 .unwrap();
155
156 let hits = store.search(ns, "cats fluffy", 2).await.unwrap();
157 assert!(!hits.is_empty());
158 assert_eq!(hits[0].key, "k1"); }
160
161 #[tokio::test]
162 async fn delete_removes_entry() {
163 let store = LocalMemoryStore::new();
164 let ns = "ns";
165 store
166 .store(ns, "key", serde_json::json!("value"))
167 .await
168 .unwrap();
169 store.delete(ns, "key").await.unwrap();
170 let val = store.retrieve(ns, "key").await.unwrap();
171 assert!(val.is_none());
172 }
173
174 #[tokio::test]
175 async fn namespace_isolation() {
176 let store = LocalMemoryStore::new();
177 store
178 .store("ns-a", "key", serde_json::json!("a-value"))
179 .await
180 .unwrap();
181 store
182 .store("ns-b", "key", serde_json::json!("b-value"))
183 .await
184 .unwrap();
185
186 let a = store.retrieve("ns-a", "key").await.unwrap().unwrap();
187 let b = store.retrieve("ns-b", "key").await.unwrap().unwrap();
188 assert_ne!(a, b);
189 }
190}