1use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use super::import_extractor::ImportExtractor;
11use super::rule_index::RuleIndex;
12use super::{ContextError, ContextResult};
13
14pub const DEFAULT_IMPORT_DEPTH: usize = 2;
18
19pub const MAX_IMPORT_DEPTH: usize = 5;
21
22#[derive(Debug, Clone)]
24pub struct MemoryLoaderConfig {
25 pub max_depth: usize,
28}
29
30impl Default for MemoryLoaderConfig {
31 fn default() -> Self {
32 Self {
33 max_depth: DEFAULT_IMPORT_DEPTH,
34 }
35 }
36}
37
38impl MemoryLoaderConfig {
39 pub fn full_expansion() -> Self {
41 Self {
42 max_depth: MAX_IMPORT_DEPTH,
43 }
44 }
45
46 pub fn with_max_depth(max_depth: usize) -> Self {
48 Self { max_depth }
49 }
50}
51
52pub struct MemoryLoader {
66 extractor: ImportExtractor,
67 config: MemoryLoaderConfig,
68}
69
70impl MemoryLoader {
71 pub fn new() -> Self {
73 Self::with_config(MemoryLoaderConfig::default())
74 }
75
76 pub fn with_config(config: MemoryLoaderConfig) -> Self {
78 Self {
79 extractor: ImportExtractor::new(),
80 config,
81 }
82 }
83
84 pub fn full_expansion() -> Self {
86 Self::with_config(MemoryLoaderConfig::full_expansion())
87 }
88
89 pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
97 let mut content = self.load_shared(start_dir).await?;
98 let local = self.load_local(start_dir).await?;
99 content.merge(local);
100 Ok(content)
101 }
102
103 pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
105 let mut content = MemoryContent::default();
106 let mut visited = HashSet::new();
107
108 for path in Self::find_claude_files(start_dir) {
109 match self
110 .load_with_imports(&path, start_dir, 0, &mut visited)
111 .await
112 {
113 Ok(text) => content.claude_md.push(text),
114 Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
115 }
116 }
117
118 let rules_dir = start_dir.join(".claude").join("rules");
119 if rules_dir.exists() {
120 content.rule_indices = self.scan_rules(&rules_dir).await?;
121 }
122
123 Ok(content)
124 }
125
126 pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
128 let mut content = MemoryContent::default();
129 let mut visited = HashSet::new();
130
131 for path in Self::find_local_files(start_dir) {
132 match self
133 .load_with_imports(&path, start_dir, 0, &mut visited)
134 .await
135 {
136 Ok(text) => content.local_md.push(text),
137 Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
138 }
139 }
140
141 Ok(content)
142 }
143
144 fn load_with_imports<'a>(
155 &'a self,
156 path: &'a Path,
157 project_root: &'a Path,
158 depth: usize,
159 visited: &'a mut HashSet<PathBuf>,
160 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
161 {
162 Box::pin(async move {
163 if depth > self.config.max_depth {
165 tracing::warn!(
166 "Import depth limit ({}) exceeded, skipping: {}",
167 self.config.max_depth,
168 path.display()
169 );
170 return Ok(String::new());
171 }
172
173 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
175 if visited.contains(&canonical) {
176 tracing::debug!("Circular import detected, skipping: {}", path.display());
177 return Ok(String::new());
178 }
179 visited.insert(canonical);
180
181 let content =
183 tokio::fs::read_to_string(path)
184 .await
185 .map_err(|e| ContextError::Source {
186 message: format!("Failed to read {}: {}", path.display(), e),
187 })?;
188
189 let file_dir = path.parent().unwrap_or(Path::new("."));
192 let imports = self.extractor.extract(&content, file_dir);
193
194 let imports: Vec<PathBuf> = imports
197 .into_iter()
198 .map(|p| Self::normalize_project_relative_path(&p, project_root))
199 .collect();
200
201 let mut result = content;
203 for import_path in imports {
204 if import_path.exists() {
205 if let Ok(imported) = self
206 .load_with_imports(&import_path, project_root, depth + 1, visited)
207 .await
208 && !imported.is_empty()
209 {
210 result.push_str("\n\n");
211 result.push_str(&imported);
212 }
213 } else {
214 tracing::debug!("Import not found, skipping: {}", import_path.display());
215 }
216 }
217
218 Ok(result)
219 })
220 }
221
222 fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
224 let mut files = Vec::new();
225
226 let claude_md = start_dir.join("CLAUDE.md");
228 if claude_md.exists() {
229 files.push(claude_md);
230 }
231
232 let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
234 if claude_dir_md.exists() {
235 files.push(claude_dir_md);
236 }
237
238 files
239 }
240
241 fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
243 let mut files = Vec::new();
244
245 let local_md = start_dir.join("CLAUDE.local.md");
247 if local_md.exists() {
248 files.push(local_md);
249 }
250
251 let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
253 if local_dir_md.exists() {
254 files.push(local_dir_md);
255 }
256
257 files
258 }
259
260 fn normalize_project_relative_path(path: &Path, project_root: &Path) -> PathBuf {
268 const MARKERS: [&str; 2] = ["/.agents/", "/.claude/"];
269
270 let path_str = path.to_string_lossy();
271
272 let last_marker_pos = MARKERS
274 .iter()
275 .filter_map(|marker| {
276 let count = path_str.matches(marker).count();
277 if count > 1 || MARKERS.iter().filter(|m| path_str.contains(*m)).count() > 1 {
279 path_str.rfind(marker).map(|pos| (pos, *marker))
280 } else {
281 None
282 }
283 })
284 .max_by_key(|(pos, _)| *pos);
285
286 if let Some((idx, _)) = last_marker_pos {
287 let relative_part = &path_str[idx + 1..]; project_root.join(relative_part)
289 } else {
290 path.to_path_buf()
291 }
292 }
293
294 async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
296 let mut indices = Vec::new();
297 self.scan_rules_recursive(dir, &mut indices).await?;
298 indices.sort_by(|a, b| b.priority.cmp(&a.priority));
299 Ok(indices)
300 }
301
302 fn scan_rules_recursive<'a>(
303 &'a self,
304 dir: &'a Path,
305 indices: &'a mut Vec<RuleIndex>,
306 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
307 Box::pin(async move {
308 let mut entries = tokio::fs::read_dir(dir)
309 .await
310 .map_err(|e| ContextError::Source {
311 message: format!("Failed to read rules directory: {}", e),
312 })?;
313
314 while let Some(entry) =
315 entries
316 .next_entry()
317 .await
318 .map_err(|e| ContextError::Source {
319 message: format!("Failed to read directory entry: {}", e),
320 })?
321 {
322 let path = entry.path();
323
324 if path.is_dir() {
325 self.scan_rules_recursive(&path, indices).await?;
326 } else if path.extension().is_some_and(|e| e == "md")
327 && let Some(index) = RuleIndex::from_file(&path)
328 {
329 indices.push(index);
330 }
331 }
332
333 Ok(())
334 })
335 }
336}
337
338impl Default for MemoryLoader {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[derive(Debug, Default, Clone)]
346pub struct MemoryContent {
347 pub claude_md: Vec<String>,
349 pub local_md: Vec<String>,
351 pub rule_indices: Vec<RuleIndex>,
353}
354
355impl MemoryContent {
356 pub fn combined_claude_md(&self) -> String {
358 self.claude_md
359 .iter()
360 .chain(self.local_md.iter())
361 .filter(|c| !c.trim().is_empty())
362 .cloned()
363 .collect::<Vec<_>>()
364 .join("\n\n")
365 }
366
367 pub fn is_empty(&self) -> bool {
369 self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
370 }
371
372 pub fn merge(&mut self, other: MemoryContent) {
374 self.claude_md.extend(other.claude_md);
375 self.local_md.extend(other.local_md);
376 self.rule_indices.extend(other.rule_indices);
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use tempfile::tempdir;
384 use tokio::fs;
385
386 #[tokio::test]
387 async fn test_load_claude_md() {
388 let dir = tempdir().unwrap();
389 fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
390 .await
391 .unwrap();
392
393 let loader = MemoryLoader::new();
394 let content = loader.load(dir.path()).await.unwrap();
395
396 assert_eq!(content.claude_md.len(), 1);
397 assert!(content.claude_md[0].contains("Test content"));
398 }
399
400 #[tokio::test]
401 async fn test_load_local_md() {
402 let dir = tempdir().unwrap();
403 fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
404 .await
405 .unwrap();
406
407 let loader = MemoryLoader::new();
408 let content = loader.load(dir.path()).await.unwrap();
409
410 assert_eq!(content.local_md.len(), 1);
411 assert!(content.local_md[0].contains("Private"));
412 }
413
414 #[tokio::test]
415 async fn test_scan_rules_recursive() {
416 let dir = tempdir().unwrap();
417 let rules_dir = dir.path().join(".claude").join("rules");
418 let sub_dir = rules_dir.join("frontend");
419 fs::create_dir_all(&sub_dir).await.unwrap();
420
421 fs::write(
422 rules_dir.join("rust.md"),
423 "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
424 )
425 .await
426 .unwrap();
427
428 fs::write(
429 sub_dir.join("react.md"),
430 "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
431 )
432 .await
433 .unwrap();
434
435 let loader = MemoryLoader::new();
436 let content = loader.load(dir.path()).await.unwrap();
437
438 assert_eq!(content.rule_indices.len(), 2);
439 assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
440 assert!(content.rule_indices.iter().any(|r| r.name == "react"));
441 }
442
443 #[tokio::test]
444 async fn test_import_syntax() {
445 let dir = tempdir().unwrap();
446
447 fs::write(
448 dir.path().join("CLAUDE.md"),
449 "# Main\n@docs/guidelines.md\nEnd",
450 )
451 .await
452 .unwrap();
453
454 let docs_dir = dir.path().join("docs");
455 fs::create_dir_all(&docs_dir).await.unwrap();
456 fs::write(docs_dir.join("guidelines.md"), "Imported content")
457 .await
458 .unwrap();
459
460 let loader = MemoryLoader::new();
461 let content = loader.load(dir.path()).await.unwrap();
462
463 assert!(content.combined_claude_md().contains("Imported content"));
464 }
465
466 #[tokio::test]
467 async fn test_combined_includes_local() {
468 let dir = tempdir().unwrap();
469 fs::write(dir.path().join("CLAUDE.md"), "Main content")
470 .await
471 .unwrap();
472 fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
473 .await
474 .unwrap();
475
476 let loader = MemoryLoader::new();
477 let content = loader.load(dir.path()).await.unwrap();
478
479 let combined = content.combined_claude_md();
480 assert!(combined.contains("Main content"));
481 assert!(combined.contains("Local content"));
482 }
483
484 #[tokio::test]
485 async fn test_recursive_import() {
486 let dir = tempdir().unwrap();
487
488 fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
490 .await
491 .unwrap();
492
493 let docs_dir = dir.path().join("docs");
494 fs::create_dir_all(&docs_dir).await.unwrap();
495 fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
496 .await
497 .unwrap();
498 fs::write(docs_dir.join("detail.md"), "Detail content")
499 .await
500 .unwrap();
501
502 let loader = MemoryLoader::new();
503 let content = loader.load(dir.path()).await.unwrap();
504 let combined = content.combined_claude_md();
505
506 assert!(combined.contains("Root content"));
507 assert!(combined.contains("Guide content"));
508 assert!(combined.contains("Detail content"));
509 }
510
511 #[tokio::test]
512 async fn test_depth_limit_default() {
513 let dir = tempdir().unwrap();
514
515 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
518 .await
519 .unwrap();
520
521 for i in 1..=3 {
522 let content = if i < 3 {
523 format!("Level {} @level{}.md", i, i + 1)
524 } else {
525 format!("Level {}", i)
526 };
527 fs::write(dir.path().join(format!("level{}.md", i)), content)
528 .await
529 .unwrap();
530 }
531
532 let loader = MemoryLoader::new(); let content = loader.load(dir.path()).await.unwrap();
534 let combined = content.combined_claude_md();
535
536 assert!(combined.contains("Level 0"));
538 assert!(combined.contains("Level 2"));
539 assert!(!combined.contains("Level 3"));
540 }
541
542 #[tokio::test]
543 async fn test_depth_limit_full_expansion() {
544 let dir = tempdir().unwrap();
545
546 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
549 .await
550 .unwrap();
551
552 for i in 1..=6 {
553 let content = if i < 6 {
554 format!("Level {} @level{}.md", i, i + 1)
555 } else {
556 format!("Level {}", i)
557 };
558 fs::write(dir.path().join(format!("level{}.md", i)), content)
559 .await
560 .unwrap();
561 }
562
563 let loader = MemoryLoader::full_expansion(); let content = loader.load(dir.path()).await.unwrap();
565 let combined = content.combined_claude_md();
566
567 assert!(combined.contains("Level 0"));
569 assert!(combined.contains("Level 5"));
570 assert!(!combined.contains("Level 6"));
571 }
572
573 #[tokio::test]
574 async fn test_depth_limit_custom() {
575 let dir = tempdir().unwrap();
576
577 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
579 .await
580 .unwrap();
581
582 for i in 1..=3 {
583 let content = if i < 3 {
584 format!("Level {} @level{}.md", i, i + 1)
585 } else {
586 format!("Level {}", i)
587 };
588 fs::write(dir.path().join(format!("level{}.md", i)), content)
589 .await
590 .unwrap();
591 }
592
593 let loader = MemoryLoader::with_config(MemoryLoaderConfig::with_max_depth(1));
595 let content = loader.load(dir.path()).await.unwrap();
596 let combined = content.combined_claude_md();
597
598 assert!(combined.contains("Level 0"));
599 assert!(combined.contains("Level 1"));
600 assert!(!combined.contains("Level 2"));
601 }
602
603 #[tokio::test]
604 async fn test_circular_import() {
605 let dir = tempdir().unwrap();
606
607 fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
609 .await
610 .unwrap();
611 fs::write(dir.path().join("a.md"), "A content @b.md")
612 .await
613 .unwrap();
614 fs::write(dir.path().join("b.md"), "B content @a.md")
615 .await
616 .unwrap();
617
618 let loader = MemoryLoader::new();
619 let result = loader.load(dir.path()).await;
620
621 assert!(result.is_ok());
623 let combined = result.unwrap().combined_claude_md();
624 assert!(combined.contains("A content"));
625 assert!(combined.contains("B content"));
626 }
627
628 #[tokio::test]
629 async fn test_import_in_code_block_ignored() {
630 let dir = tempdir().unwrap();
631
632 fs::write(
633 dir.path().join("CLAUDE.md"),
634 "# Example\n```\n@should/not/import.md\n```\n@should/import.md",
635 )
636 .await
637 .unwrap();
638
639 fs::write(
640 dir.path().join("should").join("import.md"),
641 "This is imported",
642 )
643 .await
644 .ok();
645 let should_dir = dir.path().join("should");
646 fs::create_dir_all(&should_dir).await.unwrap();
647 fs::write(should_dir.join("import.md"), "Imported content")
648 .await
649 .unwrap();
650
651 let loader = MemoryLoader::new();
652 let content = loader.load(dir.path()).await.unwrap();
653 let combined = content.combined_claude_md();
654
655 assert!(combined.contains("Imported content"));
656 assert!(combined.contains("@should/not/import.md"));
658 }
659
660 #[tokio::test]
661 async fn test_missing_import_ignored() {
662 let dir = tempdir().unwrap();
663
664 fs::write(
665 dir.path().join("CLAUDE.md"),
666 "# Main\n@nonexistent/file.md\nRest of content",
667 )
668 .await
669 .unwrap();
670
671 let loader = MemoryLoader::new();
672 let content = loader.load(dir.path()).await.unwrap();
673 let combined = content.combined_claude_md();
674
675 assert!(combined.contains("# Main"));
677 assert!(combined.contains("Rest of content"));
678 }
679
680 #[tokio::test]
681 async fn test_empty_content() {
682 let dir = tempdir().unwrap();
683
684 let loader = MemoryLoader::new();
685 let content = loader.load(dir.path()).await.unwrap();
686
687 assert!(content.is_empty());
688 assert!(content.combined_claude_md().is_empty());
689 }
690
691 #[tokio::test]
692 async fn test_memory_content_merge() {
693 let mut content1 = MemoryContent {
694 claude_md: vec!["content1".to_string()],
695 local_md: vec!["local1".to_string()],
696 rule_indices: vec![],
697 };
698
699 let content2 = MemoryContent {
700 claude_md: vec!["content2".to_string()],
701 local_md: vec!["local2".to_string()],
702 rule_indices: vec![],
703 };
704
705 content1.merge(content2);
706
707 assert_eq!(content1.claude_md.len(), 2);
708 assert_eq!(content1.local_md.len(), 2);
709 }
710}