claude_agent/context/
memory_loader.rs

1//! CLAUDE.md Memory Loader
2//!
3//! Implements Claude Code CLI compatible memory loading:
4//! - Recursive loading from current directory to root
5//! - CLAUDE.local.md support
6//! - @import syntax (max 5 hops)
7//! - .claude/rules/ index-only loading for progressive disclosure
8
9use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11
12use super::provider::MAX_IMPORT_DEPTH;
13use super::rule_index::RuleIndex;
14use super::{ContextError, ContextResult};
15
16#[derive(Debug, Default)]
17pub struct MemoryLoader {
18    loaded_paths: HashSet<PathBuf>,
19    current_depth: usize,
20}
21
22impl MemoryLoader {
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub async fn load_all(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
28        let mut content = MemoryContent::default();
29
30        let claude_files = self.find_claude_files(start_dir);
31        for path in claude_files {
32            if let Ok(text) = self.load_file_with_imports(&path).await {
33                content.claude_md.push(text);
34            }
35        }
36
37        let local_files = self.find_local_files(start_dir);
38        for path in local_files {
39            if let Ok(text) = self.load_file_with_imports(&path).await {
40                content.local_md.push(text);
41            }
42        }
43
44        let rules_dir = start_dir.join(".claude").join("rules");
45        if rules_dir.exists() {
46            content.rule_indices = self.scan_rules_directory(&rules_dir).await?;
47        }
48
49        Ok(content)
50    }
51
52    pub async fn load_local_only(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
53        let mut content = MemoryContent::default();
54
55        let local_files = self.find_local_files(start_dir);
56        for path in local_files {
57            if let Ok(text) = self.load_file_with_imports(&path).await {
58                content.local_md.push(text);
59            }
60        }
61
62        Ok(content)
63    }
64
65    fn find_claude_files(&self, start_dir: &Path) -> Vec<PathBuf> {
66        let mut files = Vec::new();
67        let mut current = start_dir.to_path_buf();
68
69        loop {
70            let claude_md = current.join("CLAUDE.md");
71            if claude_md.exists() {
72                files.push(claude_md);
73            }
74
75            let claude_dir_md = current.join(".claude").join("CLAUDE.md");
76            if claude_dir_md.exists() {
77                files.push(claude_dir_md);
78            }
79
80            match current.parent() {
81                Some(parent) if parent != current && !parent.as_os_str().is_empty() => {
82                    current = parent.to_path_buf();
83                }
84                _ => break,
85            }
86        }
87
88        files.reverse();
89        files
90    }
91
92    fn find_local_files(&self, start_dir: &Path) -> Vec<PathBuf> {
93        let mut files = Vec::new();
94        let mut current = start_dir.to_path_buf();
95
96        loop {
97            let local_md = current.join("CLAUDE.local.md");
98            if local_md.exists() {
99                files.push(local_md);
100            }
101
102            let local_dir_md = current.join(".claude").join("CLAUDE.local.md");
103            if local_dir_md.exists() {
104                files.push(local_dir_md);
105            }
106
107            match current.parent() {
108                Some(parent) if parent != current && !parent.as_os_str().is_empty() => {
109                    current = parent.to_path_buf();
110                }
111                _ => break,
112            }
113        }
114
115        files.reverse();
116        files
117    }
118
119    async fn scan_rules_directory(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
120        let mut indices = Vec::new();
121
122        let mut entries = tokio::fs::read_dir(dir)
123            .await
124            .map_err(|e| ContextError::Source {
125                message: format!("Failed to read rules directory: {}", e),
126            })?;
127
128        while let Some(entry) = entries
129            .next_entry()
130            .await
131            .map_err(|e| ContextError::Source {
132                message: format!("Failed to read directory entry: {}", e),
133            })?
134        {
135            let path = entry.path();
136            if path.extension().is_some_and(|e| e == "md")
137                && let Some(index) = RuleIndex::from_file(&path)
138            {
139                indices.push(index);
140            }
141        }
142
143        indices.sort_by(|a, b| b.priority.cmp(&a.priority));
144        Ok(indices)
145    }
146
147    fn load_file_with_imports<'a>(
148        &'a mut self,
149        path: &'a Path,
150    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
151    {
152        Box::pin(async move {
153            if self.current_depth >= MAX_IMPORT_DEPTH {
154                tracing::warn!(
155                    "Import depth limit ({}) reached, skipping: {}",
156                    MAX_IMPORT_DEPTH,
157                    path.display()
158                );
159                return Ok(String::new());
160            }
161
162            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
163            if self.loaded_paths.contains(&canonical) {
164                return Ok(String::new());
165            }
166            self.loaded_paths.insert(canonical.clone());
167
168            let content =
169                tokio::fs::read_to_string(path)
170                    .await
171                    .map_err(|e| ContextError::Source {
172                        message: format!("Failed to read {}: {}", path.display(), e),
173                    })?;
174
175            self.current_depth += 1;
176            let result = self
177                .process_imports(&content, path.parent().unwrap_or(Path::new(".")))
178                .await;
179            self.current_depth -= 1;
180
181            result
182        })
183    }
184
185    fn expand_home(path: &str) -> PathBuf {
186        if let Some(rest) = path.strip_prefix("~/")
187            && let Some(home) = crate::common::home_dir()
188        {
189            return home.join(rest);
190        }
191        PathBuf::from(path)
192    }
193
194    fn process_imports<'a>(
195        &'a mut self,
196        content: &'a str,
197        base_dir: &'a Path,
198    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
199    {
200        Box::pin(async move {
201            let mut result = String::new();
202
203            for line in content.lines() {
204                let trimmed = line.trim();
205
206                if trimmed.starts_with('@') && !trimmed.starts_with("@@") {
207                    let import_path = trimmed.trim_start_matches('@').trim();
208                    if !import_path.is_empty() {
209                        let full_path = if import_path.starts_with("~/") {
210                            Self::expand_home(import_path)
211                        } else if import_path.starts_with('/') {
212                            PathBuf::from(import_path)
213                        } else {
214                            base_dir.join(import_path)
215                        };
216
217                        if full_path.exists() {
218                            match self.load_file_with_imports(&full_path).await {
219                                Ok(imported) => {
220                                    result.push_str(&imported);
221                                    result.push('\n');
222                                }
223                                Err(e) => {
224                                    tracing::warn!("Failed to import {}: {}", import_path, e);
225                                    result.push_str(line);
226                                    result.push('\n');
227                                }
228                            }
229                        } else {
230                            result.push_str(line);
231                            result.push('\n');
232                        }
233                    } else {
234                        result.push_str(line);
235                        result.push('\n');
236                    }
237                } else {
238                    result.push_str(line);
239                    result.push('\n');
240                }
241            }
242
243            Ok(result)
244        })
245    }
246}
247
248#[derive(Debug, Default)]
249pub struct MemoryContent {
250    pub claude_md: Vec<String>,
251    pub local_md: Vec<String>,
252    pub rule_indices: Vec<RuleIndex>,
253}
254
255impl MemoryContent {
256    pub fn combined_claude_md(&self) -> String {
257        let mut parts = Vec::new();
258
259        for content in &self.claude_md {
260            if !content.trim().is_empty() {
261                parts.push(content.clone());
262            }
263        }
264
265        for content in &self.local_md {
266            if !content.trim().is_empty() {
267                parts.push(content.clone());
268            }
269        }
270
271        parts.join("\n\n")
272    }
273
274    pub fn is_empty(&self) -> bool {
275        self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use tempfile::tempdir;
283    use tokio::fs;
284
285    #[tokio::test]
286    async fn test_load_claude_md() {
287        let dir = tempdir().unwrap();
288        fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
289            .await
290            .unwrap();
291
292        let mut loader = MemoryLoader::new();
293        let content = loader.load_all(dir.path()).await.unwrap();
294
295        assert_eq!(content.claude_md.len(), 1);
296        assert!(content.claude_md[0].contains("Test content"));
297    }
298
299    #[tokio::test]
300    async fn test_load_local_md() {
301        let dir = tempdir().unwrap();
302        fs::write(dir.path().join("CLAUDE.local.md"), "Local settings")
303            .await
304            .unwrap();
305
306        let mut loader = MemoryLoader::new();
307        let content = loader.load_all(dir.path()).await.unwrap();
308
309        assert_eq!(content.local_md.len(), 1);
310        assert!(content.local_md[0].contains("Local settings"));
311    }
312
313    #[tokio::test]
314    async fn test_scan_rules_indices_only() {
315        let dir = tempdir().unwrap();
316        let rules_dir = dir.path().join(".claude").join("rules");
317        fs::create_dir_all(&rules_dir).await.unwrap();
318
319        fs::write(
320            rules_dir.join("rust.md"),
321            r#"---
322paths: **/*.rs
323priority: 10
324---
325
326# Rust Rules
327Use snake_case"#,
328        )
329        .await
330        .unwrap();
331
332        fs::write(rules_dir.join("security.md"), "# Security\nNo secrets")
333            .await
334            .unwrap();
335
336        let mut loader = MemoryLoader::new();
337        let content = loader.load_all(dir.path()).await.unwrap();
338
339        assert_eq!(content.rule_indices.len(), 2);
340
341        let rust_rule = content.rule_indices.iter().find(|r| r.name == "rust");
342        assert!(rust_rule.is_some());
343        assert_eq!(rust_rule.unwrap().priority, 10);
344        assert!(rust_rule.unwrap().paths.is_some());
345    }
346
347    #[tokio::test]
348    async fn test_import_syntax() {
349        let dir = tempdir().unwrap();
350
351        fs::write(
352            dir.path().join("CLAUDE.md"),
353            "# Main\n@docs/guidelines.md\nEnd",
354        )
355        .await
356        .unwrap();
357
358        let docs_dir = dir.path().join("docs");
359        fs::create_dir_all(&docs_dir).await.unwrap();
360        fs::write(docs_dir.join("guidelines.md"), "Imported content")
361            .await
362            .unwrap();
363
364        let mut loader = MemoryLoader::new();
365        let content = loader.load_all(dir.path()).await.unwrap();
366
367        assert!(content.combined_claude_md().contains("Imported content"));
368    }
369
370    #[tokio::test]
371    async fn test_combined_content() {
372        let dir = tempdir().unwrap();
373        fs::write(dir.path().join("CLAUDE.md"), "Main content")
374            .await
375            .unwrap();
376        fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
377            .await
378            .unwrap();
379
380        let mut loader = MemoryLoader::new();
381        let content = loader.load_all(dir.path()).await.unwrap();
382
383        let combined = content.combined_claude_md();
384        assert!(combined.contains("Main content"));
385        assert!(combined.contains("Local content"));
386    }
387}