cersei_memory/
claudemd.rs1use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub enum MemoryScope {
17 Managed = 0,
18 User = 1,
19 Project = 2,
20 Local = 3,
21}
22
23#[derive(Debug, Clone)]
25pub struct MemoryFileInfo {
26 pub path: PathBuf,
27 pub scope: MemoryScope,
28 pub content: String,
29 pub mtime: u64,
30}
31
32const MAX_INCLUDE_DEPTH: usize = 10;
35const MAX_INCLUDE_SIZE: usize = 40_000; pub fn load_all_memory_files(project_root: &Path) -> Vec<MemoryFileInfo> {
41 let mut files = Vec::new();
42 let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
43
44 let rules_dir = home.join(".claude").join("rules");
46 if rules_dir.exists() {
47 let mut rule_files: Vec<PathBuf> = std::fs::read_dir(&rules_dir)
48 .into_iter()
49 .flatten()
50 .flatten()
51 .filter(|e| e.path().extension().and_then(|x| x.to_str()) == Some("md"))
52 .map(|e| e.path())
53 .collect();
54 rule_files.sort();
55 for path in rule_files {
56 if let Some(info) = load_memory_file(&path, MemoryScope::Managed) {
57 files.push(info);
58 }
59 }
60 }
61
62 let user_path = home.join(".claude").join("CLAUDE.md");
64 if let Some(info) = load_memory_file(&user_path, MemoryScope::User) {
65 files.push(info);
66 }
67
68 let project_path = project_root.join("CLAUDE.md");
70 if let Some(info) = load_memory_file(&project_path, MemoryScope::Project) {
71 files.push(info);
72 }
73
74 let local_path = project_root.join(".claude").join("CLAUDE.md");
76 if let Some(info) = load_memory_file(&local_path, MemoryScope::Local) {
77 files.push(info);
78 }
79
80 files
81}
82
83fn load_memory_file(path: &Path, scope: MemoryScope) -> Option<MemoryFileInfo> {
85 let content = std::fs::read_to_string(path).ok()?;
86 if content.trim().is_empty() {
87 return None;
88 }
89
90 let mtime = std::fs::metadata(path)
91 .and_then(|m| m.modified())
92 .ok()
93 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
94 .map(|d| d.as_secs())
95 .unwrap_or(0);
96
97 let body = crate::strip_frontmatter(&content);
99
100 let base_dir = path.parent().unwrap_or(Path::new("."));
102 let mut visited = HashSet::new();
103 visited.insert(path.to_path_buf());
104 let expanded = expand_includes(&body, base_dir, &mut visited, 0);
105
106 Some(MemoryFileInfo {
107 path: path.to_path_buf(),
108 scope,
109 content: expanded,
110 mtime,
111 })
112}
113
114fn expand_includes(
122 content: &str,
123 base_dir: &Path,
124 visited: &mut HashSet<PathBuf>,
125 depth: usize,
126) -> String {
127 if depth >= MAX_INCLUDE_DEPTH {
128 return content.to_string();
129 }
130
131 let mut result = String::new();
132
133 for line in content.lines() {
134 let trimmed = line.trim();
135 if trimmed.starts_with("@include ") {
136 let include_path = trimmed.strip_prefix("@include ").unwrap().trim();
137
138 let expanded_path = if include_path.starts_with("~/") {
140 let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
141 home.join(&include_path[2..])
142 } else {
143 base_dir.join(include_path)
144 };
145
146 if visited.contains(&expanded_path) {
148 result.push_str(&format!("<!-- circular @include: {} -->\n", include_path));
149 continue;
150 }
151
152 if let Ok(meta) = std::fs::metadata(&expanded_path) {
154 if meta.len() > MAX_INCLUDE_SIZE as u64 {
155 result.push_str(&format!(
156 "<!-- @include too large: {} ({} bytes, max {}) -->\n",
157 include_path,
158 meta.len(),
159 MAX_INCLUDE_SIZE
160 ));
161 continue;
162 }
163 }
164
165 if let Ok(included_content) = std::fs::read_to_string(&expanded_path) {
167 visited.insert(expanded_path.clone());
168 let included_dir = expanded_path.parent().unwrap_or(base_dir);
169 let expanded = expand_includes(&included_content, included_dir, visited, depth + 1);
170 result.push_str(&expanded);
171 result.push('\n');
172 } else {
173 result.push_str(&format!("<!-- @include not found: {} -->\n", include_path));
174 }
175 } else {
176 result.push_str(line);
177 result.push('\n');
178 }
179 }
180
181 result
182}
183
184pub fn build_memory_prompt(files: &[MemoryFileInfo]) -> String {
186 files
187 .iter()
188 .filter(|f| !f.content.trim().is_empty())
189 .map(|f| f.content.as_str())
190 .collect::<Vec<_>>()
191 .join("\n\n")
192}
193
194#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_load_project_claude_md() {
202 let tmp = tempfile::tempdir().unwrap();
203 std::fs::write(tmp.path().join("CLAUDE.md"), "# Project Rules\n\nUse Rust.").unwrap();
204
205 let files = load_all_memory_files(tmp.path());
206 let project = files.iter().find(|f| f.scope == MemoryScope::Project);
207 assert!(project.is_some());
208 assert!(project.unwrap().content.contains("Use Rust"));
209 }
210
211 #[test]
212 fn test_load_local_claude_md() {
213 let tmp = tempfile::tempdir().unwrap();
214 let claude_dir = tmp.path().join(".claude");
215 std::fs::create_dir_all(&claude_dir).unwrap();
216 std::fs::write(claude_dir.join("CLAUDE.md"), "Local overrides here.").unwrap();
217
218 let files = load_all_memory_files(tmp.path());
219 let local = files.iter().find(|f| f.scope == MemoryScope::Local);
220 assert!(local.is_some());
221 assert!(local.unwrap().content.contains("Local overrides"));
222 }
223
224 #[test]
225 fn test_scope_ordering() {
226 let tmp = tempfile::tempdir().unwrap();
227 std::fs::write(tmp.path().join("CLAUDE.md"), "Project").unwrap();
228 let claude_dir = tmp.path().join(".claude");
229 std::fs::create_dir_all(&claude_dir).unwrap();
230 std::fs::write(claude_dir.join("CLAUDE.md"), "Local").unwrap();
231
232 let files = load_all_memory_files(tmp.path());
233 let project_idx = files.iter().position(|f| f.scope == MemoryScope::Project);
235 let local_idx = files.iter().position(|f| f.scope == MemoryScope::Local);
236 if let (Some(pi), Some(li)) = (project_idx, local_idx) {
237 assert!(pi < li);
238 }
239 }
240
241 #[test]
242 fn test_include_expansion() {
243 let tmp = tempfile::tempdir().unwrap();
244 std::fs::write(
245 tmp.path().join("main.md"),
246 "Before\n@include extra.md\nAfter",
247 )
248 .unwrap();
249 std::fs::write(tmp.path().join("extra.md"), "INCLUDED CONTENT").unwrap();
250
251 let mut visited = HashSet::new();
252 visited.insert(tmp.path().join("main.md"));
253 let content = std::fs::read_to_string(tmp.path().join("main.md")).unwrap();
254 let expanded = expand_includes(&content, tmp.path(), &mut visited, 0);
255
256 assert!(expanded.contains("Before"));
257 assert!(expanded.contains("INCLUDED CONTENT"));
258 assert!(expanded.contains("After"));
259 }
260
261 #[test]
262 fn test_circular_include() {
263 let tmp = tempfile::tempdir().unwrap();
264 std::fs::write(tmp.path().join("a.md"), "@include b.md").unwrap();
265 std::fs::write(tmp.path().join("b.md"), "@include a.md").unwrap();
266
267 let mut visited = HashSet::new();
268 visited.insert(tmp.path().join("a.md"));
269 let content = std::fs::read_to_string(tmp.path().join("a.md")).unwrap();
270 let expanded = expand_includes(&content, tmp.path(), &mut visited, 0);
271
272 assert!(expanded.contains("circular"));
273 }
274
275 #[test]
276 fn test_build_memory_prompt() {
277 let files = vec![
278 MemoryFileInfo {
279 path: PathBuf::from("a.md"),
280 scope: MemoryScope::Managed,
281 content: "Rule 1".into(),
282 mtime: 0,
283 },
284 MemoryFileInfo {
285 path: PathBuf::from("b.md"),
286 scope: MemoryScope::Project,
287 content: "Rule 2".into(),
288 mtime: 0,
289 },
290 ];
291 let prompt = build_memory_prompt(&files);
292 assert!(prompt.contains("Rule 1"));
293 assert!(prompt.contains("Rule 2"));
294 assert!(prompt.contains("\n\n")); }
296
297 #[test]
298 fn test_empty_file_skipped() {
299 let tmp = tempfile::tempdir().unwrap();
300 std::fs::write(tmp.path().join("CLAUDE.md"), "").unwrap();
301
302 let files = load_all_memory_files(tmp.path());
303 assert!(files.iter().all(|f| f.scope != MemoryScope::Project));
304 }
305
306 #[test]
307 fn test_frontmatter_stripped() {
308 let tmp = tempfile::tempdir().unwrap();
309 std::fs::write(
310 tmp.path().join("CLAUDE.md"),
311 "---\nscope: project\n---\n\nActual content.",
312 )
313 .unwrap();
314
315 let files = load_all_memory_files(tmp.path());
316 let project = files
317 .iter()
318 .find(|f| f.scope == MemoryScope::Project)
319 .unwrap();
320 assert!(project.content.contains("Actual content"));
321 assert!(!project.content.contains("scope: project"));
322 }
323}