1use anyhow::{Context, Result};
6use rayon::prelude::*;
7use std::sync::{Arc, Mutex};
8use tiktoken_rs::{cl100k_base, CoreBPE};
9
10use super::file_collector::{CollectorConfig, FileCollector};
11use super::project_detector::{FileLoader, ProjectDetector};
12use super::token_counter::TokenCounter;
13use crate::models::ProjectContext;
14use crate::utils::MutexExt;
15
16const DEFAULT_PRIORITY_EXTENSIONS: &[&str] = &[
18 "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "cpp", "c", "h", "hpp", "cs", "rb", "php",
19 "swift", "kt", "scala", "r", "sql", "sh", "yaml", "yml", "toml", "json", "xml", "html", "css",
20 "scss", "md", "txt",
21];
22
23const DEFAULT_IGNORE_PATTERNS: &[&str] = &[
24 "*.log", "*.tmp", "*.cache", "*.pyc", "*.pyo", "*.pyd", "*.so", "*.dylib", "*.dll", "*.exe",
25 "*.o", "*.a", "*.lib", "*.png", "*.jpg", "*.jpeg", "*.gif", "*.bmp", "*.ico", "*.svg", "*.pdf",
26 "*.zip", "*.tar", "*.gz", "*.rar", "*.7z",
27];
28
29#[derive(Debug, Clone)]
31struct LoadingState {
32 files_loaded: usize,
33 tokens_used: usize,
34}
35
36impl LoadingState {
37 fn new() -> Self {
38 Self {
39 files_loaded: 0,
40 tokens_used: 0,
41 }
42 }
43
44 fn try_add_file(&mut self, tokens: usize, max_files: usize, max_tokens: usize) -> bool {
47 if self.files_loaded >= max_files {
48 return false;
49 }
50
51 if self.tokens_used + tokens > max_tokens {
52 return false;
53 }
54
55 self.files_loaded += 1;
56 self.tokens_used += tokens;
57 true
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct LoaderConfig {
64 pub max_file_size: usize,
66 pub max_files: usize,
68 pub max_context_tokens: usize,
70 pub priority_extensions: Vec<&'static str>,
72 pub ignore_patterns: Vec<&'static str>,
74}
75
76impl Default for LoaderConfig {
77 fn default() -> Self {
78 Self {
79 max_file_size: 1024 * 1024, max_files: 100,
81 max_context_tokens: 50000,
82 priority_extensions: DEFAULT_PRIORITY_EXTENSIONS.to_vec(),
83 ignore_patterns: DEFAULT_IGNORE_PATTERNS.to_vec(),
84 }
85 }
86}
87
88pub struct ContextLoader {
90 config: LoaderConfig,
91 tokenizer: CoreBPE,
92 cache_manager: Option<Arc<crate::cache::CacheManager>>,
93}
94
95impl ContextLoader {
96 pub fn new() -> Result<Self> {
98 Ok(Self {
99 config: LoaderConfig::default(),
100 tokenizer: cl100k_base()?,
101 cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
102 })
103 }
104
105 pub fn with_config(config: LoaderConfig) -> Result<Self> {
107 Ok(Self {
108 config,
109 tokenizer: cl100k_base()?,
110 cache_manager: crate::cache::CacheManager::new().ok().map(Arc::new),
111 })
112 }
113
114 pub async fn load(&self, root_path: &std::path::Path) -> Result<ProjectContext> {
116 self.load_context(root_path).await
117 }
118
119 pub async fn load_structure(
121 &self,
122 root_path: &std::path::Path,
123 ) -> Result<crate::models::LazyProjectContext> {
124 let collector_config = CollectorConfig {
125 max_file_size: self.config.max_file_size,
126 max_files: self.config.max_files,
127 priority_extensions: self.config.priority_extensions.clone(),
128 ignore_patterns: self.config.ignore_patterns.clone(),
129 };
130 let collector = FileCollector::new(collector_config);
131 let files = collector.collect_files(root_path).await?;
132
133 let lazy_context =
134 crate::models::LazyProjectContext::new(root_path.to_string_lossy().to_string(), files);
135
136 Ok(lazy_context)
137 }
138
139 pub async fn load_context(&self, root_path: &std::path::Path) -> Result<ProjectContext> {
141 let mut context = ProjectContext::new(root_path.to_string_lossy().to_string());
142
143 context.project_type = ProjectDetector::detect_project_type(root_path);
145
146 let collector_config = CollectorConfig {
148 max_file_size: self.config.max_file_size,
149 max_files: self.config.max_files,
150 priority_extensions: self.config.priority_extensions.clone(),
151 ignore_patterns: self.config.ignore_patterns.clone(),
152 };
153 let collector = FileCollector::new(collector_config);
154 let files = collector.collect_files(root_path).await?;
155
156 let loading_state = Arc::new(Mutex::new(LoadingState::new()));
158 let token_counter = TokenCounter::new(self.tokenizer.clone(), self.cache_manager.clone());
159
160 let max_files = self.config.max_files;
162 let max_tokens = self.config.max_context_tokens;
163
164 let loaded_contents: Vec<(String, String, usize)> = files
166 .par_iter()
167 .filter_map(|file_path| {
168 let remaining_budget = {
170 let state = loading_state.lock_mut_safe();
171 max_tokens.saturating_sub(state.tokens_used)
172 };
173
174 if remaining_budget == 0 {
175 return None;
176 }
177
178 let (content, tokens) = token_counter
180 .load_file_cached(file_path, remaining_budget)
181 .ok()?;
182
183 let mut state = loading_state.lock_mut_safe();
185 if !state.try_add_file(tokens, max_files, max_tokens) {
186 return None;
187 }
188
189 let relative_path = file_path
190 .strip_prefix(root_path)
191 .unwrap_or(file_path)
192 .to_string_lossy()
193 .to_string();
194
195 Some((relative_path, content, tokens))
196 })
197 .collect();
198
199 let mut actual_total_tokens = 0;
201 for (path, content, tokens) in loaded_contents {
202 context.add_file(path, content);
203 actual_total_tokens += tokens;
204 }
205
206 context.token_count = actual_total_tokens;
207
208 ProjectDetector::auto_include_important_files(&mut context, root_path, self);
210
211 Ok(context)
212 }
213}
214
215impl FileLoader for ContextLoader {
216 fn load_file(&self, path: &std::path::Path) -> Result<String> {
217 std::fs::read_to_string(path)
218 .with_context(|| format!("Failed to read file: {}", path.display()))
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use std::fs::File;
226 use std::io::Write;
227 use tempfile::TempDir;
228
229 #[test]
230 fn test_detect_project_type() {
231 let temp_dir = TempDir::new().unwrap();
232 let loader = ContextLoader::new().unwrap();
233
234 File::create(temp_dir.path().join("Cargo.toml")).unwrap();
236 assert_eq!(
237 ProjectDetector::detect_project_type(temp_dir.path()),
238 Some("rust".to_string())
239 );
240
241 File::create(temp_dir.path().join("requirements.txt")).unwrap();
243 assert_eq!(
244 ProjectDetector::detect_project_type(temp_dir.path()),
245 Some("rust".to_string()) );
247 }
248
249 #[tokio::test]
250 async fn test_load_context() {
251 let temp_dir = TempDir::new().unwrap();
252 let loader = ContextLoader::new().unwrap();
253
254 let mut cargo_file = File::create(temp_dir.path().join("Cargo.toml")).unwrap();
256 writeln!(cargo_file, "[package]\nname = \"test\"").unwrap();
257
258 let src_dir = temp_dir.path().join("src");
259 std::fs::create_dir(&src_dir).unwrap();
260
261 let mut main_file = File::create(src_dir.join("main.rs")).unwrap();
262 writeln!(main_file, "fn main() {{\n println!(\"Hello\");\n}}").unwrap();
263
264 let context = loader.load_context(temp_dir.path()).await.unwrap();
266
267 assert_eq!(context.project_type, Some("rust".to_string()));
268 assert!(context.files.contains_key("Cargo.toml"));
269 assert!(context.files.contains_key("src/main.rs"));
270 assert!(context.token_count > 0);
271 }
272
273 #[test]
274 fn test_loading_state_atomicity() {
275 let mut state = LoadingState::new();
276
277 assert!(state.try_add_file(10, 100, 1000));
278 assert_eq!(state.files_loaded, 1);
279 assert_eq!(state.tokens_used, 10);
280
281 state.files_loaded = 100;
282 assert!(!state.try_add_file(5, 100, 1000));
283 assert_eq!(state.files_loaded, 100);
284
285 let mut state2 = LoadingState::new();
286 state2.tokens_used = 990;
287 assert!(!state2.try_add_file(100, 100, 1000));
288 assert_eq!(state2.tokens_used, 990);
289 }
290
291 #[test]
292 fn test_concurrent_file_loading_safety() {
293 use std::thread;
294
295 let state = Arc::new(Mutex::new(LoadingState::new()));
296 let mut handles = vec![];
297
298 for _ in 0..10 {
299 let state_clone = Arc::clone(&state);
300 let handle = thread::spawn(move || {
301 let mut state = state_clone.lock().unwrap();
302 state.try_add_file(100, 100, 500)
303 });
304 handles.push(handle);
305 }
306
307 let results: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
308
309 assert_eq!(results.iter().filter(|&&r| r).count(), 5);
310 assert_eq!(results.iter().filter(|&&r| !r).count(), 5);
311
312 let final_state = state.lock().unwrap();
313 assert_eq!(final_state.files_loaded, 5);
314 assert_eq!(final_state.tokens_used, 500);
315 }
316}