Skip to main content

mermaid_cli/context/
context.rs

1/// Unified context management
2///
3/// Single module that handles:
4/// - File collection and caching
5/// - Change detection for dynamic reloading
6/// - Token counting and content loading
7/// - Project context building
8
9use anyhow::Result;
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use super::file_collector::{CollectorConfig, FileCollector};
17use super::token_counter::TokenCounter;
18use crate::models::{LazyProjectContext, ProjectContext};
19use crate::utils::MutexExt;
20
21// Default file extensions to prioritize
22const DEFAULT_PRIORITY_EXTENSIONS: &[&str] = &[
23    "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "cpp", "c", "h", "hpp", "cs", "rb", "php",
24    "swift", "kt", "scala", "r", "sql", "sh", "yaml", "yml", "toml", "json", "xml", "html", "css",
25    "scss", "md", "txt",
26];
27
28// Default patterns to ignore
29const DEFAULT_IGNORE_PATTERNS: &[&str] = &[
30    "*.log", "*.tmp", "*.cache", "*.pyc", "*.pyo", "*.pyd", "*.so", "*.dylib", "*.dll", "*.exe",
31    "*.o", "*.a", "*.lib", "*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp", "*.ico", "*.svg", "*.pdf",
32    "*.zip", "*.tar", "*.gz", "*.rar", "*.7z",
33];
34
35/// Configuration for context loading
36#[derive(Debug, Clone)]
37pub struct ContextConfig {
38    /// Maximum file size to load (in bytes)
39    pub max_file_size: usize,
40    /// Maximum number of files to include
41    pub max_files: usize,
42    /// Maximum total context size in tokens
43    pub max_context_tokens: usize,
44    /// File extensions to prioritize
45    pub priority_extensions: Vec<&'static str>,
46    /// Additional patterns to ignore
47    pub ignore_patterns: Vec<&'static str>,
48}
49
50impl Default for ContextConfig {
51    fn default() -> Self {
52        Self {
53            max_file_size: 1024 * 1024, // 1MB
54            max_files: 100,
55            max_context_tokens: 50000,
56            priority_extensions: DEFAULT_PRIORITY_EXTENSIONS.to_vec(),
57            ignore_patterns: DEFAULT_IGNORE_PATTERNS.to_vec(),
58        }
59    }
60}
61
62/// Thread-safe state for tracking loading progress
63#[derive(Debug, Clone)]
64struct LoadingState {
65    files_loaded: usize,
66    tokens_used: usize,
67}
68
69impl LoadingState {
70    fn new() -> Self {
71        Self {
72            files_loaded: 0,
73            tokens_used: 0,
74        }
75    }
76
77    /// Check and update counters atomically
78    /// Returns true if the file should be processed (limits not exceeded)
79    fn try_add_file(&mut self, tokens: usize, max_files: usize, max_tokens: usize) -> bool {
80        if self.files_loaded >= max_files {
81            return false;
82        }
83
84        if self.tokens_used + tokens > max_tokens {
85            return false;
86        }
87
88        self.files_loaded += 1;
89        self.tokens_used += tokens;
90        true
91    }
92}
93
94/// Unified context manager for project files
95///
96/// Combines file collection, change detection, and content loading into a single interface.
97#[derive(Clone)]
98pub struct Context {
99    /// Root path of the project
100    root_path: PathBuf,
101    /// Configuration
102    config: ContextConfig,
103    /// Cache manager for token caching
104    cache_manager: Option<Arc<crate::cache::CacheManager>>,
105    /// Last computed hash of the file tree
106    last_file_hash: Option<u64>,
107    /// Last time context was loaded
108    last_load_time: Option<u64>,
109    /// Cached file list from last load
110    cached_files: Vec<PathBuf>,
111}
112
113impl std::fmt::Debug for Context {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("Context")
116            .field("root_path", &self.root_path)
117            .field("config", &self.config)
118            .field("last_file_hash", &self.last_file_hash)
119            .field("last_load_time", &self.last_load_time)
120            .field("cached_files", &self.cached_files.len())
121            .finish()
122    }
123}
124
125impl Context {
126    /// Create a new context manager for the given project path
127    pub fn new(root_path: impl AsRef<Path>) -> Result<Self> {
128        Ok(Self {
129            root_path: root_path.as_ref().to_path_buf(),
130            config: ContextConfig::default(),
131            cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
132            last_file_hash: None,
133            last_load_time: None,
134            cached_files: Vec::new(),
135        })
136    }
137
138    /// Create with custom config
139    pub fn with_config(root_path: impl AsRef<Path>, config: ContextConfig) -> Result<Self> {
140        Ok(Self {
141            root_path: root_path.as_ref().to_path_buf(),
142            config,
143            cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
144            last_file_hash: None,
145            last_load_time: None,
146            cached_files: Vec::new(),
147        })
148    }
149
150    /// Load full project context from a path (one-shot, static method)
151    ///
152    /// This is the main entry point for loading a complete project context
153    /// with file contents and token counting.
154    pub async fn load(root_path: impl AsRef<Path>) -> Result<ProjectContext> {
155        let ctx = Self::new(&root_path)?;
156        ctx.load_full_context().await
157    }
158
159    /// Load full context with file contents and token counting
160    pub async fn load_full_context(&self) -> Result<ProjectContext> {
161        let mut context = ProjectContext::new(self.root_path.to_string_lossy().to_string());
162
163        // Collect files
164        let collector = self.create_collector();
165        let files = collector.collect_files(&self.root_path).await?;
166
167        // Use Mutex-protected state for thread-safe tracking
168        let loading_state = Arc::new(Mutex::new(LoadingState::new()));
169        let token_counter = TokenCounter::new(self.cache_manager.clone());
170
171        // Configuration for convenient access
172        let max_files = self.config.max_files;
173        let max_tokens = self.config.max_context_tokens;
174
175        // Process files sequentially (I/O bound, parallelism overhead not worth it)
176        let loaded_contents: Vec<(String, String, usize)> = files
177            .iter()
178            .filter_map(|file_path| {
179                // Determine token budget for this file
180                let remaining_budget = {
181                    let state = loading_state.lock_mut_safe();
182                    max_tokens.saturating_sub(state.tokens_used)
183                };
184
185                if remaining_budget == 0 {
186                    return None;
187                }
188
189                // Load file with caching
190                let (content, tokens) = token_counter
191                    .load_file_cached(file_path, remaining_budget)
192                    .ok()?;
193
194                // Try to add file with mutex protection
195                let mut state = loading_state.lock_mut_safe();
196                if !state.try_add_file(tokens, max_files, max_tokens) {
197                    return None;
198                }
199
200                let relative_path = file_path
201                    .strip_prefix(&self.root_path)
202                    .unwrap_or(file_path)
203                    .to_string_lossy()
204                    .replace('\\', "/");  // Normalize to forward slashes for cross-platform consistency
205
206                Some((relative_path, content, tokens))
207            })
208            .collect();
209
210        // Add all loaded files to context
211        let mut actual_total_tokens = 0;
212        for (path, content, tokens) in loaded_contents {
213            context.add_file(path, content);
214            actual_total_tokens += tokens;
215        }
216
217        context.token_count = actual_total_tokens;
218
219        Ok(context)
220    }
221
222    /// Load only project structure (file paths, no contents) - fast
223    pub async fn load_structure(&self) -> Result<LazyProjectContext> {
224        let collector = self.create_collector();
225        let files = collector.collect_files(&self.root_path).await?;
226
227        let lazy_context =
228            LazyProjectContext::new(self.root_path.to_string_lossy().to_string(), files);
229
230        Ok(lazy_context)
231    }
232
233    /// Check if the file tree has changed since last load
234    pub async fn needs_reload(&self) -> bool {
235        match self.compute_file_hash().await {
236            Ok(current_hash) => {
237                if let Some(last_hash) = self.last_file_hash {
238                    current_hash != last_hash
239                } else {
240                    // Never loaded before, needs initial load
241                    true
242                }
243            }
244            Err(_) => false, // Error computing hash, don't reload
245        }
246    }
247
248    /// Reload the project context if needed
249    pub async fn reload_if_needed(&mut self) -> Result<bool> {
250        if self.needs_reload().await {
251            self.reload().await?;
252            Ok(true)
253        } else {
254            Ok(false)
255        }
256    }
257
258    /// Force a reload of the project structure (file paths only)
259    pub async fn reload(&mut self) -> Result<()> {
260        // Collect files from the project
261        let collector = self.create_collector();
262        let files = collector.collect_files(&self.root_path).await?;
263
264        // Compute hash from the files we just collected (avoid re-scanning)
265        let hash = self.compute_hash_from_files(&files)?;
266
267        // Update cached state
268        self.cached_files = files;
269        self.last_file_hash = Some(hash);
270        self.last_load_time = Some(
271            SystemTime::now()
272                .duration_since(UNIX_EPOCH)
273                .unwrap_or_default()
274                .as_secs(),
275        );
276
277        Ok(())
278    }
279
280    /// Build a ProjectContext with the current file tree (paths only, no contents)
281    ///
282    /// This creates a context with:
283    /// - Root path
284    /// - Complete file tree (file paths only, not contents)
285    /// - No file contents loaded (those load on demand)
286    pub fn build_context(&self) -> ProjectContext {
287        let mut context = ProjectContext::new(self.root_path.to_string_lossy().to_string());
288
289        // Add all file paths to context (for file tree structure in prompt)
290        for file_path in &self.cached_files {
291            if let Ok(rel_path) = file_path.strip_prefix(&self.root_path) {
292                if let Some(path_str) = rel_path.to_str() {
293                    // Add file path (empty content, just for tree structure)
294                    context.add_file(path_str.to_string(), String::new());
295                }
296            }
297        }
298
299        context
300    }
301
302    /// Get the list of currently cached file paths
303    pub fn get_file_list(&self) -> Vec<String> {
304        self.cached_files
305            .iter()
306            .filter_map(|p| {
307                p.strip_prefix(&self.root_path)
308                    .ok()
309                    .and_then(|p| p.to_str())
310                    .map(|s| s.to_string())
311            })
312            .collect()
313    }
314
315    /// Get the total number of files in the project
316    pub fn total_files(&self) -> usize {
317        self.cached_files.len()
318    }
319
320    /// Create a file collector with current config
321    fn create_collector(&self) -> FileCollector {
322        let collector_config = CollectorConfig {
323            max_file_size: self.config.max_file_size,
324            max_files: self.config.max_files,
325            priority_extensions: self.config.priority_extensions.clone(),
326            ignore_patterns: self.config.ignore_patterns.clone(),
327        };
328        FileCollector::new(collector_config)
329    }
330
331    /// Compute a hash of the current file tree for change detection
332    async fn compute_file_hash(&self) -> Result<u64> {
333        let collector = self.create_collector();
334        let current_files = collector.collect_files(&self.root_path).await?;
335        self.compute_hash_from_files(&current_files)
336    }
337
338    /// Compute hash from a given list of files
339    fn compute_hash_from_files(&self, files: &[PathBuf]) -> Result<u64> {
340        let mut hasher = DefaultHasher::new();
341
342        // Hash all file paths (sorted for consistency)
343        let mut file_paths: Vec<_> = files
344            .iter()
345            .filter_map(|p| {
346                p.strip_prefix(&self.root_path)
347                    .ok()
348                    .and_then(|p| p.to_str())
349            })
350            .collect();
351        file_paths.sort();
352
353        for path in file_paths {
354            path.hash(&mut hasher);
355        }
356
357        Ok(hasher.finish())
358    }
359}
360
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use std::fs;
366    use std::fs::File;
367    use std::io::Write;
368    use tempfile::TempDir;
369
370    #[tokio::test]
371    async fn test_context_creation() {
372        let temp_dir = TempDir::new().unwrap();
373        let ctx = Context::new(temp_dir.path()).unwrap();
374
375        assert_eq!(ctx.root_path, temp_dir.path());
376        assert_eq!(ctx.total_files(), 0);
377        assert!(ctx.needs_reload().await);
378    }
379
380    #[tokio::test]
381    async fn test_file_tree_change_detection() {
382        let temp_dir = TempDir::new().unwrap();
383        let mut ctx = Context::new(temp_dir.path()).unwrap();
384
385        // Initial load
386        ctx.reload().await.unwrap();
387        let initial_hash = ctx.last_file_hash;
388
389        // No changes - should not need reload
390        assert!(!ctx.needs_reload().await);
391
392        // Add a file - should need reload
393        let test_file = temp_dir.path().join("test.py");
394        fs::write(&test_file, "print('test')").unwrap();
395
396        assert!(ctx.needs_reload().await);
397
398        // Reload and verify hash changed
399        ctx.reload().await.unwrap();
400        assert_ne!(ctx.last_file_hash, initial_hash);
401    }
402
403    #[tokio::test]
404    async fn test_project_context_building() {
405        let temp_dir = TempDir::new().unwrap();
406
407        // Create some test files
408        fs::write(temp_dir.path().join("main.py"), "print('hello')").unwrap();
409        fs::write(temp_dir.path().join("lib.py"), "def helper(): pass").unwrap();
410        fs::write(temp_dir.path().join("requirements.txt"), "requests\n").unwrap();
411
412        let mut ctx = Context::new(temp_dir.path()).unwrap();
413        ctx.reload().await.unwrap();
414
415        let context = ctx.build_context();
416        assert_eq!(
417            context.root_path,
418            temp_dir.path().to_string_lossy().to_string()
419        );
420        assert_eq!(context.files.len(), 3);
421    }
422
423    #[tokio::test]
424    async fn test_load_full_context() {
425        let temp_dir = TempDir::new().unwrap();
426
427        // Create some test files
428        let mut cargo_file = File::create(temp_dir.path().join("Cargo.toml")).unwrap();
429        writeln!(cargo_file, "[package]\nname = \"test\"").unwrap();
430
431        let src_dir = temp_dir.path().join("src");
432        std::fs::create_dir(&src_dir).unwrap();
433
434        let mut main_file = File::create(src_dir.join("main.rs")).unwrap();
435        writeln!(main_file, "fn main() {{\n    println!(\"Hello\");\n}}").unwrap();
436
437        // Load full context (static method)
438        let context = Context::load(temp_dir.path()).await.unwrap();
439
440        assert!(context.files.contains_key("Cargo.toml"));
441        assert!(context.files.contains_key("src/main.rs"));
442        assert!(context.token_count > 0);
443    }
444
445    #[test]
446    fn test_loading_state_atomicity() {
447        let mut state = LoadingState::new();
448
449        assert!(state.try_add_file(10, 100, 1000));
450        assert_eq!(state.files_loaded, 1);
451        assert_eq!(state.tokens_used, 10);
452
453        state.files_loaded = 100;
454        assert!(!state.try_add_file(5, 100, 1000));
455        assert_eq!(state.files_loaded, 100);
456
457        let mut state2 = LoadingState::new();
458        state2.tokens_used = 990;
459        assert!(!state2.try_add_file(100, 100, 1000));
460        assert_eq!(state2.tokens_used, 990);
461    }
462
463    #[test]
464    fn test_concurrent_file_loading_safety() {
465        use std::thread;
466
467        let state = Arc::new(Mutex::new(LoadingState::new()));
468        let mut handles = vec![];
469
470        for _ in 0..10 {
471            let state_clone = Arc::clone(&state);
472            let handle = thread::spawn(move || {
473                let mut state = state_clone.lock().unwrap();
474                state.try_add_file(100, 100, 500)
475            });
476            handles.push(handle);
477        }
478
479        let results: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
480
481        assert_eq!(results.iter().filter(|&&r| r).count(), 5);
482        assert_eq!(results.iter().filter(|&&r| !r).count(), 5);
483
484        let final_state = state.lock().unwrap();
485        assert_eq!(final_state.files_loaded, 5);
486        assert_eq!(final_state.tokens_used, 500);
487    }
488}