Skip to main content

cersei_memory/
claudemd.rs

1//! CLAUDE.md hierarchical loading.
2//!
3//! Loads instruction files from 4 scopes with priority merging:
4//! 1. Managed: ~/.claude/rules/*.md
5//! 2. User: ~/.claude/CLAUDE.md
6//! 3. Project: {root}/CLAUDE.md
7//! 4. Local: {root}/.claude/CLAUDE.md
8
9use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11
12// ─── Types ───────────────────────────────────────────────────────────────────
13
14/// Scope of a CLAUDE.md file (highest priority first).
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub enum MemoryScope {
17    Managed = 0,
18    User = 1,
19    Project = 2,
20    Local = 3,
21}
22
23/// A loaded CLAUDE.md / rules file with metadata.
24#[derive(Debug, Clone)]
25pub struct MemoryFileInfo {
26    pub path: PathBuf,
27    pub scope: MemoryScope,
28    pub content: String,
29    pub mtime: u64,
30}
31
32// ─── Constants ───────────────────────────────────────────────────────────────
33
34const MAX_INCLUDE_DEPTH: usize = 10;
35const MAX_INCLUDE_SIZE: usize = 40_000; // 40KB
36
37// ─── Loading ─────────────────────────────────────────────────────────────────
38
39/// Load all CLAUDE.md / rules files for a project, in priority order.
40pub fn load_all_memory_files(project_root: &Path) -> Vec<MemoryFileInfo> {
41    let mut files = Vec::new();
42    let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
43
44    // 1. Managed: ~/.claude/rules/*.md (sorted alphabetically)
45    let rules_dir = home.join(".claude").join("rules");
46    if rules_dir.exists() {
47        let mut rule_files: Vec<PathBuf> = std::fs::read_dir(&rules_dir)
48            .into_iter()
49            .flatten()
50            .flatten()
51            .filter(|e| e.path().extension().and_then(|x| x.to_str()) == Some("md"))
52            .map(|e| e.path())
53            .collect();
54        rule_files.sort();
55        for path in rule_files {
56            if let Some(info) = load_memory_file(&path, MemoryScope::Managed) {
57                files.push(info);
58            }
59        }
60    }
61
62    // 2. User: ~/.claude/CLAUDE.md
63    let user_path = home.join(".claude").join("CLAUDE.md");
64    if let Some(info) = load_memory_file(&user_path, MemoryScope::User) {
65        files.push(info);
66    }
67
68    // 3. Project: {root}/CLAUDE.md
69    let project_path = project_root.join("CLAUDE.md");
70    if let Some(info) = load_memory_file(&project_path, MemoryScope::Project) {
71        files.push(info);
72    }
73
74    // 4. Local: {root}/.claude/CLAUDE.md
75    let local_path = project_root.join(".claude").join("CLAUDE.md");
76    if let Some(info) = load_memory_file(&local_path, MemoryScope::Local) {
77        files.push(info);
78    }
79
80    files
81}
82
83/// Load a single memory file with @include expansion.
84fn load_memory_file(path: &Path, scope: MemoryScope) -> Option<MemoryFileInfo> {
85    let content = std::fs::read_to_string(path).ok()?;
86    if content.trim().is_empty() {
87        return None;
88    }
89
90    let mtime = std::fs::metadata(path)
91        .and_then(|m| m.modified())
92        .ok()
93        .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
94        .map(|d| d.as_secs())
95        .unwrap_or(0);
96
97    // Strip frontmatter
98    let body = crate::strip_frontmatter(&content);
99
100    // Expand @include directives
101    let base_dir = path.parent().unwrap_or(Path::new("."));
102    let mut visited = HashSet::new();
103    visited.insert(path.to_path_buf());
104    let expanded = expand_includes(&body, base_dir, &mut visited, 0);
105
106    Some(MemoryFileInfo {
107        path: path.to_path_buf(),
108        scope,
109        content: expanded,
110        mtime,
111    })
112}
113
114/// Expand `@include <path>` directives in content.
115/// Supports:
116/// - Relative paths (resolved from base_dir)
117/// - ~ home directory expansion
118/// - Max depth of 10 levels
119/// - Max include size of 40KB
120/// - Circular reference detection
121fn expand_includes(
122    content: &str,
123    base_dir: &Path,
124    visited: &mut HashSet<PathBuf>,
125    depth: usize,
126) -> String {
127    if depth >= MAX_INCLUDE_DEPTH {
128        return content.to_string();
129    }
130
131    let mut result = String::new();
132
133    for line in content.lines() {
134        let trimmed = line.trim();
135        if trimmed.starts_with("@include ") {
136            let include_path = trimmed.strip_prefix("@include ").unwrap().trim();
137
138            // Expand ~ to home directory
139            let expanded_path = if include_path.starts_with("~/") {
140                let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
141                home.join(&include_path[2..])
142            } else {
143                base_dir.join(include_path)
144            };
145
146            // Circular reference check
147            if visited.contains(&expanded_path) {
148                result.push_str(&format!("<!-- circular @include: {} -->\n", include_path));
149                continue;
150            }
151
152            // Size check
153            if let Ok(meta) = std::fs::metadata(&expanded_path) {
154                if meta.len() > MAX_INCLUDE_SIZE as u64 {
155                    result.push_str(&format!(
156                        "<!-- @include too large: {} ({} bytes, max {}) -->\n",
157                        include_path,
158                        meta.len(),
159                        MAX_INCLUDE_SIZE
160                    ));
161                    continue;
162                }
163            }
164
165            // Load and recurse
166            if let Ok(included_content) = std::fs::read_to_string(&expanded_path) {
167                visited.insert(expanded_path.clone());
168                let included_dir = expanded_path.parent().unwrap_or(base_dir);
169                let expanded = expand_includes(&included_content, included_dir, visited, depth + 1);
170                result.push_str(&expanded);
171                result.push('\n');
172            } else {
173                result.push_str(&format!("<!-- @include not found: {} -->\n", include_path));
174            }
175        } else {
176            result.push_str(line);
177            result.push('\n');
178        }
179    }
180
181    result
182}
183
184/// Build a merged memory prompt from all loaded files.
185pub fn build_memory_prompt(files: &[MemoryFileInfo]) -> String {
186    files
187        .iter()
188        .filter(|f| !f.content.trim().is_empty())
189        .map(|f| f.content.as_str())
190        .collect::<Vec<_>>()
191        .join("\n\n")
192}
193
194// ─── Tests ───────────────────────────────────────────────────────────────────
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_load_project_claude_md() {
202        let tmp = tempfile::tempdir().unwrap();
203        std::fs::write(tmp.path().join("CLAUDE.md"), "# Project Rules\n\nUse Rust.").unwrap();
204
205        let files = load_all_memory_files(tmp.path());
206        let project = files.iter().find(|f| f.scope == MemoryScope::Project);
207        assert!(project.is_some());
208        assert!(project.unwrap().content.contains("Use Rust"));
209    }
210
211    #[test]
212    fn test_load_local_claude_md() {
213        let tmp = tempfile::tempdir().unwrap();
214        let claude_dir = tmp.path().join(".claude");
215        std::fs::create_dir_all(&claude_dir).unwrap();
216        std::fs::write(claude_dir.join("CLAUDE.md"), "Local overrides here.").unwrap();
217
218        let files = load_all_memory_files(tmp.path());
219        let local = files.iter().find(|f| f.scope == MemoryScope::Local);
220        assert!(local.is_some());
221        assert!(local.unwrap().content.contains("Local overrides"));
222    }
223
224    #[test]
225    fn test_scope_ordering() {
226        let tmp = tempfile::tempdir().unwrap();
227        std::fs::write(tmp.path().join("CLAUDE.md"), "Project").unwrap();
228        let claude_dir = tmp.path().join(".claude");
229        std::fs::create_dir_all(&claude_dir).unwrap();
230        std::fs::write(claude_dir.join("CLAUDE.md"), "Local").unwrap();
231
232        let files = load_all_memory_files(tmp.path());
233        // Project comes before Local in the list (lower scope number = higher priority)
234        let project_idx = files.iter().position(|f| f.scope == MemoryScope::Project);
235        let local_idx = files.iter().position(|f| f.scope == MemoryScope::Local);
236        if let (Some(pi), Some(li)) = (project_idx, local_idx) {
237            assert!(pi < li);
238        }
239    }
240
241    #[test]
242    fn test_include_expansion() {
243        let tmp = tempfile::tempdir().unwrap();
244        std::fs::write(
245            tmp.path().join("main.md"),
246            "Before\n@include extra.md\nAfter",
247        )
248        .unwrap();
249        std::fs::write(tmp.path().join("extra.md"), "INCLUDED CONTENT").unwrap();
250
251        let mut visited = HashSet::new();
252        visited.insert(tmp.path().join("main.md"));
253        let content = std::fs::read_to_string(tmp.path().join("main.md")).unwrap();
254        let expanded = expand_includes(&content, tmp.path(), &mut visited, 0);
255
256        assert!(expanded.contains("Before"));
257        assert!(expanded.contains("INCLUDED CONTENT"));
258        assert!(expanded.contains("After"));
259    }
260
261    #[test]
262    fn test_circular_include() {
263        let tmp = tempfile::tempdir().unwrap();
264        std::fs::write(tmp.path().join("a.md"), "@include b.md").unwrap();
265        std::fs::write(tmp.path().join("b.md"), "@include a.md").unwrap();
266
267        let mut visited = HashSet::new();
268        visited.insert(tmp.path().join("a.md"));
269        let content = std::fs::read_to_string(tmp.path().join("a.md")).unwrap();
270        let expanded = expand_includes(&content, tmp.path(), &mut visited, 0);
271
272        assert!(expanded.contains("circular"));
273    }
274
275    #[test]
276    fn test_build_memory_prompt() {
277        let files = vec![
278            MemoryFileInfo {
279                path: PathBuf::from("a.md"),
280                scope: MemoryScope::Managed,
281                content: "Rule 1".into(),
282                mtime: 0,
283            },
284            MemoryFileInfo {
285                path: PathBuf::from("b.md"),
286                scope: MemoryScope::Project,
287                content: "Rule 2".into(),
288                mtime: 0,
289            },
290        ];
291        let prompt = build_memory_prompt(&files);
292        assert!(prompt.contains("Rule 1"));
293        assert!(prompt.contains("Rule 2"));
294        assert!(prompt.contains("\n\n")); // separator
295    }
296
297    #[test]
298    fn test_empty_file_skipped() {
299        let tmp = tempfile::tempdir().unwrap();
300        std::fs::write(tmp.path().join("CLAUDE.md"), "").unwrap();
301
302        let files = load_all_memory_files(tmp.path());
303        assert!(files.iter().all(|f| f.scope != MemoryScope::Project));
304    }
305
306    #[test]
307    fn test_frontmatter_stripped() {
308        let tmp = tempfile::tempdir().unwrap();
309        std::fs::write(
310            tmp.path().join("CLAUDE.md"),
311            "---\nscope: project\n---\n\nActual content.",
312        )
313        .unwrap();
314
315        let files = load_all_memory_files(tmp.path());
316        let project = files
317            .iter()
318            .find(|f| f.scope == MemoryScope::Project)
319            .unwrap();
320        assert!(project.content.contains("Actual content"));
321        assert!(!project.content.contains("scope: project"));
322    }
323}