gitai/core/
token_optimizer.rs1use crate::{core::context::CommitContext, debug};
2use tiktoken_rs::cl100k_base;
3
4pub struct TokenOptimizer {
5 encoder: tiktoken_rs::CoreBPE,
6 max_tokens: usize,
7}
8
9#[derive(Debug)]
10pub enum TokenError {
11 EncoderInit(String),
12 EncodingFailed(String),
13 DecodingFailed(String),
14}
15
16impl std::fmt::Display for TokenError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 TokenError::EncoderInit(e) => write!(f, "Failed to initialize encoder: {e}"),
20 TokenError::EncodingFailed(e) => write!(f, "Encoding failed: {e}"),
21 TokenError::DecodingFailed(e) => write!(f, "Decoding failed: {e}"),
22 }
23 }
24}
25
26impl std::error::Error for TokenError {}
27
28impl TokenOptimizer {
29 pub fn new(max_tokens: usize) -> Result<Self, TokenError> {
30 let encoder = cl100k_base().map_err(|e| TokenError::EncoderInit(e.to_string()))?;
31
32 Ok(Self {
33 encoder,
34 max_tokens,
35 })
36 }
37
38 pub fn optimize_context(&self, context: &mut CommitContext) -> Result<(), TokenError> {
39 let mut remaining_tokens = self.max_tokens;
40
41 remaining_tokens = self.optimize_diffs(context, remaining_tokens)?;
43 if remaining_tokens == 0 {
44 debug!("Token budget exhausted after diffs, clearing commits and contents");
45 Self::clear_commits_and_contents(context);
46 return Ok(());
47 }
48
49 remaining_tokens = self.optimize_commits(context, remaining_tokens)?;
51 if remaining_tokens == 0 {
52 debug!("Token budget exhausted after commits, clearing contents");
53 Self::clear_contents(context);
54 return Ok(());
55 }
56
57 self.optimize_contents(context, remaining_tokens)?;
59
60 debug!("Final token count: {}", self.max_tokens - remaining_tokens);
61
62 Ok(())
63 }
64
65 fn optimize_diffs(
67 &self,
68 context: &mut CommitContext,
69 mut remaining: usize,
70 ) -> Result<usize, TokenError> {
71 for file in &mut context.staged_files {
72 let diff_tokens = self.count_tokens(&file.diff);
73
74 if diff_tokens > remaining {
75 debug!(
76 "Truncating diff for {} from {} to {} tokens",
77 file.path, diff_tokens, remaining
78 );
79 file.diff = self.truncate_string(&file.diff, remaining)?;
80 return Ok(0);
81 }
82
83 remaining = remaining.saturating_sub(diff_tokens);
84 }
85 Ok(remaining)
86 }
87
88 fn optimize_commits(
90 &self,
91 context: &mut CommitContext,
92 mut remaining: usize,
93 ) -> Result<usize, TokenError> {
94 for commit in &mut context.recent_commits {
95 let commit_tokens = self.count_tokens(&commit.message);
96
97 if commit_tokens > remaining {
98 debug!(
99 "Truncating commit message from {} to {} tokens",
100 commit_tokens, remaining
101 );
102 commit.message = self.truncate_string(&commit.message, remaining)?;
103 return Ok(0);
104 }
105
106 remaining = remaining.saturating_sub(commit_tokens);
107 }
108 Ok(remaining)
109 }
110
111 fn optimize_contents(
113 &self,
114 context: &mut CommitContext,
115 mut remaining: usize,
116 ) -> Result<usize, TokenError> {
117 for file in &mut context.staged_files {
118 if let Some(content) = &mut file.content {
119 let content_tokens = self.count_tokens(content);
120
121 if content_tokens > remaining {
122 debug!(
123 "Truncating file content for {} from {} to {} tokens",
124 file.path, content_tokens, remaining
125 );
126 *content = self.truncate_string(content, remaining)?;
127 return Ok(0);
128 }
129
130 remaining = remaining.saturating_sub(content_tokens);
131 }
132 }
133 Ok(remaining)
134 }
135
136 pub fn truncate_string(&self, s: &str, max_tokens: usize) -> Result<String, TokenError> {
137 let tokens = self.encoder.encode_ordinary(s);
138
139 if tokens.len() <= max_tokens {
140 return Ok(s.to_string());
141 }
142
143 if max_tokens == 0 {
144 return Ok(String::from("…"));
145 }
146
147 let truncation_limit = max_tokens.saturating_sub(1);
149 let ellipsis_token = self
150 .encoder
151 .encode_ordinary("…")
152 .first()
153 .copied()
154 .ok_or_else(|| TokenError::EncodingFailed("Failed to encode ellipsis".to_string()))?;
155
156 let mut truncated_tokens = Vec::with_capacity(truncation_limit + 1);
157 truncated_tokens.extend_from_slice(&tokens[..truncation_limit]);
158 truncated_tokens.push(ellipsis_token);
159
160 self.encoder
161 .decode(truncated_tokens)
162 .map_err(|e| TokenError::DecodingFailed(e.to_string()))
163 }
164
165 #[inline]
166 fn clear_commits_and_contents(context: &mut CommitContext) {
167 Self::clear_commits(context);
168 Self::clear_contents(context);
169 }
170
171 #[inline]
172 fn clear_commits(context: &mut CommitContext) {
173 context
174 .recent_commits
175 .iter_mut()
176 .for_each(|c| c.message.clear());
177 }
178
179 #[inline]
180 fn clear_contents(context: &mut CommitContext) {
181 context
182 .staged_files
183 .iter_mut()
184 .for_each(|f| f.content = None);
185 }
186
187 #[inline]
188 pub fn count_tokens(&self, s: &str) -> usize {
189 self.encoder.encode_ordinary(s).len()
190 }
191}