ai/
hook.rs

1use std::collections::HashMap;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4use std::fs::File;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8use rayon::prelude::*;
9use git2::{Diff, DiffFormat, DiffOptions, Repository, Tree};
10use anyhow::{Context, Result};
11use thiserror::Error;
12use num_cpus;
13
14use crate::model::Model;
15use crate::profile;
16
17// Constants
18
19const DEFAULT_STRING_CAPACITY: usize = 8192;
20const PARALLEL_CHUNK_SIZE: usize = 25;
21const ESTIMATED_FILES_COUNT: usize = 100;
22
23// Types
24type DiffData = Vec<(PathBuf, String, usize)>;
25
26// Error definitions
27#[derive(Error, Debug)]
28pub enum HookError {
29  #[error("Failed to open repository")]
30  OpenRepository,
31
32  #[error("Failed to get patch")]
33  GetPatch,
34
35  #[error("Empty diff output")]
36  EmptyDiffOutput,
37
38  #[error("Failed to write commit message")]
39  WriteCommitMessage,
40
41  #[error(transparent)]
42  Anyhow(#[from] anyhow::Error)
43}
44
45// File operations traits
46pub trait FilePath {
47  fn is_empty(&self) -> Result<bool> {
48    self.read().map(|s| s.is_empty())
49  }
50
51  fn write(&self, msg: String) -> Result<()>;
52  fn read(&self) -> Result<String>;
53}
54
55impl FilePath for PathBuf {
56  fn write(&self, msg: String) -> Result<()> {
57    File::create(self)?
58      .write_all(msg.as_bytes())
59      .map_err(Into::into)
60  }
61
62  fn read(&self) -> Result<String> {
63    let mut contents = String::new();
64    File::open(self)?.read_to_string(&mut contents)?;
65    Ok(contents)
66  }
67}
68
69// Git operations traits
70trait DiffDeltaPath {
71  fn path(&self) -> PathBuf;
72}
73
74impl DiffDeltaPath for git2::DiffDelta<'_> {
75  fn path(&self) -> PathBuf {
76    self
77      .new_file()
78      .path()
79      .or_else(|| self.old_file().path())
80      .map(PathBuf::from)
81      .unwrap_or_default()
82  }
83}
84
85// String conversion traits
86pub trait Utf8String {
87  fn to_utf8(&self) -> String;
88}
89
90impl Utf8String for Vec<u8> {
91  fn to_utf8(&self) -> String {
92    // Fast path for valid UTF-8 (most common case)
93    if let Ok(s) = std::str::from_utf8(self) {
94      return s.to_string();
95    }
96    // Fallback for invalid UTF-8
97    String::from_utf8_lossy(self).into_owned()
98  }
99}
100
101impl Utf8String for [u8] {
102  fn to_utf8(&self) -> String {
103    // Fast path for valid UTF-8 (most common case)
104    if let Ok(s) = std::str::from_utf8(self) {
105      return s.to_string();
106    }
107    // Fallback for invalid UTF-8
108    String::from_utf8_lossy(self).into_owned()
109  }
110}
111
112// Patch generation traits
113pub trait PatchDiff {
114  fn to_patch(&self, max_token_count: usize, model: Model) -> Result<String>;
115  fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>>;
116  fn is_empty(&self) -> Result<bool>;
117}
118
119impl PatchDiff for Diff<'_> {
120  fn to_patch(&self, max_tokens: usize, model: Model) -> Result<String> {
121    profile!("Generating patch diff");
122
123    // Step 1: Collect diff data (non-parallel)
124    let files = self.collect_diff_data()?;
125    if files.is_empty() {
126      return Ok(String::new());
127    }
128
129    // Fast path for small diffs - skip tokenization entirely
130    if files.len() == 1 {
131      profile!("Single file fast path");
132      let (_, content) = files
133        .into_iter()
134        .next()
135        .ok_or_else(|| HookError::EmptyDiffOutput)?;
136
137      // If content is small enough to fit, just return it directly
138      if content.len() < max_tokens * 4 {
139        // Estimate 4 chars per token
140        return Ok(content);
141      }
142
143      // Otherwise do a simple truncation
144      return model.truncate(&content, max_tokens);
145    }
146
147    // Optimization: Skip token counting entirely for small diffs
148    if files.len() <= 5 && max_tokens > 500 {
149      profile!("Small diff fast path");
150      let mut result = String::new();
151      let files_clone = files.clone(); // Clone files for use after iteration
152
153      // Just combine the files with a limit on total size
154      for (i, (_, content)) in files.into_iter().enumerate() {
155        if i > 0 {
156          result.push('\n');
157        }
158        // Only add as much as we can estimate will fit
159        let limit = (max_tokens / files_clone.len()) * 4; // ~4 chars per token
160        let truncated = if content.len() > limit {
161          let truncated = content.chars().take(limit).collect::<String>();
162          // Find last space to avoid cutting words
163          let last_space = truncated
164            .rfind(char::is_whitespace)
165            .unwrap_or(truncated.len());
166          if last_space > 0 {
167            truncated[..last_space].to_string()
168          } else {
169            truncated
170          }
171        } else {
172          content
173        };
174        result.push_str(&truncated);
175      }
176
177      return Ok(result);
178    }
179
180    // Step 2: Prepare files for processing - optimized path for medium diffs
181    if files.len() <= 20 {
182      profile!("Medium diff optimized path");
183
184      // Convert to vector with simple heuristic for token count
185      let mut files_vec: Vec<(PathBuf, String, usize)> = files
186        .into_iter()
187        .map(|(path, content)| {
188          // Estimate token count as character count / 4
189          let estimated_tokens = content.len() / 4;
190          (path, content, estimated_tokens)
191        })
192        .collect();
193
194      // Sort by estimated size
195      files_vec.sort_by_key(|(_, _, count)| *count);
196
197      // Allocate tokens to files and process
198      let mut result = String::new();
199      let mut tokens_used = 0;
200
201      for (i, (_, content, estimated_tokens)) in files_vec.into_iter().enumerate() {
202        if tokens_used >= max_tokens {
203          break;
204        }
205
206        if i > 0 {
207          result.push('\n');
208        }
209
210        let tokens_left = max_tokens.saturating_sub(tokens_used);
211        let tokens_for_file = estimated_tokens.min(tokens_left);
212
213        // Only truncate if needed
214        let processed_content = if estimated_tokens > tokens_for_file {
215          // Simple character-based truncation for speed
216          let char_limit = tokens_for_file * 4;
217          let truncated: String = content.chars().take(char_limit).collect();
218          truncated
219        } else {
220          content
221        };
222
223        result.push_str(&processed_content);
224        tokens_used += tokens_for_file;
225      }
226
227      return Ok(result);
228    }
229
230    // Step 3: Complex diff path - use parallel processing with optimizations
231    profile!("Converting files to vector");
232    let files_vec: Vec<_> = files.into_iter().collect();
233    let total_files = files_vec.len();
234
235    // Use rayon for parallel token counting - with batching for performance
236    let thread_pool = rayon::ThreadPoolBuilder::new()
237      .num_threads(num_cpus::get())
238      .build()
239      .context("Failed to create thread pool")?;
240
241    profile!("Parallel token counting");
242    // Use chunked processing for token counting to reduce contention
243    let chunk_size = (total_files / num_cpus::get().max(1)).max(10);
244    let files_with_tokens: DiffData = thread_pool.install(|| {
245      files_vec
246        .chunks(chunk_size)
247        .collect::<Vec<_>>()
248        .into_par_iter()
249        .flat_map(|chunk| {
250          chunk
251            .iter()
252            .map(|(path, content)| {
253              let token_count = model.count_tokens(content).unwrap_or_default();
254              (path.clone(), content.clone(), token_count)
255            })
256            .collect::<Vec<_>>()
257        })
258        .collect()
259    });
260
261    // Skip sorting for very large diffs - it's not worth the time
262    profile!("Sorting files by token count");
263    let sorted_files = if total_files > 500 {
264      files_with_tokens
265    } else {
266      let mut sorted = files_with_tokens;
267      sorted.sort_by_key(|(_, _, count)| *count);
268      sorted
269    };
270
271    // Step 4: Process files with optimized token allocation
272    let remaining_tokens = Arc::new(AtomicUsize::new(max_tokens));
273    let results = Arc::new(parking_lot::RwLock::new(Vec::with_capacity(total_files)));
274    let processed_files = Arc::new(AtomicUsize::new(0));
275
276    // Optimize chunking - use larger chunks for better performance
277    let adaptive_chunk_size = (total_files / (2 * num_cpus::get().max(1))).max(PARALLEL_CHUNK_SIZE);
278
279    let chunks: Vec<_> = sorted_files
280      .chunks(adaptive_chunk_size)
281      .map(|chunk| chunk.to_vec())
282      .collect();
283
284    let model = Arc::new(model);
285
286    profile!("Parallel chunk processing");
287    thread_pool.install(|| {
288      chunks
289        .par_iter()
290        .try_for_each(|chunk| process_chunk(chunk, &model, total_files, &processed_files, &remaining_tokens, &results))
291    })?;
292
293    // Step 5: Combine results efficiently
294    profile!("Combining results");
295    let results_guard = results.read();
296
297    // Fast path for empty results
298    if results_guard.is_empty() {
299      return Ok(String::new());
300    }
301
302    // Optimize string allocation
303    let total_len = results_guard
304      .iter()
305      .map(|(_, content): &(PathBuf, String)| content.len())
306      .sum::<usize>();
307    let mut final_result = String::with_capacity(total_len + results_guard.len());
308
309    for (i, (_, content)) in results_guard.iter().enumerate() {
310      if i > 0 {
311        final_result.push('\n');
312      }
313      final_result.push_str(content);
314    }
315
316    Ok(final_result)
317  }
318
319  fn collect_diff_data(&self) -> Result<HashMap<PathBuf, String>> {
320    profile!("Processing diff changes");
321
322    // Pre-allocate HashMap with estimated capacity
323    let mut files = HashMap::with_capacity(ESTIMATED_FILES_COUNT);
324
325    // Create thread-local cache for paths to avoid allocations
326    thread_local! {
327      static PATH_CACHE: std::cell::RefCell<HashMap<PathBuf, ()>> =
328        std::cell::RefCell::new(HashMap::with_capacity(20));
329    }
330
331    // Process diffs with optimized buffer handling
332    self.print(DiffFormat::Patch, |diff, _hunk, line| {
333      // Get path with potential reuse from cache for better performance
334      let path = PATH_CACHE.with(|cache| {
335        let mut cache = cache.borrow_mut();
336        let new_path = diff.path();
337        if let Some(existing_path) = cache.keys().find(|p| *p == &new_path) {
338          existing_path.clone()
339        } else {
340          cache.insert(new_path.clone(), ());
341          new_path
342        }
343      });
344
345      // Fast path for UTF-8 content - avoid expensive conversions
346      let content = if let Ok(s) = std::str::from_utf8(line.content()) {
347        s.to_string()
348      } else {
349        // Fallback for non-UTF8 content
350        line.content().to_utf8()
351      };
352
353      // Process line by line origin more efficiently
354      match line.origin() {
355        '+' | '-' | ' ' => {
356          // Added, removed, or context lines - append with origin prefix
357          let entry = files
358            .entry(path)
359            .or_insert_with(|| String::with_capacity(DEFAULT_STRING_CAPACITY));
360
361          // Add the origin character for proper diff format
362          match line.origin() {
363            '+' => entry.push('+'),
364            '-' => entry.push('-'),
365            ' ' => entry.push(' '),
366            _ => {}
367          }
368          entry.push_str(&content);
369        }
370        _ => {
371          // Other lines (headers, etc.) - skip them as they're not part of the actual diff content
372          // The git diff headers are already included by the diff format
373          log::trace!("Skipping diff line with origin: {:?}", line.origin());
374        }
375      }
376
377      true
378    })?;
379
380    Ok(files)
381  }
382
383  fn is_empty(&self) -> Result<bool> {
384    let mut has_changes = false;
385
386    self.foreach(
387      &mut |_file, _progress| {
388        has_changes = true;
389        true
390      },
391      None,
392      None,
393      None
394    )?;
395
396    Ok(!has_changes)
397  }
398}
399
400fn process_chunk(
401  chunk: &[(PathBuf, String, usize)], model: &Arc<Model>, total_files: usize, processed_files: &AtomicUsize,
402  remaining_tokens: &AtomicUsize, result_chunks: &Arc<parking_lot::RwLock<Vec<(PathBuf, String)>>>
403) -> Result<()> {
404  profile!("Processing chunk");
405  // Fast path for empty chunks
406  if chunk.is_empty() {
407    return Ok(());
408  }
409
410  // Fast path for no tokens remaining
411  let total_remaining = remaining_tokens.load(Ordering::Acquire);
412  if total_remaining == 0 {
413    return Ok(());
414  }
415
416  // Ultra-fast path for small chunks that will likely fit
417  if chunk.len() <= 3 {
418    let total_token_count = chunk.iter().map(|(_, _, count)| *count).sum::<usize>();
419    // If entire chunk is small enough, process it in one go
420    if total_token_count <= total_remaining {
421      // Try to allocate all tokens at once
422      if remaining_tokens
423        .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
424          if current >= total_token_count {
425            Some(current - total_token_count)
426          } else {
427            None
428          }
429        })
430        .is_ok()
431      {
432        // Update processed files counter once
433        processed_files.fetch_add(chunk.len(), Ordering::AcqRel);
434
435        // Collect all results without truncation
436        let chunk_results: Vec<_> = chunk
437          .iter()
438          .map(|(path, content, _)| (path.clone(), content.clone()))
439          .collect();
440
441        if !chunk_results.is_empty() {
442          result_chunks.write().extend(chunk_results);
443        }
444
445        return Ok(());
446      }
447    }
448  }
449
450  // Fast path for small files that don't need tokenization
451  let mut chunk_results = Vec::with_capacity(chunk.len());
452  let mut local_processed = 0;
453
454  for (path, content, token_count) in chunk {
455    local_processed += 1;
456
457    // Recheck remaining tokens to allow early exit
458    let current_remaining = remaining_tokens.load(Ordering::Acquire);
459    if current_remaining == 0 {
460      break;
461    }
462
463    // For very small files or text, don't bother with complex calculations
464    let token_count = *token_count;
465
466    // If small content is less than threshold, just clone without tokenization
467    if token_count <= 100
468      && token_count <= current_remaining
469      && remaining_tokens
470        .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
471          if current >= token_count {
472            Some(current - token_count)
473          } else {
474            None
475          }
476        })
477        .is_ok()
478    {
479      chunk_results.push((path.clone(), content.clone()));
480      continue;
481    }
482
483    // For larger content, do the normal allocation
484    // Batch update processed files counter - just once at the end
485    let current_file_num = processed_files.load(Ordering::Acquire);
486    let files_remaining = total_files.saturating_sub(current_file_num + local_processed);
487
488    // Calculate tokens per file
489    let max_tokens_per_file = if files_remaining > 0 {
490      current_remaining.saturating_div(files_remaining)
491    } else {
492      current_remaining
493    };
494
495    if max_tokens_per_file == 0 {
496      continue;
497    }
498
499    let allocated_tokens = token_count.min(max_tokens_per_file);
500
501    if remaining_tokens
502      .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
503        if current >= allocated_tokens {
504          Some(current - allocated_tokens)
505        } else {
506          None
507        }
508      })
509      .is_ok()
510    {
511      // Fast path for content that doesn't need truncation
512      if token_count <= allocated_tokens {
513        chunk_results.push((path.clone(), content.clone()));
514      } else {
515        // Use fast character-based truncation for most cases
516        if content.len() < 2000 || allocated_tokens > 500 {
517          // Character-based truncation is much faster than tokenization
518          let char_limit = allocated_tokens * 4;
519          let truncated: String = content.chars().take(char_limit).collect();
520          chunk_results.push((path.clone(), truncated));
521        } else {
522          // Use proper truncation for complex cases
523          let truncated = model.truncate(content, allocated_tokens)?;
524          chunk_results.push((path.clone(), truncated));
525        }
526      }
527    }
528  }
529
530  // Update processed files counter once at the end
531  if local_processed > 0 {
532    processed_files.fetch_add(local_processed, Ordering::AcqRel);
533  }
534
535  // Batch update the result collection
536  if !chunk_results.is_empty() {
537    result_chunks.write().extend(chunk_results);
538  }
539
540  Ok(())
541}
542
543pub trait PatchRepository {
544  fn to_patch(&self, tree: Option<Tree<'_>>, max_token_count: usize, model: Model) -> Result<String>;
545  fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
546  fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>>;
547  fn configure_diff_options(&self, opts: &mut DiffOptions);
548  fn configure_commit_diff_options(&self, opts: &mut DiffOptions);
549}
550
551impl PatchRepository for Repository {
552  fn to_patch(&self, tree: Option<Tree>, max_token_count: usize, model: Model) -> Result<String> {
553    profile!("Repository patch generation");
554    self.to_commit_diff(tree)?.to_patch(max_token_count, model)
555  }
556
557  fn to_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
558    profile!("Git diff generation");
559    let mut opts = DiffOptions::new();
560    self.configure_diff_options(&mut opts);
561
562    match tree {
563      Some(tree) => {
564        // Get the diff between tree and working directory, including staged changes
565        self.diff_tree_to_workdir_with_index(Some(&tree), Some(&mut opts))
566      }
567      None => {
568        // If there's no HEAD yet, compare against an empty tree
569        let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
570        // Get the diff between empty tree and working directory, including staged changes
571        self.diff_tree_to_workdir_with_index(Some(&empty_tree), Some(&mut opts))
572      }
573    }
574    .context("Failed to get diff")
575  }
576
577  fn to_commit_diff(&self, tree: Option<Tree<'_>>) -> Result<git2::Diff<'_>> {
578    profile!("Git commit diff generation");
579    let mut opts = DiffOptions::new();
580    self.configure_commit_diff_options(&mut opts);
581
582    match tree {
583      Some(tree) => {
584        // Get the diff between tree and index (staged changes only)
585        self.diff_tree_to_index(Some(&tree), None, Some(&mut opts))
586      }
587      None => {
588        // If there's no HEAD yet, compare against an empty tree
589        let empty_tree = self.find_tree(self.treebuilder(None)?.write()?)?;
590        // Get the diff between empty tree and index (staged changes only)
591        self.diff_tree_to_index(Some(&empty_tree), None, Some(&mut opts))
592      }
593    }
594    .context("Failed to get diff")
595  }
596
597  fn configure_diff_options(&self, opts: &mut DiffOptions) {
598    opts
599      .ignore_whitespace_change(true)
600      .recurse_untracked_dirs(true)
601      .recurse_ignored_dirs(false)
602      .ignore_whitespace_eol(true)
603      .ignore_blank_lines(true)
604      .include_untracked(true)
605      .ignore_whitespace(true)
606      .indent_heuristic(false)
607      .ignore_submodules(true)
608      .include_ignored(false)
609      .interhunk_lines(0)
610      .context_lines(0)
611      .patience(true)
612      .minimal(true);
613  }
614
615  fn configure_commit_diff_options(&self, opts: &mut DiffOptions) {
616    opts
617      .ignore_whitespace_change(false)
618      .recurse_untracked_dirs(false)
619      .recurse_ignored_dirs(false)
620      .ignore_whitespace_eol(true)
621      .ignore_blank_lines(true)
622      .include_untracked(false)
623      .ignore_whitespace(true)
624      .indent_heuristic(false)
625      .ignore_submodules(true)
626      .include_ignored(false)
627      .interhunk_lines(0)
628      .context_lines(0)
629      .patience(true)
630      .minimal(true);
631  }
632}