Skip to main content

atomcode_core/config/
memory.rs

1use std::fs;
2use std::io::{self, Write};
3use std::path::{Path, PathBuf};
4
5const MAX_MEMORY_FILE_SIZE: u64 = 64 * 1024;
6const DEFAULT_CHAR_LIMIT: usize = 4000;
7
8pub struct MemoryStore {
9    path: PathBuf,
10}
11
12impl MemoryStore {
13    pub fn new(path: PathBuf) -> Self {
14        Self { path }
15    }
16
17    pub fn global() -> Self {
18        let dir = super::Config::config_dir();
19        Self::new(dir.join("memory.md"))
20    }
21
22    pub fn project(project_root: &Path) -> Self {
23        Self::new(project_root.join(".atomcode").join("memory.md"))
24    }
25
26    pub fn path(&self) -> &Path {
27        &self.path
28    }
29
30    pub fn load(&self) -> Vec<String> {
31        let content = match fs::metadata(&self.path) {
32            Ok(meta) => {
33                if meta.len() > MAX_MEMORY_FILE_SIZE {
34                    let bytes = fs::read(&self.path).unwrap_or_default();
35                    let start = bytes.len().saturating_sub(MAX_MEMORY_FILE_SIZE as usize);
36                    // Scan forward to the next newline to avoid splitting UTF-8 chars
37                    let safe_start = bytes[start..].iter()
38                        .position(|&b| b == b'\n')
39                        .map(|pos| start + pos + 1)
40                        .unwrap_or(start);
41                    String::from_utf8_lossy(&bytes[safe_start..]).to_string()
42                } else {
43                    fs::read_to_string(&self.path).unwrap_or_default()
44                }
45            }
46            Err(_) => return Vec::new(),
47        };
48        content
49            .lines()
50            .filter_map(|line| {
51                let trimmed = line.trim();
52                if trimmed.starts_with("- ") {
53                    Some(trimmed[2..].to_string())
54                } else {
55                    None
56                }
57            })
58            .collect()
59    }
60
61    pub fn append(&self, content: &str) -> io::Result<()> {
62        if let Some(parent) = self.path.parent() {
63            fs::create_dir_all(parent)?;
64        }
65
66        // Read existing content to check if we need a leading newline
67        let existing = fs::read_to_string(&self.path).unwrap_or_default();
68        let needs_newline = !existing.is_empty() && !existing.ends_with('\n');
69
70        let mut file = fs::OpenOptions::new()
71            .create(true)
72            .append(true)
73            .open(&self.path)?;
74
75        if needs_newline {
76            writeln!(file)?;
77        }
78        writeln!(file, "- {}", content.trim())
79    }
80
81    pub fn remove_matching(&self, keyword: &str) -> io::Result<Vec<String>> {
82        let content = fs::read_to_string(&self.path).unwrap_or_default();
83        let keyword_lower = keyword.to_lowercase();
84        let mut removed = Vec::new();
85        let mut kept = Vec::new();
86
87        for line in content.lines() {
88            let trimmed = line.trim();
89            if trimmed.starts_with("- ") && trimmed.to_lowercase().contains(&keyword_lower) {
90                removed.push(trimmed[2..].to_string());
91            } else {
92                kept.push(line.to_string());
93            }
94        }
95
96        if !removed.is_empty() {
97            let mut out = kept.join("\n");
98            if !out.is_empty() && !out.ends_with('\n') {
99                out.push('\n');
100            }
101            fs::write(&self.path, out)?;
102        }
103
104        Ok(removed)
105    }
106
107    pub fn find_matching(&self, keyword: &str) -> Vec<String> {
108        let keyword_lower = keyword.to_lowercase();
109        self.load()
110            .into_iter()
111            .filter(|entry| entry.to_lowercase().contains(&keyword_lower))
112            .collect()
113    }
114
115    pub fn merged_for_prompt(global: &MemoryStore, project: &MemoryStore, project_name: &str) -> String {
116        let global_entries = global.load();
117        let project_entries = project.load();
118
119        if global_entries.is_empty() && project_entries.is_empty() {
120            return String::new();
121        }
122
123        let mut result = String::from("=== MEMORY ===\nThe user has asked you to remember these facts and preferences:\n");
124
125        if !global_entries.is_empty() {
126            result.push_str("\n[Global]\n");
127            for entry in &global_entries {
128                result.push_str(&format!("- {}\n", entry));
129            }
130        }
131
132        if !project_entries.is_empty() {
133            result.push_str(&format!("\n[Project: {}]\n", project_name));
134            for entry in &project_entries {
135                result.push_str(&format!("- {}\n", entry));
136            }
137        }
138
139        if result.chars().count() > DEFAULT_CHAR_LIMIT {
140            let truncated: String = result.chars().take(DEFAULT_CHAR_LIMIT).collect();
141            format!("{}\n[...truncated, run /memory to review]", truncated)
142        } else {
143            result
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_append_creates_file() {
154        let dir = tempfile::tempdir().unwrap();
155        let store = MemoryStore::new(dir.path().join("sub").join("memory.md"));
156        store.append("test entry").unwrap();
157        let content = fs::read_to_string(store.path()).unwrap();
158        assert_eq!(content, "- test entry\n");
159    }
160
161    #[test]
162    fn test_append_to_existing() {
163        let dir = tempfile::tempdir().unwrap();
164        let path = dir.path().join("memory.md");
165        fs::write(&path, "- first\n").unwrap();
166        let store = MemoryStore::new(path);
167        store.append("second").unwrap();
168        let entries = store.load();
169        assert_eq!(entries, vec!["first", "second"]);
170    }
171
172    #[test]
173    fn test_load_skips_non_entries() {
174        let dir = tempfile::tempdir().unwrap();
175        let path = dir.path().join("memory.md");
176        fs::write(&path, "# Header\n\n- real entry\nnot an entry\n- another\n").unwrap();
177        let store = MemoryStore::new(path);
178        assert_eq!(store.load(), vec!["real entry", "another"]);
179    }
180
181    #[test]
182    fn test_load_empty_file() {
183        let dir = tempfile::tempdir().unwrap();
184        let path = dir.path().join("memory.md");
185        fs::write(&path, "").unwrap();
186        let store = MemoryStore::new(path);
187        assert!(store.load().is_empty());
188    }
189
190    #[test]
191    fn test_load_nonexistent() {
192        let store = MemoryStore::new(PathBuf::from("/nonexistent/memory.md"));
193        assert!(store.load().is_empty());
194    }
195
196    #[test]
197    fn test_remove_matching_case_insensitive() {
198        let dir = tempfile::tempdir().unwrap();
199        let path = dir.path().join("memory.md");
200        fs::write(&path, "- Use tabs\n- use spaces\n- pnpm only\n").unwrap();
201        let store = MemoryStore::new(path);
202        let removed = store.remove_matching("use").unwrap();
203        assert_eq!(removed, vec!["Use tabs", "use spaces"]);
204        assert_eq!(store.load(), vec!["pnpm only"]);
205    }
206
207    #[test]
208    fn test_remove_matching_no_match() {
209        let dir = tempfile::tempdir().unwrap();
210        let path = dir.path().join("memory.md");
211        fs::write(&path, "- keep this\n").unwrap();
212        let store = MemoryStore::new(path.clone());
213        let removed = store.remove_matching("nonexistent").unwrap();
214        assert!(removed.is_empty());
215        assert_eq!(fs::read_to_string(&path).unwrap(), "- keep this\n");
216    }
217
218    #[test]
219    fn test_merged_for_prompt_truncation() {
220        let dir = tempfile::tempdir().unwrap();
221        let path = dir.path().join("memory.md");
222        let long_entry = "x".repeat(5000);
223        fs::write(&path, format!("- {}\n", long_entry)).unwrap();
224        let store = MemoryStore::new(path);
225        let empty = MemoryStore::new(PathBuf::from("/none"));
226        let result = MemoryStore::merged_for_prompt(&store, &empty, "p");
227        assert!(result.contains("[...truncated"));
228        assert!(result.chars().count() < 5000);
229    }
230}