Skip to main content

claude_agent/context/
memory_loader.rs

1//! CLAUDE.md and CLAUDE.local.md loader with CLI-compatible @import processing.
2//!
3//! This module provides a memory loader that reads CLAUDE.md and CLAUDE.local.md files
4//! with support for recursive @import directives. It implements the same import behavior
5//! as Claude Code CLI 2.1.12.
6
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use super::import_extractor::ImportExtractor;
11use super::rule_index::RuleIndex;
12use super::{ContextError, ContextResult};
13
14/// Default maximum import depth for CLI-like behavior.
15/// CLI uses depth 5 technically but loads ~24K tokens of memory.
16/// Depth 2 gives ~31K tokens which is the closest match.
17pub(crate) const DEFAULT_IMPORT_DEPTH: usize = 2;
18
19/// Maximum import depth when full expansion is needed (CLI's technical limit).
20pub(crate) const MAX_IMPORT_DEPTH: usize = 5;
21
22/// Configuration for MemoryLoader.
23#[derive(Debug, Clone)]
24pub struct MemoryLoaderConfig {
25    /// Maximum import depth (default: 2 for CLI-like token counts).
26    /// Use MAX_IMPORT_DEPTH (5) for full expansion.
27    pub max_depth: usize,
28}
29
30impl Default for MemoryLoaderConfig {
31    fn default() -> Self {
32        Self {
33            max_depth: DEFAULT_IMPORT_DEPTH,
34        }
35    }
36}
37
38impl MemoryLoaderConfig {
39    /// Creates config with full import expansion (depth 5).
40    pub fn full_expansion() -> Self {
41        Self {
42            max_depth: MAX_IMPORT_DEPTH,
43        }
44    }
45
46    /// Creates config with specified max depth.
47    pub fn max_depth(max_depth: usize) -> Self {
48        Self { max_depth }
49    }
50}
51
52/// Memory loader with CLI-compatible @import processing.
53///
54/// # Features
55/// - Loads CLAUDE.md and CLAUDE.local.md from project directories
56/// - Supports recursive @import with depth limiting
57/// - Circular import detection using canonical path tracking
58/// - Scans .claude/rules/ directory for rule files
59///
60/// # CLI Compatibility
61/// This implementation matches Claude Code CLI 2.1.12 behavior:
62/// - Maximum import depth of 5 (configurable, default 3 for similar token counts)
63/// - Same path validation rules
64/// - Same circular import prevention
65pub struct MemoryLoader {
66    extractor: ImportExtractor,
67    config: MemoryLoaderConfig,
68}
69
70impl MemoryLoader {
71    /// Creates a new MemoryLoader with default configuration (depth 3).
72    pub fn new() -> Self {
73        Self::from_config(MemoryLoaderConfig::default())
74    }
75
76    /// Creates a new MemoryLoader with custom configuration.
77    pub fn from_config(config: MemoryLoaderConfig) -> Self {
78        Self {
79            extractor: ImportExtractor::new(),
80            config,
81        }
82    }
83
84    /// Creates a new MemoryLoader with full import expansion (depth 5).
85    pub fn full_expansion() -> Self {
86        Self::from_config(MemoryLoaderConfig::full_expansion())
87    }
88
89    /// Loads all memory content (CLAUDE.md + CLAUDE.local.md + rules) from a directory.
90    ///
91    /// # Arguments
92    /// * `start_dir` - The project root directory to load from
93    ///
94    /// # Returns
95    /// Combined MemoryContent with all loaded files and rules
96    pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
97        let mut content = self.load_shared(start_dir).await?;
98        let local = self.load_local(start_dir).await?;
99        content.merge(local);
100        Ok(content)
101    }
102
103    /// Loads shared CLAUDE.md and rules (visible to all team members).
104    pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
105        let mut content = MemoryContent::default();
106        let mut visited = HashSet::new();
107
108        for path in Self::find_claude_files(start_dir) {
109            match self
110                .load_with_imports(&path, start_dir, 0, &mut visited)
111                .await
112            {
113                Ok(text) => content.claude_md.push(text),
114                Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
115            }
116        }
117
118        let rules_dir = start_dir.join(".claude").join("rules");
119        if rules_dir.exists() {
120            content.rule_indices = self.scan_rules(&rules_dir).await?;
121        }
122
123        Ok(content)
124    }
125
126    /// Loads local CLAUDE.local.md (private to the user, not in version control).
127    pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
128        let mut content = MemoryContent::default();
129        let mut visited = HashSet::new();
130
131        for path in Self::find_local_files(start_dir) {
132            match self
133                .load_with_imports(&path, start_dir, 0, &mut visited)
134                .await
135            {
136                Ok(text) => content.local_md.push(text),
137                Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
138            }
139        }
140
141        Ok(content)
142    }
143
144    /// Loads a file with recursive @import expansion.
145    ///
146    /// # Arguments
147    /// * `path` - Path to the file to load
148    /// * `project_root` - Project root directory for resolving @.agents/... style imports
149    /// * `depth` - Current import depth (0 = root)
150    /// * `visited` - Set of canonical paths already loaded (for circular detection)
151    ///
152    /// # Returns
153    /// File content with all imports expanded inline
154    fn load_with_imports<'a>(
155        &'a self,
156        path: &'a Path,
157        project_root: &'a Path,
158        depth: usize,
159        visited: &'a mut HashSet<PathBuf>,
160    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
161    {
162        Box::pin(async move {
163            // Depth limit check (configurable, default 3)
164            if depth > self.config.max_depth {
165                tracing::warn!(
166                    "Import depth limit ({}) exceeded, skipping: {}",
167                    self.config.max_depth,
168                    path.display()
169                );
170                return Ok(String::new());
171            }
172
173            // Circular import detection using canonical paths
174            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
175            if visited.contains(&canonical) {
176                tracing::debug!("Circular import detected, skipping: {}", path.display());
177                return Ok(String::new());
178            }
179            visited.insert(canonical);
180
181            let content =
182                tokio::fs::read_to_string(path)
183                    .await
184                    .map_err(|e| ContextError::Source {
185                        message: format!("Failed to read {}: {}", path.display(), e),
186                    })?;
187
188            // Use current file's directory for relative path resolution
189            let file_dir = path.parent().unwrap_or(Path::new("."));
190            let imports = self.extractor.extract(&content, file_dir);
191
192            // Post-process: fix duplicated .agents/ or .claude/ paths
193            // e.g., /project/.agents/guides/.agents/patterns -> .agents/patterns
194            let imports: Vec<PathBuf> = imports
195                .into_iter()
196                .map(|p| Self::normalize_project_relative_path(&p, project_root))
197                .collect();
198
199            let mut result = content;
200            for import_path in imports {
201                if import_path.exists() {
202                    if let Ok(imported) = self
203                        .load_with_imports(&import_path, project_root, depth + 1, visited)
204                        .await
205                        && !imported.is_empty()
206                    {
207                        result.push_str("\n\n");
208                        result.push_str(&imported);
209                    }
210                } else {
211                    tracing::debug!("Import not found, skipping: {}", import_path.display());
212                }
213            }
214
215            Ok(result)
216        })
217    }
218
219    /// Finds CLAUDE.md files in standard locations.
220    fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
221        let mut files = Vec::new();
222
223        // Project root CLAUDE.md
224        let claude_md = start_dir.join("CLAUDE.md");
225        if claude_md.exists() {
226            files.push(claude_md);
227        }
228
229        // .claude/CLAUDE.md (alternative location)
230        let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
231        if claude_dir_md.exists() {
232            files.push(claude_dir_md);
233        }
234
235        files
236    }
237
238    /// Finds CLAUDE.local.md files in standard locations.
239    fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
240        let mut files = Vec::new();
241
242        // Project root CLAUDE.local.md
243        let local_md = start_dir.join("CLAUDE.local.md");
244        if local_md.exists() {
245            files.push(local_md);
246        }
247
248        // .claude/CLAUDE.local.md (alternative location)
249        let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
250        if local_dir_md.exists() {
251            files.push(local_dir_md);
252        }
253
254        files
255    }
256
257    /// Normalizes paths with duplicated .agents/ or .claude/ segments.
258    ///
259    /// When imports are resolved from nested files (e.g., .agents/guides/workflow.md),
260    /// relative paths like @.agents/patterns/... get incorrectly expanded to
261    /// /project/.agents/guides/.agents/patterns/... (duplicated .agents/).
262    ///
263    /// This function detects such duplications and re-resolves from project root.
264    fn normalize_project_relative_path(path: &Path, project_root: &Path) -> PathBuf {
265        const MARKERS: [&str; 2] = ["/.agents/", "/.claude/"];
266
267        let path_str = path.to_string_lossy();
268
269        // Find the last occurrence of any project-relative marker
270        let last_marker_pos = MARKERS
271            .iter()
272            .filter_map(|marker| {
273                let count = path_str.matches(marker).count();
274                // Only fix if duplicated (count > 1) or mixed markers exist
275                if count > 1 || MARKERS.iter().filter(|m| path_str.contains(*m)).count() > 1 {
276                    path_str.rfind(marker).map(|pos| (pos, *marker))
277                } else {
278                    None
279                }
280            })
281            .max_by_key(|(pos, _)| *pos);
282
283        if let Some((idx, _)) = last_marker_pos {
284            let relative_part = &path_str[idx + 1..]; // Skip leading "/"
285            project_root.join(relative_part)
286        } else {
287            path.to_path_buf()
288        }
289    }
290
291    /// Scans .claude/rules/ directory recursively for rule files.
292    async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
293        let mut indices = Vec::new();
294        self.scan_rules_recursive(dir, &mut indices).await?;
295        indices.sort_by(|a, b| b.priority.cmp(&a.priority));
296        Ok(indices)
297    }
298
299    fn scan_rules_recursive<'a>(
300        &'a self,
301        dir: &'a Path,
302        indices: &'a mut Vec<RuleIndex>,
303    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
304        Box::pin(async move {
305            let mut entries = tokio::fs::read_dir(dir)
306                .await
307                .map_err(|e| ContextError::Source {
308                    message: format!("Failed to read rules directory: {}", e),
309                })?;
310
311            while let Some(entry) =
312                entries
313                    .next_entry()
314                    .await
315                    .map_err(|e| ContextError::Source {
316                        message: format!("Failed to read directory entry: {}", e),
317                    })?
318            {
319                let path = entry.path();
320
321                if path.is_dir() {
322                    self.scan_rules_recursive(&path, indices).await?;
323                } else if path.extension().is_some_and(|e| e == "md")
324                    && let Some(index) = RuleIndex::from_file(&path)
325                {
326                    indices.push(index);
327                }
328            }
329
330            Ok(())
331        })
332    }
333}
334
335impl Default for MemoryLoader {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341/// Loaded memory content from CLAUDE.md files and rules.
342#[derive(Debug, Default, Clone)]
343pub struct MemoryContent {
344    /// Content from CLAUDE.md files (shared/team config).
345    pub claude_md: Vec<String>,
346    /// Content from CLAUDE.local.md files (user-specific config).
347    pub local_md: Vec<String>,
348    /// Rule indices from .claude/rules/ directory.
349    pub rule_indices: Vec<RuleIndex>,
350}
351
352impl MemoryContent {
353    /// Combines all CLAUDE.md and CLAUDE.local.md content into a single string.
354    pub fn combined_claude_md(&self) -> String {
355        self.claude_md
356            .iter()
357            .chain(self.local_md.iter())
358            .filter(|c| !c.trim().is_empty())
359            .cloned()
360            .collect::<Vec<_>>()
361            .join("\n\n")
362    }
363
364    /// Returns true if no content was loaded.
365    pub fn is_empty(&self) -> bool {
366        self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
367    }
368
369    /// Merges another MemoryContent into this one.
370    pub fn merge(&mut self, other: MemoryContent) {
371        self.claude_md.extend(other.claude_md);
372        self.local_md.extend(other.local_md);
373        self.rule_indices.extend(other.rule_indices);
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use tempfile::tempdir;
381    use tokio::fs;
382
383    #[tokio::test]
384    async fn test_load_claude_md() {
385        let dir = tempdir().unwrap();
386        fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
387            .await
388            .unwrap();
389
390        let loader = MemoryLoader::new();
391        let content = loader.load(dir.path()).await.unwrap();
392
393        assert_eq!(content.claude_md.len(), 1);
394        assert!(content.claude_md[0].contains("Test content"));
395    }
396
397    #[tokio::test]
398    async fn test_load_local_md() {
399        let dir = tempdir().unwrap();
400        fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
401            .await
402            .unwrap();
403
404        let loader = MemoryLoader::new();
405        let content = loader.load(dir.path()).await.unwrap();
406
407        assert_eq!(content.local_md.len(), 1);
408        assert!(content.local_md[0].contains("Private"));
409    }
410
411    #[tokio::test]
412    async fn test_scan_rules_recursive() {
413        let dir = tempdir().unwrap();
414        let rules_dir = dir.path().join(".claude").join("rules");
415        let sub_dir = rules_dir.join("frontend");
416        fs::create_dir_all(&sub_dir).await.unwrap();
417
418        fs::write(
419            rules_dir.join("rust.md"),
420            "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
421        )
422        .await
423        .unwrap();
424
425        fs::write(
426            sub_dir.join("react.md"),
427            "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
428        )
429        .await
430        .unwrap();
431
432        let loader = MemoryLoader::new();
433        let content = loader.load(dir.path()).await.unwrap();
434
435        assert_eq!(content.rule_indices.len(), 2);
436        assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
437        assert!(content.rule_indices.iter().any(|r| r.name == "react"));
438    }
439
440    #[tokio::test]
441    async fn test_import_syntax() {
442        let dir = tempdir().unwrap();
443
444        fs::write(
445            dir.path().join("CLAUDE.md"),
446            "# Main\n@docs/guidelines.md\nEnd",
447        )
448        .await
449        .unwrap();
450
451        let docs_dir = dir.path().join("docs");
452        fs::create_dir_all(&docs_dir).await.unwrap();
453        fs::write(docs_dir.join("guidelines.md"), "Imported content")
454            .await
455            .unwrap();
456
457        let loader = MemoryLoader::new();
458        let content = loader.load(dir.path()).await.unwrap();
459
460        assert!(content.combined_claude_md().contains("Imported content"));
461    }
462
463    #[tokio::test]
464    async fn test_combined_includes_local() {
465        let dir = tempdir().unwrap();
466        fs::write(dir.path().join("CLAUDE.md"), "Main content")
467            .await
468            .unwrap();
469        fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
470            .await
471            .unwrap();
472
473        let loader = MemoryLoader::new();
474        let content = loader.load(dir.path()).await.unwrap();
475
476        let combined = content.combined_claude_md();
477        assert!(combined.contains("Main content"));
478        assert!(combined.contains("Local content"));
479    }
480
481    #[tokio::test]
482    async fn test_recursive_import() {
483        let dir = tempdir().unwrap();
484
485        // CLAUDE.md → docs/guide.md → docs/detail.md
486        fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
487            .await
488            .unwrap();
489
490        let docs_dir = dir.path().join("docs");
491        fs::create_dir_all(&docs_dir).await.unwrap();
492        fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
493            .await
494            .unwrap();
495        fs::write(docs_dir.join("detail.md"), "Detail content")
496            .await
497            .unwrap();
498
499        let loader = MemoryLoader::new();
500        let content = loader.load(dir.path()).await.unwrap();
501        let combined = content.combined_claude_md();
502
503        assert!(combined.contains("Root content"));
504        assert!(combined.contains("Guide content"));
505        assert!(combined.contains("Detail content"));
506    }
507
508    #[tokio::test]
509    async fn test_depth_limit_default() {
510        let dir = tempdir().unwrap();
511
512        // Create chain: CLAUDE.md → level1.md → level2.md → level3.md
513        // With default depth 2, should stop at level 2 (0,1,2 loaded, 3 not)
514        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
515            .await
516            .unwrap();
517
518        for i in 1..=3 {
519            let content = if i < 3 {
520                format!("Level {} @level{}.md", i, i + 1)
521            } else {
522                format!("Level {}", i)
523            };
524            fs::write(dir.path().join(format!("level{}.md", i)), content)
525                .await
526                .unwrap();
527        }
528
529        let loader = MemoryLoader::new(); // Default depth 2
530        let content = loader.load(dir.path()).await.unwrap();
531        let combined = content.combined_claude_md();
532
533        // Should have levels 0-2 but NOT level 3 (default depth limit = 2)
534        assert!(combined.contains("Level 0"));
535        assert!(combined.contains("Level 2"));
536        assert!(!combined.contains("Level 3"));
537    }
538
539    #[tokio::test]
540    async fn test_depth_limit_full_expansion() {
541        let dir = tempdir().unwrap();
542
543        // Create chain: CLAUDE.md → level1.md → level2.md → ... → level6.md
544        // With full expansion (depth 5), should stop at level 5 (0-5 loaded, 6 not)
545        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
546            .await
547            .unwrap();
548
549        for i in 1..=6 {
550            let content = if i < 6 {
551                format!("Level {} @level{}.md", i, i + 1)
552            } else {
553                format!("Level {}", i)
554            };
555            fs::write(dir.path().join(format!("level{}.md", i)), content)
556                .await
557                .unwrap();
558        }
559
560        let loader = MemoryLoader::full_expansion(); // Depth 5
561        let content = loader.load(dir.path()).await.unwrap();
562        let combined = content.combined_claude_md();
563
564        // Should have levels 0-5 but NOT level 6 (max depth limit = 5)
565        assert!(combined.contains("Level 0"));
566        assert!(combined.contains("Level 5"));
567        assert!(!combined.contains("Level 6"));
568    }
569
570    #[tokio::test]
571    async fn test_depth_limit_custom() {
572        let dir = tempdir().unwrap();
573
574        // Create chain up to level 3
575        fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
576            .await
577            .unwrap();
578
579        for i in 1..=3 {
580            let content = if i < 3 {
581                format!("Level {} @level{}.md", i, i + 1)
582            } else {
583                format!("Level {}", i)
584            };
585            fs::write(dir.path().join(format!("level{}.md", i)), content)
586                .await
587                .unwrap();
588        }
589
590        // Custom depth 1: should only load levels 0-1
591        let loader = MemoryLoader::from_config(MemoryLoaderConfig::max_depth(1));
592        let content = loader.load(dir.path()).await.unwrap();
593        let combined = content.combined_claude_md();
594
595        assert!(combined.contains("Level 0"));
596        assert!(combined.contains("Level 1"));
597        assert!(!combined.contains("Level 2"));
598    }
599
600    #[tokio::test]
601    async fn test_circular_import() {
602        let dir = tempdir().unwrap();
603
604        // CLAUDE.md → a.md → b.md → a.md (circular)
605        fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
606            .await
607            .unwrap();
608        fs::write(dir.path().join("a.md"), "A content @b.md")
609            .await
610            .unwrap();
611        fs::write(dir.path().join("b.md"), "B content @a.md")
612            .await
613            .unwrap();
614
615        let loader = MemoryLoader::new();
616        let result = loader.load(dir.path()).await;
617
618        // Should not infinite loop and should succeed
619        assert!(result.is_ok());
620        let combined = result.unwrap().combined_claude_md();
621        assert!(combined.contains("A content"));
622        assert!(combined.contains("B content"));
623    }
624
625    #[tokio::test]
626    async fn test_import_in_code_block_ignored() {
627        let dir = tempdir().unwrap();
628
629        fs::write(
630            dir.path().join("CLAUDE.md"),
631            "# Example\n```\n@should/not/import.md\n```\n@should/import.md",
632        )
633        .await
634        .unwrap();
635
636        fs::write(
637            dir.path().join("should").join("import.md"),
638            "This is imported",
639        )
640        .await
641        .ok();
642        let should_dir = dir.path().join("should");
643        fs::create_dir_all(&should_dir).await.unwrap();
644        fs::write(should_dir.join("import.md"), "Imported content")
645            .await
646            .unwrap();
647
648        let loader = MemoryLoader::new();
649        let content = loader.load(dir.path()).await.unwrap();
650        let combined = content.combined_claude_md();
651
652        assert!(combined.contains("Imported content"));
653        // The @should/not/import.md in code block should remain as-is, not be processed
654        assert!(combined.contains("@should/not/import.md"));
655    }
656
657    #[tokio::test]
658    async fn test_missing_import_ignored() {
659        let dir = tempdir().unwrap();
660
661        fs::write(
662            dir.path().join("CLAUDE.md"),
663            "# Main\n@nonexistent/file.md\nRest of content",
664        )
665        .await
666        .unwrap();
667
668        let loader = MemoryLoader::new();
669        let content = loader.load(dir.path()).await.unwrap();
670        let combined = content.combined_claude_md();
671
672        // Should still load the main content even if import doesn't exist
673        assert!(combined.contains("# Main"));
674        assert!(combined.contains("Rest of content"));
675    }
676
677    #[tokio::test]
678    async fn test_empty_content() {
679        let dir = tempdir().unwrap();
680
681        let loader = MemoryLoader::new();
682        let content = loader.load(dir.path()).await.unwrap();
683
684        assert!(content.is_empty());
685        assert!(content.combined_claude_md().is_empty());
686    }
687
688    #[tokio::test]
689    async fn test_memory_content_merge() {
690        let mut content1 = MemoryContent {
691            claude_md: vec!["content1".to_string()],
692            local_md: vec!["local1".to_string()],
693            rule_indices: vec![],
694        };
695
696        let content2 = MemoryContent {
697            claude_md: vec!["content2".to_string()],
698            local_md: vec!["local2".to_string()],
699            rule_indices: vec![],
700        };
701
702        content1.merge(content2);
703
704        assert_eq!(content1.claude_md.len(), 2);
705        assert_eq!(content1.local_md.len(), 2);
706    }
707}