1#![allow(dead_code)]
2use std::collections::HashMap;
3use std::io::{Read, Write};
4use std::path::PathBuf;
5use std::fs::File;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use rayon::prelude::*;
10use structopt::StructOpt;
11use git2::{Diff, DiffFormat, DiffOptions, Repository, Tree};
12use anyhow::{Context, Result};
13use thiserror::Error;
14use num_cpus;
15
16use crate::model::Model;
17use crate::profile;
18
19const MAX_POOL_SIZE: usize = 1000;
21const DEFAULT_STRING_CAPACITY: usize = 8192;
22const PARALLEL_CHUNK_SIZE: usize = 25;
23const ESTIMATED_FILES_COUNT: usize = 100;
24
25type DiffData = Vec<(PathBuf, String, usize)>;
27
28#[derive(Error, Debug)]
30pub enum HookError {
31 #[error("Failed to open repository")]
32 OpenRepository,
33
34 #[error("Failed to get patch")]
35 GetPatch,
36
37 #[error("Empty diff output")]
38 EmptyDiffOutput,
39
40 #[error("Failed to write commit message")]
41 WriteCommitMessage,
42
43 #[error(transparent)]
44 Anyhow(#[from] anyhow::Error)
45}
46
47#[derive(StructOpt, Debug)]
49#[structopt(name = "commit-msg-hook", about = "A tool for generating commit messages.")]
50pub struct Args {
51 pub commit_msg_file: PathBuf,
52
53 #[structopt(short = "t", long = "type")]
54 pub commit_type: Option<String>,
55
56 #[structopt(short = "s", long = "sha1")]
57 pub sha1: Option<String>
58}
59
60#[derive(Debug)]
62struct StringPool {
63 strings: Vec<String>,
64 capacity: usize
65}
66
67impl StringPool {
68 fn new(capacity: usize) -> Self {
69 Self { strings: Vec::with_capacity(capacity), capacity }
70 }
71
72 fn get(&mut self) -> String {
73 self
74 .strings
75 .pop()
76 .unwrap_or_else(|| String::with_capacity(self.capacity))
77 }
78
79 fn put(&mut self, mut string: String) {
80 string.clear();
81 if self.strings.len() < MAX_POOL_SIZE {
82 self.strings.push(string);
83 }
84 }
85}
86
87pub trait FilePath {
89 fn is_empty(&self) -> Result<bool> {
90 self.read().map(|s| s.is_empty())
91 }
92
93 fn write(&self, msg: String) -> Result<()>;
94 fn read(&self) -> Result<String>;
95}
96
97impl FilePath for PathBuf {
98 fn write(&self, msg: String) -> Result<()> {
99 File::create(self)?
100 .write_all(msg.as_bytes())
101 .map_err(Into::into)
102 }
103
104 fn read(&self) -> Result<String> {
105 let mut contents = String::new();
106 File::open(self)?.read_to_string(&mut contents)?;
107 Ok(contents)
108 }
109}
110
111trait DiffDeltaPath {
113 fn path(&self) -> PathBuf;
114}
115
116impl DiffDeltaPath for git2::DiffDelta<'_> {
117 fn path(&self) -> PathBuf {
118 self
119 .new_file()
120 .path()
121 .or_else(|| self.old_file().path())
122 .map(PathBuf::from)
123 .unwrap_or_default()
124 }
125}
126
127pub trait Utf8String {
129 fn to_utf8(&self) -> String;
130}
131
132impl Utf8String for Vec<u8> {
133 fn to_utf8(&self) -> String {
134 if let Ok(s) = std::str::from_utf8(self) {
136 return s.to_string();
137 }
138 String::from_utf8_lossy(self).into_owned()
140 }
141}
142
143impl Utf8String for [u8] {
144 fn to_utf8(&self) -> String {
145 if let Ok(s) = std::str::from_utf8(self) {
147 return s.to_string();
148 }
149 String::from_utf8_lossy(self).into_owned()
151 }
152}
153
154pub trait PatchDiff {
156 fn to_patch(&self, max_token_count: usize, model: Model) -> Result<String>;
157 fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>>;
158 fn is_empty(&self) -> Result<bool>;
159}
160
161impl PatchDiff for Diff<'_> {
162 fn to_patch(&self, max_tokens: usize, model: Model) -> Result<String> {
163 profile!("Generating patch diff");
164
165 let files = self.collect_diff_data()?;
167 if files.is_empty() {
168 return Ok(String::new());
169 }
170
171 if files.len() == 1 {
173 profile!("Single file fast path");
174 let (_, content) = files
175 .into_iter()
176 .next()
177 .ok_or_else(|| HookError::EmptyDiffOutput)?;
178
179 if content.len() < max_tokens * 4 {
181 return Ok(content);
183 }
184
185 return model.truncate(&content, max_tokens);
187 }
188
189 if files.len() <= 5 && max_tokens > 500 {
191 profile!("Small diff fast path");
192 let mut result = String::new();
193 let files_clone = files.clone(); for (i, (_, content)) in files.into_iter().enumerate() {
197 if i > 0 {
198 result.push('\n');
199 }
200 let limit = (max_tokens / files_clone.len()) * 4; let truncated = if content.len() > limit {
203 let truncated = content.chars().take(limit).collect::<String>();
204 let last_space = truncated
206 .rfind(char::is_whitespace)
207 .unwrap_or(truncated.len());
208 if last_space > 0 {
209 truncated[..last_space].to_string()
210 } else {
211 truncated
212 }
213 } else {
214 content
215 };
216 result.push_str(&truncated);
217 }
218
219 return Ok(result);
220 }
221
222 if files.len() <= 20 {
224 profile!("Medium diff optimized path");
225
226 let mut files_vec: Vec<(PathBuf, String, usize)> = files
228 .into_iter()
229 .map(|(path, content)| {
230 let estimated_tokens = content.len() / 4;
232 (path, content, estimated_tokens)
233 })
234 .collect();
235
236 files_vec.sort_by_key(|(_, _, count)| *count);
238
239 let mut result = String::new();
241 let mut tokens_used = 0;
242
243 for (i, (_, content, estimated_tokens)) in files_vec.into_iter().enumerate() {
244 if tokens_used >= max_tokens {
245 break;
246 }
247
248 if i > 0 {
249 result.push('\n');
250 }
251
252 let tokens_left = max_tokens.saturating_sub(tokens_used);
253 let tokens_for_file = estimated_tokens.min(tokens_left);
254
255 let processed_content = if estimated_tokens > tokens_for_file {
257 let char_limit = tokens_for_file * 4;
259 let truncated: String = content.chars().take(char_limit).collect();
260 truncated
261 } else {
262 content
263 };
264
265 result.push_str(&processed_content);
266 tokens_used += tokens_for_file;
267 }
268
269 return Ok(result);
270 }
271
272 profile!("Converting files to vector");
274 let files_vec: Vec<_> = files.into_iter().collect();
275 let total_files = files_vec.len();
276
277 let thread_pool = rayon::ThreadPoolBuilder::new()
279 .num_threads(num_cpus::get())
280 .build()
281 .context("Failed to create thread pool")?;
282
283 profile!("Parallel token counting");
284 let chunk_size = (total_files / num_cpus::get().max(1)).max(10);
286 let files_with_tokens: DiffData = thread_pool.install(|| {
287 files_vec
288 .chunks(chunk_size)
289 .collect::<Vec<_>>()
290 .into_par_iter()
291 .flat_map(|chunk| {
292 chunk
293 .iter()
294 .map(|(path, content)| {
295 let token_count = model.count_tokens(content).unwrap_or_default();
296 (path.clone(), content.clone(), token_count)
297 })
298 .collect::<Vec<_>>()
299 })
300 .collect()
301 });
302
303 profile!("Sorting files by token count");
305 let sorted_files = if total_files > 500 {
306 files_with_tokens
307 } else {
308 let mut sorted = files_with_tokens;
309 sorted.sort_by_key(|(_, _, count)| *count);
310 sorted
311 };
312
313 let remaining_tokens = Arc::new(AtomicUsize::new(max_tokens));
315 let results = Arc::new(parking_lot::RwLock::new(Vec::with_capacity(total_files)));
316 let processed_files = Arc::new(AtomicUsize::new(0));
317
318 let adaptive_chunk_size = (total_files / (2 * num_cpus::get().max(1))).max(PARALLEL_CHUNK_SIZE);
320
321 let chunks: Vec<_> = sorted_files
322 .chunks(adaptive_chunk_size)
323 .map(|chunk| chunk.to_vec())
324 .collect();
325
326 let model = Arc::new(model);
327
328 profile!("Parallel chunk processing");
329 thread_pool.install(|| {
330 chunks
331 .par_iter()
332 .try_for_each(|chunk| process_chunk(chunk, &model, total_files, &processed_files, &remaining_tokens, &results))
333 })?;
334
335 profile!("Combining results");
337 let results_guard = results.read();
338
339 if results_guard.is_empty() {
341 return Ok(String::new());
342 }
343
344 let total_len = results_guard
346 .iter()
347 .map(|(_, content): &(PathBuf, String)| content.len())
348 .sum::<usize>();
349 let mut final_result = String::with_capacity(total_len + results_guard.len());
350
351 for (i, (_, content)) in results_guard.iter().enumerate() {
352 if i > 0 {
353 final_result.push('\n');
354 }
355 final_result.push_str(content);
356 }
357
358 Ok(final_result)
359 }
360
361 fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>> {
362 profile!("Processing diff changes");
363
364 let mut files = HashMap::with_capacity(ESTIMATED_FILES_COUNT);
366
367 const BUFFER_SIZE: usize = 64; static CONTEXT_PREFIX: &str = "context: ";
370
371 thread_local! {
373 static PATH_CACHE: std::cell::RefCell<HashMap<PathBuf, ()>> =
374 std::cell::RefCell::new(HashMap::with_capacity(20));
375 }
376
377 self.print(DiffFormat::Patch, |diff, _hunk, line| {
379 let path = PATH_CACHE.with(|cache| {
381 let mut cache = cache.borrow_mut();
382 let new_path = diff.path();
383 if let Some(existing_path) = cache.keys().find(|p| *p == &new_path) {
384 existing_path.clone()
385 } else {
386 cache.insert(new_path.clone(), ());
387 new_path
388 }
389 });
390
391 let content = if let Ok(s) = std::str::from_utf8(line.content()) {
393 s.to_string()
394 } else {
395 line.content().to_utf8()
397 };
398
399 match line.origin() {
401 '+' | '-' => {
402 let entry = files
404 .entry(path)
405 .or_insert_with(|| String::with_capacity(DEFAULT_STRING_CAPACITY));
406 entry.push_str(&content);
407 }
408 _ => {
409 let mut buf = String::with_capacity(CONTEXT_PREFIX.len() + content.len());
411 buf.push_str(CONTEXT_PREFIX);
412 buf.push_str(&content);
413
414 let entry = files
415 .entry(path)
416 .or_insert_with(|| String::with_capacity(DEFAULT_STRING_CAPACITY));
417 entry.push_str(&buf);
418 }
419 }
420
421 true
422 })?;
423
424 Ok(files)
425 }
426
427 fn is_empty(&self) -> Result<bool> {
428 let mut has_changes = false;
429
430 self.foreach(
431 &mut |_file, _progress| {
432 has_changes = true;
433 true
434 },
435 None,
436 None,
437 None
438 )?;
439
440 Ok(!has_changes)
441 }
442}
443
444fn process_chunk(
445 chunk: &[(PathBuf, String, usize)], model: &Arc<Model>, total_files: usize, processed_files: &AtomicUsize,
446 remaining_tokens: &AtomicUsize, result_chunks: &Arc<parking_lot::RwLock<Vec<(PathBuf, String)>>>
447) -> Result<()> {
448 profile!("Processing chunk");
449 if chunk.is_empty() {
451 return Ok(());
452 }
453
454 let total_remaining = remaining_tokens.load(Ordering::Acquire);
456 if total_remaining == 0 {
457 return Ok(());
458 }
459
460 if chunk.len() <= 3 {
462 let total_token_count = chunk.iter().map(|(_, _, count)| *count).sum::<usize>();
463 if total_token_count <= total_remaining {
465 if remaining_tokens
467 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
468 if current >= total_token_count {
469 Some(current - total_token_count)
470 } else {
471 None
472 }
473 })
474 .is_ok()
475 {
476 processed_files.fetch_add(chunk.len(), Ordering::AcqRel);
478
479 let chunk_results: Vec<_> = chunk
481 .iter()
482 .map(|(path, content, _)| (path.clone(), content.clone()))
483 .collect();
484
485 if !chunk_results.is_empty() {
486 result_chunks.write().extend(chunk_results);
487 }
488
489 return Ok(());
490 }
491 }
492 }
493
494 let mut chunk_results = Vec::with_capacity(chunk.len());
496 let mut local_processed = 0;
497
498 for (path, content, token_count) in chunk {
499 local_processed += 1;
500
501 let current_remaining = remaining_tokens.load(Ordering::Acquire);
503 if current_remaining == 0 {
504 break;
505 }
506
507 let token_count = *token_count;
509
510 if token_count <= 100
512 && token_count <= current_remaining
513 && remaining_tokens
514 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
515 if current >= token_count {
516 Some(current - token_count)
517 } else {
518 None
519 }
520 })
521 .is_ok()
522 {
523 chunk_results.push((path.clone(), content.clone()));
524 continue;
525 }
526
527 let current_file_num = processed_files.load(Ordering::Acquire);
530 let files_remaining = total_files.saturating_sub(current_file_num + local_processed);
531
532 let max_tokens_per_file = if files_remaining > 0 {
534 current_remaining.saturating_div(files_remaining)
535 } else {
536 current_remaining
537 };
538
539 if max_tokens_per_file == 0 {
540 continue;
541 }
542
543 let allocated_tokens = token_count.min(max_tokens_per_file);
544
545 if remaining_tokens
546 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
547 if current >= allocated_tokens {
548 Some(current - allocated_tokens)
549 } else {
550 None
551 }
552 })
553 .is_ok()
554 {
555 if token_count <= allocated_tokens {
557 chunk_results.push((path.clone(), content.clone()));
558 } else {
559 if content.len() < 2000 || allocated_tokens > 500 {
561 let char_limit = allocated_tokens * 4;
563 let truncated: String = content.chars().take(char_limit).collect();
564 chunk_results.push((path.clone(), truncated));
565 } else {
566 let truncated = model.truncate(content, allocated_tokens)?;
568 chunk_results.push((path.clone(), truncated));
569 }
570 }
571 }
572 }
573
574 if local_processed > 0 {
576 processed_files.fetch_add(local_processed, Ordering::AcqRel);
577 }
578
579 if !chunk_results.is_empty() {
581 result_chunks.write().extend(chunk_results);
582 }
583
584 Ok(())
585}
586
587pub trait PatchRepository {
588 fn to_patch(&self, tree: Option<Tree<'_>>, max_token_count: usize, model: Model) -> Result<String>;
589 fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
590 fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
591 fn configure_diff_options(&self, opts: &mut DiffOptions);
592 fn configure_commit_diff_options(&self, opts: &mut DiffOptions);
593}
594
595impl PatchRepository for Repository {
596 fn to_patch(&self, tree: Option<Tree>, max_token_count: usize, model: Model) -> Result<String> {
597 profile!("Repository patch generation");
598 self.to_commit_diff(tree)?.to_patch(max_token_count, model)
599 }
600
601 fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
602 profile!("Git diff generation");
603 let mut opts = DiffOptions::new();
604 self.configure_diff_options(&mut opts);
605
606 match tree {
607 Some(tree) => {
608 self.diff_tree_to_workdir_with_index(Some(&tree), Some(&mut opts))
610 }
611 None => {
612 let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
614 self.diff_tree_to_workdir_with_index(Some(&empty_tree), Some(&mut opts))
616 }
617 }
618 .context("Failed to get diff")
619 }
620
621 fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
622 profile!("Git commit diff generation");
623 let mut opts = DiffOptions::new();
624 self.configure_commit_diff_options(&mut opts);
625
626 match tree {
627 Some(tree) => {
628 self.diff_tree_to_index(Some(&tree), None, Some(&mut opts))
630 }
631 None => {
632 let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
634 self.diff_tree_to_index(Some(&empty_tree), None, Some(&mut opts))
636 }
637 }
638 .context("Failed to get diff")
639 }
640
641 fn configure_diff_options(&self, opts: &mut DiffOptions) {
642 opts
643 .ignore_whitespace_change(true)
644 .recurse_untracked_dirs(true)
645 .recurse_ignored_dirs(false)
646 .ignore_whitespace_eol(true)
647 .ignore_blank_lines(true)
648 .include_untracked(true)
649 .ignore_whitespace(true)
650 .indent_heuristic(false)
651 .ignore_submodules(true)
652 .include_ignored(false)
653 .interhunk_lines(0)
654 .context_lines(0)
655 .patience(true)
656 .minimal(true);
657 }
658
659 fn configure_commit_diff_options(&self, opts: &mut DiffOptions) {
660 opts
661 .ignore_whitespace_change(false)
662 .recurse_untracked_dirs(false)
663 .recurse_ignored_dirs(false)
664 .ignore_whitespace_eol(true)
665 .ignore_blank_lines(true)
666 .include_untracked(false)
667 .ignore_whitespace(true)
668 .indent_heuristic(false)
669 .ignore_submodules(true)
670 .include_ignored(false)
671 .interhunk_lines(0)
672 .context_lines(0)
673 .patience(true)
674 .minimal(true);
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681
682 #[test]
683 fn test_string_pool_new() {
684 let pool = StringPool::new(100);
685 assert_eq!(pool.strings.len(), 0);
686 assert_eq!(pool.capacity, 100);
687 }
688
689 #[test]
690 fn test_string_pool_put_and_get() {
691 let mut pool = StringPool::new(10);
692 let mut s1 = String::with_capacity(10);
693 s1.push_str("test");
694 pool.put(s1);
695
696 assert_eq!(pool.strings.len(), 1);
697
698 let s2 = pool.get();
699 assert_eq!(s2.capacity(), 10);
700 assert_eq!(s2.len(), 0);
701 assert_eq!(pool.strings.len(), 0);
702 }
703
704 #[test]
705 fn test_string_pool_limit() {
706 let mut pool = StringPool::new(10);
707
708 for _ in 0..150 {
709 pool.put(String::with_capacity(10));
710 }
711
712 assert_eq!(pool.strings.len(), 150);
713 }
714}
715
716#[test]
717fn test_string_pool_get() {
718 let mut pool = StringPool::new(10);
719 let s1 = pool.get();
720 assert_eq!(s1.capacity(), 10);
721 assert_eq!(s1.len(), 0);
722}