claude_agent/context/
memory_loader.rs1use std::collections::HashSet;
4use std::path::{Path, PathBuf};
5
6use super::provider::MAX_IMPORT_DEPTH;
7use super::rule_index::RuleIndex;
8use super::{ContextError, ContextResult};
9
10#[derive(Debug, Default)]
11pub struct MemoryLoader {
12 loaded_paths: HashSet<PathBuf>,
13 current_depth: usize,
14}
15
16impl MemoryLoader {
17 pub fn new() -> Self {
18 Self::default()
19 }
20
21 pub async fn load(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
23 let mut content = self.load_shared(start_dir).await?;
24 let local = self.load_local(start_dir).await?;
25 content.merge(local);
26 Ok(content)
27 }
28
29 pub async fn load_shared(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
31 let mut content = MemoryContent::default();
32
33 for path in self.find_claude_files(start_dir) {
34 if let Ok(text) = self.load_file_with_imports(&path).await {
35 content.claude_md.push(text);
36 }
37 }
38
39 let rules_dir = start_dir.join(".claude").join("rules");
40 if rules_dir.exists() {
41 content.rule_indices = self.scan_rules_directory_recursive(&rules_dir).await?;
42 }
43
44 Ok(content)
45 }
46
47 pub async fn load_local(&mut self, start_dir: &Path) -> ContextResult<MemoryContent> {
49 let mut content = MemoryContent::default();
50
51 for path in self.find_local_files(start_dir) {
52 if let Ok(text) = self.load_file_with_imports(&path).await {
53 content.local_md.push(text);
54 }
55 }
56
57 Ok(content)
58 }
59
60 fn find_claude_files(&self, start_dir: &Path) -> Vec<PathBuf> {
61 let mut files = Vec::new();
62
63 let claude_md = start_dir.join("CLAUDE.md");
64 if claude_md.exists() {
65 files.push(claude_md);
66 }
67
68 let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
69 if claude_dir_md.exists() {
70 files.push(claude_dir_md);
71 }
72
73 files
74 }
75
76 fn find_local_files(&self, start_dir: &Path) -> Vec<PathBuf> {
77 let mut files = Vec::new();
78
79 let local_md = start_dir.join("CLAUDE.local.md");
80 if local_md.exists() {
81 files.push(local_md);
82 }
83
84 let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
85 if local_dir_md.exists() {
86 files.push(local_dir_md);
87 }
88
89 files
90 }
91
92 fn scan_rules_directory_recursive<'a>(
93 &'a self,
94 dir: &'a Path,
95 ) -> std::pin::Pin<
96 Box<dyn std::future::Future<Output = ContextResult<Vec<RuleIndex>>> + Send + 'a>,
97 > {
98 Box::pin(async move {
99 let mut indices = Vec::new();
100
101 let mut entries = tokio::fs::read_dir(dir)
102 .await
103 .map_err(|e| ContextError::Source {
104 message: format!("Failed to read rules directory: {}", e),
105 })?;
106
107 while let Some(entry) =
108 entries
109 .next_entry()
110 .await
111 .map_err(|e| ContextError::Source {
112 message: format!("Failed to read directory entry: {}", e),
113 })?
114 {
115 let path = entry.path();
116
117 if path.is_dir() {
118 let sub_indices = self.scan_rules_directory_recursive(&path).await?;
119 indices.extend(sub_indices);
120 } else if path.extension().is_some_and(|e| e == "md")
121 && let Some(index) = RuleIndex::from_file(&path)
122 {
123 indices.push(index);
124 }
125 }
126
127 indices.sort_by(|a, b| b.priority.cmp(&a.priority));
128 Ok(indices)
129 })
130 }
131
132 fn load_file_with_imports<'a>(
133 &'a mut self,
134 path: &'a Path,
135 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
136 {
137 Box::pin(async move {
138 if self.current_depth >= MAX_IMPORT_DEPTH {
139 tracing::warn!(
140 "Import depth limit ({}) reached, skipping: {}",
141 MAX_IMPORT_DEPTH,
142 path.display()
143 );
144 return Ok(String::new());
145 }
146
147 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
148 if self.loaded_paths.contains(&canonical) {
149 return Ok(String::new());
150 }
151 self.loaded_paths.insert(canonical.clone());
152
153 let content =
154 tokio::fs::read_to_string(path)
155 .await
156 .map_err(|e| ContextError::Source {
157 message: format!("Failed to read {}: {}", path.display(), e),
158 })?;
159
160 self.current_depth += 1;
161 let result = self
162 .process_imports(&content, path.parent().unwrap_or(Path::new(".")))
163 .await;
164 self.current_depth -= 1;
165
166 result
167 })
168 }
169
170 fn expand_home(path: &str) -> PathBuf {
171 if let Some(rest) = path.strip_prefix("~/")
172 && let Some(home) = crate::common::home_dir()
173 {
174 return home.join(rest);
175 }
176 PathBuf::from(path)
177 }
178
179 fn process_imports<'a>(
180 &'a mut self,
181 content: &'a str,
182 base_dir: &'a Path,
183 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
184 {
185 Box::pin(async move {
186 let mut result = String::new();
187
188 for line in content.lines() {
189 let trimmed = line.trim();
190
191 if trimmed.starts_with('@') && !trimmed.starts_with("@@") {
192 let import_path = trimmed.trim_start_matches('@').trim();
193 if !import_path.is_empty() {
194 let full_path = if import_path.starts_with("~/") {
195 Self::expand_home(import_path)
196 } else if import_path.starts_with('/') {
197 PathBuf::from(import_path)
198 } else {
199 base_dir.join(import_path)
200 };
201
202 if full_path.exists() {
203 match self.load_file_with_imports(&full_path).await {
204 Ok(imported) => {
205 result.push_str(&imported);
206 result.push('\n');
207 }
208 Err(e) => {
209 tracing::warn!("Failed to import {}: {}", import_path, e);
210 result.push_str(line);
211 result.push('\n');
212 }
213 }
214 } else {
215 result.push_str(line);
216 result.push('\n');
217 }
218 } else {
219 result.push_str(line);
220 result.push('\n');
221 }
222 } else {
223 result.push_str(line);
224 result.push('\n');
225 }
226 }
227
228 Ok(result)
229 })
230 }
231}
232
233#[derive(Debug, Default, Clone)]
234pub struct MemoryContent {
235 pub claude_md: Vec<String>,
236 pub local_md: Vec<String>,
237 pub rule_indices: Vec<RuleIndex>,
238}
239
240impl MemoryContent {
241 pub fn combined_claude_md(&self) -> String {
242 self.claude_md
243 .iter()
244 .chain(self.local_md.iter())
245 .filter(|c| !c.trim().is_empty())
246 .cloned()
247 .collect::<Vec<_>>()
248 .join("\n\n")
249 }
250
251 pub fn is_empty(&self) -> bool {
252 self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
253 }
254
255 pub fn merge(&mut self, other: MemoryContent) {
256 self.claude_md.extend(other.claude_md);
257 self.local_md.extend(other.local_md);
258 self.rule_indices.extend(other.rule_indices);
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use tempfile::tempdir;
266 use tokio::fs;
267
268 #[tokio::test]
269 async fn test_load_claude_md() {
270 let dir = tempdir().unwrap();
271 fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
272 .await
273 .unwrap();
274
275 let mut loader = MemoryLoader::new();
276 let content = loader.load(dir.path()).await.unwrap();
277
278 assert_eq!(content.claude_md.len(), 1);
279 assert!(content.claude_md[0].contains("Test content"));
280 }
281
282 #[tokio::test]
283 async fn test_load_local_md() {
284 let dir = tempdir().unwrap();
285 fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
286 .await
287 .unwrap();
288
289 let mut loader = MemoryLoader::new();
290 let content = loader.load(dir.path()).await.unwrap();
291
292 assert_eq!(content.local_md.len(), 1);
293 assert!(content.local_md[0].contains("Private"));
294 }
295
296 #[tokio::test]
297 async fn test_scan_rules_recursive() {
298 let dir = tempdir().unwrap();
299 let rules_dir = dir.path().join(".claude").join("rules");
300 let sub_dir = rules_dir.join("frontend");
301 fs::create_dir_all(&sub_dir).await.unwrap();
302
303 fs::write(
304 rules_dir.join("rust.md"),
305 "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
306 )
307 .await
308 .unwrap();
309
310 fs::write(
311 sub_dir.join("react.md"),
312 "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
313 )
314 .await
315 .unwrap();
316
317 let mut loader = MemoryLoader::new();
318 let content = loader.load(dir.path()).await.unwrap();
319
320 assert_eq!(content.rule_indices.len(), 2);
321 assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
322 assert!(content.rule_indices.iter().any(|r| r.name == "react"));
323 }
324
325 #[tokio::test]
326 async fn test_import_syntax() {
327 let dir = tempdir().unwrap();
328
329 fs::write(
330 dir.path().join("CLAUDE.md"),
331 "# Main\n@docs/guidelines.md\nEnd",
332 )
333 .await
334 .unwrap();
335
336 let docs_dir = dir.path().join("docs");
337 fs::create_dir_all(&docs_dir).await.unwrap();
338 fs::write(docs_dir.join("guidelines.md"), "Imported content")
339 .await
340 .unwrap();
341
342 let mut loader = MemoryLoader::new();
343 let content = loader.load(dir.path()).await.unwrap();
344
345 assert!(content.combined_claude_md().contains("Imported content"));
346 }
347
348 #[tokio::test]
349 async fn test_combined_includes_local() {
350 let dir = tempdir().unwrap();
351 fs::write(dir.path().join("CLAUDE.md"), "Main content")
352 .await
353 .unwrap();
354 fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
355 .await
356 .unwrap();
357
358 let mut loader = MemoryLoader::new();
359 let content = loader.load(dir.path()).await.unwrap();
360
361 let combined = content.combined_claude_md();
362 assert!(combined.contains("Main content"));
363 assert!(combined.contains("Local content"));
364 }
365}