1use std::collections::HashMap;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4use std::fs::File;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8use rayon::prelude::*;
9use git2::{Diff, DiffFormat, DiffOptions, Repository, Tree};
10use anyhow::{Context, Result};
11use thiserror::Error;
12use num_cpus;
13
14use crate::model::Model;
15use crate::profile;
16
17const DEFAULT_STRING_CAPACITY: usize = 8192;
20const PARALLEL_CHUNK_SIZE: usize = 25;
21const ESTIMATED_FILES_COUNT: usize = 100;
22
23type DiffData = Vec<(PathBuf, String, usize)>;
25
26#[derive(Error, Debug)]
28pub enum HookError {
29 #[error("Failed to open repository")]
30 OpenRepository,
31
32 #[error("Failed to get patch")]
33 GetPatch,
34
35 #[error("Empty diff output")]
36 EmptyDiffOutput,
37
38 #[error("Failed to write commit message")]
39 WriteCommitMessage,
40
41 #[error(transparent)]
42 Anyhow(#[from] anyhow::Error)
43}
44
45pub trait FilePath {
47 fn is_empty(&self) -> Result<bool> {
48 self.read().map(|s| s.is_empty())
49 }
50
51 fn write(&self, msg: String) -> Result<()>;
52 fn read(&self) -> Result<String>;
53}
54
55impl FilePath for PathBuf {
56 fn write(&self, msg: String) -> Result<()> {
57 File::create(self)?
58 .write_all(msg.as_bytes())
59 .map_err(Into::into)
60 }
61
62 fn read(&self) -> Result<String> {
63 let mut contents = String::new();
64 File::open(self)?.read_to_string(&mut contents)?;
65 Ok(contents)
66 }
67}
68
69trait DiffDeltaPath {
71 fn path(&self) -> PathBuf;
72}
73
74impl DiffDeltaPath for git2::DiffDelta<'_> {
75 fn path(&self) -> PathBuf {
76 self
77 .new_file()
78 .path()
79 .or_else(|| self.old_file().path())
80 .map(PathBuf::from)
81 .unwrap_or_default()
82 }
83}
84
85pub trait Utf8String {
87 fn to_utf8(&self) -> String;
88}
89
90impl Utf8String for Vec<u8> {
91 fn to_utf8(&self) -> String {
92 if let Ok(s) = std::str::from_utf8(self) {
94 return s.to_string();
95 }
96 String::from_utf8_lossy(self).into_owned()
98 }
99}
100
101impl Utf8String for [u8] {
102 fn to_utf8(&self) -> String {
103 if let Ok(s) = std::str::from_utf8(self) {
105 return s.to_string();
106 }
107 String::from_utf8_lossy(self).into_owned()
109 }
110}
111
112pub trait PatchDiff {
114 fn to_patch(&self, max_token_count: usize, model: Model) -> Result<String>;
115 fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>>;
116 fn is_empty(&self) -> Result<bool>;
117}
118
119impl PatchDiff for Diff<'_> {
120 fn to_patch(&self, max_tokens: usize, model: Model) -> Result<String> {
121 profile!("Generating patch diff");
122
123 let files = self.collect_diff_data()?;
125 if files.is_empty() {
126 return Ok(String::new());
127 }
128
129 if files.len() == 1 {
131 profile!("Single file fast path");
132 let (_, content) = files
133 .into_iter()
134 .next()
135 .ok_or_else(|| HookError::EmptyDiffOutput)?;
136
137 if content.len() < max_tokens * 4 {
139 return Ok(content);
141 }
142
143 return model.truncate(&content, max_tokens);
145 }
146
147 if files.len() <= 5 && max_tokens > 500 {
149 profile!("Small diff fast path");
150 let mut result = String::new();
151 let files_clone = files.clone(); for (i, (_, content)) in files.into_iter().enumerate() {
155 if i > 0 {
156 result.push('\n');
157 }
158 let limit = (max_tokens / files_clone.len()) * 4; let truncated = if content.len() > limit {
161 let truncated = content.chars().take(limit).collect::<String>();
162 let last_space = truncated
164 .rfind(char::is_whitespace)
165 .unwrap_or(truncated.len());
166 if last_space > 0 {
167 truncated[..last_space].to_string()
168 } else {
169 truncated
170 }
171 } else {
172 content
173 };
174 result.push_str(&truncated);
175 }
176
177 return Ok(result);
178 }
179
180 if files.len() <= 20 {
182 profile!("Medium diff optimized path");
183
184 let mut files_vec: Vec<(PathBuf, String, usize)> = files
186 .into_iter()
187 .map(|(path, content)| {
188 let estimated_tokens = content.len() / 4;
190 (path, content, estimated_tokens)
191 })
192 .collect();
193
194 files_vec.sort_by_key(|(_, _, count)| *count);
196
197 let mut result = String::new();
199 let mut tokens_used = 0;
200
201 for (i, (_, content, estimated_tokens)) in files_vec.into_iter().enumerate() {
202 if tokens_used >= max_tokens {
203 break;
204 }
205
206 if i > 0 {
207 result.push('\n');
208 }
209
210 let tokens_left = max_tokens.saturating_sub(tokens_used);
211 let tokens_for_file = estimated_tokens.min(tokens_left);
212
213 let processed_content = if estimated_tokens > tokens_for_file {
215 let char_limit = tokens_for_file * 4;
217 let truncated: String = content.chars().take(char_limit).collect();
218 truncated
219 } else {
220 content
221 };
222
223 result.push_str(&processed_content);
224 tokens_used += tokens_for_file;
225 }
226
227 return Ok(result);
228 }
229
230 profile!("Converting files to vector");
232 let files_vec: Vec<_> = files.into_iter().collect();
233 let total_files = files_vec.len();
234
235 let thread_pool = rayon::ThreadPoolBuilder::new()
237 .num_threads(num_cpus::get())
238 .build()
239 .context("Failed to create thread pool")?;
240
241 profile!("Parallel token counting");
242 let chunk_size = (total_files / num_cpus::get().max(1)).max(10);
244 let files_with_tokens: DiffData = thread_pool.install(|| {
245 files_vec
246 .chunks(chunk_size)
247 .collect::<Vec<_>>()
248 .into_par_iter()
249 .flat_map(|chunk| {
250 chunk
251 .iter()
252 .map(|(path, content)| {
253 let token_count = model.count_tokens(content).unwrap_or_default();
254 (path.clone(), content.clone(), token_count)
255 })
256 .collect::<Vec<_>>()
257 })
258 .collect()
259 });
260
261 profile!("Sorting files by token count");
263 let sorted_files = if total_files > 500 {
264 files_with_tokens
265 } else {
266 let mut sorted = files_with_tokens;
267 sorted.sort_by_key(|(_, _, count)| *count);
268 sorted
269 };
270
271 let remaining_tokens = Arc::new(AtomicUsize::new(max_tokens));
273 let results = Arc::new(parking_lot::RwLock::new(Vec::with_capacity(total_files)));
274 let processed_files = Arc::new(AtomicUsize::new(0));
275
276 let adaptive_chunk_size = (total_files / (2 * num_cpus::get().max(1))).max(PARALLEL_CHUNK_SIZE);
278
279 let chunks: Vec<_> = sorted_files
280 .chunks(adaptive_chunk_size)
281 .map(|chunk| chunk.to_vec())
282 .collect();
283
284 let model = Arc::new(model);
285
286 profile!("Parallel chunk processing");
287 thread_pool.install(|| {
288 chunks
289 .par_iter()
290 .try_for_each(|chunk| process_chunk(chunk, &model, total_files, &processed_files, &remaining_tokens, &results))
291 })?;
292
293 profile!("Combining results");
295 let results_guard = results.read();
296
297 if results_guard.is_empty() {
299 return Ok(String::new());
300 }
301
302 let total_len = results_guard
304 .iter()
305 .map(|(_, content): &(PathBuf, String)| content.len())
306 .sum::<usize>();
307 let mut final_result = String::with_capacity(total_len + results_guard.len());
308
309 for (i, (_, content)) in results_guard.iter().enumerate() {
310 if i > 0 {
311 final_result.push('\n');
312 }
313 final_result.push_str(content);
314 }
315
316 Ok(final_result)
317 }
318
319 fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>> {
320 profile!("Processing diff changes");
321
322 let mut files = HashMap::with_capacity(ESTIMATED_FILES_COUNT);
324
325 thread_local! {
327 static PATH_CACHE: std::cell::RefCell<HashMap<PathBuf, ()>> =
328 std::cell::RefCell::new(HashMap::with_capacity(20));
329 }
330
331 self.print(DiffFormat::Patch, |diff, _hunk, line| {
333 let path = PATH_CACHE.with(|cache| {
335 let mut cache = cache.borrow_mut();
336 let new_path = diff.path();
337 if let Some(existing_path) = cache.keys().find(|p| *p == &new_path) {
338 existing_path.clone()
339 } else {
340 cache.insert(new_path.clone(), ());
341 new_path
342 }
343 });
344
345 let content = if let Ok(s) = std::str::from_utf8(line.content()) {
347 s.to_string()
348 } else {
349 line.content().to_utf8()
351 };
352
353 match line.origin() {
355 '+' | '-' | ' ' => {
356 let entry = files
358 .entry(path)
359 .or_insert_with(|| String::with_capacity(DEFAULT_STRING_CAPACITY));
360
361 match line.origin() {
363 '+' => entry.push('+'),
364 '-' => entry.push('-'),
365 ' ' => entry.push(' '),
366 _ => {}
367 }
368 entry.push_str(&content);
369 }
370 _ => {
371 log::trace!("Skipping diff line with origin: {:?}", line.origin());
374 }
375 }
376
377 true
378 })?;
379
380 Ok(files)
381 }
382
383 fn is_empty(&self) -> Result<bool> {
384 let mut has_changes = false;
385
386 self.foreach(
387 &mut |_file, _progress| {
388 has_changes = true;
389 true
390 },
391 None,
392 None,
393 None
394 )?;
395
396 Ok(!has_changes)
397 }
398}
399
400fn process_chunk(
401 chunk: &[(PathBuf, String, usize)], model: &Arc<Model>, total_files: usize, processed_files: &AtomicUsize,
402 remaining_tokens: &AtomicUsize, result_chunks: &Arc<parking_lot::RwLock<Vec<(PathBuf, String)>>>
403) -> Result<()> {
404 profile!("Processing chunk");
405 if chunk.is_empty() {
407 return Ok(());
408 }
409
410 let total_remaining = remaining_tokens.load(Ordering::Acquire);
412 if total_remaining == 0 {
413 return Ok(());
414 }
415
416 if chunk.len() <= 3 {
418 let total_token_count = chunk.iter().map(|(_, _, count)| *count).sum::<usize>();
419 if total_token_count <= total_remaining {
421 if remaining_tokens
423 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
424 if current >= total_token_count {
425 Some(current - total_token_count)
426 } else {
427 None
428 }
429 })
430 .is_ok()
431 {
432 processed_files.fetch_add(chunk.len(), Ordering::AcqRel);
434
435 let chunk_results: Vec<_> = chunk
437 .iter()
438 .map(|(path, content, _)| (path.clone(), content.clone()))
439 .collect();
440
441 if !chunk_results.is_empty() {
442 result_chunks.write().extend(chunk_results);
443 }
444
445 return Ok(());
446 }
447 }
448 }
449
450 let mut chunk_results = Vec::with_capacity(chunk.len());
452 let mut local_processed = 0;
453
454 for (path, content, token_count) in chunk {
455 local_processed += 1;
456
457 let current_remaining = remaining_tokens.load(Ordering::Acquire);
459 if current_remaining == 0 {
460 break;
461 }
462
463 let token_count = *token_count;
465
466 if token_count <= 100
468 && token_count <= current_remaining
469 && remaining_tokens
470 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
471 if current >= token_count {
472 Some(current - token_count)
473 } else {
474 None
475 }
476 })
477 .is_ok()
478 {
479 chunk_results.push((path.clone(), content.clone()));
480 continue;
481 }
482
483 let current_file_num = processed_files.load(Ordering::Acquire);
486 let files_remaining = total_files.saturating_sub(current_file_num + local_processed);
487
488 let max_tokens_per_file = if files_remaining > 0 {
490 current_remaining.saturating_div(files_remaining)
491 } else {
492 current_remaining
493 };
494
495 if max_tokens_per_file == 0 {
496 continue;
497 }
498
499 let allocated_tokens = token_count.min(max_tokens_per_file);
500
501 if remaining_tokens
502 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
503 if current >= allocated_tokens {
504 Some(current - allocated_tokens)
505 } else {
506 None
507 }
508 })
509 .is_ok()
510 {
511 if token_count <= allocated_tokens {
513 chunk_results.push((path.clone(), content.clone()));
514 } else {
515 if content.len() < 2000 || allocated_tokens > 500 {
517 let char_limit = allocated_tokens * 4;
519 let truncated: String = content.chars().take(char_limit).collect();
520 chunk_results.push((path.clone(), truncated));
521 } else {
522 let truncated = model.truncate(content, allocated_tokens)?;
524 chunk_results.push((path.clone(), truncated));
525 }
526 }
527 }
528 }
529
530 if local_processed > 0 {
532 processed_files.fetch_add(local_processed, Ordering::AcqRel);
533 }
534
535 if !chunk_results.is_empty() {
537 result_chunks.write().extend(chunk_results);
538 }
539
540 Ok(())
541}
542
543pub trait PatchRepository {
544 fn to_patch(&self, tree: Option<Tree<'_>>, max_token_count: usize, model: Model) -> Result<String>;
545 fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
546 fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
547 fn configure_diff_options(&self, opts: &mut DiffOptions);
548 fn configure_commit_diff_options(&self, opts: &mut DiffOptions);
549}
550
551impl PatchRepository for Repository {
552 fn to_patch(&self, tree: Option<Tree>, max_token_count: usize, model: Model) -> Result<String> {
553 profile!("Repository patch generation");
554 self.to_commit_diff(tree)?.to_patch(max_token_count, model)
555 }
556
557 fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
558 profile!("Git diff generation");
559 let mut opts = DiffOptions::new();
560 self.configure_diff_options(&mut opts);
561
562 match tree {
563 Some(tree) => {
564 self.diff_tree_to_workdir_with_index(Some(&tree), Some(&mut opts))
566 }
567 None => {
568 let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
570 self.diff_tree_to_workdir_with_index(Some(&empty_tree), Some(&mut opts))
572 }
573 }
574 .context("Failed to get diff")
575 }
576
577 fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
578 profile!("Git commit diff generation");
579 let mut opts = DiffOptions::new();
580 self.configure_commit_diff_options(&mut opts);
581
582 match tree {
583 Some(tree) => {
584 self.diff_tree_to_index(Some(&tree), None, Some(&mut opts))
586 }
587 None => {
588 let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
590 self.diff_tree_to_index(Some(&empty_tree), None, Some(&mut opts))
592 }
593 }
594 .context("Failed to get diff")
595 }
596
597 fn configure_diff_options(&self, opts: &mut DiffOptions) {
598 opts
599 .ignore_whitespace_change(true)
600 .recurse_untracked_dirs(true)
601 .recurse_ignored_dirs(false)
602 .ignore_whitespace_eol(true)
603 .ignore_blank_lines(true)
604 .include_untracked(true)
605 .ignore_whitespace(true)
606 .indent_heuristic(false)
607 .ignore_submodules(true)
608 .include_ignored(false)
609 .interhunk_lines(0)
610 .context_lines(0)
611 .patience(true)
612 .minimal(true);
613 }
614
615 fn configure_commit_diff_options(&self, opts: &mut DiffOptions) {
616 opts
617 .ignore_whitespace_change(false)
618 .recurse_untracked_dirs(false)
619 .recurse_ignored_dirs(false)
620 .ignore_whitespace_eol(true)
621 .ignore_blank_lines(true)
622 .include_untracked(false)
623 .ignore_whitespace(true)
624 .indent_heuristic(false)
625 .ignore_submodules(true)
626 .include_ignored(false)
627 .interhunk_lines(0)
628 .context_lines(0)
629 .patience(true)
630 .minimal(true);
631 }
632}