Skip to main content

adk_memory/
inmemory.rs

1use crate::service::*;
2use adk_core::Result;
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8struct MemoryKey {
9    app_name: String,
10    user_id: String,
11}
12
13#[derive(Clone)]
14struct StoredEntry {
15    entry: MemoryEntry,
16    words: HashSet<String>,
17    project_id: Option<String>,
18}
19
20type MemoryStore = HashMap<MemoryKey, HashMap<String, Vec<StoredEntry>>>;
21
22pub struct InMemoryMemoryService {
23    store: Arc<RwLock<MemoryStore>>,
24}
25
26impl InMemoryMemoryService {
27    pub fn new() -> Self {
28        Self { store: Arc::new(RwLock::new(HashMap::new())) }
29    }
30
31    fn has_intersection(set1: &HashSet<String>, set2: &HashSet<String>) -> bool {
32        if set1.is_empty() || set2.is_empty() {
33            return false;
34        }
35        set1.iter().any(|word| set2.contains(word))
36    }
37}
38
39impl Default for InMemoryMemoryService {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45#[async_trait]
46impl MemoryService for InMemoryMemoryService {
47    async fn add_session(
48        &self,
49        app_name: &str,
50        user_id: &str,
51        session_id: &str,
52        entries: Vec<MemoryEntry>,
53    ) -> Result<()> {
54        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
55
56        let stored_entries: Vec<StoredEntry> = entries
57            .into_iter()
58            .map(|entry| {
59                let words = crate::text::extract_words_from_content(&entry.content);
60                StoredEntry { entry, words, project_id: None }
61            })
62            .filter(|e| !e.words.is_empty())
63            .collect();
64
65        if stored_entries.is_empty() {
66            return Ok(());
67        }
68
69        let mut store = self.store.write().unwrap();
70        let sessions = store.entry(key).or_default();
71        sessions.insert(session_id.to_string(), stored_entries);
72
73        Ok(())
74    }
75
76    async fn add_session_to_project(
77        &self,
78        app_name: &str,
79        user_id: &str,
80        session_id: &str,
81        project_id: &str,
82        entries: Vec<MemoryEntry>,
83    ) -> Result<()> {
84        validate_project_id(project_id)?;
85
86        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
87
88        let stored_entries: Vec<StoredEntry> = entries
89            .into_iter()
90            .map(|entry| {
91                let words = crate::text::extract_words_from_content(&entry.content);
92                StoredEntry { entry, words, project_id: Some(project_id.to_string()) }
93            })
94            .filter(|e| !e.words.is_empty())
95            .collect();
96
97        if stored_entries.is_empty() {
98            return Ok(());
99        }
100
101        let mut store = self.store.write().unwrap();
102        let sessions = store.entry(key).or_default();
103        sessions.insert(session_id.to_string(), stored_entries);
104
105        Ok(())
106    }
107
108    async fn add_entry(&self, app_name: &str, user_id: &str, entry: MemoryEntry) -> Result<()> {
109        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
110        let words = crate::text::extract_words_from_content(&entry.content);
111        let stored = StoredEntry { entry, words, project_id: None };
112
113        let mut store = self.store.write().unwrap();
114        let sessions = store.entry(key).or_default();
115        sessions.entry("__direct__".to_string()).or_default().push(stored);
116
117        Ok(())
118    }
119
120    async fn add_entry_to_project(
121        &self,
122        app_name: &str,
123        user_id: &str,
124        project_id: &str,
125        entry: MemoryEntry,
126    ) -> Result<()> {
127        validate_project_id(project_id)?;
128
129        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
130        let words = crate::text::extract_words_from_content(&entry.content);
131        let stored = StoredEntry { entry, words, project_id: Some(project_id.to_string()) };
132
133        let mut store = self.store.write().unwrap();
134        let sessions = store.entry(key).or_default();
135        sessions.entry("__direct__".to_string()).or_default().push(stored);
136
137        Ok(())
138    }
139
140    async fn delete_entries(&self, app_name: &str, user_id: &str, query: &str) -> Result<u64> {
141        let query_words = crate::text::extract_words(query);
142        if query_words.is_empty() {
143            return Ok(0);
144        }
145
146        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
147
148        let mut store = self.store.write().unwrap();
149        let sessions = match store.get_mut(&key) {
150            Some(s) => s,
151            None => return Ok(0),
152        };
153
154        let mut removed: u64 = 0;
155        for entries in sessions.values_mut() {
156            let before = entries.len();
157            entries.retain(|stored| {
158                // Only delete global entries (project_id is None)
159                stored.project_id.is_some() || !Self::has_intersection(&stored.words, &query_words)
160            });
161            removed += (before - entries.len()) as u64;
162        }
163
164        Ok(removed)
165    }
166
167    async fn delete_entries_in_project(
168        &self,
169        app_name: &str,
170        user_id: &str,
171        project_id: &str,
172        query: &str,
173    ) -> Result<u64> {
174        let query_words = crate::text::extract_words(query);
175        if query_words.is_empty() {
176            return Ok(0);
177        }
178
179        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
180
181        let mut store = self.store.write().unwrap();
182        let sessions = match store.get_mut(&key) {
183            Some(s) => s,
184            None => return Ok(0),
185        };
186
187        let mut removed: u64 = 0;
188        for entries in sessions.values_mut() {
189            let before = entries.len();
190            entries.retain(|stored| {
191                // Only delete entries matching the given project
192                stored.project_id.as_deref() != Some(project_id)
193                    || !Self::has_intersection(&stored.words, &query_words)
194            });
195            removed += (before - entries.len()) as u64;
196        }
197
198        Ok(removed)
199    }
200
201    async fn delete_project(&self, app_name: &str, user_id: &str, project_id: &str) -> Result<u64> {
202        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
203
204        let mut store = self.store.write().unwrap();
205        let sessions = match store.get_mut(&key) {
206            Some(s) => s,
207            None => return Ok(0),
208        };
209
210        let mut removed: u64 = 0;
211        for entries in sessions.values_mut() {
212            let before = entries.len();
213            entries.retain(|stored| stored.project_id.as_deref() != Some(project_id));
214            removed += (before - entries.len()) as u64;
215        }
216
217        Ok(removed)
218    }
219
220    async fn delete_user(&self, app_name: &str, user_id: &str) -> Result<()> {
221        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
222
223        let mut store = self.store.write().unwrap();
224        store.remove(&key);
225
226        Ok(())
227    }
228
229    async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
230        let query_words = crate::text::extract_words(&req.query);
231        let limit = req.limit.unwrap_or(10);
232
233        let key = MemoryKey { app_name: req.app_name, user_id: req.user_id };
234
235        let store = self.store.read().unwrap();
236        let sessions = match store.get(&key) {
237            Some(s) => s,
238            None => return Ok(SearchResponse { memories: Vec::new() }),
239        };
240
241        let mut memories = Vec::new();
242        for stored_entries in sessions.values() {
243            for stored in stored_entries {
244                if !Self::has_intersection(&stored.words, &query_words) {
245                    continue;
246                }
247
248                match &req.project_id {
249                    // Global search: only include global entries
250                    None => {
251                        if stored.project_id.is_none() {
252                            memories.push(stored.entry.clone());
253                        }
254                    }
255                    // Project search: include global + matching project entries
256                    Some(pid) => {
257                        if stored.project_id.is_none()
258                            || stored.project_id.as_deref() == Some(pid.as_str())
259                        {
260                            memories.push(stored.entry.clone());
261                        }
262                    }
263                }
264            }
265        }
266
267        memories.truncate(limit);
268
269        Ok(SearchResponse { memories })
270    }
271}