claude_agent/context/
memory_loader.rs

1//! CLAUDE.md and CLAUDE.local.md loader with @import support.
2
3use std::collections::HashSet;
4use std::path::{Path, PathBuf};
5
6use super::provider::MAX_IMPORT_DEPTH;
7use super::rule_index::RuleIndex;
8use super::{ContextError, ContextResult};
9
10#[derive(Debug, Default)]
11pub struct MemoryLoader {
12    loaded_paths: HashSet<PathBuf>,
13    current_depth: usize,
14}
15
16impl MemoryLoader {
17    pub fn new() -> Self {
18        Self::default()
19    }
20
21    /// Loads all memory content (CLAUDE.md + CLAUDE.local.md + rules).
22    pub async fn load(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
23        let mut content = self.load_shared(start_dir).await?;
24        let local = self.load_local(start_dir).await?;
25        content.merge(local);
26        Ok(content)
27    }
28
29    /// Loads CLAUDE.md and rules from any resource level (enterprise/user/project).
30    pub async fn load_shared(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
31        let mut content = MemoryContent::default();
32
33        for path in self.find_claude_files(start_dir) {
34            if let Ok(text) = self.load_file_with_imports(&path).await {
35                content.claude_md.push(text);
36            }
37        }
38
39        let rules_dir = start_dir.join(".claude").join("rules");
40        if rules_dir.exists() {
41            content.rule_indices = self.scan_rules_directory_recursive(&rules_dir).await?;
42        }
43
44        Ok(content)
45    }
46
47    /// Loads CLAUDE.local.md only (project-level private config).
48    pub async fn load_local(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
49        let mut content = MemoryContent::default();
50
51        for path in self.find_local_files(start_dir) {
52            if let Ok(text) = self.load_file_with_imports(&path).await {
53                content.local_md.push(text);
54            }
55        }
56
57        Ok(content)
58    }
59
60    fn find_claude_files(&self, start_dir: &Path) -> Vec<PathBuf> {
61        let mut files = Vec::new();
62
63        let claude_md = start_dir.join("CLAUDE.md");
64        if claude_md.exists() {
65            files.push(claude_md);
66        }
67
68        let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
69        if claude_dir_md.exists() {
70            files.push(claude_dir_md);
71        }
72
73        files
74    }
75
76    fn find_local_files(&self, start_dir: &Path) -> Vec<PathBuf> {
77        let mut files = Vec::new();
78
79        let local_md = start_dir.join("CLAUDE.local.md");
80        if local_md.exists() {
81            files.push(local_md);
82        }
83
84        let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
85        if local_dir_md.exists() {
86            files.push(local_dir_md);
87        }
88
89        files
90    }
91
92    fn scan_rules_directory_recursive<'a>(
93        &'a self,
94        dir: &'a Path,
95    ) -> std::pin::Pin<
96        Box<dyn std::future::Future<Output = ContextResult<Vec<RuleIndex>>> + Send + 'a>,
97    > {
98        Box::pin(async move {
99            let mut indices = Vec::new();
100
101            let mut entries = tokio::fs::read_dir(dir)
102                .await
103                .map_err(|e| ContextError::Source {
104                    message: format!("Failed to read rules directory: {}", e),
105                })?;
106
107            while let Some(entry) =
108                entries
109                    .next_entry()
110                    .await
111                    .map_err(|e| ContextError::Source {
112                        message: format!("Failed to read directory entry: {}", e),
113                    })?
114            {
115                let path = entry.path();
116
117                if path.is_dir() {
118                    let sub_indices = self.scan_rules_directory_recursive(&path).await?;
119                    indices.extend(sub_indices);
120                } else if path.extension().is_some_and(|e| e == "md")
121                    && let Some(index) = RuleIndex::from_file(&path)
122                {
123                    indices.push(index);
124                }
125            }
126
127            indices.sort_by(|a, b| b.priority.cmp(&a.priority));
128            Ok(indices)
129        })
130    }
131
132    fn load_file_with_imports<'a>(
133        &'a mut self,
134        path: &'a Path,
135    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
136    {
137        Box::pin(async move {
138            if self.current_depth >= MAX_IMPORT_DEPTH {
139                tracing::warn!(
140                    "Import depth limit ({}) reached, skipping: {}",
141                    MAX_IMPORT_DEPTH,
142                    path.display()
143                );
144                return Ok(String::new());
145            }
146
147            let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
148            if self.loaded_paths.contains(&canonical) {
149                return Ok(String::new());
150            }
151            self.loaded_paths.insert(canonical.clone());
152
153            let content =
154                tokio::fs::read_to_string(path)
155                    .await
156                    .map_err(|e| ContextError::Source {
157                        message: format!("Failed to read {}: {}", path.display(), e),
158                    })?;
159
160            self.current_depth += 1;
161            let result = self
162                .process_imports(&content, path.parent().unwrap_or(Path::new(".")))
163                .await;
164            self.current_depth -= 1;
165
166            result
167        })
168    }
169
170    fn expand_home(path: &str) -> PathBuf {
171        if let Some(rest) = path.strip_prefix("~/")
172            && let Some(home) = crate::common::home_dir()
173        {
174            return home.join(rest);
175        }
176        PathBuf::from(path)
177    }
178
179    fn process_imports<'a>(
180        &'a mut self,
181        content: &'a str,
182        base_dir: &'a Path,
183    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
184    {
185        Box::pin(async move {
186            let mut result = String::new();
187
188            for line in content.lines() {
189                let trimmed = line.trim();
190
191                if trimmed.starts_with('@') && !trimmed.starts_with("@@") {
192                    let import_path = trimmed.trim_start_matches('@').trim();
193                    if !import_path.is_empty() {
194                        let full_path = if import_path.starts_with("~/") {
195                            Self::expand_home(import_path)
196                        } else if import_path.starts_with('/') {
197                            PathBuf::from(import_path)
198                        } else {
199                            base_dir.join(import_path)
200                        };
201
202                        if full_path.exists() {
203                            match self.load_file_with_imports(&full_path).await {
204                                Ok(imported) => {
205                                    result.push_str(&imported);
206                                    result.push('\n');
207                                }
208                                Err(e) => {
209                                    tracing::warn!("Failed to import {}: {}", import_path, e);
210                                    result.push_str(line);
211                                    result.push('\n');
212                                }
213                            }
214                        } else {
215                            result.push_str(line);
216                            result.push('\n');
217                        }
218                    } else {
219                        result.push_str(line);
220                        result.push('\n');
221                    }
222                } else {
223                    result.push_str(line);
224                    result.push('\n');
225                }
226            }
227
228            Ok(result)
229        })
230    }
231}
232
233#[derive(Debug, Default, Clone)]
234pub struct MemoryContent {
235    pub claude_md: Vec<String>,
236    pub local_md: Vec<String>,
237    pub rule_indices: Vec<RuleIndex>,
238}
239
240impl MemoryContent {
241    pub fn combined_claude_md(&self) -> String {
242        self.claude_md
243            .iter()
244            .chain(self.local_md.iter())
245            .filter(|c| !c.trim().is_empty())
246            .cloned()
247            .collect::<Vec<_>>()
248            .join("\n\n")
249    }
250
251    pub fn is_empty(&self) -> bool {
252        self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
253    }
254
255    pub fn merge(&mut self, other: MemoryContent) {
256        self.claude_md.extend(other.claude_md);
257        self.local_md.extend(other.local_md);
258        self.rule_indices.extend(other.rule_indices);
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use tempfile::tempdir;
266    use tokio::fs;
267
268    #[tokio::test]
269    async fn test_load_claude_md() {
270        let dir = tempdir().unwrap();
271        fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
272            .await
273            .unwrap();
274
275        let mut loader = MemoryLoader::new();
276        let content = loader.load(dir.path()).await.unwrap();
277
278        assert_eq!(content.claude_md.len(), 1);
279        assert!(content.claude_md[0].contains("Test content"));
280    }
281
282    #[tokio::test]
283    async fn test_load_local_md() {
284        let dir = tempdir().unwrap();
285        fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
286            .await
287            .unwrap();
288
289        let mut loader = MemoryLoader::new();
290        let content = loader.load(dir.path()).await.unwrap();
291
292        assert_eq!(content.local_md.len(), 1);
293        assert!(content.local_md[0].contains("Private"));
294    }
295
296    #[tokio::test]
297    async fn test_scan_rules_recursive() {
298        let dir = tempdir().unwrap();
299        let rules_dir = dir.path().join(".claude").join("rules");
300        let sub_dir = rules_dir.join("frontend");
301        fs::create_dir_all(&sub_dir).await.unwrap();
302
303        fs::write(
304            rules_dir.join("rust.md"),
305            "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
306        )
307        .await
308        .unwrap();
309
310        fs::write(
311            sub_dir.join("react.md"),
312            "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
313        )
314        .await
315        .unwrap();
316
317        let mut loader = MemoryLoader::new();
318        let content = loader.load(dir.path()).await.unwrap();
319
320        assert_eq!(content.rule_indices.len(), 2);
321        assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
322        assert!(content.rule_indices.iter().any(|r| r.name == "react"));
323    }
324
325    #[tokio::test]
326    async fn test_import_syntax() {
327        let dir = tempdir().unwrap();
328
329        fs::write(
330            dir.path().join("CLAUDE.md"),
331            "# Main\n@docs/guidelines.md\nEnd",
332        )
333        .await
334        .unwrap();
335
336        let docs_dir = dir.path().join("docs");
337        fs::create_dir_all(&docs_dir).await.unwrap();
338        fs::write(docs_dir.join("guidelines.md"), "Imported content")
339            .await
340            .unwrap();
341
342        let mut loader = MemoryLoader::new();
343        let content = loader.load(dir.path()).await.unwrap();
344
345        assert!(content.combined_claude_md().contains("Imported content"));
346    }
347
348    #[tokio::test]
349    async fn test_combined_includes_local() {
350        let dir = tempdir().unwrap();
351        fs::write(dir.path().join("CLAUDE.md"), "Main content")
352            .await
353            .unwrap();
354        fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
355            .await
356            .unwrap();
357
358        let mut loader = MemoryLoader::new();
359        let content = loader.load(dir.path()).await.unwrap();
360
361        let combined = content.combined_claude_md();
362        assert!(combined.contains("Main content"));
363        assert!(combined.contains("Local content"));
364    }
365}