claude_agent/context/
memory_loader.rs

1//! CLAUDE.md and CLAUDE.local.md loader with CLI-compatible @import processing.
2//!
3//! This module provides a memory loader that reads CLAUDE.md and CLAUDE.local.md files
4//! with support for recursive @import directives. It implements the same import behavior
5//! as Claude Code CLI 2.1.12.
6
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use super::import_extractor::ImportExtractor;
11use super::rule_index::RuleIndex;
12use super::{ContextError, ContextResult};
13
14/// Maximum import depth to prevent infinite recursion (CLI: ZH5 = 5).
15pub const MAX_IMPORT_DEPTH: usize = 5;
16
17/// Memory loader with CLI-compatible @import processing.
18///
19/// # Features
20/// - Loads CLAUDE.md and CLAUDE.local.md from project directories
21/// - Supports recursive @import with depth limiting
22/// - Circular import detection using canonical path tracking
23/// - Scans .claude/rules/ directory for rule files
24///
25/// # CLI Compatibility
26/// This implementation matches Claude Code CLI 2.1.12 behavior:
27/// - Maximum import depth of 5
28/// - Same path validation rules
29/// - Same circular import prevention
30pub struct MemoryLoader {
31    extractor: ImportExtractor,
32}
33
34impl MemoryLoader {
35    /// Creates a new MemoryLoader with CLI-compatible import extraction.
36    pub fn new() -> Self {
37        Self {
38            extractor: ImportExtractor::new(),
39        }
40    }
41
42    /// Loads all memory content (CLAUDE.md + CLAUDE.local.md + rules) from a directory.
43    ///
44    /// # Arguments
45    /// * `start_dir` - The project root directory to load from
46    ///
47    /// # Returns
48    /// Combined MemoryContent with all loaded files and rules
49    pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
50        let mut content = self.load_shared(start_dir).await?;
51        let local = self.load_local(start_dir).await?;
52        content.merge(local);
53        Ok(content)
54    }
55
56    /// Loads shared CLAUDE.md and rules (visible to all team members).
57    pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
58        let mut content = MemoryContent::default();
59        let mut visited = HashSet::new();
60
61        for path in Self::find_claude_files(start_dir) {
62            match self.load_with_imports(&path, 0, &mut visited).await {
63                Ok(text) => content.claude_md.push(text),
64                Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
65            }
66        }
67
68        let rules_dir = start_dir.join(".claude").join("rules");
69        if rules_dir.exists() {
70            content.rule_indices = self.scan_rules(&rules_dir).await?;
71        }
72
73        Ok(content)
74    }
75
76    /// Loads local CLAUDE.local.md (private to the user, not in version control).
77    pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
78        let mut content = MemoryContent::default();
79        let mut visited = HashSet::new();
80
81        for path in Self::find_local_files(start_dir) {
82            match self.load_with_imports(&path, 0, &mut visited).await {
83                Ok(text) => content.local_md.push(text),
84                Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
85            }
86        }
87
88        Ok(content)
89    }
90
91    /// Loads a file with recursive @import expansion.
92    ///
93    /// # Arguments
94    /// * `path` - Path to the file to load
95    /// * `depth` - Current import depth (0 = root)
96    /// * `visited` - Set of canonical paths already loaded (for circular detection)
97    ///
98    /// # Returns
99    /// File content with all imports expanded inline
100    fn load_with_imports<'a>(
101        &'a self,
102        path: &'a Path,
103        depth: usize,
104        visited: &'a mut HashSet<PathBuf>,
105    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
106    {
107        Box::pin(async move {
108            // Depth limit check (CLI: ZH5 = 5)
109            if depth > MAX_IMPORT_DEPTH {
110                tracing::warn!(
111                    "Import depth limit ({}) exceeded, skipping: {}",
112                    MAX_IMPORT_DEPTH,
113                    path.display()
114                );
115                return Ok(String::new());
116            }
117
118            // Circular import detection using canonical paths
119            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
120            if visited.contains(&canonical) {
121                tracing::debug!("Circular import detected, skipping: {}", path.display());
122                return Ok(String::new());
123            }
124            visited.insert(canonical);
125
126            // Read file content
127            let content =
128                tokio::fs::read_to_string(path)
129                    .await
130                    .map_err(|e| ContextError::Source {
131                        message: format!("Failed to read {}: {}", path.display(), e),
132                    })?;
133
134            // Extract and process imports
135            let base_dir = path.parent().unwrap_or(Path::new("."));
136            let imports = self.extractor.extract(&content, base_dir);
137
138            // Build result with imported content appended
139            let mut result = content;
140            for import_path in imports {
141                if import_path.exists() {
142                    if let Ok(imported) = self
143                        .load_with_imports(&import_path, depth + 1, visited)
144                        .await
145                        && !imported.is_empty()
146                    {
147                        result.push_str("\n\n");
148                        result.push_str(&imported);
149                    }
150                } else {
151                    tracing::debug!("Import not found, skipping: {}", import_path.display());
152                }
153            }
154
155            Ok(result)
156        })
157    }
158
159    /// Finds CLAUDE.md files in standard locations.
160    fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
161        let mut files = Vec::new();
162
163        // Project root CLAUDE.md
164        let claude_md = start_dir.join("CLAUDE.md");
165        if claude_md.exists() {
166            files.push(claude_md);
167        }
168
169        // .claude/CLAUDE.md (alternative location)
170        let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
171        if claude_dir_md.exists() {
172            files.push(claude_dir_md);
173        }
174
175        files
176    }
177
178    /// Finds CLAUDE.local.md files in standard locations.
179    fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
180        let mut files = Vec::new();
181
182        // Project root CLAUDE.local.md
183        let local_md = start_dir.join("CLAUDE.local.md");
184        if local_md.exists() {
185            files.push(local_md);
186        }
187
188        // .claude/CLAUDE.local.md (alternative location)
189        let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
190        if local_dir_md.exists() {
191            files.push(local_dir_md);
192        }
193
194        files
195    }
196
197    /// Scans .claude/rules/ directory recursively for rule files.
198    async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
199        let mut indices = Vec::new();
200        self.scan_rules_recursive(dir, &mut indices).await?;
201        indices.sort_by(|a, b| b.priority.cmp(&a.priority));
202        Ok(indices)
203    }
204
205    fn scan_rules_recursive<'a>(
206        &'a self,
207        dir: &'a Path,
208        indices: &'a mut Vec<RuleIndex>,
209    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
210        Box::pin(async move {
211            let mut entries = tokio::fs::read_dir(dir)
212                .await
213                .map_err(|e| ContextError::Source {
214                    message: format!("Failed to read rules directory: {}", e),
215                })?;
216
217            while let Some(entry) =
218                entries
219                    .next_entry()
220                    .await
221                    .map_err(|e| ContextError::Source {
222                        message: format!("Failed to read directory entry: {}", e),
223                    })?
224            {
225                let path = entry.path();
226
227                if path.is_dir() {
228                    self.scan_rules_recursive(&path, indices).await?;
229                } else if path.extension().is_some_and(|e| e == "md")
230                    && let Some(index) = RuleIndex::from_file(&path)
231                {
232                    indices.push(index);
233                }
234            }
235
236            Ok(())
237        })
238    }
239}
240
241impl Default for MemoryLoader {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247/// Loaded memory content from CLAUDE.md files and rules.
248#[derive(Debug, Default, Clone)]
249pub struct MemoryContent {
250    /// Content from CLAUDE.md files (shared/team config).
251    pub claude_md: Vec<String>,
252    /// Content from CLAUDE.local.md files (user-specific config).
253    pub local_md: Vec<String>,
254    /// Rule indices from .claude/rules/ directory.
255    pub rule_indices: Vec<RuleIndex>,
256}
257
258impl MemoryContent {
259    /// Combines all CLAUDE.md and CLAUDE.local.md content into a single string.
260    pub fn combined_claude_md(&self) -> String {
261        self.claude_md
262            .iter()
263            .chain(self.local_md.iter())
264            .filter(|c| !c.trim().is_empty())
265            .cloned()
266            .collect::<Vec<_>>()
267            .join("\n\n")
268    }
269
270    /// Returns true if no content was loaded.
271    pub fn is_empty(&self) -> bool {
272        self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
273    }
274
275    /// Merges another MemoryContent into this one.
276    pub fn merge(&mut self, other: MemoryContent) {
277        self.claude_md.extend(other.claude_md);
278        self.local_md.extend(other.local_md);
279        self.rule_indices.extend(other.rule_indices);
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use tempfile::tempdir;
287    use tokio::fs;
288
289    #[tokio::test]
290    async fn test_load_claude_md() {
291        let dir = tempdir().unwrap();
292        fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
293            .await
294            .unwrap();
295
296        let loader = MemoryLoader::new();
297        let content = loader.load(dir.path()).await.unwrap();
298
299        assert_eq!(content.claude_md.len(), 1);
300        assert!(content.claude_md[0].contains("Test content"));
301    }
302
303    #[tokio::test]
304    async fn test_load_local_md() {
305        let dir = tempdir().unwrap();
306        fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
307            .await
308            .unwrap();
309
310        let loader = MemoryLoader::new();
311        let content = loader.load(dir.path()).await.unwrap();
312
313        assert_eq!(content.local_md.len(), 1);
314        assert!(content.local_md[0].contains("Private"));
315    }
316
317    #[tokio::test]
318    async fn test_scan_rules_recursive() {
319        let dir = tempdir().unwrap();
320        let rules_dir = dir.path().join(".claude").join("rules");
321        let sub_dir = rules_dir.join("frontend");
322        fs::create_dir_all(&sub_dir).await.unwrap();
323
324        fs::write(
325            rules_dir.join("rust.md"),
326            "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
327        )
328        .await
329        .unwrap();
330
331        fs::write(
332            sub_dir.join("react.md"),
333            "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
334        )
335        .await
336        .unwrap();
337
338        let loader = MemoryLoader::new();
339        let content = loader.load(dir.path()).await.unwrap();
340
341        assert_eq!(content.rule_indices.len(), 2);
342        assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
343        assert!(content.rule_indices.iter().any(|r| r.name == "react"));
344    }
345
346    #[tokio::test]
347    async fn test_import_syntax() {
348        let dir = tempdir().unwrap();
349
350        fs::write(
351            dir.path().join("CLAUDE.md"),
352            "# Main\n@docs/guidelines.md\nEnd",
353        )
354        .await
355        .unwrap();
356
357        let docs_dir = dir.path().join("docs");
358        fs::create_dir_all(&docs_dir).await.unwrap();
359        fs::write(docs_dir.join("guidelines.md"), "Imported content")
360            .await
361            .unwrap();
362
363        let loader = MemoryLoader::new();
364        let content = loader.load(dir.path()).await.unwrap();
365
366        assert!(content.combined_claude_md().contains("Imported content"));
367    }
368
369    #[tokio::test]
370    async fn test_combined_includes_local() {
371        let dir = tempdir().unwrap();
372        fs::write(dir.path().join("CLAUDE.md"), "Main content")
373            .await
374            .unwrap();
375        fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
376            .await
377            .unwrap();
378
379        let loader = MemoryLoader::new();
380        let content = loader.load(dir.path()).await.unwrap();
381
382        let combined = content.combined_claude_md();
383        assert!(combined.contains("Main content"));
384        assert!(combined.contains("Local content"));
385    }
386
387    #[tokio::test]
388    async fn test_recursive_import() {
389        let dir = tempdir().unwrap();
390
391        // CLAUDE.md → docs/guide.md → docs/detail.md
392        fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
393            .await
394            .unwrap();
395
396        let docs_dir = dir.path().join("docs");
397        fs::create_dir_all(&docs_dir).await.unwrap();
398        fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
399            .await
400            .unwrap();
401        fs::write(docs_dir.join("detail.md"), "Detail content")
402            .await
403            .unwrap();
404
405        let loader = MemoryLoader::new();
406        let content = loader.load(dir.path()).await.unwrap();
407        let combined = content.combined_claude_md();
408
409        assert!(combined.contains("Root content"));
410        assert!(combined.contains("Guide content"));
411        assert!(combined.contains("Detail content"));
412    }
413
414    #[tokio::test]
415    async fn test_depth_limit() {
416        let dir = tempdir().unwrap();
417
418        // Create chain: CLAUDE.md → level1.md → level2.md → ... → level6.md
419        // Should stop at level 5 (depth = 5 means 6 files deep: root + 5 imports)
420        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
421            .await
422            .unwrap();
423
424        for i in 1..=6 {
425            let content = if i < 6 {
426                format!("Level {} @level{}.md", i, i + 1)
427            } else {
428                format!("Level {}", i)
429            };
430            fs::write(dir.path().join(format!("level{}.md", i)), content)
431                .await
432                .unwrap();
433        }
434
435        let loader = MemoryLoader::new();
436        let content = loader.load(dir.path()).await.unwrap();
437        let combined = content.combined_claude_md();
438
439        // Should have levels 0-5 but NOT level 6 (depth limit)
440        assert!(combined.contains("Level 0"));
441        assert!(combined.contains("Level 5"));
442        assert!(!combined.contains("Level 6"));
443    }
444
445    #[tokio::test]
446    async fn test_circular_import() {
447        let dir = tempdir().unwrap();
448
449        // CLAUDE.md → a.md → b.md → a.md (circular)
450        fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
451            .await
452            .unwrap();
453        fs::write(dir.path().join("a.md"), "A content @b.md")
454            .await
455            .unwrap();
456        fs::write(dir.path().join("b.md"), "B content @a.md")
457            .await
458            .unwrap();
459
460        let loader = MemoryLoader::new();
461        let result = loader.load(dir.path()).await;
462
463        // Should not infinite loop and should succeed
464        assert!(result.is_ok());
465        let combined = result.unwrap().combined_claude_md();
466        assert!(combined.contains("A content"));
467        assert!(combined.contains("B content"));
468    }
469
470    #[tokio::test]
471    async fn test_import_in_code_block_ignored() {
472        let dir = tempdir().unwrap();
473
474        fs::write(
475            dir.path().join("CLAUDE.md"),
476            "# Example\n```\n@should/not/import.md\n```\n@should/import.md",
477        )
478        .await
479        .unwrap();
480
481        fs::write(
482            dir.path().join("should").join("import.md"),
483            "This is imported",
484        )
485        .await
486        .ok();
487        let should_dir = dir.path().join("should");
488        fs::create_dir_all(&should_dir).await.unwrap();
489        fs::write(should_dir.join("import.md"), "Imported content")
490            .await
491            .unwrap();
492
493        let loader = MemoryLoader::new();
494        let content = loader.load(dir.path()).await.unwrap();
495        let combined = content.combined_claude_md();
496
497        assert!(combined.contains("Imported content"));
498        // The @should/not/import.md in code block should remain as-is, not be processed
499        assert!(combined.contains("@should/not/import.md"));
500    }
501
502    #[tokio::test]
503    async fn test_missing_import_ignored() {
504        let dir = tempdir().unwrap();
505
506        fs::write(
507            dir.path().join("CLAUDE.md"),
508            "# Main\n@nonexistent/file.md\nRest of content",
509        )
510        .await
511        .unwrap();
512
513        let loader = MemoryLoader::new();
514        let content = loader.load(dir.path()).await.unwrap();
515        let combined = content.combined_claude_md();
516
517        // Should still load the main content even if import doesn't exist
518        assert!(combined.contains("# Main"));
519        assert!(combined.contains("Rest of content"));
520    }
521
522    #[tokio::test]
523    async fn test_empty_content() {
524        let dir = tempdir().unwrap();
525
526        let loader = MemoryLoader::new();
527        let content = loader.load(dir.path()).await.unwrap();
528
529        assert!(content.is_empty());
530        assert!(content.combined_claude_md().is_empty());
531    }
532
533    #[tokio::test]
534    async fn test_memory_content_merge() {
535        let mut content1 = MemoryContent {
536            claude_md: vec!["content1".to_string()],
537            local_md: vec!["local1".to_string()],
538            rule_indices: vec![],
539        };
540
541        let content2 = MemoryContent {
542            claude_md: vec!["content2".to_string()],
543            local_md: vec!["local2".to_string()],
544            rule_indices: vec![],
545        };
546
547        content1.merge(content2);
548
549        assert_eq!(content1.claude_md.len(), 2);
550        assert_eq!(content1.local_md.len(), 2);
551    }
552}