mermaid_cli/context/
loader.rs

1/// Context loader orchestration
2///
3/// Orchestrates file collection, token counting, and project detection
4/// to build a complete project context.
5use anyhow::{Context, Result};
6use rayon::prelude::*;
7use std::sync::{Arc, Mutex};
8use tiktoken_rs::{cl100k_base, CoreBPE};
9
10use super::file_collector::{CollectorConfig, FileCollector};
11use super::project_detector::{FileLoader, ProjectDetector};
12use super::token_counter::TokenCounter;
13use crate::models::ProjectContext;
14use crate::utils::MutexExt;
15
16// Static string slices for configuration (zero-allocation)
17const DEFAULT_PRIORITY_EXTENSIONS: &[&str] = &[
18    "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "cpp", "c", "h", "hpp", "cs", "rb", "php",
19    "swift", "kt", "scala", "r", "sql", "sh", "yaml", "yml", "toml", "json", "xml", "html", "css",
20    "scss", "md", "txt",
21];
22
23const DEFAULT_IGNORE_PATTERNS: &[&str] = &[
24    "*.log", "*.tmp", "*.cache", "*.pyc", "*.pyo", "*.pyd", "*.so", "*.dylib", "*.dll", "*.exe",
25    "*.o", "*.a", "*.lib", "*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp", "*.ico", "*.svg", "*.pdf",
26    "*.zip", "*.tar", "*.gz", "*.rar", "*.7z",
27];
28
29/// Thread-safe state for tracking loading progress
30#[derive(Debug, Clone)]
31struct LoadingState {
32    files_loaded: usize,
33    tokens_used: usize,
34}
35
36impl LoadingState {
37    fn new() -> Self {
38        Self {
39            files_loaded: 0,
40            tokens_used: 0,
41        }
42    }
43
44    /// Check and update counters atomically
45    /// Returns true if the file should be processed (limits not exceeded)
46    fn try_add_file(&mut self, tokens: usize, max_files: usize, max_tokens: usize) -> bool {
47        if self.files_loaded >= max_files {
48            return false;
49        }
50
51        if self.tokens_used + tokens > max_tokens {
52            return false;
53        }
54
55        self.files_loaded += 1;
56        self.tokens_used += tokens;
57        true
58    }
59}
60
61/// Configuration for the context loader
62#[derive(Debug, Clone)]
63pub struct LoaderConfig {
64    /// Maximum file size to load (in bytes)
65    pub max_file_size: usize,
66    /// Maximum number of files to include
67    pub max_files: usize,
68    /// Maximum total context size in tokens
69    pub max_context_tokens: usize,
70    /// File extensions to prioritize
71    pub priority_extensions: Vec<&'static str>,
72    /// Additional patterns to ignore
73    pub ignore_patterns: Vec<&'static str>,
74}
75
76impl Default for LoaderConfig {
77    fn default() -> Self {
78        Self {
79            max_file_size: 1024 * 1024, // 1MB
80            max_files: 100,
81            max_context_tokens: 50000,
82            priority_extensions: DEFAULT_PRIORITY_EXTENSIONS.to_vec(),
83            ignore_patterns: DEFAULT_IGNORE_PATTERNS.to_vec(),
84        }
85    }
86}
87
88/// Loads project context from the filesystem
89pub struct ContextLoader {
90    config: LoaderConfig,
91    tokenizer: CoreBPE,
92    cache_manager: Option<Arc<crate::cache::CacheManager>>,
93}
94
95impl ContextLoader {
96    /// Create a new context loader with default config
97    pub fn new() -> Result<Self> {
98        Ok(Self {
99            config: LoaderConfig::default(),
100            tokenizer: cl100k_base()?,
101            cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
102        })
103    }
104
105    /// Create with custom config
106    pub fn with_config(config: LoaderConfig) -> Result<Self> {
107        Ok(Self {
108            config,
109            tokenizer: cl100k_base()?,
110            cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
111        })
112    }
113
114    /// Load project context from the given path (alias for compatibility)
115    pub async fn load(&self, root_path: &std::path::Path) -> Result<ProjectContext> {
116        self.load_context(root_path).await
117    }
118
119    /// Load only the project structure without file contents (fast)
120    pub async fn load_structure(
121        &self,
122        root_path: &std::path::Path,
123    ) -> Result<crate::models::LazyProjectContext> {
124        let collector_config = CollectorConfig {
125            max_file_size: self.config.max_file_size,
126            max_files: self.config.max_files,
127            priority_extensions: self.config.priority_extensions.clone(),
128            ignore_patterns: self.config.ignore_patterns.clone(),
129        };
130        let collector = FileCollector::new(collector_config);
131        let files = collector.collect_files(root_path).await?;
132
133        let lazy_context =
134            crate::models::LazyProjectContext::new(root_path.to_string_lossy().to_string(), files);
135
136        Ok(lazy_context)
137    }
138
139    /// Load project context from the given path
140    pub async fn load_context(&self, root_path: &std::path::Path) -> Result<ProjectContext> {
141        let mut context = ProjectContext::new(root_path.to_string_lossy().to_string());
142
143        // Detect project type
144        context.project_type = ProjectDetector::detect_project_type(root_path);
145
146        // Collect files
147        let collector_config = CollectorConfig {
148            max_file_size: self.config.max_file_size,
149            max_files: self.config.max_files,
150            priority_extensions: self.config.priority_extensions.clone(),
151            ignore_patterns: self.config.ignore_patterns.clone(),
152        };
153        let collector = FileCollector::new(collector_config);
154        let files = collector.collect_files(root_path).await?;
155
156        // Use Mutex-protected state for thread-safe tracking
157        let loading_state = Arc::new(Mutex::new(LoadingState::new()));
158        let token_counter = TokenCounter::new(self.tokenizer.clone(), self.cache_manager.clone());
159
160        // Configuration for convenient access
161        let max_files = self.config.max_files;
162        let max_tokens = self.config.max_context_tokens;
163
164        // Process files in parallel
165        let loaded_contents: Vec<(String, String, usize)> = files
166            .par_iter()
167            .filter_map(|file_path| {
168                // Determine token budget for this file
169                let remaining_budget = {
170                    let state = loading_state.lock_mut_safe();
171                    max_tokens.saturating_sub(state.tokens_used)
172                };
173
174                if remaining_budget == 0 {
175                    return None;
176                }
177
178                // Load file with caching
179                let (content, tokens) = token_counter
180                    .load_file_cached(file_path, remaining_budget)
181                    .ok()?;
182
183                // Try to add file with mutex protection
184                let mut state = loading_state.lock_mut_safe();
185                if !state.try_add_file(tokens, max_files, max_tokens) {
186                    return None;
187                }
188
189                let relative_path = file_path
190                    .strip_prefix(root_path)
191                    .unwrap_or(file_path)
192                    .to_string_lossy()
193                    .to_string();
194
195                Some((relative_path, content, tokens))
196            })
197            .collect();
198
199        // Add all loaded files to context
200        let mut actual_total_tokens = 0;
201        for (path, content, tokens) in loaded_contents {
202            context.add_file(path, content);
203            actual_total_tokens += tokens;
204        }
205
206        context.token_count = actual_total_tokens;
207
208        // Auto-include important files
209        ProjectDetector::auto_include_important_files(&mut context, root_path, self);
210
211        Ok(context)
212    }
213}
214
215impl FileLoader for ContextLoader {
216    fn load_file(&self, path: &std::path::Path) -> Result<String> {
217        std::fs::read_to_string(path)
218            .with_context(|| format!("Failed to read file: {}", path.display()))
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use std::fs::File;
226    use std::io::Write;
227    use tempfile::TempDir;
228
229    #[test]
230    fn test_detect_project_type() {
231        let temp_dir = TempDir::new().unwrap();
232        let loader = ContextLoader::new().unwrap();
233
234        // Test Rust project
235        File::create(temp_dir.path().join("Cargo.toml")).unwrap();
236        assert_eq!(
237            ProjectDetector::detect_project_type(temp_dir.path()),
238            Some("rust".to_string())
239        );
240
241        // Test Python project
242        File::create(temp_dir.path().join("requirements.txt")).unwrap();
243        assert_eq!(
244            ProjectDetector::detect_project_type(temp_dir.path()),
245            Some("rust".to_string()) // Cargo.toml takes precedence
246        );
247    }
248
249    #[tokio::test]
250    async fn test_load_context() {
251        let temp_dir = TempDir::new().unwrap();
252        let loader = ContextLoader::new().unwrap();
253
254        // Create some test files
255        let mut cargo_file = File::create(temp_dir.path().join("Cargo.toml")).unwrap();
256        writeln!(cargo_file, "[package]\nname = \"test\"").unwrap();
257
258        let src_dir = temp_dir.path().join("src");
259        std::fs::create_dir(&src_dir).unwrap();
260
261        let mut main_file = File::create(src_dir.join("main.rs")).unwrap();
262        writeln!(main_file, "fn main() {{\n    println!(\"Hello\");\n}}").unwrap();
263
264        // Load context
265        let context = loader.load_context(temp_dir.path()).await.unwrap();
266
267        assert_eq!(context.project_type, Some("rust".to_string()));
268        assert!(context.files.contains_key("Cargo.toml"));
269        assert!(context.files.contains_key("src/main.rs"));
270        assert!(context.token_count > 0);
271    }
272
273    #[test]
274    fn test_loading_state_atomicity() {
275        let mut state = LoadingState::new();
276
277        assert!(state.try_add_file(10, 100, 1000));
278        assert_eq!(state.files_loaded, 1);
279        assert_eq!(state.tokens_used, 10);
280
281        state.files_loaded = 100;
282        assert!(!state.try_add_file(5, 100, 1000));
283        assert_eq!(state.files_loaded, 100);
284
285        let mut state2 = LoadingState::new();
286        state2.tokens_used = 990;
287        assert!(!state2.try_add_file(100, 100, 1000));
288        assert_eq!(state2.tokens_used, 990);
289    }
290
291    #[test]
292    fn test_concurrent_file_loading_safety() {
293        use std::thread;
294
295        let state = Arc::new(Mutex::new(LoadingState::new()));
296        let mut handles = vec![];
297
298        for _ in 0..10 {
299            let state_clone = Arc::clone(&state);
300            let handle = thread::spawn(move || {
301                let mut state = state_clone.lock().unwrap();
302                state.try_add_file(100, 100, 500)
303            });
304            handles.push(handle);
305        }
306
307        let results: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
308
309        assert_eq!(results.iter().filter(|&&r| r).count(), 5);
310        assert_eq!(results.iter().filter(|&&r| !r).count(), 5);
311
312        let final_state = state.lock().unwrap();
313        assert_eq!(final_state.files_loaded, 5);
314        assert_eq!(final_state.tokens_used, 500);
315    }
316}