claude_agent/context/
memory_loader.rs1use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11
12use super::provider::MAX_IMPORT_DEPTH;
13use super::rule_index::RuleIndex;
14use super::{ContextError, ContextResult};
15
16#[derive(Debug, Default)]
17pub struct MemoryLoader {
18 loaded_paths: HashSet<PathBuf>,
19 current_depth: usize,
20}
21
22impl MemoryLoader {
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub async fn load_all(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
28 let mut content = MemoryContent::default();
29
30 let claude_files = self.find_claude_files(start_dir);
31 for path in claude_files {
32 if let Ok(text) = self.load_file_with_imports(&path).await {
33 content.claude_md.push(text);
34 }
35 }
36
37 let local_files = self.find_local_files(start_dir);
38 for path in local_files {
39 if let Ok(text) = self.load_file_with_imports(&path).await {
40 content.local_md.push(text);
41 }
42 }
43
44 let rules_dir = start_dir.join(".claude").join("rules");
45 if rules_dir.exists() {
46 content.rule_indices = self.scan_rules_directory(&rules_dir).await?;
47 }
48
49 Ok(content)
50 }
51
52 pub async fn load_local_only(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
53 let mut content = MemoryContent::default();
54
55 let local_files = self.find_local_files(start_dir);
56 for path in local_files {
57 if let Ok(text) = self.load_file_with_imports(&path).await {
58 content.local_md.push(text);
59 }
60 }
61
62 Ok(content)
63 }
64
65 fn find_claude_files(&self, start_dir: &Path) -> Vec<PathBuf> {
66 let mut files = Vec::new();
67 let mut current = start_dir.to_path_buf();
68
69 loop {
70 let claude_md = current.join("CLAUDE.md");
71 if claude_md.exists() {
72 files.push(claude_md);
73 }
74
75 let claude_dir_md = current.join(".claude").join("CLAUDE.md");
76 if claude_dir_md.exists() {
77 files.push(claude_dir_md);
78 }
79
80 match current.parent() {
81 Some(parent) if parent != current && !parent.as_os_str().is_empty() => {
82 current = parent.to_path_buf();
83 }
84 _ => break,
85 }
86 }
87
88 files.reverse();
89 files
90 }
91
92 fn find_local_files(&self, start_dir: &Path) -> Vec<PathBuf> {
93 let mut files = Vec::new();
94 let mut current = start_dir.to_path_buf();
95
96 loop {
97 let local_md = current.join("CLAUDE.local.md");
98 if local_md.exists() {
99 files.push(local_md);
100 }
101
102 let local_dir_md = current.join(".claude").join("CLAUDE.local.md");
103 if local_dir_md.exists() {
104 files.push(local_dir_md);
105 }
106
107 match current.parent() {
108 Some(parent) if parent != current && !parent.as_os_str().is_empty() => {
109 current = parent.to_path_buf();
110 }
111 _ => break,
112 }
113 }
114
115 files.reverse();
116 files
117 }
118
119 async fn scan_rules_directory(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
120 let mut indices = Vec::new();
121
122 let mut entries = tokio::fs::read_dir(dir)
123 .await
124 .map_err(|e| ContextError::Source {
125 message: format!("Failed to read rules directory: {}", e),
126 })?;
127
128 while let Some(entry) = entries
129 .next_entry()
130 .await
131 .map_err(|e| ContextError::Source {
132 message: format!("Failed to read directory entry: {}", e),
133 })?
134 {
135 let path = entry.path();
136 if path.extension().is_some_and(|e| e == "md")
137 && let Some(index) = RuleIndex::from_file(&path)
138 {
139 indices.push(index);
140 }
141 }
142
143 indices.sort_by(|a, b| b.priority.cmp(&a.priority));
144 Ok(indices)
145 }
146
147 fn load_file_with_imports<'a>(
148 &'a mut self,
149 path: &'a Path,
150 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
151 {
152 Box::pin(async move {
153 if self.current_depth >= MAX_IMPORT_DEPTH {
154 tracing::warn!(
155 "Import depth limit ({}) reached, skipping: {}",
156 MAX_IMPORT_DEPTH,
157 path.display()
158 );
159 return Ok(String::new());
160 }
161
162 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
163 if self.loaded_paths.contains(&canonical) {
164 return Ok(String::new());
165 }
166 self.loaded_paths.insert(canonical.clone());
167
168 let content =
169 tokio::fs::read_to_string(path)
170 .await
171 .map_err(|e| ContextError::Source {
172 message: format!("Failed to read {}: {}", path.display(), e),
173 })?;
174
175 self.current_depth += 1;
176 let result = self
177 .process_imports(&content, path.parent().unwrap_or(Path::new(".")))
178 .await;
179 self.current_depth -= 1;
180
181 result
182 })
183 }
184
185 fn expand_home(path: &str) -> PathBuf {
186 if let Some(rest) = path.strip_prefix("~/")
187 && let Some(home) = crate::common::home_dir()
188 {
189 return home.join(rest);
190 }
191 PathBuf::from(path)
192 }
193
194 fn process_imports<'a>(
195 &'a mut self,
196 content: &'a str,
197 base_dir: &'a Path,
198 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
199 {
200 Box::pin(async move {
201 let mut result = String::new();
202
203 for line in content.lines() {
204 let trimmed = line.trim();
205
206 if trimmed.starts_with('@') && !trimmed.starts_with("@@") {
207 let import_path = trimmed.trim_start_matches('@').trim();
208 if !import_path.is_empty() {
209 let full_path = if import_path.starts_with("~/") {
210 Self::expand_home(import_path)
211 } else if import_path.starts_with('/') {
212 PathBuf::from(import_path)
213 } else {
214 base_dir.join(import_path)
215 };
216
217 if full_path.exists() {
218 match self.load_file_with_imports(&full_path).await {
219 Ok(imported) => {
220 result.push_str(&imported);
221 result.push('\n');
222 }
223 Err(e) => {
224 tracing::warn!("Failed to import {}: {}", import_path, e);
225 result.push_str(line);
226 result.push('\n');
227 }
228 }
229 } else {
230 result.push_str(line);
231 result.push('\n');
232 }
233 } else {
234 result.push_str(line);
235 result.push('\n');
236 }
237 } else {
238 result.push_str(line);
239 result.push('\n');
240 }
241 }
242
243 Ok(result)
244 })
245 }
246}
247
248#[derive(Debug, Default)]
249pub struct MemoryContent {
250 pub claude_md: Vec<String>,
251 pub local_md: Vec<String>,
252 pub rule_indices: Vec<RuleIndex>,
253}
254
255impl MemoryContent {
256 pub fn combined_claude_md(&self) -> String {
257 let mut parts = Vec::new();
258
259 for content in &self.claude_md {
260 if !content.trim().is_empty() {
261 parts.push(content.clone());
262 }
263 }
264
265 for content in &self.local_md {
266 if !content.trim().is_empty() {
267 parts.push(content.clone());
268 }
269 }
270
271 parts.join("\n\n")
272 }
273
274 pub fn is_empty(&self) -> bool {
275 self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use tempfile::tempdir;
283 use tokio::fs;
284
285 #[tokio::test]
286 async fn test_load_claude_md() {
287 let dir = tempdir().unwrap();
288 fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
289 .await
290 .unwrap();
291
292 let mut loader = MemoryLoader::new();
293 let content = loader.load_all(dir.path()).await.unwrap();
294
295 assert_eq!(content.claude_md.len(), 1);
296 assert!(content.claude_md[0].contains("Test content"));
297 }
298
299 #[tokio::test]
300 async fn test_load_local_md() {
301 let dir = tempdir().unwrap();
302 fs::write(dir.path().join("CLAUDE.local.md"), "Local settings")
303 .await
304 .unwrap();
305
306 let mut loader = MemoryLoader::new();
307 let content = loader.load_all(dir.path()).await.unwrap();
308
309 assert_eq!(content.local_md.len(), 1);
310 assert!(content.local_md[0].contains("Local settings"));
311 }
312
313 #[tokio::test]
314 async fn test_scan_rules_indices_only() {
315 let dir = tempdir().unwrap();
316 let rules_dir = dir.path().join(".claude").join("rules");
317 fs::create_dir_all(&rules_dir).await.unwrap();
318
319 fs::write(
320 rules_dir.join("rust.md"),
321 r#"---
322paths: **/*.rs
323priority: 10
324---
325
326# Rust Rules
327Use snake_case"#,
328 )
329 .await
330 .unwrap();
331
332 fs::write(rules_dir.join("security.md"), "# Security\nNo secrets")
333 .await
334 .unwrap();
335
336 let mut loader = MemoryLoader::new();
337 let content = loader.load_all(dir.path()).await.unwrap();
338
339 assert_eq!(content.rule_indices.len(), 2);
340
341 let rust_rule = content.rule_indices.iter().find(|r| r.name == "rust");
342 assert!(rust_rule.is_some());
343 assert_eq!(rust_rule.unwrap().priority, 10);
344 assert!(rust_rule.unwrap().paths.is_some());
345 }
346
347 #[tokio::test]
348 async fn test_import_syntax() {
349 let dir = tempdir().unwrap();
350
351 fs::write(
352 dir.path().join("CLAUDE.md"),
353 "# Main\n@docs/guidelines.md\nEnd",
354 )
355 .await
356 .unwrap();
357
358 let docs_dir = dir.path().join("docs");
359 fs::create_dir_all(&docs_dir).await.unwrap();
360 fs::write(docs_dir.join("guidelines.md"), "Imported content")
361 .await
362 .unwrap();
363
364 let mut loader = MemoryLoader::new();
365 let content = loader.load_all(dir.path()).await.unwrap();
366
367 assert!(content.combined_claude_md().contains("Imported content"));
368 }
369
370 #[tokio::test]
371 async fn test_combined_content() {
372 let dir = tempdir().unwrap();
373 fs::write(dir.path().join("CLAUDE.md"), "Main content")
374 .await
375 .unwrap();
376 fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
377 .await
378 .unwrap();
379
380 let mut loader = MemoryLoader::new();
381 let content = loader.load_all(dir.path()).await.unwrap();
382
383 let combined = content.combined_claude_md();
384 assert!(combined.contains("Main content"));
385 assert!(combined.contains("Local content"));
386 }
387}