1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8struct MemoryKey {
9 app_name: String,
10 user_id: String,
11}
12
13#[derive(Clone)]
14struct StoredEntry {
15 entry: MemoryEntry,
16 words: HashSet<String>,
17}
18
19type MemoryStore = HashMap<MemoryKey, HashMap<String, Vec<StoredEntry>>>;
20
21pub struct InMemoryMemoryService {
22 store: Arc<RwLock<MemoryStore>>,
23}
24
25impl InMemoryMemoryService {
26 pub fn new() -> Self {
27 Self { store: Arc::new(RwLock::new(HashMap::new())) }
28 }
29
30 fn extract_words(text: &str) -> HashSet<String> {
31 text.split_whitespace().filter(|s| !s.is_empty()).map(|s| s.to_lowercase()).collect()
32 }
33
34 fn extract_words_from_content(content: &adk_core::Content) -> HashSet<String> {
35 let mut words = HashSet::new();
36 for part in &content.parts {
37 if let Part::Text { text } = part {
38 words.extend(Self::extract_words(text));
39 }
40 }
41 words
42 }
43
44 fn has_intersection(set1: &HashSet<String>, set2: &HashSet<String>) -> bool {
45 if set1.is_empty() || set2.is_empty() {
46 return false;
47 }
48 set1.iter().any(|word| set2.contains(word))
49 }
50}
51
52impl Default for InMemoryMemoryService {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58#[async_trait]
59impl MemoryService for InMemoryMemoryService {
60 async fn add_session(
61 &self,
62 app_name: &str,
63 user_id: &str,
64 session_id: &str,
65 entries: Vec<MemoryEntry>,
66 ) -> Result<()> {
67 let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
68
69 let stored_entries: Vec<StoredEntry> = entries
70 .into_iter()
71 .map(|entry| {
72 let words = Self::extract_words_from_content(&entry.content);
73 StoredEntry { entry, words }
74 })
75 .filter(|e| !e.words.is_empty())
76 .collect();
77
78 if stored_entries.is_empty() {
79 return Ok(());
80 }
81
82 let mut store = self.store.write().unwrap();
83 let sessions = store.entry(key).or_default();
84 sessions.insert(session_id.to_string(), stored_entries);
85
86 Ok(())
87 }
88
89 async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
90 let query_words = Self::extract_words(&req.query);
91
92 let key = MemoryKey { app_name: req.app_name, user_id: req.user_id };
93
94 let store = self.store.read().unwrap();
95 let sessions = match store.get(&key) {
96 Some(s) => s,
97 None => return Ok(SearchResponse { memories: Vec::new() }),
98 };
99
100 let mut memories = Vec::new();
101 for stored_entries in sessions.values() {
102 for stored in stored_entries {
103 if Self::has_intersection(&stored.words, &query_words) {
104 memories.push(stored.entry.clone());
105 }
106 }
107 }
108
109 Ok(SearchResponse { memories })
110 }
111}