Skip to main content

adk_memory/
inmemory.rs

1use crate::service::*;
2use adk_core::Result;
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8struct MemoryKey {
9    app_name: String,
10    user_id: String,
11}
12
13#[derive(Clone)]
14struct StoredEntry {
15    entry: MemoryEntry,
16    words: HashSet<String>,
17}
18
19type MemoryStore = HashMap<MemoryKey, HashMap<String, Vec<StoredEntry>>>;
20
21pub struct InMemoryMemoryService {
22    store: Arc<RwLock<MemoryStore>>,
23}
24
25impl InMemoryMemoryService {
26    pub fn new() -> Self {
27        Self { store: Arc::new(RwLock::new(HashMap::new())) }
28    }
29
30    fn has_intersection(set1: &HashSet<String>, set2: &HashSet<String>) -> bool {
31        if set1.is_empty() || set2.is_empty() {
32            return false;
33        }
34        set1.iter().any(|word| set2.contains(word))
35    }
36}
37
38impl Default for InMemoryMemoryService {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44#[async_trait]
45impl MemoryService for InMemoryMemoryService {
46    async fn add_session(
47        &self,
48        app_name: &str,
49        user_id: &str,
50        session_id: &str,
51        entries: Vec<MemoryEntry>,
52    ) -> Result<()> {
53        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
54
55        let stored_entries: Vec<StoredEntry> = entries
56            .into_iter()
57            .map(|entry| {
58                let words = crate::text::extract_words_from_content(&entry.content);
59                StoredEntry { entry, words }
60            })
61            .filter(|e| !e.words.is_empty())
62            .collect();
63
64        if stored_entries.is_empty() {
65            return Ok(());
66        }
67
68        let mut store = self.store.write().unwrap();
69        let sessions = store.entry(key).or_default();
70        sessions.insert(session_id.to_string(), stored_entries);
71
72        Ok(())
73    }
74
75    async fn add_entry(&self, app_name: &str, user_id: &str, entry: MemoryEntry) -> Result<()> {
76        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
77        let words = crate::text::extract_words_from_content(&entry.content);
78        let stored = StoredEntry { entry, words };
79
80        let mut store = self.store.write().unwrap();
81        let sessions = store.entry(key).or_default();
82        sessions.entry("__direct__".to_string()).or_default().push(stored);
83
84        Ok(())
85    }
86
87    async fn delete_entries(&self, app_name: &str, user_id: &str, query: &str) -> Result<u64> {
88        let query_words = crate::text::extract_words(query);
89        if query_words.is_empty() {
90            return Ok(0);
91        }
92
93        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
94
95        let mut store = self.store.write().unwrap();
96        let sessions = match store.get_mut(&key) {
97            Some(s) => s,
98            None => return Ok(0),
99        };
100
101        let mut removed: u64 = 0;
102        for entries in sessions.values_mut() {
103            let before = entries.len();
104            entries.retain(|stored| !Self::has_intersection(&stored.words, &query_words));
105            removed += (before - entries.len()) as u64;
106        }
107
108        Ok(removed)
109    }
110
111    async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
112        let query_words = crate::text::extract_words(&req.query);
113        let limit = req.limit.unwrap_or(10);
114
115        let key = MemoryKey { app_name: req.app_name, user_id: req.user_id };
116
117        let store = self.store.read().unwrap();
118        let sessions = match store.get(&key) {
119            Some(s) => s,
120            None => return Ok(SearchResponse { memories: Vec::new() }),
121        };
122
123        let mut memories = Vec::new();
124        for stored_entries in sessions.values() {
125            for stored in stored_entries {
126                if Self::has_intersection(&stored.words, &query_words) {
127                    memories.push(stored.entry.clone());
128                }
129            }
130        }
131
132        memories.truncate(limit);
133
134        Ok(SearchResponse { memories })
135    }
136}