1use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10use super::import_extractor::ImportExtractor;
11use super::rule_index::RuleIndex;
12use super::{ContextError, ContextResult};
13
14pub(crate) const DEFAULT_IMPORT_DEPTH: usize = 2;
18
19pub(crate) const MAX_IMPORT_DEPTH: usize = 5;
21
22#[derive(Debug, Clone)]
24pub struct MemoryLoaderConfig {
25 pub max_depth: usize,
28}
29
30impl Default for MemoryLoaderConfig {
31 fn default() -> Self {
32 Self {
33 max_depth: DEFAULT_IMPORT_DEPTH,
34 }
35 }
36}
37
38impl MemoryLoaderConfig {
39 pub fn full_expansion() -> Self {
41 Self {
42 max_depth: MAX_IMPORT_DEPTH,
43 }
44 }
45
46 pub fn max_depth(max_depth: usize) -> Self {
48 Self { max_depth }
49 }
50}
51
52pub struct MemoryLoader {
66 extractor: ImportExtractor,
67 config: MemoryLoaderConfig,
68}
69
70impl MemoryLoader {
71 pub fn new() -> Self {
73 Self::from_config(MemoryLoaderConfig::default())
74 }
75
76 pub fn from_config(config: MemoryLoaderConfig) -> Self {
78 Self {
79 extractor: ImportExtractor::new(),
80 config,
81 }
82 }
83
84 pub fn full_expansion() -> Self {
86 Self::from_config(MemoryLoaderConfig::full_expansion())
87 }
88
89 pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
97 let mut content = self.load_shared(start_dir).await?;
98 let local = self.load_local(start_dir).await?;
99 content.merge(local);
100 Ok(content)
101 }
102
103 pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
105 let mut content = MemoryContent::default();
106 let mut visited = HashSet::new();
107
108 for path in Self::find_claude_files(start_dir) {
109 match self
110 .load_with_imports(&path, start_dir, 0, &mut visited)
111 .await
112 {
113 Ok(text) => content.claude_md.push(text),
114 Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
115 }
116 }
117
118 let rules_dir = start_dir.join(".claude").join("rules");
119 if rules_dir.exists() {
120 content.rule_indices = self.scan_rules(&rules_dir).await?;
121 }
122
123 Ok(content)
124 }
125
126 pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
128 let mut content = MemoryContent::default();
129 let mut visited = HashSet::new();
130
131 for path in Self::find_local_files(start_dir) {
132 match self
133 .load_with_imports(&path, start_dir, 0, &mut visited)
134 .await
135 {
136 Ok(text) => content.local_md.push(text),
137 Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
138 }
139 }
140
141 Ok(content)
142 }
143
144 fn load_with_imports<'a>(
155 &'a self,
156 path: &'a Path,
157 project_root: &'a Path,
158 depth: usize,
159 visited: &'a mut HashSet<PathBuf>,
160 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
161 {
162 Box::pin(async move {
163 if depth > self.config.max_depth {
165 tracing::warn!(
166 "Import depth limit ({}) exceeded, skipping: {}",
167 self.config.max_depth,
168 path.display()
169 );
170 return Ok(String::new());
171 }
172
173 let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
175 if visited.contains(&canonical) {
176 tracing::debug!("Circular import detected, skipping: {}", path.display());
177 return Ok(String::new());
178 }
179 visited.insert(canonical);
180
181 let content =
182 tokio::fs::read_to_string(path)
183 .await
184 .map_err(|e| ContextError::Source {
185 message: format!("Failed to read {}: {}", path.display(), e),
186 })?;
187
188 let file_dir = path.parent().unwrap_or(Path::new("."));
190 let imports = self.extractor.extract(&content, file_dir);
191
192 let imports: Vec<PathBuf> = imports
195 .into_iter()
196 .map(|p| Self::normalize_project_relative_path(&p, project_root))
197 .collect();
198
199 let mut result = content;
200 for import_path in imports {
201 if import_path.exists() {
202 if let Ok(imported) = self
203 .load_with_imports(&import_path, project_root, depth + 1, visited)
204 .await
205 && !imported.is_empty()
206 {
207 result.push_str("\n\n");
208 result.push_str(&imported);
209 }
210 } else {
211 tracing::debug!("Import not found, skipping: {}", import_path.display());
212 }
213 }
214
215 Ok(result)
216 })
217 }
218
219 fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
221 let mut files = Vec::new();
222
223 let claude_md = start_dir.join("CLAUDE.md");
225 if claude_md.exists() {
226 files.push(claude_md);
227 }
228
229 let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
231 if claude_dir_md.exists() {
232 files.push(claude_dir_md);
233 }
234
235 files
236 }
237
238 fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
240 let mut files = Vec::new();
241
242 let local_md = start_dir.join("CLAUDE.local.md");
244 if local_md.exists() {
245 files.push(local_md);
246 }
247
248 let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
250 if local_dir_md.exists() {
251 files.push(local_dir_md);
252 }
253
254 files
255 }
256
257 fn normalize_project_relative_path(path: &Path, project_root: &Path) -> PathBuf {
265 const MARKERS: [&str; 2] = ["/.agents/", "/.claude/"];
266
267 let path_str = path.to_string_lossy();
268
269 let last_marker_pos = MARKERS
271 .iter()
272 .filter_map(|marker| {
273 let count = path_str.matches(marker).count();
274 if count > 1 || MARKERS.iter().filter(|m| path_str.contains(*m)).count() > 1 {
276 path_str.rfind(marker).map(|pos| (pos, *marker))
277 } else {
278 None
279 }
280 })
281 .max_by_key(|(pos, _)| *pos);
282
283 if let Some((idx, _)) = last_marker_pos {
284 let relative_part = &path_str[idx + 1..]; project_root.join(relative_part)
286 } else {
287 path.to_path_buf()
288 }
289 }
290
291 async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
293 let mut indices = Vec::new();
294 self.scan_rules_recursive(dir, &mut indices).await?;
295 indices.sort_by(|a, b| b.priority.cmp(&a.priority));
296 Ok(indices)
297 }
298
299 fn scan_rules_recursive<'a>(
300 &'a self,
301 dir: &'a Path,
302 indices: &'a mut Vec<RuleIndex>,
303 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
304 Box::pin(async move {
305 let mut entries = tokio::fs::read_dir(dir)
306 .await
307 .map_err(|e| ContextError::Source {
308 message: format!("Failed to read rules directory: {}", e),
309 })?;
310
311 while let Some(entry) =
312 entries
313 .next_entry()
314 .await
315 .map_err(|e| ContextError::Source {
316 message: format!("Failed to read directory entry: {}", e),
317 })?
318 {
319 let path = entry.path();
320
321 if path.is_dir() {
322 self.scan_rules_recursive(&path, indices).await?;
323 } else if path.extension().is_some_and(|e| e == "md")
324 && let Some(index) = RuleIndex::from_file(&path)
325 {
326 indices.push(index);
327 }
328 }
329
330 Ok(())
331 })
332 }
333}
334
335impl Default for MemoryLoader {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341#[derive(Debug, Default, Clone)]
343pub struct MemoryContent {
344 pub claude_md: Vec<String>,
346 pub local_md: Vec<String>,
348 pub rule_indices: Vec<RuleIndex>,
350}
351
352impl MemoryContent {
353 pub fn combined_claude_md(&self) -> String {
355 self.claude_md
356 .iter()
357 .chain(self.local_md.iter())
358 .filter(|c| !c.trim().is_empty())
359 .cloned()
360 .collect::<Vec<_>>()
361 .join("\n\n")
362 }
363
364 pub fn is_empty(&self) -> bool {
366 self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
367 }
368
369 pub fn merge(&mut self, other: MemoryContent) {
371 self.claude_md.extend(other.claude_md);
372 self.local_md.extend(other.local_md);
373 self.rule_indices.extend(other.rule_indices);
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use tempfile::tempdir;
381 use tokio::fs;
382
383 #[tokio::test]
384 async fn test_load_claude_md() {
385 let dir = tempdir().unwrap();
386 fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
387 .await
388 .unwrap();
389
390 let loader = MemoryLoader::new();
391 let content = loader.load(dir.path()).await.unwrap();
392
393 assert_eq!(content.claude_md.len(), 1);
394 assert!(content.claude_md[0].contains("Test content"));
395 }
396
397 #[tokio::test]
398 async fn test_load_local_md() {
399 let dir = tempdir().unwrap();
400 fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
401 .await
402 .unwrap();
403
404 let loader = MemoryLoader::new();
405 let content = loader.load(dir.path()).await.unwrap();
406
407 assert_eq!(content.local_md.len(), 1);
408 assert!(content.local_md[0].contains("Private"));
409 }
410
411 #[tokio::test]
412 async fn test_scan_rules_recursive() {
413 let dir = tempdir().unwrap();
414 let rules_dir = dir.path().join(".claude").join("rules");
415 let sub_dir = rules_dir.join("frontend");
416 fs::create_dir_all(&sub_dir).await.unwrap();
417
418 fs::write(
419 rules_dir.join("rust.md"),
420 "---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
421 )
422 .await
423 .unwrap();
424
425 fs::write(
426 sub_dir.join("react.md"),
427 "---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
428 )
429 .await
430 .unwrap();
431
432 let loader = MemoryLoader::new();
433 let content = loader.load(dir.path()).await.unwrap();
434
435 assert_eq!(content.rule_indices.len(), 2);
436 assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
437 assert!(content.rule_indices.iter().any(|r| r.name == "react"));
438 }
439
440 #[tokio::test]
441 async fn test_import_syntax() {
442 let dir = tempdir().unwrap();
443
444 fs::write(
445 dir.path().join("CLAUDE.md"),
446 "# Main\n@docs/guidelines.md\nEnd",
447 )
448 .await
449 .unwrap();
450
451 let docs_dir = dir.path().join("docs");
452 fs::create_dir_all(&docs_dir).await.unwrap();
453 fs::write(docs_dir.join("guidelines.md"), "Imported content")
454 .await
455 .unwrap();
456
457 let loader = MemoryLoader::new();
458 let content = loader.load(dir.path()).await.unwrap();
459
460 assert!(content.combined_claude_md().contains("Imported content"));
461 }
462
463 #[tokio::test]
464 async fn test_combined_includes_local() {
465 let dir = tempdir().unwrap();
466 fs::write(dir.path().join("CLAUDE.md"), "Main content")
467 .await
468 .unwrap();
469 fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
470 .await
471 .unwrap();
472
473 let loader = MemoryLoader::new();
474 let content = loader.load(dir.path()).await.unwrap();
475
476 let combined = content.combined_claude_md();
477 assert!(combined.contains("Main content"));
478 assert!(combined.contains("Local content"));
479 }
480
481 #[tokio::test]
482 async fn test_recursive_import() {
483 let dir = tempdir().unwrap();
484
485 fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
487 .await
488 .unwrap();
489
490 let docs_dir = dir.path().join("docs");
491 fs::create_dir_all(&docs_dir).await.unwrap();
492 fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
493 .await
494 .unwrap();
495 fs::write(docs_dir.join("detail.md"), "Detail content")
496 .await
497 .unwrap();
498
499 let loader = MemoryLoader::new();
500 let content = loader.load(dir.path()).await.unwrap();
501 let combined = content.combined_claude_md();
502
503 assert!(combined.contains("Root content"));
504 assert!(combined.contains("Guide content"));
505 assert!(combined.contains("Detail content"));
506 }
507
508 #[tokio::test]
509 async fn test_depth_limit_default() {
510 let dir = tempdir().unwrap();
511
512 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
515 .await
516 .unwrap();
517
518 for i in 1..=3 {
519 let content = if i < 3 {
520 format!("Level {} @level{}.md", i, i + 1)
521 } else {
522 format!("Level {}", i)
523 };
524 fs::write(dir.path().join(format!("level{}.md", i)), content)
525 .await
526 .unwrap();
527 }
528
529 let loader = MemoryLoader::new(); let content = loader.load(dir.path()).await.unwrap();
531 let combined = content.combined_claude_md();
532
533 assert!(combined.contains("Level 0"));
535 assert!(combined.contains("Level 2"));
536 assert!(!combined.contains("Level 3"));
537 }
538
539 #[tokio::test]
540 async fn test_depth_limit_full_expansion() {
541 let dir = tempdir().unwrap();
542
543 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
546 .await
547 .unwrap();
548
549 for i in 1..=6 {
550 let content = if i < 6 {
551 format!("Level {} @level{}.md", i, i + 1)
552 } else {
553 format!("Level {}", i)
554 };
555 fs::write(dir.path().join(format!("level{}.md", i)), content)
556 .await
557 .unwrap();
558 }
559
560 let loader = MemoryLoader::full_expansion(); let content = loader.load(dir.path()).await.unwrap();
562 let combined = content.combined_claude_md();
563
564 assert!(combined.contains("Level 0"));
566 assert!(combined.contains("Level 5"));
567 assert!(!combined.contains("Level 6"));
568 }
569
570 #[tokio::test]
571 async fn test_depth_limit_custom() {
572 let dir = tempdir().unwrap();
573
574 fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
576 .await
577 .unwrap();
578
579 for i in 1..=3 {
580 let content = if i < 3 {
581 format!("Level {} @level{}.md", i, i + 1)
582 } else {
583 format!("Level {}", i)
584 };
585 fs::write(dir.path().join(format!("level{}.md", i)), content)
586 .await
587 .unwrap();
588 }
589
590 let loader = MemoryLoader::from_config(MemoryLoaderConfig::max_depth(1));
592 let content = loader.load(dir.path()).await.unwrap();
593 let combined = content.combined_claude_md();
594
595 assert!(combined.contains("Level 0"));
596 assert!(combined.contains("Level 1"));
597 assert!(!combined.contains("Level 2"));
598 }
599
600 #[tokio::test]
601 async fn test_circular_import() {
602 let dir = tempdir().unwrap();
603
604 fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
606 .await
607 .unwrap();
608 fs::write(dir.path().join("a.md"), "A content @b.md")
609 .await
610 .unwrap();
611 fs::write(dir.path().join("b.md"), "B content @a.md")
612 .await
613 .unwrap();
614
615 let loader = MemoryLoader::new();
616 let result = loader.load(dir.path()).await;
617
618 assert!(result.is_ok());
620 let combined = result.unwrap().combined_claude_md();
621 assert!(combined.contains("A content"));
622 assert!(combined.contains("B content"));
623 }
624
625 #[tokio::test]
626 async fn test_import_in_code_block_ignored() {
627 let dir = tempdir().unwrap();
628
629 fs::write(
630 dir.path().join("CLAUDE.md"),
631 "# Example\n```\n@should/not/import.md\n```\n@should/import.md",
632 )
633 .await
634 .unwrap();
635
636 fs::write(
637 dir.path().join("should").join("import.md"),
638 "This is imported",
639 )
640 .await
641 .ok();
642 let should_dir = dir.path().join("should");
643 fs::create_dir_all(&should_dir).await.unwrap();
644 fs::write(should_dir.join("import.md"), "Imported content")
645 .await
646 .unwrap();
647
648 let loader = MemoryLoader::new();
649 let content = loader.load(dir.path()).await.unwrap();
650 let combined = content.combined_claude_md();
651
652 assert!(combined.contains("Imported content"));
653 assert!(combined.contains("@should/not/import.md"));
655 }
656
657 #[tokio::test]
658 async fn test_missing_import_ignored() {
659 let dir = tempdir().unwrap();
660
661 fs::write(
662 dir.path().join("CLAUDE.md"),
663 "# Main\n@nonexistent/file.md\nRest of content",
664 )
665 .await
666 .unwrap();
667
668 let loader = MemoryLoader::new();
669 let content = loader.load(dir.path()).await.unwrap();
670 let combined = content.combined_claude_md();
671
672 assert!(combined.contains("# Main"));
674 assert!(combined.contains("Rest of content"));
675 }
676
677 #[tokio::test]
678 async fn test_empty_content() {
679 let dir = tempdir().unwrap();
680
681 let loader = MemoryLoader::new();
682 let content = loader.load(dir.path()).await.unwrap();
683
684 assert!(content.is_empty());
685 assert!(content.combined_claude_md().is_empty());
686 }
687
688 #[tokio::test]
689 async fn test_memory_content_merge() {
690 let mut content1 = MemoryContent {
691 claude_md: vec!["content1".to_string()],
692 local_md: vec!["local1".to_string()],
693 rule_indices: vec![],
694 };
695
696 let content2 = MemoryContent {
697 claude_md: vec!["content2".to_string()],
698 local_md: vec!["local2".to_string()],
699 rule_indices: vec![],
700 };
701
702 content1.merge(content2);
703
704 assert_eq!(content1.claude_md.len(), 2);
705 assert_eq!(content1.local_md.len(), 2);
706 }
707}