Skip to main content

construct/agent/
memory_loader.rs

1use crate::memory::{self, Memory, decay};
2use async_trait::async_trait;
3use std::fmt::Write;
4
5#[async_trait]
6pub trait MemoryLoader: Send + Sync {
7    async fn load_context(
8        &self,
9        memory: &dyn Memory,
10        user_message: &str,
11        session_id: Option<&str>,
12    ) -> anyhow::Result<String>;
13}
14
15pub struct DefaultMemoryLoader {
16    limit: usize,
17    min_relevance_score: f64,
18}
19
20impl Default for DefaultMemoryLoader {
21    fn default() -> Self {
22        Self {
23            limit: 5,
24            min_relevance_score: 0.4,
25        }
26    }
27}
28
29impl DefaultMemoryLoader {
30    pub fn new(limit: usize, min_relevance_score: f64) -> Self {
31        Self {
32            limit: limit.max(1),
33            min_relevance_score,
34        }
35    }
36}
37
38#[async_trait]
39impl MemoryLoader for DefaultMemoryLoader {
40    async fn load_context(
41        &self,
42        memory: &dyn Memory,
43        user_message: &str,
44        session_id: Option<&str>,
45    ) -> anyhow::Result<String> {
46        let mut entries = memory
47            .recall(user_message, self.limit, session_id, None, None)
48            .await?;
49        if entries.is_empty() {
50            return Ok(String::new());
51        }
52
53        // Apply time decay: older non-Core memories score lower
54        decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
55
56        let mut context = String::from("[Memory context]\n");
57        for entry in entries {
58            if memory::is_assistant_autosave_key(&entry.key) {
59                continue;
60            }
61            if memory::should_skip_autosave_content(&entry.content) {
62                continue;
63            }
64            if let Some(score) = entry.score {
65                if score < self.min_relevance_score {
66                    continue;
67                }
68            }
69            let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
70        }
71
72        // If all entries were below threshold, return empty
73        if context == "[Memory context]\n" {
74            return Ok(String::new());
75        }
76
77        context.push_str("[/Memory context]\n\n");
78        Ok(context)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::memory::{Memory, MemoryCategory, MemoryEntry};
86    use std::sync::Arc;
87
88    struct MockMemory;
89    struct MockMemoryWithEntries {
90        entries: Arc<Vec<MemoryEntry>>,
91    }
92
93    #[async_trait]
94    impl Memory for MockMemory {
95        async fn store(
96            &self,
97            _key: &str,
98            _content: &str,
99            _category: MemoryCategory,
100            _session_id: Option<&str>,
101        ) -> anyhow::Result<()> {
102            Ok(())
103        }
104
105        async fn recall(
106            &self,
107            _query: &str,
108            limit: usize,
109            _session_id: Option<&str>,
110            _since: Option<&str>,
111            _until: Option<&str>,
112        ) -> anyhow::Result<Vec<MemoryEntry>> {
113            if limit == 0 {
114                return Ok(vec![]);
115            }
116            Ok(vec![MemoryEntry {
117                id: "1".into(),
118                key: "k".into(),
119                content: "v".into(),
120                category: MemoryCategory::Conversation,
121                timestamp: "now".into(),
122                session_id: None,
123                score: None,
124                namespace: "default".into(),
125                importance: None,
126                superseded_by: None,
127            }])
128        }
129
130        async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
131            Ok(None)
132        }
133
134        async fn list(
135            &self,
136            _category: Option<&MemoryCategory>,
137            _session_id: Option<&str>,
138        ) -> anyhow::Result<Vec<MemoryEntry>> {
139            Ok(vec![])
140        }
141
142        async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
143            Ok(true)
144        }
145
146        async fn count(&self) -> anyhow::Result<usize> {
147            Ok(0)
148        }
149
150        async fn health_check(&self) -> bool {
151            true
152        }
153
154        fn name(&self) -> &str {
155            "mock"
156        }
157    }
158
159    #[async_trait]
160    impl Memory for MockMemoryWithEntries {
161        async fn store(
162            &self,
163            _key: &str,
164            _content: &str,
165            _category: MemoryCategory,
166            _session_id: Option<&str>,
167        ) -> anyhow::Result<()> {
168            Ok(())
169        }
170
171        async fn recall(
172            &self,
173            _query: &str,
174            _limit: usize,
175            _session_id: Option<&str>,
176            _since: Option<&str>,
177            _until: Option<&str>,
178        ) -> anyhow::Result<Vec<MemoryEntry>> {
179            Ok(self.entries.as_ref().clone())
180        }
181
182        async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
183            Ok(None)
184        }
185
186        async fn list(
187            &self,
188            _category: Option<&MemoryCategory>,
189            _session_id: Option<&str>,
190        ) -> anyhow::Result<Vec<MemoryEntry>> {
191            Ok(vec![])
192        }
193
194        async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
195            Ok(true)
196        }
197
198        async fn count(&self) -> anyhow::Result<usize> {
199            Ok(self.entries.len())
200        }
201
202        async fn health_check(&self) -> bool {
203            true
204        }
205
206        fn name(&self) -> &str {
207            "mock-with-entries"
208        }
209    }
210
211    #[tokio::test]
212    async fn default_loader_formats_context() {
213        let loader = DefaultMemoryLoader::default();
214        let context = loader
215            .load_context(&MockMemory, "hello", None)
216            .await
217            .unwrap();
218        assert!(context.contains("[Memory context]"));
219        assert!(context.contains("- k: v"));
220    }
221
222    #[tokio::test]
223    async fn default_loader_skips_legacy_assistant_autosave_entries() {
224        let loader = DefaultMemoryLoader::new(5, 0.0);
225        let memory = MockMemoryWithEntries {
226            entries: Arc::new(vec![
227                MemoryEntry {
228                    id: "1".into(),
229                    key: "assistant_resp_legacy".into(),
230                    content: "fabricated detail".into(),
231                    category: MemoryCategory::Daily,
232                    timestamp: "now".into(),
233                    session_id: None,
234                    score: Some(0.95),
235                    namespace: "default".into(),
236                    importance: None,
237                    superseded_by: None,
238                },
239                MemoryEntry {
240                    id: "2".into(),
241                    key: "user_fact".into(),
242                    content: "User prefers concise answers".into(),
243                    category: MemoryCategory::Conversation,
244                    timestamp: "now".into(),
245                    session_id: None,
246                    score: Some(0.9),
247                    namespace: "default".into(),
248                    importance: None,
249                    superseded_by: None,
250                },
251            ]),
252        };
253
254        let context = loader
255            .load_context(&memory, "answer style", None)
256            .await
257            .unwrap();
258        assert!(context.contains("user_fact"));
259        assert!(!context.contains("assistant_resp_legacy"));
260        assert!(!context.contains("fabricated detail"));
261    }
262}