1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use super::error::{MemoryError, MemoryResult};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum MemoryEntryKind {
13 RunTrace,
14 Rationale,
15 Diff,
16 Snapshot,
17 ToolResult,
18}
19
20impl std::fmt::Display for MemoryEntryKind {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 Self::RunTrace => write!(f, "run_trace"),
24 Self::Rationale => write!(f, "rationale"),
25 Self::Diff => write!(f, "diff"),
26 Self::Snapshot => write!(f, "snapshot"),
27 Self::ToolResult => write!(f, "tool_result"),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryEntry {
35 pub id: String,
36 pub kind: MemoryEntryKind,
37 pub summary: String,
38 pub content_digest: String,
39 pub created_at: DateTime<Utc>,
40 pub tags: Vec<String>,
41 pub token_estimate: usize,
42 pub relevance: f64,
43}
44
45#[derive(Debug, Clone, Default)]
47pub struct IndexQuery {
48 pub kind: Option<MemoryEntryKind>,
49 pub tag: Option<String>,
50 pub after: Option<DateTime<Utc>>,
51 pub limit: Option<usize>,
52}
53
54impl IndexQuery {
55 pub fn all() -> Self {
57 Self::default()
58 }
59
60 pub fn with_kind(mut self, kind: MemoryEntryKind) -> Self {
61 self.kind = Some(kind);
62 self
63 }
64
65 pub fn with_tag(mut self, tag: &str) -> Self {
66 self.tag = Some(tag.to_string());
67 self
68 }
69
70 pub fn after(mut self, after: DateTime<Utc>) -> Self {
71 self.after = Some(after);
72 self
73 }
74
75 pub fn with_limit(mut self, limit: usize) -> Self {
76 self.limit = Some(limit);
77 self
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct IndexResult {
84 pub entries: Vec<MemoryEntry>,
85 pub total_matches: usize,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct MemoryIndex {
91 entries: HashMap<String, MemoryEntry>,
92}
93
94impl MemoryIndex {
95 pub fn new() -> Self {
96 Self {
97 entries: HashMap::new(),
98 }
99 }
100
101 pub fn len(&self) -> usize {
102 self.entries.len()
103 }
104
105 pub fn is_empty(&self) -> bool {
106 self.entries.is_empty()
107 }
108
109 pub fn insert(&mut self, entry: MemoryEntry) -> MemoryResult<()> {
111 if self.entries.contains_key(&entry.id) {
112 return Err(MemoryError::DuplicateEntry { id: entry.id });
113 }
114 self.entries.insert(entry.id.clone(), entry);
115 Ok(())
116 }
117
118 pub fn get(&self, id: &str) -> MemoryResult<&MemoryEntry> {
120 self.entries
121 .get(id)
122 .ok_or_else(|| MemoryError::EntryNotFound { id: id.into() })
123 }
124
125 pub fn remove(&mut self, id: &str) -> MemoryResult<MemoryEntry> {
127 self.entries
128 .remove(id)
129 .ok_or_else(|| MemoryError::EntryNotFound { id: id.into() })
130 }
131
132 pub fn entries_mut(&mut self) -> &mut HashMap<String, MemoryEntry> {
134 &mut self.entries
135 }
136
137 pub fn query(&self, q: &IndexQuery) -> IndexResult {
139 let mut matches: Vec<MemoryEntry> = self
140 .entries
141 .values()
142 .filter(|e| {
143 if let Some(ref kind) = q.kind {
144 if &e.kind != kind {
145 return false;
146 }
147 }
148 if let Some(ref tag) = q.tag {
149 if !e.tags.contains(tag) {
150 return false;
151 }
152 }
153 if let Some(after) = q.after {
154 if e.created_at < after {
155 return false;
156 }
157 }
158 true
159 })
160 .cloned()
161 .collect();
162
163 matches.sort_by(|a, b| b.created_at.cmp(&a.created_at));
165
166 let total_matches = matches.len();
167
168 if let Some(limit) = q.limit {
169 matches.truncate(limit);
170 }
171
172 IndexResult {
173 entries: matches,
174 total_matches,
175 }
176 }
177}
178
179impl Default for MemoryIndex {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use chrono::Duration;
189
190 fn make_entry(id: &str, kind: MemoryEntryKind) -> MemoryEntry {
191 MemoryEntry {
192 id: id.into(),
193 kind,
194 summary: format!("summary {id}"),
195 content_digest: format!("digest_{id}"),
196 created_at: Utc::now(),
197 tags: Vec::new(),
198 token_estimate: 100,
199 relevance: 0.5,
200 }
201 }
202
203 #[test]
204 fn test_insert_and_get() {
205 let mut idx = MemoryIndex::new();
206 idx.insert(make_entry("a", MemoryEntryKind::RunTrace))
207 .unwrap();
208 assert_eq!(idx.len(), 1);
209 assert_eq!(idx.get("a").unwrap().kind, MemoryEntryKind::RunTrace);
210 }
211
212 #[test]
213 fn test_remove() {
214 let mut idx = MemoryIndex::new();
215 idx.insert(make_entry("a", MemoryEntryKind::RunTrace))
216 .unwrap();
217 idx.remove("a").unwrap();
218 assert!(idx.is_empty());
219 }
220
221 #[test]
222 fn test_get_not_found() {
223 let idx = MemoryIndex::new();
224 assert!(idx.get("nope").is_err());
225 }
226
227 #[test]
228 fn test_query_all() {
229 let mut idx = MemoryIndex::new();
230 idx.insert(make_entry("a", MemoryEntryKind::RunTrace))
231 .unwrap();
232 idx.insert(make_entry("b", MemoryEntryKind::Diff)).unwrap();
233 let r = idx.query(&IndexQuery::all());
234 assert_eq!(r.total_matches, 2);
235 }
236
237 #[test]
238 fn test_query_by_kind() {
239 let mut idx = MemoryIndex::new();
240 idx.insert(make_entry("a", MemoryEntryKind::RunTrace))
241 .unwrap();
242 idx.insert(make_entry("b", MemoryEntryKind::Diff)).unwrap();
243 let r = idx.query(&IndexQuery::all().with_kind(MemoryEntryKind::Diff));
244 assert_eq!(r.total_matches, 1);
245 assert_eq!(r.entries[0].id, "b");
246 }
247
248 #[test]
249 fn test_query_with_limit() {
250 let mut idx = MemoryIndex::new();
251 for i in 0..10 {
252 let mut e = make_entry(&format!("e{i}"), MemoryEntryKind::RunTrace);
253 e.created_at = Utc::now() - Duration::hours(i);
254 idx.insert(e).unwrap();
255 }
256 let r = idx.query(&IndexQuery::all().with_limit(3));
257 assert_eq!(r.total_matches, 10);
258 assert_eq!(r.entries.len(), 3);
259 }
260
261 #[test]
262 fn test_serde_roundtrip() {
263 let mut idx = MemoryIndex::new();
264 idx.insert(make_entry("x", MemoryEntryKind::Snapshot))
265 .unwrap();
266 let json = serde_json::to_string(&idx).unwrap();
267 let back: MemoryIndex = serde_json::from_str(&json).unwrap();
268 assert_eq!(back.len(), 1);
269 }
270
271 #[test]
272 fn test_insert_duplicate_id_rejected() {
273 let mut idx = MemoryIndex::new();
274 idx.insert(make_entry("dup", MemoryEntryKind::RunTrace))
275 .unwrap();
276 let err = idx
277 .insert(make_entry("dup", MemoryEntryKind::Diff))
278 .expect_err("duplicate id should fail");
279 assert!(matches!(err, MemoryError::DuplicateEntry { .. }));
280 assert_eq!(idx.len(), 1);
281 assert_eq!(idx.get("dup").unwrap().kind, MemoryEntryKind::RunTrace);
282 }
283}