Skip to main content

devsper_memory/
store.rs

1use devsper_core::{MemoryStore, MemoryHit};
2use anyhow::Result;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::debug;
9
10/// A stored memory entry
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MemoryEntry {
13    pub key: String,
14    pub value: serde_json::Value,
15    pub namespace: String,
16    pub created_at: u64,
17    pub tags: Vec<String>,
18}
19
20/// In-memory store backed by a HashMap (no SQLite dep needed for initial impl).
21/// Replace with SQLite in production via the same trait.
22pub struct LocalMemoryStore {
23    /// namespace → key → entry
24    data: Arc<RwLock<HashMap<String, HashMap<String, MemoryEntry>>>>,
25}
26
27impl LocalMemoryStore {
28    pub fn new() -> Self {
29        Self {
30            data: Arc::new(RwLock::new(HashMap::new())),
31        }
32    }
33}
34
35impl Default for LocalMemoryStore {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41#[async_trait]
42impl MemoryStore for LocalMemoryStore {
43    async fn store(&self, namespace: &str, key: &str, value: serde_json::Value) -> Result<()> {
44        debug!(namespace = %namespace, key = %key, "Memory store");
45        let entry = MemoryEntry {
46            key: key.to_string(),
47            value,
48            namespace: namespace.to_string(),
49            created_at: devsper_core::now_ms(),
50            tags: vec![],
51        };
52        let mut data = self.data.write().await;
53        data.entry(namespace.to_string())
54            .or_insert_with(HashMap::new)
55            .insert(key.to_string(), entry);
56        Ok(())
57    }
58
59    async fn retrieve(&self, namespace: &str, key: &str) -> Result<Option<serde_json::Value>> {
60        let data = self.data.read().await;
61        Ok(data
62            .get(namespace)
63            .and_then(|ns| ns.get(key))
64            .map(|e| e.value.clone()))
65    }
66
67    async fn search(&self, namespace: &str, query: &str, top_k: usize) -> Result<Vec<MemoryHit>> {
68        // Simple text matching (BM25-lite): score by query term overlap
69        let data = self.data.read().await;
70        let ns_data = match data.get(namespace) {
71            Some(d) => d,
72            None => return Ok(vec![]),
73        };
74
75        let query_terms: Vec<String> = query
76            .to_lowercase()
77            .split_whitespace()
78            .map(str::to_string)
79            .collect();
80
81        let mut hits: Vec<MemoryHit> = ns_data
82            .values()
83            .map(|entry| {
84                let text = entry.value.to_string().to_lowercase();
85                let score = query_terms.iter().filter(|t| text.contains(t.as_str())).count()
86                    as f32
87                    / query_terms.len().max(1) as f32;
88                MemoryHit {
89                    key: entry.key.clone(),
90                    value: entry.value.clone(),
91                    score,
92                }
93            })
94            .filter(|h| h.score > 0.0)
95            .collect();
96
97        hits.sort_by(|a, b| {
98            b.score
99                .partial_cmp(&a.score)
100                .unwrap_or(std::cmp::Ordering::Equal)
101        });
102        hits.truncate(top_k);
103        Ok(hits)
104    }
105
106    async fn delete(&self, namespace: &str, key: &str) -> Result<()> {
107        let mut data = self.data.write().await;
108        if let Some(ns) = data.get_mut(namespace) {
109            ns.remove(key);
110        }
111        Ok(())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[tokio::test]
120    async fn store_and_retrieve() {
121        let store = LocalMemoryStore::new();
122        let ns = "run-1/agent-a";
123        store
124            .store(ns, "fact-1", serde_json::json!({"text": "The sky is blue"}))
125            .await
126            .unwrap();
127        let val = store.retrieve(ns, "fact-1").await.unwrap();
128        assert!(val.is_some());
129        assert_eq!(val.unwrap()["text"], "The sky is blue");
130    }
131
132    #[tokio::test]
133    async fn retrieve_missing_returns_none() {
134        let store = LocalMemoryStore::new();
135        let val = store.retrieve("ns", "missing").await.unwrap();
136        assert!(val.is_none());
137    }
138
139    #[tokio::test]
140    async fn search_returns_relevant_hits() {
141        let store = LocalMemoryStore::new();
142        let ns = "run-1/agent-a";
143        store
144            .store(ns, "k1", serde_json::json!({"text": "cats are fluffy animals"}))
145            .await
146            .unwrap();
147        store
148            .store(ns, "k2", serde_json::json!({"text": "dogs are loyal pets"}))
149            .await
150            .unwrap();
151        store
152            .store(ns, "k3", serde_json::json!({"text": "the weather is nice today"}))
153            .await
154            .unwrap();
155
156        let hits = store.search(ns, "cats fluffy", 2).await.unwrap();
157        assert!(!hits.is_empty());
158        assert_eq!(hits[0].key, "k1"); // highest score
159    }
160
161    #[tokio::test]
162    async fn delete_removes_entry() {
163        let store = LocalMemoryStore::new();
164        let ns = "ns";
165        store
166            .store(ns, "key", serde_json::json!("value"))
167            .await
168            .unwrap();
169        store.delete(ns, "key").await.unwrap();
170        let val = store.retrieve(ns, "key").await.unwrap();
171        assert!(val.is_none());
172    }
173
174    #[tokio::test]
175    async fn namespace_isolation() {
176        let store = LocalMemoryStore::new();
177        store
178            .store("ns-a", "key", serde_json::json!("a-value"))
179            .await
180            .unwrap();
181        store
182            .store("ns-b", "key", serde_json::json!("b-value"))
183            .await
184            .unwrap();
185
186        let a = store.retrieve("ns-a", "key").await.unwrap().unwrap();
187        let b = store.retrieve("ns-b", "key").await.unwrap().unwrap();
188        assert_ne!(a, b);
189    }
190}