use std::collections::HashSet;
use std::path::{Path, PathBuf};
use super::import_extractor::ImportExtractor;
use super::rule_index::RuleIndex;
use super::{ContextError, ContextResult};
pub(crate) const DEFAULT_IMPORT_DEPTH: usize = 2;
pub(crate) const MAX_IMPORT_DEPTH: usize = 5;
#[derive(Debug, Clone)]
pub struct MemoryLoaderConfig {
pub max_depth: usize,
}
impl Default for MemoryLoaderConfig {
fn default() -> Self {
Self {
max_depth: DEFAULT_IMPORT_DEPTH,
}
}
}
impl MemoryLoaderConfig {
pub fn full_expansion() -> Self {
Self {
max_depth: MAX_IMPORT_DEPTH,
}
}
pub fn max_depth(max_depth: usize) -> Self {
Self { max_depth }
}
}
pub struct MemoryLoader {
extractor: ImportExtractor,
config: MemoryLoaderConfig,
}
impl MemoryLoader {
pub fn new() -> Self {
Self::from_config(MemoryLoaderConfig::default())
}
pub fn from_config(config: MemoryLoaderConfig) -> Self {
Self {
extractor: ImportExtractor::new(),
config,
}
}
pub fn full_expansion() -> Self {
Self::from_config(MemoryLoaderConfig::full_expansion())
}
pub async fn load(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
let mut content = self.load_shared(start_dir).await?;
let local = self.load_local(start_dir).await?;
content.merge(local);
Ok(content)
}
pub async fn load_shared(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
let mut content = MemoryContent::default();
let mut visited = HashSet::new();
for path in Self::find_claude_files(start_dir) {
match self
.load_with_imports(&path, start_dir, 0, &mut visited)
.await
{
Ok(text) => content.claude_md.push(text),
Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
}
}
let rules_dir = start_dir.join(".claude").join("rules");
if rules_dir.exists() {
content.rule_indices = self.scan_rules(&rules_dir).await?;
}
Ok(content)
}
pub async fn load_local(&self, start_dir: &Path) -> ContextResult<MemoryContent> {
let mut content = MemoryContent::default();
let mut visited = HashSet::new();
for path in Self::find_local_files(start_dir) {
match self
.load_with_imports(&path, start_dir, 0, &mut visited)
.await
{
Ok(text) => content.local_md.push(text),
Err(e) => tracing::debug!("Failed to load {}: {}", path.display(), e),
}
}
Ok(content)
}
fn load_with_imports<'a>(
&'a self,
path: &'a Path,
project_root: &'a Path,
depth: usize,
visited: &'a mut HashSet<PathBuf>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<String>> + Send + 'a>>
{
Box::pin(async move {
if depth > self.config.max_depth {
tracing::warn!(
"Import depth limit ({}) exceeded, skipping: {}",
self.config.max_depth,
path.display()
);
return Ok(String::new());
}
let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
if visited.contains(&canonical) {
tracing::debug!("Circular import detected, skipping: {}", path.display());
return Ok(String::new());
}
visited.insert(canonical);
let content =
tokio::fs::read_to_string(path)
.await
.map_err(|e| ContextError::Source {
message: format!("Failed to read {}: {}", path.display(), e),
})?;
let file_dir = path.parent().unwrap_or(Path::new("."));
let imports = self.extractor.extract(&content, file_dir);
let imports: Vec<PathBuf> = imports
.into_iter()
.map(|p| Self::normalize_project_relative_path(&p, project_root))
.collect();
let mut result = content;
for import_path in imports {
if import_path.exists() {
if let Ok(imported) = self
.load_with_imports(&import_path, project_root, depth + 1, visited)
.await
&& !imported.is_empty()
{
result.push_str("\n\n");
result.push_str(&imported);
}
} else {
tracing::debug!("Import not found, skipping: {}", import_path.display());
}
}
Ok(result)
})
}
fn find_claude_files(start_dir: &Path) -> Vec<PathBuf> {
let mut files = Vec::new();
let claude_md = start_dir.join("CLAUDE.md");
if claude_md.exists() {
files.push(claude_md);
}
let claude_dir_md = start_dir.join(".claude").join("CLAUDE.md");
if claude_dir_md.exists() {
files.push(claude_dir_md);
}
files
}
fn find_local_files(start_dir: &Path) -> Vec<PathBuf> {
let mut files = Vec::new();
let local_md = start_dir.join("CLAUDE.local.md");
if local_md.exists() {
files.push(local_md);
}
let local_dir_md = start_dir.join(".claude").join("CLAUDE.local.md");
if local_dir_md.exists() {
files.push(local_dir_md);
}
files
}
fn normalize_project_relative_path(path: &Path, project_root: &Path) -> PathBuf {
const MARKERS: [&str; 2] = ["/.agents/", "/.claude/"];
let path_str = path.to_string_lossy();
let last_marker_pos = MARKERS
.iter()
.filter_map(|marker| {
let count = path_str.matches(marker).count();
if count > 1 || MARKERS.iter().filter(|m| path_str.contains(*m)).count() > 1 {
path_str.rfind(marker).map(|pos| (pos, *marker))
} else {
None
}
})
.max_by_key(|(pos, _)| *pos);
if let Some((idx, _)) = last_marker_pos {
let relative_part = &path_str[idx + 1..]; project_root.join(relative_part)
} else {
path.to_path_buf()
}
}
async fn scan_rules(&self, dir: &Path) -> ContextResult<Vec<RuleIndex>> {
let mut indices = Vec::new();
self.scan_rules_recursive(dir, &mut indices).await?;
indices.sort_by(|a, b| b.priority.cmp(&a.priority));
Ok(indices)
}
fn scan_rules_recursive<'a>(
&'a self,
dir: &'a Path,
indices: &'a mut Vec<RuleIndex>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ContextResult<()>> + Send + 'a>> {
Box::pin(async move {
let mut entries = tokio::fs::read_dir(dir)
.await
.map_err(|e| ContextError::Source {
message: format!("Failed to read rules directory: {}", e),
})?;
while let Some(entry) =
entries
.next_entry()
.await
.map_err(|e| ContextError::Source {
message: format!("Failed to read directory entry: {}", e),
})?
{
let path = entry.path();
if path.is_dir() {
self.scan_rules_recursive(&path, indices).await?;
} else if path.extension().is_some_and(|e| e == "md")
&& let Some(index) = RuleIndex::from_file(&path)
{
indices.push(index);
}
}
Ok(())
})
}
}
impl Default for MemoryLoader {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone)]
pub struct MemoryContent {
pub claude_md: Vec<String>,
pub local_md: Vec<String>,
pub rule_indices: Vec<RuleIndex>,
}
impl MemoryContent {
pub fn combined_claude_md(&self) -> String {
self.claude_md
.iter()
.chain(self.local_md.iter())
.filter(|c| !c.trim().is_empty())
.cloned()
.collect::<Vec<_>>()
.join("\n\n")
}
pub fn is_empty(&self) -> bool {
self.claude_md.is_empty() && self.local_md.is_empty() && self.rule_indices.is_empty()
}
pub fn merge(&mut self, other: MemoryContent) {
self.claude_md.extend(other.claude_md);
self.local_md.extend(other.local_md);
self.rule_indices.extend(other.rule_indices);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use tokio::fs;
#[tokio::test]
async fn test_load_claude_md() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "# Project\nTest content")
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
assert_eq!(content.claude_md.len(), 1);
assert!(content.claude_md[0].contains("Test content"));
}
#[tokio::test]
async fn test_load_local_md() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.local.md"), "# Local\nPrivate")
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
assert_eq!(content.local_md.len(), 1);
assert!(content.local_md[0].contains("Private"));
}
#[tokio::test]
async fn test_scan_rules_recursive() {
let dir = tempdir().unwrap();
let rules_dir = dir.path().join(".claude").join("rules");
let sub_dir = rules_dir.join("frontend");
fs::create_dir_all(&sub_dir).await.unwrap();
fs::write(
rules_dir.join("rust.md"),
"---\npaths: **/*.rs\npriority: 10\n---\n\n# Rust Rules",
)
.await
.unwrap();
fs::write(
sub_dir.join("react.md"),
"---\npaths: **/*.tsx\npriority: 5\n---\n\n# React Rules",
)
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
assert_eq!(content.rule_indices.len(), 2);
assert!(content.rule_indices.iter().any(|r| r.name == "rust"));
assert!(content.rule_indices.iter().any(|r| r.name == "react"));
}
#[tokio::test]
async fn test_import_syntax() {
let dir = tempdir().unwrap();
fs::write(
dir.path().join("CLAUDE.md"),
"# Main\n@docs/guidelines.md\nEnd",
)
.await
.unwrap();
let docs_dir = dir.path().join("docs");
fs::create_dir_all(&docs_dir).await.unwrap();
fs::write(docs_dir.join("guidelines.md"), "Imported content")
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
assert!(content.combined_claude_md().contains("Imported content"));
}
#[tokio::test]
async fn test_combined_includes_local() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "Main content")
.await
.unwrap();
fs::write(dir.path().join("CLAUDE.local.md"), "Local content")
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("Main content"));
assert!(combined.contains("Local content"));
}
#[tokio::test]
async fn test_recursive_import() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "Root content @docs/guide.md")
.await
.unwrap();
let docs_dir = dir.path().join("docs");
fs::create_dir_all(&docs_dir).await.unwrap();
fs::write(docs_dir.join("guide.md"), "Guide content @detail.md")
.await
.unwrap();
fs::write(docs_dir.join("detail.md"), "Detail content")
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("Root content"));
assert!(combined.contains("Guide content"));
assert!(combined.contains("Detail content"));
}
#[tokio::test]
async fn test_depth_limit_default() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
.await
.unwrap();
for i in 1..=3 {
let content = if i < 3 {
format!("Level {} @level{}.md", i, i + 1)
} else {
format!("Level {}", i)
};
fs::write(dir.path().join(format!("level{}.md", i)), content)
.await
.unwrap();
}
let loader = MemoryLoader::new(); let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("Level 0"));
assert!(combined.contains("Level 2"));
assert!(!combined.contains("Level 3"));
}
#[tokio::test]
async fn test_depth_limit_full_expansion() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
.await
.unwrap();
for i in 1..=6 {
let content = if i < 6 {
format!("Level {} @level{}.md", i, i + 1)
} else {
format!("Level {}", i)
};
fs::write(dir.path().join(format!("level{}.md", i)), content)
.await
.unwrap();
}
let loader = MemoryLoader::full_expansion(); let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("Level 0"));
assert!(combined.contains("Level 5"));
assert!(!combined.contains("Level 6"));
}
#[tokio::test]
async fn test_depth_limit_custom() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "Level 0 @level1.md")
.await
.unwrap();
for i in 1..=3 {
let content = if i < 3 {
format!("Level {} @level{}.md", i, i + 1)
} else {
format!("Level {}", i)
};
fs::write(dir.path().join(format!("level{}.md", i)), content)
.await
.unwrap();
}
let loader = MemoryLoader::from_config(MemoryLoaderConfig::max_depth(1));
let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("Level 0"));
assert!(combined.contains("Level 1"));
assert!(!combined.contains("Level 2"));
}
#[tokio::test]
async fn test_circular_import() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("CLAUDE.md"), "Root @a.md")
.await
.unwrap();
fs::write(dir.path().join("a.md"), "A content @b.md")
.await
.unwrap();
fs::write(dir.path().join("b.md"), "B content @a.md")
.await
.unwrap();
let loader = MemoryLoader::new();
let result = loader.load(dir.path()).await;
assert!(result.is_ok());
let combined = result.unwrap().combined_claude_md();
assert!(combined.contains("A content"));
assert!(combined.contains("B content"));
}
#[tokio::test]
async fn test_import_in_code_block_ignored() {
let dir = tempdir().unwrap();
fs::write(
dir.path().join("CLAUDE.md"),
"# Example\n```\n@should/not/import.md\n```\n@should/import.md",
)
.await
.unwrap();
fs::write(
dir.path().join("should").join("import.md"),
"This is imported",
)
.await
.ok();
let should_dir = dir.path().join("should");
fs::create_dir_all(&should_dir).await.unwrap();
fs::write(should_dir.join("import.md"), "Imported content")
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("Imported content"));
assert!(combined.contains("@should/not/import.md"));
}
#[tokio::test]
async fn test_missing_import_ignored() {
let dir = tempdir().unwrap();
fs::write(
dir.path().join("CLAUDE.md"),
"# Main\n@nonexistent/file.md\nRest of content",
)
.await
.unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
let combined = content.combined_claude_md();
assert!(combined.contains("# Main"));
assert!(combined.contains("Rest of content"));
}
#[tokio::test]
async fn test_empty_content() {
let dir = tempdir().unwrap();
let loader = MemoryLoader::new();
let content = loader.load(dir.path()).await.unwrap();
assert!(content.is_empty());
assert!(content.combined_claude_md().is_empty());
}
#[tokio::test]
async fn test_memory_content_merge() {
let mut content1 = MemoryContent {
claude_md: vec!["content1".to_string()],
local_md: vec!["local1".to_string()],
rule_indices: vec![],
};
let content2 = MemoryContent {
claude_md: vec!["content2".to_string()],
local_md: vec!["local2".to_string()],
rule_indices: vec![],
};
content1.merge(content2);
assert_eq!(content1.claude_md.len(), 2);
assert_eq!(content1.local_md.len(), 2);
}
}