claude_agent/context/
memory_loader.rs

1//! CLAUDE.md and CLAUDE.local.md loader with CLI-compatible @import processing.
2//!
3//! This module provides a memory loader that reads CLAUDE.md and CLAUDE.local.md files
4//! with support for recursive @import directives. It implements the same import behavior
5//! as Claude Code CLI 2.1.12.
6
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use super::import_extractor::ImportExtractor;
11use super::rule_index::RuleIndex;
12use super::{ContextError, ContextResult};
13
14/// Default maximum import depth for CLI-like behavior.
15/// CLI uses depth 5 technically but loads ~24K tokens of memory.
16/// Depth 2 gives ~31K tokens which is the closest match.
17pub const DEFAULT_IMPORT_DEPTH: usize = 2;
18
19/// Maximum import depth when full expansion is needed (CLI's technical limit).
20pub const MAX_IMPORT_DEPTH: usize = 5;
21
22/// Configuration for MemoryLoader.
23#[derive(Debug, Clone)]
24pub struct MemoryLoaderConfig {
25    /// Maximum import depth (default: 2 for CLI-like token counts).
26    /// Use MAX_IMPORT_DEPTH (5) for full expansion.
27    pub max_depth: usize,
28}
29
30impl Default for MemoryLoaderConfig {
31    fn default() -> Self {
32        Self {
33            max_depth: DEFAULT_IMPORT_DEPTH,
34        }
35    }
36}
37
38impl MemoryLoaderConfig {
39    /// Creates config with full import expansion (depth 5).
40    pub fn full_expansion() -> Self {
41        Self {
42            max_depth: MAX_IMPORT_DEPTH,
43        }
44    }
45
46    /// Creates config with specified max depth.
47    pub fn with_max_depth(max_depth: usize) -> Self {
48        Self { max_depth }
49    }
50}
51
52/// Memory loader with CLI-compatible @import processing.
53///
54/// # Features
55/// - Loads CLAUDE.md and CLAUDE.local.md from project directories
56/// - Supports recursive @import with depth limiting
57/// - Circular import detection using canonical path tracking
58/// - Scans .claude/rules/ directory for rule files
59///
60/// # CLI Compatibility
61/// This implementation matches Claude Code CLI 2.1.12 behavior:
62/// - Maximum import depth of 5 (configurable, default 3 for similar token counts)
63/// - Same path validation rules
64/// - Same circular import prevention
65pub struct MemoryLoader {
66    extractor: ImportExtractor,
67    config: MemoryLoaderConfig,
68}
69
70impl MemoryLoader {
71    /// Creates a new MemoryLoader with default configuration (depth 3).
72    pub fn new() -> Self {
73        Self::with_config(MemoryLoaderConfig::default())
74    }
75
76    /// Creates a new MemoryLoader with custom configuration.
77    pub fn with_config(config: MemoryLoaderConfig) -> Self {
78        Self {
79            extractor: ImportExtractor::new(),
80            config,
81        }
82    }
83
84    /// Creates a new MemoryLoader with full import expansion (depth 5).
85    pub fn full_expansion() -> Self {
86        Self::with_config(MemoryLoaderConfig::full_expansion())
87    }
88
89    /// Loads all memory content (CLAUDE.md + CLAUDE.local.md + rules) from a directory.
90    ///
91    /// # Arguments
92    /// * `start_dir` - The project root directory to load from
93    ///
94    /// # Returns
95    /// Combined MemoryContent with all loaded files and rules
96    pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
97        let mut content = self.load_shared(start_dir).await?;
98        let local = self.load_local(start_dir).await?;
99        content.merge(local);
100        Ok(content)
101    }
102
103    /// Loads shared CLAUDE.md and rules (visible to all team members).
104    pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
105        let mut content = MemoryContent::default();
106        let mut visited = HashSet::new();
107
108        for path in Self::find_claude_files(start_dir) {
109            match self
110                .load_with_imports(&path, start_dir, 0, &mut visited)
111                .await
112            {
113                Ok(text) => content.claude_md.push(text),
114                Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
115            }
116        }
117
118        let rules_dir = start_dir.join(".claude").join("rules");
119        if rules_dir.exists() {
120            content.rule_indices = self.scan_rules(&rules_dir).await?;
121        }
122
123        Ok(content)
124    }
125
126    /// Loads local CLAUDE.local.md (private to the user, not in version control).
127    pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
128        let mut content = MemoryContent::default();
129        let mut visited = HashSet::new();
130
131        for path in Self::find_local_files(start_dir) {
132            match self
133                .load_with_imports(&path, start_dir, 0, &mut visited)
134                .await
135            {
136                Ok(text) => content.local_md.push(text),
137                Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
138            }
139        }
140
141        Ok(content)
142    }
143
144    /// Loads a file with recursive @import expansion.
145    ///
146    /// # Arguments
147    /// * `path` - Path to the file to load
148    /// * `project_root` - Project root directory for resolving @.agents/... style imports
149    /// * `depth` - Current import depth (0 = root)
150    /// * `visited` - Set of canonical paths already loaded (for circular detection)
151    ///
152    /// # Returns
153    /// File content with all imports expanded inline
154    fn load_with_imports<'a>(
155        &'a self,
156        path: &'a Path,
157        project_root: &'a Path,
158        depth: usize,
159        visited: &'a mut HashSet<PathBuf>,
160    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
161    {
162        Box::pin(async move {
163            // Depth limit check (configurable, default 3)
164            if depth > self.config.max_depth {
165                tracing::warn!(
166                    "Import depth limit ({}) exceeded, skipping: {}",
167                    self.config.max_depth,
168                    path.display()
169                );
170                return Ok(String::new());
171            }
172
173            // Circular import detection using canonical paths
174            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
175            if visited.contains(&canonical) {
176                tracing::debug!("Circular import detected, skipping: {}", path.display());
177                return Ok(String::new());
178            }
179            visited.insert(canonical);
180
181            // Read file content
182            let content =
183                tokio::fs::read_to_string(path)
184                    .await
185                    .map_err(|e| ContextError::Source {
186                        message: format!("Failed to read {}: {}", path.display(), e),
187                    })?;
188
189            // Extract and process imports
190            // Use current file's directory for relative path resolution
191            let file_dir = path.parent().unwrap_or(Path::new("."));
192            let imports = self.extractor.extract(&content, file_dir);
193
194            // Post-process: fix duplicated .agents/ or .claude/ paths
195            // e.g., /project/.agents/guides/.agents/patterns -> .agents/patterns
196            let imports: Vec<PathBuf> = imports
197                .into_iter()
198                .map(|p| Self::normalize_project_relative_path(&p, project_root))
199                .collect();
200
201            // Build result with imported content appended
202            let mut result = content;
203            for import_path in imports {
204                if import_path.exists() {
205                    if let Ok(imported) = self
206                        .load_with_imports(&import_path, project_root, depth + 1, visited)
207                        .await
208                        && !imported.is_empty()
209                    {
210                        result.push_str("\n\n");
211                        result.push_str(&imported);
212                    }
213                } else {
214                    tracing::debug!("Import not found, skipping: {}", import_path.display());
215                }
216            }
217
218            Ok(result)
219        })
220    }
221
222    /// Finds CLAUDE.md files in standard locations.
223    fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
224        let mut files = Vec::new();
225
226        // Project root CLAUDE.md
227        let claude_md = start_dir.join("CLAUDE.md");
228        if claude_md.exists() {
229            files.push(claude_md);
230        }
231
232        // .claude/CLAUDE.md (alternative location)
233        let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
234        if claude_dir_md.exists() {
235            files.push(claude_dir_md);
236        }
237
238        files
239    }
240
241    /// Finds CLAUDE.local.md files in standard locations.
242    fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
243        let mut files = Vec::new();
244
245        // Project root CLAUDE.local.md
246        let local_md = start_dir.join("CLAUDE.local.md");
247        if local_md.exists() {
248            files.push(local_md);
249        }
250
251        // .claude/CLAUDE.local.md (alternative location)
252        let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
253        if local_dir_md.exists() {
254            files.push(local_dir_md);
255        }
256
257        files
258    }
259
260    /// Normalizes paths with duplicated .agents/ or .claude/ segments.
261    ///
262    /// When imports are resolved from nested files (e.g., .agents/guides/workflow.md),
263    /// relative paths like @.agents/patterns/... get incorrectly expanded to
264    /// /project/.agents/guides/.agents/patterns/... (duplicated .agents/).
265    ///
266    /// This function detects such duplications and re-resolves from project root.
267    fn normalize_project_relative_path(path: &Path, project_root: &Path) -> PathBuf {
268        const MARKERS: [&str; 2] = ["/.agents/", "/.claude/"];
269
270        let path_str = path.to_string_lossy();
271
272        // Find the last occurrence of any project-relative marker
273        let last_marker_pos = MARKERS
274            .iter()
275            .filter_map(|marker| {
276                let count = path_str.matches(marker).count();
277                // Only fix if duplicated (count > 1) or mixed markers exist
278                if count > 1 || MARKERS.iter().filter(|m| path_str.contains(*m)).count() > 1 {
279                    path_str.rfind(marker).map(|pos| (pos, *marker))
280                } else {
281                    None
282                }
283            })
284            .max_by_key(|(pos, _)| *pos);
285
286        if let Some((idx, _)) = last_marker_pos {
287            let relative_part = &path_str[idx + 1..]; // Skip leading "/"
288            project_root.join(relative_part)
289        } else {
290            path.to_path_buf()
291        }
292    }
293
294    /// Scans .claude/rules/ directory recursively for rule files.
295    async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
296        let mut indices = Vec::new();
297        self.scan_rules_recursive(dir, &mut indices).await?;
298        indices.sort_by(|a, b| b.priority.cmp(&a.priority));
299        Ok(indices)
300    }
301
302    fn scan_rules_recursive<'a>(
303        &'a self,
304        dir: &'a Path,
305        indices: &'a mut Vec<RuleIndex>,
306    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
307        Box::pin(async move {
308            let mut entries = tokio::fs::read_dir(dir)
309                .await
310                .map_err(|e| ContextError::Source {
311                    message: format!("Failed to read rules directory: {}", e),
312                })?;
313
314            while let Some(entry) =
315                entries
316                    .next_entry()
317                    .await
318                    .map_err(|e| ContextError::Source {
319                        message: format!("Failed to read directory entry: {}", e),
320                    })?
321            {
322                let path = entry.path();
323
324                if path.is_dir() {
325                    self.scan_rules_recursive(&path, indices).await?;
326                } else if path.extension().is_some_and(|e| e == "md")
327                    && let Some(index) = RuleIndex::from_file(&path)
328                {
329                    indices.push(index);
330                }
331            }
332
333            Ok(())
334        })
335    }
336}
337
338impl Default for MemoryLoader {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344/// Loaded memory content from CLAUDE.md files and rules.
345#[derive(Debug, Default, Clone)]
346pub struct MemoryContent {
347    /// Content from CLAUDE.md files (shared/team config).
348    pub claude_md: Vec<String>,
349    /// Content from CLAUDE.local.md files (user-specific config).
350    pub local_md: Vec<String>,
351    /// Rule indices from .claude/rules/ directory.
352    pub rule_indices: Vec<RuleIndex>,
353}
354
355impl MemoryContent {
356    /// Combines all CLAUDE.md and CLAUDE.local.md content into a single string.
357    pub fn combined_claude_md(&self) -> String {
358        self.claude_md
359            .iter()
360            .chain(self.local_md.iter())
361            .filter(|c| !c.trim().is_empty())
362            .cloned()
363            .collect::<Vec<_>>()
364            .join("\n\n")
365    }
366
367    /// Returns true if no content was loaded.
368    pub fn is_empty(&self) -> bool {
369        self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
370    }
371
372    /// Merges another MemoryContent into this one.
373    pub fn merge(&mut self, other: MemoryContent) {
374        self.claude_md.extend(other.claude_md);
375        self.local_md.extend(other.local_md);
376        self.rule_indices.extend(other.rule_indices);
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use tempfile::tempdir;
384    use tokio::fs;
385
386    #[tokio::test]
387    async fn test_load_claude_md() {
388        let dir = tempdir().unwrap();
389        fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
390            .await
391            .unwrap();
392
393        let loader = MemoryLoader::new();
394        let content = loader.load(dir.path()).await.unwrap();
395
396        assert_eq!(content.claude_md.len(), 1);
397        assert!(content.claude_md[0].contains("Test content"));
398    }
399
400    #[tokio::test]
401    async fn test_load_local_md() {
402        let dir = tempdir().unwrap();
403        fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
404            .await
405            .unwrap();
406
407        let loader = MemoryLoader::new();
408        let content = loader.load(dir.path()).await.unwrap();
409
410        assert_eq!(content.local_md.len(), 1);
411        assert!(content.local_md[0].contains("Private"));
412    }
413
414    #[tokio::test]
415    async fn test_scan_rules_recursive() {
416        let dir = tempdir().unwrap();
417        let rules_dir = dir.path().join(".claude").join("rules");
418        let sub_dir = rules_dir.join("frontend");
419        fs::create_dir_all(&sub_dir).await.unwrap();
420
421        fs::write(
422            rules_dir.join("rust.md"),
423            "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
424        )
425        .await
426        .unwrap();
427
428        fs::write(
429            sub_dir.join("react.md"),
430            "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
431        )
432        .await
433        .unwrap();
434
435        let loader = MemoryLoader::new();
436        let content = loader.load(dir.path()).await.unwrap();
437
438        assert_eq!(content.rule_indices.len(), 2);
439        assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
440        assert!(content.rule_indices.iter().any(|r| r.name == "react"));
441    }
442
443    #[tokio::test]
444    async fn test_import_syntax() {
445        let dir = tempdir().unwrap();
446
447        fs::write(
448            dir.path().join("CLAUDE.md"),
449            "# Main\n@docs/guidelines.md\nEnd",
450        )
451        .await
452        .unwrap();
453
454        let docs_dir = dir.path().join("docs");
455        fs::create_dir_all(&docs_dir).await.unwrap();
456        fs::write(docs_dir.join("guidelines.md"), "Imported content")
457            .await
458            .unwrap();
459
460        let loader = MemoryLoader::new();
461        let content = loader.load(dir.path()).await.unwrap();
462
463        assert!(content.combined_claude_md().contains("Imported content"));
464    }
465
466    #[tokio::test]
467    async fn test_combined_includes_local() {
468        let dir = tempdir().unwrap();
469        fs::write(dir.path().join("CLAUDE.md"), "Main content")
470            .await
471            .unwrap();
472        fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
473            .await
474            .unwrap();
475
476        let loader = MemoryLoader::new();
477        let content = loader.load(dir.path()).await.unwrap();
478
479        let combined = content.combined_claude_md();
480        assert!(combined.contains("Main content"));
481        assert!(combined.contains("Local content"));
482    }
483
484    #[tokio::test]
485    async fn test_recursive_import() {
486        let dir = tempdir().unwrap();
487
488        // CLAUDE.md → docs/guide.md → docs/detail.md
489        fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
490            .await
491            .unwrap();
492
493        let docs_dir = dir.path().join("docs");
494        fs::create_dir_all(&docs_dir).await.unwrap();
495        fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
496            .await
497            .unwrap();
498        fs::write(docs_dir.join("detail.md"), "Detail content")
499            .await
500            .unwrap();
501
502        let loader = MemoryLoader::new();
503        let content = loader.load(dir.path()).await.unwrap();
504        let combined = content.combined_claude_md();
505
506        assert!(combined.contains("Root content"));
507        assert!(combined.contains("Guide content"));
508        assert!(combined.contains("Detail content"));
509    }
510
511    #[tokio::test]
512    async fn test_depth_limit_default() {
513        let dir = tempdir().unwrap();
514
515        // Create chain: CLAUDE.md → level1.md → level2.md → level3.md
516        // With default depth 2, should stop at level 2 (0,1,2 loaded, 3 not)
517        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
518            .await
519            .unwrap();
520
521        for i in 1..=3 {
522            let content = if i < 3 {
523                format!("Level {} @level{}.md", i, i + 1)
524            } else {
525                format!("Level {}", i)
526            };
527            fs::write(dir.path().join(format!("level{}.md", i)), content)
528                .await
529                .unwrap();
530        }
531
532        let loader = MemoryLoader::new(); // Default depth 2
533        let content = loader.load(dir.path()).await.unwrap();
534        let combined = content.combined_claude_md();
535
536        // Should have levels 0-2 but NOT level 3 (default depth limit = 2)
537        assert!(combined.contains("Level 0"));
538        assert!(combined.contains("Level 2"));
539        assert!(!combined.contains("Level 3"));
540    }
541
542    #[tokio::test]
543    async fn test_depth_limit_full_expansion() {
544        let dir = tempdir().unwrap();
545
546        // Create chain: CLAUDE.md → level1.md → level2.md → ... → level6.md
547        // With full expansion (depth 5), should stop at level 5 (0-5 loaded, 6 not)
548        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
549            .await
550            .unwrap();
551
552        for i in 1..=6 {
553            let content = if i < 6 {
554                format!("Level {} @level{}.md", i, i + 1)
555            } else {
556                format!("Level {}", i)
557            };
558            fs::write(dir.path().join(format!("level{}.md", i)), content)
559                .await
560                .unwrap();
561        }
562
563        let loader = MemoryLoader::full_expansion(); // Depth 5
564        let content = loader.load(dir.path()).await.unwrap();
565        let combined = content.combined_claude_md();
566
567        // Should have levels 0-5 but NOT level 6 (max depth limit = 5)
568        assert!(combined.contains("Level 0"));
569        assert!(combined.contains("Level 5"));
570        assert!(!combined.contains("Level 6"));
571    }
572
573    #[tokio::test]
574    async fn test_depth_limit_custom() {
575        let dir = tempdir().unwrap();
576
577        // Create chain up to level 3
578        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
579            .await
580            .unwrap();
581
582        for i in 1..=3 {
583            let content = if i < 3 {
584                format!("Level {} @level{}.md", i, i + 1)
585            } else {
586                format!("Level {}", i)
587            };
588            fs::write(dir.path().join(format!("level{}.md", i)), content)
589                .await
590                .unwrap();
591        }
592
593        // Custom depth 1: should only load levels 0-1
594        let loader = MemoryLoader::with_config(MemoryLoaderConfig::with_max_depth(1));
595        let content = loader.load(dir.path()).await.unwrap();
596        let combined = content.combined_claude_md();
597
598        assert!(combined.contains("Level 0"));
599        assert!(combined.contains("Level 1"));
600        assert!(!combined.contains("Level 2"));
601    }
602
603    #[tokio::test]
604    async fn test_circular_import() {
605        let dir = tempdir().unwrap();
606
607        // CLAUDE.md → a.md → b.md → a.md (circular)
608        fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
609            .await
610            .unwrap();
611        fs::write(dir.path().join("a.md"), "A content @b.md")
612            .await
613            .unwrap();
614        fs::write(dir.path().join("b.md"), "B content @a.md")
615            .await
616            .unwrap();
617
618        let loader = MemoryLoader::new();
619        let result = loader.load(dir.path()).await;
620
621        // Should not infinite loop and should succeed
622        assert!(result.is_ok());
623        let combined = result.unwrap().combined_claude_md();
624        assert!(combined.contains("A content"));
625        assert!(combined.contains("B content"));
626    }
627
628    #[tokio::test]
629    async fn test_import_in_code_block_ignored() {
630        let dir = tempdir().unwrap();
631
632        fs::write(
633            dir.path().join("CLAUDE.md"),
634            "# Example\n```\n@should/not/import.md\n```\n@should/import.md",
635        )
636        .await
637        .unwrap();
638
639        fs::write(
640            dir.path().join("should").join("import.md"),
641            "This is imported",
642        )
643        .await
644        .ok();
645        let should_dir = dir.path().join("should");
646        fs::create_dir_all(&should_dir).await.unwrap();
647        fs::write(should_dir.join("import.md"), "Imported content")
648            .await
649            .unwrap();
650
651        let loader = MemoryLoader::new();
652        let content = loader.load(dir.path()).await.unwrap();
653        let combined = content.combined_claude_md();
654
655        assert!(combined.contains("Imported content"));
656        // The @should/not/import.md in code block should remain as-is, not be processed
657        assert!(combined.contains("@should/not/import.md"));
658    }
659
660    #[tokio::test]
661    async fn test_missing_import_ignored() {
662        let dir = tempdir().unwrap();
663
664        fs::write(
665            dir.path().join("CLAUDE.md"),
666            "# Main\n@nonexistent/file.md\nRest of content",
667        )
668        .await
669        .unwrap();
670
671        let loader = MemoryLoader::new();
672        let content = loader.load(dir.path()).await.unwrap();
673        let combined = content.combined_claude_md();
674
675        // Should still load the main content even if import doesn't exist
676        assert!(combined.contains("# Main"));
677        assert!(combined.contains("Rest of content"));
678    }
679
680    #[tokio::test]
681    async fn test_empty_content() {
682        let dir = tempdir().unwrap();
683
684        let loader = MemoryLoader::new();
685        let content = loader.load(dir.path()).await.unwrap();
686
687        assert!(content.is_empty());
688        assert!(content.combined_claude_md().is_empty());
689    }
690
691    #[tokio::test]
692    async fn test_memory_content_merge() {
693        let mut content1 = MemoryContent {
694            claude_md: vec!["content1".to_string()],
695            local_md: vec!["local1".to_string()],
696            rule_indices: vec![],
697        };
698
699        let content2 = MemoryContent {
700            claude_md: vec!["content2".to_string()],
701            local_md: vec!["local2".to_string()],
702            rule_indices: vec![],
703        };
704
705        content1.merge(content2);
706
707        assert_eq!(content1.claude_md.len(), 2);
708        assert_eq!(content1.local_md.len(), 2);
709    }
710}