1use anyhow::Result;
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use super::file_collector::{CollectorConfig, FileCollector};
17use super::token_counter::TokenCounter;
18use crate::models::{LazyProjectContext, ProjectContext};
19use crate::utils::MutexExt;
20
21const DEFAULT_PRIORITY_EXTENSIONS: &[&str] = &[
23 "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "cpp", "c", "h", "hpp", "cs", "rb", "php",
24 "swift", "kt", "scala", "r", "sql", "sh", "yaml", "yml", "toml", "json", "xml", "html", "css",
25 "scss", "md", "txt",
26];
27
28const DEFAULT_IGNORE_PATTERNS: &[&str] = &[
30 "*.log", "*.tmp", "*.cache", "*.pyc", "*.pyo", "*.pyd", "*.so", "*.dylib", "*.dll", "*.exe",
31 "*.o", "*.a", "*.lib", "*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp", "*.ico", "*.svg", "*.pdf",
32 "*.zip", "*.tar", "*.gz", "*.rar", "*.7z",
33];
34
35#[derive(Debug, Clone)]
37pub struct ContextConfig {
38 pub max_file_size: usize,
40 pub max_files: usize,
42 pub max_context_tokens: usize,
44 pub priority_extensions: Vec<&'static str>,
46 pub ignore_patterns: Vec<&'static str>,
48}
49
50impl Default for ContextConfig {
51 fn default() -> Self {
52 Self {
53 max_file_size: 1024 * 1024, max_files: 100,
55 max_context_tokens: 50000,
56 priority_extensions: DEFAULT_PRIORITY_EXTENSIONS.to_vec(),
57 ignore_patterns: DEFAULT_IGNORE_PATTERNS.to_vec(),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64struct LoadingState {
65 files_loaded: usize,
66 tokens_used: usize,
67}
68
69impl LoadingState {
70 fn new() -> Self {
71 Self {
72 files_loaded: 0,
73 tokens_used: 0,
74 }
75 }
76
77 fn try_add_file(&mut self, tokens: usize, max_files: usize, max_tokens: usize) -> bool {
80 if self.files_loaded >= max_files {
81 return false;
82 }
83
84 if self.tokens_used + tokens > max_tokens {
85 return false;
86 }
87
88 self.files_loaded += 1;
89 self.tokens_used += tokens;
90 true
91 }
92}
93
94#[derive(Clone)]
98pub struct Context {
99 root_path: PathBuf,
101 config: ContextConfig,
103 cache_manager: Option<Arc<crate::cache::CacheManager>>,
105 last_file_hash: Option<u64>,
107 last_load_time: Option<u64>,
109 cached_files: Vec<PathBuf>,
111}
112
113impl std::fmt::Debug for Context {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("Context")
116 .field("root_path", &self.root_path)
117 .field("config", &self.config)
118 .field("last_file_hash", &self.last_file_hash)
119 .field("last_load_time", &self.last_load_time)
120 .field("cached_files", &self.cached_files.len())
121 .finish()
122 }
123}
124
125impl Context {
126 pub fn new(root_path: impl AsRef<Path>) -> Result<Self> {
128 Ok(Self {
129 root_path: root_path.as_ref().to_path_buf(),
130 config: ContextConfig::default(),
131 cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
132 last_file_hash: None,
133 last_load_time: None,
134 cached_files: Vec::new(),
135 })
136 }
137
138 pub fn with_config(root_path: impl AsRef<Path>, config: ContextConfig) -> Result<Self> {
140 Ok(Self {
141 root_path: root_path.as_ref().to_path_buf(),
142 config,
143 cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
144 last_file_hash: None,
145 last_load_time: None,
146 cached_files: Vec::new(),
147 })
148 }
149
150 pub async fn load(root_path: impl AsRef<Path>) -> Result<ProjectContext> {
155 let ctx = Self::new(&root_path)?;
156 ctx.load_full_context().await
157 }
158
159 pub async fn load_full_context(&self) -> Result<ProjectContext> {
161 let mut context = ProjectContext::new(self.root_path.to_string_lossy().to_string());
162
163 let collector = self.create_collector();
165 let files = collector.collect_files(&self.root_path).await?;
166
167 let loading_state = Arc::new(Mutex::new(LoadingState::new()));
169 let token_counter = TokenCounter::new(self.cache_manager.clone());
170
171 let max_files = self.config.max_files;
173 let max_tokens = self.config.max_context_tokens;
174
175 let loaded_contents: Vec<(String, String, usize)> = files
177 .iter()
178 .filter_map(|file_path| {
179 let remaining_budget = {
181 let state = loading_state.lock_mut_safe();
182 max_tokens.saturating_sub(state.tokens_used)
183 };
184
185 if remaining_budget == 0 {
186 return None;
187 }
188
189 let (content, tokens) = token_counter
191 .load_file_cached(file_path, remaining_budget)
192 .ok()?;
193
194 let mut state = loading_state.lock_mut_safe();
196 if !state.try_add_file(tokens, max_files, max_tokens) {
197 return None;
198 }
199
200 let relative_path = file_path
201 .strip_prefix(&self.root_path)
202 .unwrap_or(file_path)
203 .to_string_lossy()
204 .replace('\\', "/"); Some((relative_path, content, tokens))
207 })
208 .collect();
209
210 let mut actual_total_tokens = 0;
212 for (path, content, tokens) in loaded_contents {
213 context.add_file(path, content);
214 actual_total_tokens += tokens;
215 }
216
217 context.token_count = actual_total_tokens;
218
219 Ok(context)
220 }
221
222 pub async fn load_structure(&self) -> Result<LazyProjectContext> {
224 let collector = self.create_collector();
225 let files = collector.collect_files(&self.root_path).await?;
226
227 let lazy_context =
228 LazyProjectContext::new(self.root_path.to_string_lossy().to_string(), files);
229
230 Ok(lazy_context)
231 }
232
233 pub async fn needs_reload(&self) -> bool {
235 match self.compute_file_hash().await {
236 Ok(current_hash) => {
237 if let Some(last_hash) = self.last_file_hash {
238 current_hash != last_hash
239 } else {
240 true
242 }
243 }
244 Err(_) => false, }
246 }
247
248 pub async fn reload_if_needed(&mut self) -> Result<bool> {
250 if self.needs_reload().await {
251 self.reload().await?;
252 Ok(true)
253 } else {
254 Ok(false)
255 }
256 }
257
258 pub async fn reload(&mut self) -> Result<()> {
260 let collector = self.create_collector();
262 let files = collector.collect_files(&self.root_path).await?;
263
264 let hash = self.compute_hash_from_files(&files)?;
266
267 self.cached_files = files;
269 self.last_file_hash = Some(hash);
270 self.last_load_time = Some(
271 SystemTime::now()
272 .duration_since(UNIX_EPOCH)
273 .unwrap_or_default()
274 .as_secs(),
275 );
276
277 Ok(())
278 }
279
280 pub fn build_context(&self) -> ProjectContext {
287 let mut context = ProjectContext::new(self.root_path.to_string_lossy().to_string());
288
289 for file_path in &self.cached_files {
291 if let Ok(rel_path) = file_path.strip_prefix(&self.root_path) {
292 if let Some(path_str) = rel_path.to_str() {
293 context.add_file(path_str.to_string(), String::new());
295 }
296 }
297 }
298
299 context
300 }
301
302 pub fn get_file_list(&self) -> Vec<String> {
304 self.cached_files
305 .iter()
306 .filter_map(|p| {
307 p.strip_prefix(&self.root_path)
308 .ok()
309 .and_then(|p| p.to_str())
310 .map(|s| s.to_string())
311 })
312 .collect()
313 }
314
315 pub fn total_files(&self) -> usize {
317 self.cached_files.len()
318 }
319
320 fn create_collector(&self) -> FileCollector {
322 let collector_config = CollectorConfig {
323 max_file_size: self.config.max_file_size,
324 max_files: self.config.max_files,
325 priority_extensions: self.config.priority_extensions.clone(),
326 ignore_patterns: self.config.ignore_patterns.clone(),
327 };
328 FileCollector::new(collector_config)
329 }
330
331 async fn compute_file_hash(&self) -> Result<u64> {
333 let collector = self.create_collector();
334 let current_files = collector.collect_files(&self.root_path).await?;
335 self.compute_hash_from_files(¤t_files)
336 }
337
338 fn compute_hash_from_files(&self, files: &[PathBuf]) -> Result<u64> {
340 let mut hasher = DefaultHasher::new();
341
342 let mut file_paths: Vec<_> = files
344 .iter()
345 .filter_map(|p| {
346 p.strip_prefix(&self.root_path)
347 .ok()
348 .and_then(|p| p.to_str())
349 })
350 .collect();
351 file_paths.sort();
352
353 for path in file_paths {
354 path.hash(&mut hasher);
355 }
356
357 Ok(hasher.finish())
358 }
359}
360
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use std::fs;
366 use std::fs::File;
367 use std::io::Write;
368 use tempfile::TempDir;
369
370 #[tokio::test]
371 async fn test_context_creation() {
372 let temp_dir = TempDir::new().unwrap();
373 let ctx = Context::new(temp_dir.path()).unwrap();
374
375 assert_eq!(ctx.root_path, temp_dir.path());
376 assert_eq!(ctx.total_files(), 0);
377 assert!(ctx.needs_reload().await);
378 }
379
380 #[tokio::test]
381 async fn test_file_tree_change_detection() {
382 let temp_dir = TempDir::new().unwrap();
383 let mut ctx = Context::new(temp_dir.path()).unwrap();
384
385 ctx.reload().await.unwrap();
387 let initial_hash = ctx.last_file_hash;
388
389 assert!(!ctx.needs_reload().await);
391
392 let test_file = temp_dir.path().join("test.py");
394 fs::write(&test_file, "print('test')").unwrap();
395
396 assert!(ctx.needs_reload().await);
397
398 ctx.reload().await.unwrap();
400 assert_ne!(ctx.last_file_hash, initial_hash);
401 }
402
403 #[tokio::test]
404 async fn test_project_context_building() {
405 let temp_dir = TempDir::new().unwrap();
406
407 fs::write(temp_dir.path().join("main.py"), "print('hello')").unwrap();
409 fs::write(temp_dir.path().join("lib.py"), "def helper(): pass").unwrap();
410 fs::write(temp_dir.path().join("requirements.txt"), "requests\n").unwrap();
411
412 let mut ctx = Context::new(temp_dir.path()).unwrap();
413 ctx.reload().await.unwrap();
414
415 let context = ctx.build_context();
416 assert_eq!(
417 context.root_path,
418 temp_dir.path().to_string_lossy().to_string()
419 );
420 assert_eq!(context.files.len(), 3);
421 }
422
423 #[tokio::test]
424 async fn test_load_full_context() {
425 let temp_dir = TempDir::new().unwrap();
426
427 let mut cargo_file = File::create(temp_dir.path().join("Cargo.toml")).unwrap();
429 writeln!(cargo_file, "[package]\nname = \"test\"").unwrap();
430
431 let src_dir = temp_dir.path().join("src");
432 std::fs::create_dir(&src_dir).unwrap();
433
434 let mut main_file = File::create(src_dir.join("main.rs")).unwrap();
435 writeln!(main_file, "fn main() {{\n println!(\"Hello\");\n}}").unwrap();
436
437 let context = Context::load(temp_dir.path()).await.unwrap();
439
440 assert!(context.files.contains_key("Cargo.toml"));
441 assert!(context.files.contains_key("src/main.rs"));
442 assert!(context.token_count > 0);
443 }
444
445 #[test]
446 fn test_loading_state_atomicity() {
447 let mut state = LoadingState::new();
448
449 assert!(state.try_add_file(10, 100, 1000));
450 assert_eq!(state.files_loaded, 1);
451 assert_eq!(state.tokens_used, 10);
452
453 state.files_loaded = 100;
454 assert!(!state.try_add_file(5, 100, 1000));
455 assert_eq!(state.files_loaded, 100);
456
457 let mut state2 = LoadingState::new();
458 state2.tokens_used = 990;
459 assert!(!state2.try_add_file(100, 100, 1000));
460 assert_eq!(state2.tokens_used, 990);
461 }
462
463 #[test]
464 fn test_concurrent_file_loading_safety() {
465 use std::thread;
466
467 let state = Arc::new(Mutex::new(LoadingState::new()));
468 let mut handles = vec![];
469
470 for _ in 0..10 {
471 let state_clone = Arc::clone(&state);
472 let handle = thread::spawn(move || {
473 let mut state = state_clone.lock().unwrap();
474 state.try_add_file(100, 100, 500)
475 });
476 handles.push(handle);
477 }
478
479 let results: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
480
481 assert_eq!(results.iter().filter(|&&r| r).count(), 5);
482 assert_eq!(results.iter().filter(|&&r| !r).count(), 5);
483
484 let final_state = state.lock().unwrap();
485 assert_eq!(final_state.files_loaded, 5);
486 assert_eq!(final_state.tokens_used, 500);
487 }
488}