harness_context/
memory_file.rs1use harness_core::{Memory, MemoryEntry, MemoryError};
13use std::collections::HashSet;
14use std::path::{Path, PathBuf};
15use std::sync::Mutex;
16
17pub struct FileMemory {
19 path: PathBuf,
20 write_lock: Mutex<()>,
25}
26
27impl FileMemory {
28 pub fn open(path: impl Into<PathBuf>) -> Result<Self, MemoryError> {
31 let path = path.into();
32 if let Some(parent) = path.parent()
33 && !parent.as_os_str().is_empty()
34 {
35 std::fs::create_dir_all(parent)
36 .map_err(|e| MemoryError::Io(format!("create parent: {e}")))?;
37 }
38 if !path.exists() {
40 std::fs::OpenOptions::new()
41 .create(true)
42 .append(true)
43 .open(&path)
44 .map_err(|e| MemoryError::Io(format!("create {}: {e}", path.display())))?;
45 }
46 Ok(Self {
47 path,
48 write_lock: Mutex::new(()),
49 })
50 }
51
52 pub fn path(&self) -> &Path {
55 &self.path
56 }
57
58 pub fn compact(&self) -> Result<u32, MemoryError> {
65 let entries = self.read_all()?;
66 let now = now_ms();
67 let original_len = entries.len();
68 let kept: Vec<MemoryEntry> = entries.into_iter().filter(|e| !e.is_expired(now)).collect();
69 let removed = original_len - kept.len();
70 self.rewrite(&kept)?;
71 Ok(removed as u32)
72 }
73
74 pub fn delete_by_id(&self, id: &str) -> Result<bool, MemoryError> {
77 let entries = self.read_all()?;
78 let original_len = entries.len();
79 let kept: Vec<MemoryEntry> = entries.into_iter().filter(|e| e.id != id).collect();
80 if kept.len() == original_len {
81 return Ok(false);
82 }
83 self.rewrite(&kept)?;
84 Ok(true)
85 }
86
87 pub fn delete_all(&self) -> Result<u32, MemoryError> {
90 let entries = self.read_all()?;
91 let n = entries.len() as u32;
92 self.rewrite(&[])?;
93 Ok(n)
94 }
95
96 fn rewrite(&self, entries: &[MemoryEntry]) -> Result<(), MemoryError> {
97 let _guard = self
98 .write_lock
99 .lock()
100 .map_err(|e| MemoryError::Backend(format!("poisoned mutex: {e}")))?;
101 let mut buf = String::new();
102 for e in entries {
103 let line = serde_json::to_string(e).map_err(|e| MemoryError::Serde(e.to_string()))?;
104 buf.push_str(&line);
105 buf.push('\n');
106 }
107 let tmp = self.path.with_extension("jsonl.tmp");
110 std::fs::write(&tmp, buf.as_bytes())
111 .map_err(|e| MemoryError::Io(format!("write tmp: {e}")))?;
112 std::fs::rename(&tmp, &self.path).map_err(|e| MemoryError::Io(format!("rename: {e}")))?;
113 Ok(())
114 }
115
116 fn read_all(&self) -> Result<Vec<MemoryEntry>, MemoryError> {
117 let content = std::fs::read_to_string(&self.path)
118 .map_err(|e| MemoryError::Io(format!("read {}: {e}", self.path.display())))?;
119 let mut out = Vec::new();
120 for (i, line) in content.lines().enumerate() {
121 let line = line.trim();
122 if line.is_empty() {
123 continue;
124 }
125 match serde_json::from_str::<MemoryEntry>(line) {
126 Ok(e) => out.push(e),
127 Err(err) => {
128 tracing::warn!(line = i + 1, error = %err, "memory line skipped");
132 }
133 }
134 }
135 Ok(out)
136 }
137}
138
139#[async_trait::async_trait]
140impl Memory for FileMemory {
141 async fn recall(&self, query: &str, k: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
142 let entries = self.read_all()?;
143 if entries.is_empty() || k == 0 {
144 return Ok(Vec::new());
145 }
146 let now_ms = now_ms();
147 let entries: Vec<MemoryEntry> = entries
148 .into_iter()
149 .filter(|e| !e.is_expired(now_ms))
150 .collect();
151 if entries.is_empty() {
152 return Ok(Vec::new());
153 }
154
155 let q_tokens = tokenise(query);
156 if q_tokens.is_empty() {
157 let mut all = entries;
160 all.sort_by_key(|e| std::cmp::Reverse(e.created_ms));
161 all.truncate(k);
162 return Ok(all);
163 }
164
165 let mut scored: Vec<(u32, &MemoryEntry)> = entries
168 .iter()
169 .map(|e| {
170 let mut hay = e.content.to_lowercase();
171 if !e.tags.is_empty() {
172 hay.push(' ');
173 hay.push_str(&e.tags.join(" ").to_lowercase());
174 }
175 let hits: u32 = q_tokens
176 .iter()
177 .map(|t| if hay.contains(t.as_str()) { 1 } else { 0 })
178 .sum();
179 (hits, e)
180 })
181 .filter(|(hits, _)| *hits > 0)
182 .collect();
183 scored.sort_by(|a, b| b.0.cmp(&a.0).then(b.1.created_ms.cmp(&a.1.created_ms)));
184
185 Ok(scored.into_iter().take(k).map(|(_, e)| e.clone()).collect())
186 }
187
188 async fn write(&self, mut entry: MemoryEntry) -> Result<(), MemoryError> {
189 if entry.id.is_empty() {
190 entry.id = short_id();
191 }
192 if entry.created_ms == 0 {
193 entry.created_ms = std::time::SystemTime::now()
194 .duration_since(std::time::UNIX_EPOCH)
195 .map(|d| d.as_millis() as i64)
196 .unwrap_or(0);
197 }
198 let line = serde_json::to_string(&entry).map_err(|e| MemoryError::Serde(e.to_string()))?;
199
200 let _guard = self
201 .write_lock
202 .lock()
203 .map_err(|e| MemoryError::Backend(format!("poisoned mutex: {e}")))?;
204 let mut file = std::fs::OpenOptions::new()
205 .create(true)
206 .append(true)
207 .open(&self.path)
208 .map_err(|e| MemoryError::Io(format!("open {}: {e}", self.path.display())))?;
209 use std::io::Write;
210 writeln!(file, "{line}").map_err(|e| MemoryError::Io(format!("write: {e}")))?;
211 Ok(())
212 }
213}
214
215fn now_ms() -> i64 {
216 std::time::SystemTime::now()
217 .duration_since(std::time::UNIX_EPOCH)
218 .map(|d| d.as_millis() as i64)
219 .unwrap_or(0)
220}
221
222fn tokenise(s: &str) -> HashSet<String> {
223 s.to_lowercase()
224 .split(|c: char| !c.is_alphanumeric())
225 .filter(|t| t.len() >= 3) .map(String::from)
227 .collect()
228}
229
230fn short_id() -> String {
231 let nanos = std::time::SystemTime::now()
233 .duration_since(std::time::UNIX_EPOCH)
234 .map(|d| d.as_nanos())
235 .unwrap_or(0);
236 format!("{:08x}", nanos as u64 & 0xFFFF_FFFF)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use std::sync::atomic::{AtomicU64, Ordering};
243
244 static N: AtomicU64 = AtomicU64::new(0);
245 fn tmp() -> PathBuf {
246 let pid = std::process::id();
247 let n = N.fetch_add(1, Ordering::SeqCst);
248 let nanos = std::time::SystemTime::now()
249 .duration_since(std::time::UNIX_EPOCH)
250 .unwrap()
251 .as_nanos();
252 std::env::temp_dir().join(format!("harness-mem-test-{pid}-{nanos}-{n}.jsonl"))
253 }
254
255 #[tokio::test]
256 async fn write_then_recall_round_trips() {
257 let p = tmp();
258 let m = FileMemory::open(&p).unwrap();
259 m.write(MemoryEntry::new("user prefers dark roast coffee").with_tags(["coffee"]))
260 .await
261 .unwrap();
262 m.write(MemoryEntry::new("user lives in Beijing"))
263 .await
264 .unwrap();
265
266 let hits = m.recall("coffee preferences", 5).await.unwrap();
267 assert_eq!(hits.len(), 1);
268 assert!(hits[0].content.contains("dark roast"));
269 let _ = std::fs::remove_file(&p);
270 }
271
272 #[tokio::test]
273 async fn empty_query_falls_back_to_recent() {
274 let p = tmp();
275 let m = FileMemory::open(&p).unwrap();
276 m.write(MemoryEntry::new("fact one")).await.unwrap();
277 m.write(MemoryEntry::new("fact two")).await.unwrap();
278
279 let hits = m.recall("", 5).await.unwrap();
280 assert_eq!(hits.len(), 2);
282 let _ = std::fs::remove_file(&p);
283 }
284
285 #[tokio::test]
286 async fn malformed_lines_are_skipped() {
287 let p = tmp();
288 {
289 use std::io::Write;
291 let mut f = std::fs::File::create(&p).unwrap();
292 writeln!(f, "{{not valid json").unwrap();
293 writeln!(
294 f,
295 r#"{{"id":"abc","content":"valid fact","tags":[],"source":null,"created_ms":0}}"#
296 )
297 .unwrap();
298 }
299 let m = FileMemory::open(&p).unwrap();
300 let all = m.recall("valid", 10).await.unwrap();
301 assert_eq!(all.len(), 1);
302 let _ = std::fs::remove_file(&p);
303 }
304}