1use crate::service::*;
2use adk_core::Result;
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8struct MemoryKey {
9 app_name: String,
10 user_id: String,
11}
12
13#[derive(Clone)]
14struct StoredEntry {
15 entry: MemoryEntry,
16 words: HashSet<String>,
17}
18
19type MemoryStore = HashMap<MemoryKey, HashMap<String, Vec<StoredEntry>>>;
20
21pub struct InMemoryMemoryService {
22 store: Arc<RwLock<MemoryStore>>,
23}
24
25impl InMemoryMemoryService {
26 pub fn new() -> Self {
27 Self { store: Arc::new(RwLock::new(HashMap::new())) }
28 }
29
30 fn has_intersection(set1: &HashSet<String>, set2: &HashSet<String>) -> bool {
31 if set1.is_empty() || set2.is_empty() {
32 return false;
33 }
34 set1.iter().any(|word| set2.contains(word))
35 }
36}
37
38impl Default for InMemoryMemoryService {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44#[async_trait]
45impl MemoryService for InMemoryMemoryService {
46 async fn add_session(
47 &self,
48 app_name: &str,
49 user_id: &str,
50 session_id: &str,
51 entries: Vec<MemoryEntry>,
52 ) -> Result<()> {
53 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
54
55 let stored_entries: Vec<StoredEntry> = entries
56 .into_iter()
57 .map(|entry| {
58 let words = crate::text::extract_words_from_content(&entry.content);
59 StoredEntry { entry, words }
60 })
61 .filter(|e| !e.words.is_empty())
62 .collect();
63
64 if stored_entries.is_empty() {
65 return Ok(());
66 }
67
68 let mut store = self.store.write().unwrap();
69 let sessions = store.entry(key).or_default();
70 sessions.insert(session_id.to_string(), stored_entries);
71
72 Ok(())
73 }
74
75 async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
76 let query_words = crate::text::extract_words(&req.query);
77 let limit = req.limit.unwrap_or(10);
78
79 let key = MemoryKey { app_name: req.app_name, user_id: req.user_id };
80
81 let store = self.store.read().unwrap();
82 let sessions = match store.get(&key) {
83 Some(s) => s,
84 None => return Ok(SearchResponse { memories: Vec::new() }),
85 };
86
87 let mut memories = Vec::new();
88 for stored_entries in sessions.values() {
89 for stored in stored_entries {
90 if Self::has_intersection(&stored.words, &query_words) {
91 memories.push(stored.entry.clone());
92 }
93 }
94 }
95
96 memories.truncate(limit);
97
98 Ok(SearchResponse { memories })
99 }
100}