1use crate::service::*;
2use adk_core::Result;
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8struct MemoryKey {
9 app_name: String,
10 user_id: String,
11}
12
13#[derive(Clone)]
14struct StoredEntry {
15 entry: MemoryEntry,
16 words: HashSet<String>,
17 project_id: Option<String>,
18}
19
20type MemoryStore = HashMap<MemoryKey, HashMap<String, Vec<StoredEntry>>>;
21
22pub struct InMemoryMemoryService {
23 store: Arc<RwLock<MemoryStore>>,
24}
25
26impl InMemoryMemoryService {
27 pub fn new() -> Self {
28 Self { store: Arc::new(RwLock::new(HashMap::new())) }
29 }
30
31 fn has_intersection(set1: &HashSet<String>, set2: &HashSet<String>) -> bool {
32 if set1.is_empty() || set2.is_empty() {
33 return false;
34 }
35 set1.iter().any(|word| set2.contains(word))
36 }
37}
38
39impl Default for InMemoryMemoryService {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45#[async_trait]
46impl MemoryService for InMemoryMemoryService {
47 async fn add_session(
48 &self,
49 app_name: &str,
50 user_id: &str,
51 session_id: &str,
52 entries: Vec<MemoryEntry>,
53 ) -> Result<()> {
54 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
55
56 let stored_entries: Vec<StoredEntry> = entries
57 .into_iter()
58 .map(|entry| {
59 let words = crate::text::extract_words_from_content(&entry.content);
60 StoredEntry { entry, words, project_id: None }
61 })
62 .filter(|e| !e.words.is_empty())
63 .collect();
64
65 if stored_entries.is_empty() {
66 return Ok(());
67 }
68
69 let mut store = self.store.write().unwrap();
70 let sessions = store.entry(key).or_default();
71 sessions.insert(session_id.to_string(), stored_entries);
72
73 Ok(())
74 }
75
76 async fn add_session_to_project(
77 &self,
78 app_name: &str,
79 user_id: &str,
80 session_id: &str,
81 project_id: &str,
82 entries: Vec<MemoryEntry>,
83 ) -> Result<()> {
84 validate_project_id(project_id)?;
85
86 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
87
88 let stored_entries: Vec<StoredEntry> = entries
89 .into_iter()
90 .map(|entry| {
91 let words = crate::text::extract_words_from_content(&entry.content);
92 StoredEntry { entry, words, project_id: Some(project_id.to_string()) }
93 })
94 .filter(|e| !e.words.is_empty())
95 .collect();
96
97 if stored_entries.is_empty() {
98 return Ok(());
99 }
100
101 let mut store = self.store.write().unwrap();
102 let sessions = store.entry(key).or_default();
103 sessions.insert(session_id.to_string(), stored_entries);
104
105 Ok(())
106 }
107
108 async fn add_entry(&self, app_name: &str, user_id: &str, entry: MemoryEntry) -> Result<()> {
109 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
110 let words = crate::text::extract_words_from_content(&entry.content);
111 let stored = StoredEntry { entry, words, project_id: None };
112
113 let mut store = self.store.write().unwrap();
114 let sessions = store.entry(key).or_default();
115 sessions.entry("__direct__".to_string()).or_default().push(stored);
116
117 Ok(())
118 }
119
120 async fn add_entry_to_project(
121 &self,
122 app_name: &str,
123 user_id: &str,
124 project_id: &str,
125 entry: MemoryEntry,
126 ) -> Result<()> {
127 validate_project_id(project_id)?;
128
129 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
130 let words = crate::text::extract_words_from_content(&entry.content);
131 let stored = StoredEntry { entry, words, project_id: Some(project_id.to_string()) };
132
133 let mut store = self.store.write().unwrap();
134 let sessions = store.entry(key).or_default();
135 sessions.entry("__direct__".to_string()).or_default().push(stored);
136
137 Ok(())
138 }
139
140 async fn delete_entries(&self, app_name: &str, user_id: &str, query: &str) -> Result<u64> {
141 let query_words = crate::text::extract_words(query);
142 if query_words.is_empty() {
143 return Ok(0);
144 }
145
146 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
147
148 let mut store = self.store.write().unwrap();
149 let sessions = match store.get_mut(&key) {
150 Some(s) => s,
151 None => return Ok(0),
152 };
153
154 let mut removed: u64 = 0;
155 for entries in sessions.values_mut() {
156 let before = entries.len();
157 entries.retain(|stored| {
158 stored.project_id.is_some() || !Self::has_intersection(&stored.words, &query_words)
160 });
161 removed += (before - entries.len()) as u64;
162 }
163
164 Ok(removed)
165 }
166
167 async fn delete_entries_in_project(
168 &self,
169 app_name: &str,
170 user_id: &str,
171 project_id: &str,
172 query: &str,
173 ) -> Result<u64> {
174 let query_words = crate::text::extract_words(query);
175 if query_words.is_empty() {
176 return Ok(0);
177 }
178
179 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
180
181 let mut store = self.store.write().unwrap();
182 let sessions = match store.get_mut(&key) {
183 Some(s) => s,
184 None => return Ok(0),
185 };
186
187 let mut removed: u64 = 0;
188 for entries in sessions.values_mut() {
189 let before = entries.len();
190 entries.retain(|stored| {
191 stored.project_id.as_deref() != Some(project_id)
193 || !Self::has_intersection(&stored.words, &query_words)
194 });
195 removed += (before - entries.len()) as u64;
196 }
197
198 Ok(removed)
199 }
200
201 async fn delete_project(&self, app_name: &str, user_id: &str, project_id: &str) -> Result<u64> {
202 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
203
204 let mut store = self.store.write().unwrap();
205 let sessions = match store.get_mut(&key) {
206 Some(s) => s,
207 None => return Ok(0),
208 };
209
210 let mut removed: u64 = 0;
211 for entries in sessions.values_mut() {
212 let before = entries.len();
213 entries.retain(|stored| stored.project_id.as_deref() != Some(project_id));
214 removed += (before - entries.len()) as u64;
215 }
216
217 Ok(removed)
218 }
219
220 async fn delete_user(&self, app_name: &str, user_id: &str) -> Result<()> {
221 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
222
223 let mut store = self.store.write().unwrap();
224 store.remove(&key);
225
226 Ok(())
227 }
228
229 async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
230 let query_words = crate::text::extract_words(&req.query);
231 let limit = req.limit.unwrap_or(10);
232
233 let key = MemoryKey { app_name: req.app_name, user_id: req.user_id };
234
235 let store = self.store.read().unwrap();
236 let sessions = match store.get(&key) {
237 Some(s) => s,
238 None => return Ok(SearchResponse { memories: Vec::new() }),
239 };
240
241 let mut memories = Vec::new();
242 for stored_entries in sessions.values() {
243 for stored in stored_entries {
244 if !Self::has_intersection(&stored.words, &query_words) {
245 continue;
246 }
247
248 match &req.project_id {
249 None => {
251 if stored.project_id.is_none() {
252 memories.push(stored.entry.clone());
253 }
254 }
255 Some(pid) => {
257 if stored.project_id.is_none()
258 || stored.project_id.as_deref() == Some(pid.as_str())
259 {
260 memories.push(stored.entry.clone());
261 }
262 }
263 }
264 }
265 }
266
267 memories.truncate(limit);
268
269 Ok(SearchResponse { memories })
270 }
271}