adk_memory/
inmemory.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8struct MemoryKey {
9    app_name: String,
10    user_id: String,
11}
12
13#[derive(Clone)]
14struct StoredEntry {
15    entry: MemoryEntry,
16    words: HashSet<String>,
17}
18
19type MemoryStore = HashMap<MemoryKey, HashMap<String, Vec<StoredEntry>>>;
20
21pub struct InMemoryMemoryService {
22    store: Arc<RwLock<MemoryStore>>,
23}
24
25impl InMemoryMemoryService {
26    pub fn new() -> Self {
27        Self { store: Arc::new(RwLock::new(HashMap::new())) }
28    }
29
30    fn extract_words(text: &str) -> HashSet<String> {
31        text.split_whitespace().filter(|s| !s.is_empty()).map(|s| s.to_lowercase()).collect()
32    }
33
34    fn extract_words_from_content(content: &adk_core::Content) -> HashSet<String> {
35        let mut words = HashSet::new();
36        for part in &content.parts {
37            if let Part::Text { text } = part {
38                words.extend(Self::extract_words(text));
39            }
40        }
41        words
42    }
43
44    fn has_intersection(set1: &HashSet<String>, set2: &HashSet<String>) -> bool {
45        if set1.is_empty() || set2.is_empty() {
46            return false;
47        }
48        set1.iter().any(|word| set2.contains(word))
49    }
50}
51
52impl Default for InMemoryMemoryService {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58#[async_trait]
59impl MemoryService for InMemoryMemoryService {
60    async fn add_session(
61        &self,
62        app_name: &str,
63        user_id: &str,
64        session_id: &str,
65        entries: Vec<MemoryEntry>,
66    ) -> Result<()> {
67        let key = MemoryKey { app_name: app_name.to_string(), user_id: user_id.to_string() };
68
69        let stored_entries: Vec<StoredEntry> = entries
70            .into_iter()
71            .map(|entry| {
72                let words = Self::extract_words_from_content(&entry.content);
73                StoredEntry { entry, words }
74            })
75            .filter(|e| !e.words.is_empty())
76            .collect();
77
78        if stored_entries.is_empty() {
79            return Ok(());
80        }
81
82        let mut store = self.store.write().unwrap();
83        let sessions = store.entry(key).or_default();
84        sessions.insert(session_id.to_string(), stored_entries);
85
86        Ok(())
87    }
88
89    async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
90        let query_words = Self::extract_words(&req.query);
91
92        let key = MemoryKey { app_name: req.app_name, user_id: req.user_id };
93
94        let store = self.store.read().unwrap();
95        let sessions = match store.get(&key) {
96            Some(s) => s,
97            None => return Ok(SearchResponse { memories: Vec::new() }),
98        };
99
100        let mut memories = Vec::new();
101        for stored_entries in sessions.values() {
102            for stored in stored_entries {
103                if Self::has_intersection(&stored.words, &query_words) {
104                    memories.push(stored.entry.clone());
105                }
106            }
107        }
108
109        Ok(SearchResponse { memories })
110    }
111}