1use memmap2::{Mmap, MmapOptions};
7use rayon::prelude::*;
8use std::fs::File;
9use std::io;
10use std::path::Path;
11use std::sync::atomic::{AtomicU64, Ordering};
12
13use crate::tokenizer::{TokenCounts, TokenModel, Tokenizer};
14
15pub struct MappedFile {
17 mmap: Mmap,
18 path: String,
19}
20
21impl MappedFile {
22 #[allow(unsafe_code)]
24 pub fn open(path: &Path) -> io::Result<Self> {
25 let file = File::open(path)?;
26 let mmap = unsafe { MmapOptions::new().map(&file)? };
28
29 Ok(Self { mmap, path: path.to_string_lossy().to_string() })
30 }
31
32 #[inline]
34 pub fn as_bytes(&self) -> &[u8] {
35 &self.mmap
36 }
37
38 pub fn as_str(&self) -> Option<&str> {
40 std::str::from_utf8(&self.mmap).ok()
41 }
42
43 #[inline]
45 pub fn len(&self) -> usize {
46 self.mmap.len()
47 }
48
49 #[inline]
51 pub fn is_empty(&self) -> bool {
52 self.mmap.is_empty()
53 }
54
55 pub fn path(&self) -> &str {
57 &self.path
58 }
59
60 pub fn is_binary(&self) -> bool {
62 let check_len = self.mmap.len().min(8192);
64 let sample = &self.mmap[..check_len];
65
66 if sample.contains(&0) {
68 return true;
69 }
70
71 let non_printable = sample
73 .iter()
74 .filter(|&&b| b < 32 && b != b'\t' && b != b'\n' && b != b'\r')
75 .count();
76
77 non_printable * 10 > check_len
78 }
79
80 pub fn count_lines(&self) -> usize {
82 self.mmap.iter().filter(|&&b| b == b'\n').count()
83 }
84}
85
86pub struct MmapScanner {
88 mmap_threshold: u64,
90 max_file_size: u64,
92 tokenizer: Tokenizer,
94 stats: ScanStats,
96}
97
98#[derive(Debug, Default)]
100pub struct ScanStats {
101 pub files_scanned: AtomicU64,
102 pub bytes_read: AtomicU64,
103 pub files_skipped_binary: AtomicU64,
104 pub files_skipped_size: AtomicU64,
105 pub mmap_used: AtomicU64,
106 pub regular_read_used: AtomicU64,
107}
108
109impl ScanStats {
110 pub fn summary(&self) -> String {
111 format!(
112 "Scanned {} files ({} bytes), skipped {} binary + {} oversized, mmap: {}, regular: {}",
113 self.files_scanned.load(Ordering::Relaxed),
114 self.bytes_read.load(Ordering::Relaxed),
115 self.files_skipped_binary.load(Ordering::Relaxed),
116 self.files_skipped_size.load(Ordering::Relaxed),
117 self.mmap_used.load(Ordering::Relaxed),
118 self.regular_read_used.load(Ordering::Relaxed),
119 )
120 }
121}
122
123#[derive(Debug)]
125pub struct ScannedFile {
126 pub path: String,
127 pub relative_path: String,
128 pub size_bytes: u64,
129 pub lines: usize,
130 pub token_counts: TokenCounts,
131 pub language: Option<String>,
132 pub content: Option<String>,
133 pub is_binary: bool,
134}
135
136impl MmapScanner {
137 pub fn new() -> Self {
139 Self {
140 mmap_threshold: 64 * 1024, max_file_size: 50 * 1024 * 1024, tokenizer: Tokenizer::new(),
143 stats: ScanStats::default(),
144 }
145 }
146
147 pub fn with_mmap_threshold(mut self, bytes: u64) -> Self {
149 self.mmap_threshold = bytes;
150 self
151 }
152
153 pub fn with_max_file_size(mut self, bytes: u64) -> Self {
155 self.max_file_size = bytes;
156 self
157 }
158
159 pub fn scan_file(&self, path: &Path, base_path: &Path) -> io::Result<Option<ScannedFile>> {
161 let metadata = path.metadata()?;
162 let size = metadata.len();
163
164 if size > self.max_file_size {
166 self.stats
167 .files_skipped_size
168 .fetch_add(1, Ordering::Relaxed);
169 return Ok(None);
170 }
171
172 let relative_path = path
173 .strip_prefix(base_path)
174 .unwrap_or(path)
175 .to_string_lossy()
176 .to_string();
177
178 let (content_bytes, _use_mmap) = if size >= self.mmap_threshold {
180 self.stats.mmap_used.fetch_add(1, Ordering::Relaxed);
181 let mapped = MappedFile::open(path)?;
182
183 if mapped.is_binary() {
185 self.stats
186 .files_skipped_binary
187 .fetch_add(1, Ordering::Relaxed);
188 return Ok(None);
189 }
190
191 (mapped.as_bytes().to_vec(), true)
192 } else {
193 self.stats.regular_read_used.fetch_add(1, Ordering::Relaxed);
194 let content = std::fs::read(path)?;
195
196 if is_binary_content(&content) {
198 self.stats
199 .files_skipped_binary
200 .fetch_add(1, Ordering::Relaxed);
201 return Ok(None);
202 }
203
204 (content, false)
205 };
206
207 let content_str = match String::from_utf8(content_bytes) {
209 Ok(s) => s,
210 Err(_) => {
211 self.stats
212 .files_skipped_binary
213 .fetch_add(1, Ordering::Relaxed);
214 return Ok(None);
215 },
216 };
217
218 let token_counts = self.tokenizer.count_all(&content_str);
220
221 let lines = content_str.lines().count();
223
224 let language = detect_language(path);
226
227 self.stats.files_scanned.fetch_add(1, Ordering::Relaxed);
228 self.stats.bytes_read.fetch_add(size, Ordering::Relaxed);
229
230 Ok(Some(ScannedFile {
231 path: path.to_string_lossy().to_string(),
232 relative_path,
233 size_bytes: size,
234 lines,
235 token_counts,
236 language,
237 content: Some(content_str),
238 is_binary: false,
239 }))
240 }
241
242 pub fn scan_files_parallel(&self, paths: &[&Path], base_path: &Path) -> Vec<ScannedFile> {
244 paths
245 .par_iter()
246 .filter_map(|path| match self.scan_file(path, base_path) {
247 Ok(Some(file)) => Some(file),
248 Ok(None) => None,
249 Err(e) => {
250 log::debug!("Error scanning {:?}: {}", path, e);
251 None
252 },
253 })
254 .collect()
255 }
256
257 pub fn stats(&self) -> &ScanStats {
259 &self.stats
260 }
261
262 pub fn reset_stats(&self) {
264 self.stats.files_scanned.store(0, Ordering::Relaxed);
265 self.stats.bytes_read.store(0, Ordering::Relaxed);
266 self.stats.files_skipped_binary.store(0, Ordering::Relaxed);
267 self.stats.files_skipped_size.store(0, Ordering::Relaxed);
268 self.stats.mmap_used.store(0, Ordering::Relaxed);
269 self.stats.regular_read_used.store(0, Ordering::Relaxed);
270 }
271}
272
273impl Default for MmapScanner {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279fn is_binary_content(content: &[u8]) -> bool {
281 let check_len = content.len().min(8192);
282 let sample = &content[..check_len];
283
284 if sample.contains(&0) {
285 return true;
286 }
287
288 let non_printable = sample
289 .iter()
290 .filter(|&&b| b < 32 && b != b'\t' && b != b'\n' && b != b'\r')
291 .count();
292
293 non_printable * 10 > check_len
294}
295
296fn detect_language(path: &Path) -> Option<String> {
298 let ext = path.extension()?.to_str()?;
299
300 let lang = match ext.to_lowercase().as_str() {
301 "py" | "pyw" | "pyi" => "python",
302 "js" | "mjs" | "cjs" => "javascript",
303 "jsx" => "jsx",
304 "ts" | "mts" | "cts" => "typescript",
305 "tsx" => "tsx",
306 "rs" => "rust",
307 "go" => "go",
308 "java" => "java",
309 "c" | "h" => "c",
310 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
311 "cs" => "csharp",
312 "rb" => "ruby",
313 "php" => "php",
314 "swift" => "swift",
315 "kt" | "kts" => "kotlin",
316 "scala" => "scala",
317 "sh" | "bash" => "bash",
318 "lua" => "lua",
319 "zig" => "zig",
320 "md" | "markdown" => "markdown",
321 "json" => "json",
322 "yaml" | "yml" => "yaml",
323 "toml" => "toml",
324 "xml" => "xml",
325 "html" | "htm" => "html",
326 "css" => "css",
327 "scss" | "sass" => "scss",
328 "sql" => "sql",
329 _ => return None,
330 };
331
332 Some(lang.to_owned())
333}
334
335pub struct StreamingProcessor {
337 chunk_size: usize,
338 tokenizer: Tokenizer,
339}
340
341impl StreamingProcessor {
342 pub fn new(chunk_size: usize) -> Self {
344 Self { chunk_size, tokenizer: Tokenizer::new() }
345 }
346
347 pub fn process_file<F>(&self, path: &Path, mut callback: F) -> io::Result<()>
349 where
350 F: FnMut(&str, usize, TokenCounts),
351 {
352 let mapped = MappedFile::open(path)?;
353
354 if mapped.is_binary() {
355 return Ok(());
356 }
357
358 let content = mapped
359 .as_str()
360 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))?;
361
362 let mut offset = 0;
363 while offset < content.len() {
364 let end = (offset + self.chunk_size).min(content.len());
365
366 let chunk_end = if end < content.len() {
368 content[offset..end]
369 .rfind('\n')
370 .map(|i| offset + i + 1)
371 .unwrap_or(end)
372 } else {
373 end
374 };
375
376 let chunk = &content[offset..chunk_end];
377 let tokens = self.tokenizer.count_all(chunk);
378
379 callback(chunk, offset, tokens);
380
381 offset = chunk_end;
382 }
383
384 Ok(())
385 }
386
387 pub fn estimate_tokens(&self, path: &Path, model: TokenModel) -> io::Result<u32> {
389 let metadata = path.metadata()?;
390 let size = metadata.len();
391
392 let chars_per_token = model.chars_per_token();
394 Ok((size as f32 / chars_per_token).ceil() as u32)
395 }
396}
397
398#[cfg(test)]
399#[allow(clippy::str_to_string)]
400mod tests {
401 use super::*;
402 use std::io::Write;
403 use tempfile::NamedTempFile;
404
405 #[test]
406 fn test_mapped_file() {
407 let mut temp = NamedTempFile::new().unwrap();
408 writeln!(temp, "Hello, World!").unwrap();
409 writeln!(temp, "Second line").unwrap();
410
411 let mapped = MappedFile::open(temp.path()).unwrap();
412
413 assert!(!mapped.is_empty());
414 assert!(!mapped.is_binary());
415 assert_eq!(mapped.count_lines(), 2);
416 }
417
418 #[test]
419 fn test_binary_detection() {
420 let mut temp = NamedTempFile::new().unwrap();
421 temp.write_all(&[0x00, 0x01, 0x02, 0x03]).unwrap();
422
423 let mapped = MappedFile::open(temp.path()).unwrap();
424 assert!(mapped.is_binary());
425 }
426
427 #[test]
428 fn test_scanner() {
429 let mut temp = NamedTempFile::with_suffix(".py").unwrap();
430 writeln!(temp, "def hello():").unwrap();
431 writeln!(temp, " print('hello')").unwrap();
432
433 let scanner = MmapScanner::new();
434 let result = scanner
435 .scan_file(temp.path(), temp.path().parent().unwrap())
436 .unwrap();
437
438 assert!(result.is_some());
439 let file = result.unwrap();
440 assert_eq!(file.language, Some("python".to_string()));
441 assert!(file.token_counts.claude > 0);
442 }
443
444 #[test]
445 fn test_detect_language() {
446 assert_eq!(detect_language(Path::new("test.py")), Some("python".to_string()));
447 assert_eq!(detect_language(Path::new("test.rs")), Some("rust".to_string()));
448 assert_eq!(detect_language(Path::new("test.ts")), Some("typescript".to_string()));
449 assert_eq!(detect_language(Path::new("test.unknown")), None);
450 }
451
452 #[test]
453 fn test_streaming_processor() {
454 let mut temp = NamedTempFile::new().unwrap();
455 for i in 0..100 {
456 writeln!(temp, "Line {}: Some content here", i).unwrap();
457 }
458
459 let processor = StreamingProcessor::new(256);
460 let mut chunks = 0;
461
462 processor
463 .process_file(temp.path(), |_chunk, _offset, _tokens| {
464 chunks += 1;
465 })
466 .unwrap();
467
468 assert!(chunks > 1);
469 }
470}