1use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use super::import_extractor::ImportExtractor;
11use super::rule_index::RuleIndex;
12use super::{ContextError, ContextResult};
13
14pub const MAX_IMPORT_DEPTH: usize = 5;
16
17pub struct MemoryLoader {
31 extractor: ImportExtractor,
32}
33
34impl MemoryLoader {
35 pub fn new() -> Self {
37 Self {
38 extractor: ImportExtractor::new(),
39 }
40 }
41
42 pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
50 let mut content = self.load_shared(start_dir).await?;
51 let local = self.load_local(start_dir).await?;
52 content.merge(local);
53 Ok(content)
54 }
55
56 pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
58 let mut content = MemoryContent::default();
59 let mut visited = HashSet::new();
60
61 for path in Self::find_claude_files(start_dir) {
62 match self.load_with_imports(&path, 0, &mut visited).await {
63 Ok(text) => content.claude_md.push(text),
64 Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
65 }
66 }
67
68 let rules_dir = start_dir.join(".claude").join("rules");
69 if rules_dir.exists() {
70 content.rule_indices = self.scan_rules(&rules_dir).await?;
71 }
72
73 Ok(content)
74 }
75
76 pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
78 let mut content = MemoryContent::default();
79 let mut visited = HashSet::new();
80
81 for path in Self::find_local_files(start_dir) {
82 match self.load_with_imports(&path, 0, &mut visited).await {
83 Ok(text) => content.local_md.push(text),
84 Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
85 }
86 }
87
88 Ok(content)
89 }
90
91 fn load_with_imports<'a>(
101 &'a self,
102 path: &'a Path,
103 depth: usize,
104 visited: &'a mut HashSet<PathBuf>,
105 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
106 {
107 Box::pin(async move {
108 if depth > MAX_IMPORT_DEPTH {
110 tracing::warn!(
111 "Import depth limit ({}) exceeded, skipping: {}",
112 MAX_IMPORT_DEPTH,
113 path.display()
114 );
115 return Ok(String::new());
116 }
117
118 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
120 if visited.contains(&canonical) {
121 tracing::debug!("Circular import detected, skipping: {}", path.display());
122 return Ok(String::new());
123 }
124 visited.insert(canonical);
125
126 let content =
128 tokio::fs::read_to_string(path)
129 .await
130 .map_err(|e| ContextError::Source {
131 message: format!("Failed to read {}: {}", path.display(), e),
132 })?;
133
134 let base_dir = path.parent().unwrap_or(Path::new("."));
136 let imports = self.extractor.extract(&content, base_dir);
137
138 let mut result = content;
140 for import_path in imports {
141 if import_path.exists() {
142 if let Ok(imported) = self
143 .load_with_imports(&import_path, depth + 1, visited)
144 .await
145 && !imported.is_empty()
146 {
147 result.push_str("\n\n");
148 result.push_str(&imported);
149 }
150 } else {
151 tracing::debug!("Import not found, skipping: {}", import_path.display());
152 }
153 }
154
155 Ok(result)
156 })
157 }
158
159 fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
161 let mut files = Vec::new();
162
163 let claude_md = start_dir.join("CLAUDE.md");
165 if claude_md.exists() {
166 files.push(claude_md);
167 }
168
169 let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
171 if claude_dir_md.exists() {
172 files.push(claude_dir_md);
173 }
174
175 files
176 }
177
178 fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
180 let mut files = Vec::new();
181
182 let local_md = start_dir.join("CLAUDE.local.md");
184 if local_md.exists() {
185 files.push(local_md);
186 }
187
188 let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
190 if local_dir_md.exists() {
191 files.push(local_dir_md);
192 }
193
194 files
195 }
196
197 async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
199 let mut indices = Vec::new();
200 self.scan_rules_recursive(dir, &mut indices).await?;
201 indices.sort_by(|a, b| b.priority.cmp(&a.priority));
202 Ok(indices)
203 }
204
205 fn scan_rules_recursive<'a>(
206 &'a self,
207 dir: &'a Path,
208 indices: &'a mut Vec<RuleIndex>,
209 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
210 Box::pin(async move {
211 let mut entries = tokio::fs::read_dir(dir)
212 .await
213 .map_err(|e| ContextError::Source {
214 message: format!("Failed to read rules directory: {}", e),
215 })?;
216
217 while let Some(entry) =
218 entries
219 .next_entry()
220 .await
221 .map_err(|e| ContextError::Source {
222 message: format!("Failed to read directory entry: {}", e),
223 })?
224 {
225 let path = entry.path();
226
227 if path.is_dir() {
228 self.scan_rules_recursive(&path, indices).await?;
229 } else if path.extension().is_some_and(|e| e == "md")
230 && let Some(index) = RuleIndex::from_file(&path)
231 {
232 indices.push(index);
233 }
234 }
235
236 Ok(())
237 })
238 }
239}
240
241impl Default for MemoryLoader {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247#[derive(Debug, Default, Clone)]
249pub struct MemoryContent {
250 pub claude_md: Vec<String>,
252 pub local_md: Vec<String>,
254 pub rule_indices: Vec<RuleIndex>,
256}
257
258impl MemoryContent {
259 pub fn combined_claude_md(&self) -> String {
261 self.claude_md
262 .iter()
263 .chain(self.local_md.iter())
264 .filter(|c| !c.trim().is_empty())
265 .cloned()
266 .collect::<Vec<_>>()
267 .join("\n\n")
268 }
269
270 pub fn is_empty(&self) -> bool {
272 self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
273 }
274
275 pub fn merge(&mut self, other: MemoryContent) {
277 self.claude_md.extend(other.claude_md);
278 self.local_md.extend(other.local_md);
279 self.rule_indices.extend(other.rule_indices);
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use tempfile::tempdir;
287 use tokio::fs;
288
289 #[tokio::test]
290 async fn test_load_claude_md() {
291 let dir = tempdir().unwrap();
292 fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
293 .await
294 .unwrap();
295
296 let loader = MemoryLoader::new();
297 let content = loader.load(dir.path()).await.unwrap();
298
299 assert_eq!(content.claude_md.len(), 1);
300 assert!(content.claude_md[0].contains("Test content"));
301 }
302
303 #[tokio::test]
304 async fn test_load_local_md() {
305 let dir = tempdir().unwrap();
306 fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
307 .await
308 .unwrap();
309
310 let loader = MemoryLoader::new();
311 let content = loader.load(dir.path()).await.unwrap();
312
313 assert_eq!(content.local_md.len(), 1);
314 assert!(content.local_md[0].contains("Private"));
315 }
316
317 #[tokio::test]
318 async fn test_scan_rules_recursive() {
319 let dir = tempdir().unwrap();
320 let rules_dir = dir.path().join(".claude").join("rules");
321 let sub_dir = rules_dir.join("frontend");
322 fs::create_dir_all(&sub_dir).await.unwrap();
323
324 fs::write(
325 rules_dir.join("rust.md"),
326 "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
327 )
328 .await
329 .unwrap();
330
331 fs::write(
332 sub_dir.join("react.md"),
333 "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
334 )
335 .await
336 .unwrap();
337
338 let loader = MemoryLoader::new();
339 let content = loader.load(dir.path()).await.unwrap();
340
341 assert_eq!(content.rule_indices.len(), 2);
342 assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
343 assert!(content.rule_indices.iter().any(|r| r.name == "react"));
344 }
345
346 #[tokio::test]
347 async fn test_import_syntax() {
348 let dir = tempdir().unwrap();
349
350 fs::write(
351 dir.path().join("CLAUDE.md"),
352 "# Main\n@docs/guidelines.md\nEnd",
353 )
354 .await
355 .unwrap();
356
357 let docs_dir = dir.path().join("docs");
358 fs::create_dir_all(&docs_dir).await.unwrap();
359 fs::write(docs_dir.join("guidelines.md"), "Imported content")
360 .await
361 .unwrap();
362
363 let loader = MemoryLoader::new();
364 let content = loader.load(dir.path()).await.unwrap();
365
366 assert!(content.combined_claude_md().contains("Imported content"));
367 }
368
369 #[tokio::test]
370 async fn test_combined_includes_local() {
371 let dir = tempdir().unwrap();
372 fs::write(dir.path().join("CLAUDE.md"), "Main content")
373 .await
374 .unwrap();
375 fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
376 .await
377 .unwrap();
378
379 let loader = MemoryLoader::new();
380 let content = loader.load(dir.path()).await.unwrap();
381
382 let combined = content.combined_claude_md();
383 assert!(combined.contains("Main content"));
384 assert!(combined.contains("Local content"));
385 }
386
387 #[tokio::test]
388 async fn test_recursive_import() {
389 let dir = tempdir().unwrap();
390
391 fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
393 .await
394 .unwrap();
395
396 let docs_dir = dir.path().join("docs");
397 fs::create_dir_all(&docs_dir).await.unwrap();
398 fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
399 .await
400 .unwrap();
401 fs::write(docs_dir.join("detail.md"), "Detail content")
402 .await
403 .unwrap();
404
405 let loader = MemoryLoader::new();
406 let content = loader.load(dir.path()).await.unwrap();
407 let combined = content.combined_claude_md();
408
409 assert!(combined.contains("Root content"));
410 assert!(combined.contains("Guide content"));
411 assert!(combined.contains("Detail content"));
412 }
413
414 #[tokio::test]
415 async fn test_depth_limit() {
416 let dir = tempdir().unwrap();
417
418 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
421 .await
422 .unwrap();
423
424 for i in 1..=6 {
425 let content = if i < 6 {
426 format!("Level {} @level{}.md", i, i + 1)
427 } else {
428 format!("Level {}", i)
429 };
430 fs::write(dir.path().join(format!("level{}.md", i)), content)
431 .await
432 .unwrap();
433 }
434
435 let loader = MemoryLoader::new();
436 let content = loader.load(dir.path()).await.unwrap();
437 let combined = content.combined_claude_md();
438
439 assert!(combined.contains("Level 0"));
441 assert!(combined.contains("Level 5"));
442 assert!(!combined.contains("Level 6"));
443 }
444
445 #[tokio::test]
446 async fn test_circular_import() {
447 let dir = tempdir().unwrap();
448
449 fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
451 .await
452 .unwrap();
453 fs::write(dir.path().join("a.md"), "A content @b.md")
454 .await
455 .unwrap();
456 fs::write(dir.path().join("b.md"), "B content @a.md")
457 .await
458 .unwrap();
459
460 let loader = MemoryLoader::new();
461 let result = loader.load(dir.path()).await;
462
463 assert!(result.is_ok());
465 let combined = result.unwrap().combined_claude_md();
466 assert!(combined.contains("A content"));
467 assert!(combined.contains("B content"));
468 }
469
470 #[tokio::test]
471 async fn test_import_in_code_block_ignored() {
472 let dir = tempdir().unwrap();
473
474 fs::write(
475 dir.path().join("CLAUDE.md"),
476 "# Example\n```\n@should/not/import.md\n```\n@should/import.md",
477 )
478 .await
479 .unwrap();
480
481 fs::write(
482 dir.path().join("should").join("import.md"),
483 "This is imported",
484 )
485 .await
486 .ok();
487 let should_dir = dir.path().join("should");
488 fs::create_dir_all(&should_dir).await.unwrap();
489 fs::write(should_dir.join("import.md"), "Imported content")
490 .await
491 .unwrap();
492
493 let loader = MemoryLoader::new();
494 let content = loader.load(dir.path()).await.unwrap();
495 let combined = content.combined_claude_md();
496
497 assert!(combined.contains("Imported content"));
498 assert!(combined.contains("@should/not/import.md"));
500 }
501
502 #[tokio::test]
503 async fn test_missing_import_ignored() {
504 let dir = tempdir().unwrap();
505
506 fs::write(
507 dir.path().join("CLAUDE.md"),
508 "# Main\n@nonexistent/file.md\nRest of content",
509 )
510 .await
511 .unwrap();
512
513 let loader = MemoryLoader::new();
514 let content = loader.load(dir.path()).await.unwrap();
515 let combined = content.combined_claude_md();
516
517 assert!(combined.contains("# Main"));
519 assert!(combined.contains("Rest of content"));
520 }
521
522 #[tokio::test]
523 async fn test_empty_content() {
524 let dir = tempdir().unwrap();
525
526 let loader = MemoryLoader::new();
527 let content = loader.load(dir.path()).await.unwrap();
528
529 assert!(content.is_empty());
530 assert!(content.combined_claude_md().is_empty());
531 }
532
533 #[tokio::test]
534 async fn test_memory_content_merge() {
535 let mut content1 = MemoryContent {
536 claude_md: vec!["content1".to_string()],
537 local_md: vec!["local1".to_string()],
538 rule_indices: vec![],
539 };
540
541 let content2 = MemoryContent {
542 claude_md: vec!["content2".to_string()],
543 local_md: vec!["local2".to_string()],
544 rule_indices: vec![],
545 };
546
547 content1.merge(content2);
548
549 assert_eq!(content1.claude_md.len(), 2);
550 assert_eq!(content1.local_md.len(), 2);
551 }
552}