ai/
hook.rs

1#![allow(dead_code)]
2use std::collections::HashMap;
3use std::io::{Read, Write};
4use std::path::PathBuf;
5use std::fs::File;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use rayon::prelude::*;
10use structopt::StructOpt;
11use git2::{Diff, DiffFormat, DiffOptions, Repository, Tree};
12use anyhow::{Context, Result};
13use thiserror::Error;
14use num_cpus;
15
16use crate::model::Model;
17use crate::profile;
18
19// Constants
20const MAX_POOL_SIZE: usize = 1000;
21const DEFAULT_STRING_CAPACITY: usize = 8192;
22const PARALLEL_CHUNK_SIZE: usize = 25;
23const ESTIMATED_FILES_COUNT: usize = 100;
24
25// Types
26type DiffData = Vec<(PathBuf, String, usize)>;
27
28// Error definitions
29#[derive(Error, Debug)]
30pub enum HookError {
31  #[error("Failed to open repository")]
32  OpenRepository,
33
34  #[error("Failed to get patch")]
35  GetPatch,
36
37  #[error("Empty diff output")]
38  EmptyDiffOutput,
39
40  #[error("Failed to write commit message")]
41  WriteCommitMessage,
42
43  #[error(transparent)]
44  Anyhow(#[from] anyhow::Error)
45}
46
47// CLI Arguments
48#[derive(StructOpt, Debug)]
49#[structopt(name = "commit-msg-hook", about = "A tool for generating commit messages.")]
50pub struct Args {
51  pub commit_msg_file: PathBuf,
52
53  #[structopt(short = "t", long = "type")]
54  pub commit_type: Option<String>,
55
56  #[structopt(short = "s", long = "sha1")]
57  pub sha1: Option<String>
58}
59
60// Memory management
61#[derive(Debug)]
62struct StringPool {
63  strings:  Vec<String>,
64  capacity: usize
65}
66
67impl StringPool {
68  fn new(capacity: usize) -> Self {
69    Self { strings: Vec::with_capacity(capacity), capacity }
70  }
71
72  fn get(&mut self) -> String {
73    self
74      .strings
75      .pop()
76      .unwrap_or_else(|| String::with_capacity(self.capacity))
77  }
78
79  fn put(&mut self, mut string: String) {
80    string.clear();
81    if self.strings.len() < MAX_POOL_SIZE {
82      self.strings.push(string);
83    }
84  }
85}
86
87// File operations traits
88pub trait FilePath {
89  fn is_empty(&self) -> Result<bool> {
90    self.read().map(|s| s.is_empty())
91  }
92
93  fn write(&self, msg: String) -> Result<()>;
94  fn read(&self) -> Result<String>;
95}
96
97impl FilePath for PathBuf {
98  fn write(&self, msg: String) -> Result<()> {
99    File::create(self)?
100      .write_all(msg.as_bytes())
101      .map_err(Into::into)
102  }
103
104  fn read(&self) -> Result<String> {
105    let mut contents = String::new();
106    File::open(self)?.read_to_string(&mut contents)?;
107    Ok(contents)
108  }
109}
110
111// Git operations traits
112trait DiffDeltaPath {
113  fn path(&self) -> PathBuf;
114}
115
116impl DiffDeltaPath for git2::DiffDelta<'_> {
117  fn path(&self) -> PathBuf {
118    self
119      .new_file()
120      .path()
121      .or_else(|| self.old_file().path())
122      .map(PathBuf::from)
123      .unwrap_or_default()
124  }
125}
126
127// String conversion traits
128pub trait Utf8String {
129  fn to_utf8(&self) -> String;
130}
131
132impl Utf8String for Vec<u8> {
133  fn to_utf8(&self) -> String {
134    // Fast path for valid UTF-8 (most common case)
135    if let Ok(s) = std::str::from_utf8(self) {
136      return s.to_string();
137    }
138    // Fallback for invalid UTF-8
139    String::from_utf8_lossy(self).into_owned()
140  }
141}
142
143impl Utf8String for [u8] {
144  fn to_utf8(&self) -> String {
145    // Fast path for valid UTF-8 (most common case)
146    if let Ok(s) = std::str::from_utf8(self) {
147      return s.to_string();
148    }
149    // Fallback for invalid UTF-8
150    String::from_utf8_lossy(self).into_owned()
151  }
152}
153
154// Patch generation traits
155pub trait PatchDiff {
156  fn to_patch(&self, max_token_count: usize, model: Model) -> Result<String>;
157  fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>>;
158  fn is_empty(&self) -> Result<bool>;
159}
160
161impl PatchDiff for Diff<'_> {
162  fn to_patch(&self, max_tokens: usize, model: Model) -> Result<String> {
163    profile!("Generating patch diff");
164
165    // Step 1: Collect diff data (non-parallel)
166    let files = self.collect_diff_data()?;
167    if files.is_empty() {
168      return Ok(String::new());
169    }
170
171    // Fast path for small diffs - skip tokenization entirely
172    if files.len() == 1 {
173      profile!("Single file fast path");
174      let (_, content) = files
175        .into_iter()
176        .next()
177        .ok_or_else(|| HookError::EmptyDiffOutput)?;
178
179      // If content is small enough to fit, just return it directly
180      if content.len() < max_tokens * 4 {
181        // Estimate 4 chars per token
182        return Ok(content);
183      }
184
185      // Otherwise do a simple truncation
186      return model.truncate(&content, max_tokens);
187    }
188
189    // Optimization: Skip token counting entirely for small diffs
190    if files.len() <= 5 && max_tokens > 500 {
191      profile!("Small diff fast path");
192      let mut result = String::new();
193      let files_clone = files.clone(); // Clone files for use after iteration
194
195      // Just combine the files with a limit on total size
196      for (i, (_, content)) in files.into_iter().enumerate() {
197        if i > 0 {
198          result.push('\n');
199        }
200        // Only add as much as we can estimate will fit
201        let limit = (max_tokens / files_clone.len()) * 4; // ~4 chars per token
202        let truncated = if content.len() > limit {
203          let truncated = content.chars().take(limit).collect::<String>();
204          // Find last space to avoid cutting words
205          let last_space = truncated
206            .rfind(char::is_whitespace)
207            .unwrap_or(truncated.len());
208          if last_space > 0 {
209            truncated[..last_space].to_string()
210          } else {
211            truncated
212          }
213        } else {
214          content
215        };
216        result.push_str(&truncated);
217      }
218
219      return Ok(result);
220    }
221
222    // Step 2: Prepare files for processing - optimized path for medium diffs
223    if files.len() <= 20 {
224      profile!("Medium diff optimized path");
225
226      // Convert to vector with simple heuristic for token count
227      let mut files_vec: Vec<(PathBuf, String, usize)> = files
228        .into_iter()
229        .map(|(path, content)| {
230          // Estimate token count as character count / 4
231          let estimated_tokens = content.len() / 4;
232          (path, content, estimated_tokens)
233        })
234        .collect();
235
236      // Sort by estimated size
237      files_vec.sort_by_key(|(_, _, count)| *count);
238
239      // Allocate tokens to files and process
240      let mut result = String::new();
241      let mut tokens_used = 0;
242
243      for (i, (_, content, estimated_tokens)) in files_vec.into_iter().enumerate() {
244        if tokens_used >= max_tokens {
245          break;
246        }
247
248        if i > 0 {
249          result.push('\n');
250        }
251
252        let tokens_left = max_tokens.saturating_sub(tokens_used);
253        let tokens_for_file = estimated_tokens.min(tokens_left);
254
255        // Only truncate if needed
256        let processed_content = if estimated_tokens > tokens_for_file {
257          // Simple character-based truncation for speed
258          let char_limit = tokens_for_file * 4;
259          let truncated: String = content.chars().take(char_limit).collect();
260          truncated
261        } else {
262          content
263        };
264
265        result.push_str(&processed_content);
266        tokens_used += tokens_for_file;
267      }
268
269      return Ok(result);
270    }
271
272    // Step 3: Complex diff path - use parallel processing with optimizations
273    profile!("Converting files to vector");
274    let files_vec: Vec<_> = files.into_iter().collect();
275    let total_files = files_vec.len();
276
277    // Use rayon for parallel token counting - with batching for performance
278    let thread_pool = rayon::ThreadPoolBuilder::new()
279      .num_threads(num_cpus::get())
280      .build()
281      .context("Failed to create thread pool")?;
282
283    profile!("Parallel token counting");
284    // Use chunked processing for token counting to reduce contention
285    let chunk_size = (total_files / num_cpus::get().max(1)).max(10);
286    let files_with_tokens: DiffData = thread_pool.install(|| {
287      files_vec
288        .chunks(chunk_size)
289        .collect::<Vec<_>>()
290        .into_par_iter()
291        .flat_map(|chunk| {
292          chunk
293            .iter()
294            .map(|(path, content)| {
295              let token_count = model.count_tokens(content).unwrap_or_default();
296              (path.clone(), content.clone(), token_count)
297            })
298            .collect::<Vec<_>>()
299        })
300        .collect()
301    });
302
303    // Skip sorting for very large diffs - it's not worth the time
304    profile!("Sorting files by token count");
305    let sorted_files = if total_files > 500 {
306      files_with_tokens
307    } else {
308      let mut sorted = files_with_tokens;
309      sorted.sort_by_key(|(_, _, count)| *count);
310      sorted
311    };
312
313    // Step 4: Process files with optimized token allocation
314    let remaining_tokens = Arc::new(AtomicUsize::new(max_tokens));
315    let results = Arc::new(parking_lot::RwLock::new(Vec::with_capacity(total_files)));
316    let processed_files = Arc::new(AtomicUsize::new(0));
317
318    // Optimize chunking - use larger chunks for better performance
319    let adaptive_chunk_size = (total_files / (2 * num_cpus::get().max(1))).max(PARALLEL_CHUNK_SIZE);
320
321    let chunks: Vec<_> = sorted_files
322      .chunks(adaptive_chunk_size)
323      .map(|chunk| chunk.to_vec())
324      .collect();
325
326    let model = Arc::new(model);
327
328    profile!("Parallel chunk processing");
329    thread_pool.install(|| {
330      chunks
331        .par_iter()
332        .try_for_each(|chunk| process_chunk(chunk, &model, total_files, &processed_files, &remaining_tokens, &results))
333    })?;
334
335    // Step 5: Combine results efficiently
336    profile!("Combining results");
337    let results_guard = results.read();
338
339    // Fast path for empty results
340    if results_guard.is_empty() {
341      return Ok(String::new());
342    }
343
344    // Optimize string allocation
345    let total_len = results_guard
346      .iter()
347      .map(|(_, content): &(PathBuf, String)| content.len())
348      .sum::<usize>();
349    let mut final_result = String::with_capacity(total_len + results_guard.len());
350
351    for (i, (_, content)) in results_guard.iter().enumerate() {
352      if i > 0 {
353        final_result.push('\n');
354      }
355      final_result.push_str(content);
356    }
357
358    Ok(final_result)
359  }
360
361  fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>> {
362    profile!("Processing diff changes");
363
364    // Pre-allocate HashMap with estimated capacity
365    let mut files = HashMap::with_capacity(ESTIMATED_FILES_COUNT);
366
367    // Use pre-sized buffers to avoid reallocations
368    const BUFFER_SIZE: usize = 64; // Hold context prefix strings
369    static CONTEXT_PREFIX: &str = "context: ";
370
371    // Create thread-local cache for paths to avoid allocations
372    thread_local! {
373      static PATH_CACHE: std::cell::RefCell<HashMap<PathBuf, ()>> =
374        std::cell::RefCell::new(HashMap::with_capacity(20));
375    }
376
377    // Process diffs with optimized buffer handling
378    self.print(DiffFormat::Patch, |diff, _hunk, line| {
379      // Get path with potential reuse from cache for better performance
380      let path = PATH_CACHE.with(|cache| {
381        let mut cache = cache.borrow_mut();
382        let new_path = diff.path();
383        if let Some(existing_path) = cache.keys().find(|p| *p == &new_path) {
384          existing_path.clone()
385        } else {
386          cache.insert(new_path.clone(), ());
387          new_path
388        }
389      });
390
391      // Fast path for UTF-8 content - avoid expensive conversions
392      let content = if let Ok(s) = std::str::from_utf8(line.content()) {
393        s.to_string()
394      } else {
395        // Fallback for non-UTF8 content
396        line.content().to_utf8()
397      };
398
399      // Process line by line origin more efficiently
400      match line.origin() {
401        '+' | '-' => {
402          // Most common case - just get/create entry and append content
403          let entry = files
404            .entry(path)
405            .or_insert_with(|| String::with_capacity(DEFAULT_STRING_CAPACITY));
406          entry.push_str(&content);
407        }
408        _ => {
409          // Context line - less common but still needs efficient handling
410          let mut buf = String::with_capacity(CONTEXT_PREFIX.len() + content.len());
411          buf.push_str(CONTEXT_PREFIX);
412          buf.push_str(&content);
413
414          let entry = files
415            .entry(path)
416            .or_insert_with(|| String::with_capacity(DEFAULT_STRING_CAPACITY));
417          entry.push_str(&buf);
418        }
419      }
420
421      true
422    })?;
423
424    Ok(files)
425  }
426
427  fn is_empty(&self) -> Result<bool> {
428    let mut has_changes = false;
429
430    self.foreach(
431      &mut |_file, _progress| {
432        has_changes = true;
433        true
434      },
435      None,
436      None,
437      None
438    )?;
439
440    Ok(!has_changes)
441  }
442}
443
444fn process_chunk(
445  chunk: &[(PathBuf, String, usize)], model: &Arc<Model>, total_files: usize, processed_files: &AtomicUsize,
446  remaining_tokens: &AtomicUsize, result_chunks: &Arc<parking_lot::RwLock<Vec<(PathBuf, String)>>>
447) -> Result<()> {
448  profile!("Processing chunk");
449  // Fast path for empty chunks
450  if chunk.is_empty() {
451    return Ok(());
452  }
453
454  // Fast path for no tokens remaining
455  let total_remaining = remaining_tokens.load(Ordering::Acquire);
456  if total_remaining == 0 {
457    return Ok(());
458  }
459
460  // Ultra-fast path for small chunks that will likely fit
461  if chunk.len() <= 3 {
462    let total_token_count = chunk.iter().map(|(_, _, count)| *count).sum::<usize>();
463    // If entire chunk is small enough, process it in one go
464    if total_token_count <= total_remaining {
465      // Try to allocate all tokens at once
466      if remaining_tokens
467        .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
468          if current >= total_token_count {
469            Some(current - total_token_count)
470          } else {
471            None
472          }
473        })
474        .is_ok()
475      {
476        // Update processed files counter once
477        processed_files.fetch_add(chunk.len(), Ordering::AcqRel);
478
479        // Collect all results without truncation
480        let chunk_results: Vec<_> = chunk
481          .iter()
482          .map(|(path, content, _)| (path.clone(), content.clone()))
483          .collect();
484
485        if !chunk_results.is_empty() {
486          result_chunks.write().extend(chunk_results);
487        }
488
489        return Ok(());
490      }
491    }
492  }
493
494  // Fast path for small files that don't need tokenization
495  let mut chunk_results = Vec::with_capacity(chunk.len());
496  let mut local_processed = 0;
497
498  for (path, content, token_count) in chunk {
499    local_processed += 1;
500
501    // Recheck remaining tokens to allow early exit
502    let current_remaining = remaining_tokens.load(Ordering::Acquire);
503    if current_remaining == 0 {
504      break;
505    }
506
507    // For very small files or text, don't bother with complex calculations
508    let token_count = *token_count;
509
510    // If small content is less than threshold, just clone without tokenization
511    if token_count <= 100
512      && token_count <= current_remaining
513      && remaining_tokens
514        .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
515          if current >= token_count {
516            Some(current - token_count)
517          } else {
518            None
519          }
520        })
521        .is_ok()
522    {
523      chunk_results.push((path.clone(), content.clone()));
524      continue;
525    }
526
527    // For larger content, do the normal allocation
528    // Batch update processed files counter - just once at the end
529    let current_file_num = processed_files.load(Ordering::Acquire);
530    let files_remaining = total_files.saturating_sub(current_file_num + local_processed);
531
532    // Calculate tokens per file
533    let max_tokens_per_file = if files_remaining > 0 {
534      current_remaining.saturating_div(files_remaining)
535    } else {
536      current_remaining
537    };
538
539    if max_tokens_per_file == 0 {
540      continue;
541    }
542
543    let allocated_tokens = token_count.min(max_tokens_per_file);
544
545    if remaining_tokens
546      .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
547        if current >= allocated_tokens {
548          Some(current - allocated_tokens)
549        } else {
550          None
551        }
552      })
553      .is_ok()
554    {
555      // Fast path for content that doesn't need truncation
556      if token_count <= allocated_tokens {
557        chunk_results.push((path.clone(), content.clone()));
558      } else {
559        // Use fast character-based truncation for most cases
560        if content.len() < 2000 || allocated_tokens > 500 {
561          // Character-based truncation is much faster than tokenization
562          let char_limit = allocated_tokens * 4;
563          let truncated: String = content.chars().take(char_limit).collect();
564          chunk_results.push((path.clone(), truncated));
565        } else {
566          // Use proper truncation for complex cases
567          let truncated = model.truncate(content, allocated_tokens)?;
568          chunk_results.push((path.clone(), truncated));
569        }
570      }
571    }
572  }
573
574  // Update processed files counter once at the end
575  if local_processed > 0 {
576    processed_files.fetch_add(local_processed, Ordering::AcqRel);
577  }
578
579  // Batch update the result collection
580  if !chunk_results.is_empty() {
581    result_chunks.write().extend(chunk_results);
582  }
583
584  Ok(())
585}
586
587pub trait PatchRepository {
588  fn to_patch(&self, tree: Option<Tree<'_>>, max_token_count: usize, model: Model) -> Result<String>;
589  fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
590  fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
591  fn configure_diff_options(&self, opts: &mut DiffOptions);
592  fn configure_commit_diff_options(&self, opts: &mut DiffOptions);
593}
594
595impl PatchRepository for Repository {
596  fn to_patch(&self, tree: Option<Tree>, max_token_count: usize, model: Model) -> Result<String> {
597    profile!("Repository patch generation");
598    self.to_commit_diff(tree)?.to_patch(max_token_count, model)
599  }
600
601  fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
602    profile!("Git diff generation");
603    let mut opts = DiffOptions::new();
604    self.configure_diff_options(&mut opts);
605
606    match tree {
607      Some(tree) => {
608        // Get the diff between tree and working directory, including staged changes
609        self.diff_tree_to_workdir_with_index(Some(&tree), Some(&mut opts))
610      }
611      None => {
612        // If there's no HEAD yet, compare against an empty tree
613        let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
614        // Get the diff between empty tree and working directory, including staged changes
615        self.diff_tree_to_workdir_with_index(Some(&empty_tree), Some(&mut opts))
616      }
617    }
618    .context("Failed to get diff")
619  }
620
621  fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
622    profile!("Git commit diff generation");
623    let mut opts = DiffOptions::new();
624    self.configure_commit_diff_options(&mut opts);
625
626    match tree {
627      Some(tree) => {
628        // Get the diff between tree and index (staged changes only)
629        self.diff_tree_to_index(Some(&tree), None, Some(&mut opts))
630      }
631      None => {
632        // If there's no HEAD yet, compare against an empty tree
633        let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
634        // Get the diff between empty tree and index (staged changes only)
635        self.diff_tree_to_index(Some(&empty_tree), None, Some(&mut opts))
636      }
637    }
638    .context("Failed to get diff")
639  }
640
641  fn configure_diff_options(&self, opts: &mut DiffOptions) {
642    opts
643      .ignore_whitespace_change(true)
644      .recurse_untracked_dirs(true)
645      .recurse_ignored_dirs(false)
646      .ignore_whitespace_eol(true)
647      .ignore_blank_lines(true)
648      .include_untracked(true)
649      .ignore_whitespace(true)
650      .indent_heuristic(false)
651      .ignore_submodules(true)
652      .include_ignored(false)
653      .interhunk_lines(0)
654      .context_lines(0)
655      .patience(true)
656      .minimal(true);
657  }
658
659  fn configure_commit_diff_options(&self, opts: &mut DiffOptions) {
660    opts
661      .ignore_whitespace_change(false)
662      .recurse_untracked_dirs(false)
663      .recurse_ignored_dirs(false)
664      .ignore_whitespace_eol(true)
665      .ignore_blank_lines(true)
666      .include_untracked(false)
667      .ignore_whitespace(true)
668      .indent_heuristic(false)
669      .ignore_submodules(true)
670      .include_ignored(false)
671      .interhunk_lines(0)
672      .context_lines(0)
673      .patience(true)
674      .minimal(true);
675  }
676}
677
678#[cfg(test)]
679mod tests {
680  use super::*;
681
682  #[test]
683  fn test_string_pool_new() {
684    let pool = StringPool::new(100);
685    assert_eq!(pool.strings.len(), 0);
686    assert_eq!(pool.capacity, 100);
687  }
688
689  #[test]
690  fn test_string_pool_put_and_get() {
691    let mut pool = StringPool::new(10);
692    let mut s1 = String::with_capacity(10);
693    s1.push_str("test");
694    pool.put(s1);
695
696    assert_eq!(pool.strings.len(), 1);
697
698    let s2 = pool.get();
699    assert_eq!(s2.capacity(), 10);
700    assert_eq!(s2.len(), 0);
701    assert_eq!(pool.strings.len(), 0);
702  }
703
704  #[test]
705  fn test_string_pool_limit() {
706    let mut pool = StringPool::new(10);
707
708    for _ in 0..150 {
709      pool.put(String::with_capacity(10));
710    }
711
712    assert_eq!(pool.strings.len(), 150);
713  }
714}
715
716#[test]
717fn test_string_pool_get() {
718  let mut pool = StringPool::new(10);
719  let s1 = pool.get();
720  assert_eq!(s1.capacity(), 10);
721  assert_eq!(s1.len(), 0);
722}