gitai/core/
token_optimizer.rs

1use crate::{core::context::CommitContext, debug};
2use tiktoken_rs::cl100k_base;
3
4pub struct TokenOptimizer {
5    encoder: tiktoken_rs::CoreBPE,
6    max_tokens: usize,
7}
8
9#[derive(Debug)]
10pub enum TokenError {
11    EncoderInit(String),
12    EncodingFailed(String),
13    DecodingFailed(String),
14}
15
16impl std::fmt::Display for TokenError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            TokenError::EncoderInit(e) => write!(f, "Failed to initialize encoder: {e}"),
20            TokenError::EncodingFailed(e) => write!(f, "Encoding failed: {e}"),
21            TokenError::DecodingFailed(e) => write!(f, "Decoding failed: {e}"),
22        }
23    }
24}
25
26impl std::error::Error for TokenError {}
27
28impl TokenOptimizer {
29    pub fn new(max_tokens: usize) -> Result<Self, TokenError> {
30        let encoder = cl100k_base().map_err(|e| TokenError::EncoderInit(e.to_string()))?;
31
32        Ok(Self {
33            encoder,
34            max_tokens,
35        })
36    }
37
38    pub fn optimize_context(&self, context: &mut CommitContext) -> Result<(), TokenError> {
39        let mut remaining_tokens = self.max_tokens;
40
41        // Step 1: Process diffs (highest priority)
42        remaining_tokens = self.optimize_diffs(context, remaining_tokens)?;
43        if remaining_tokens == 0 {
44            debug!("Token budget exhausted after diffs, clearing commits and contents");
45            Self::clear_commits_and_contents(context);
46            return Ok(());
47        }
48
49        // Step 2: Process commits (medium priority)
50        remaining_tokens = self.optimize_commits(context, remaining_tokens)?;
51        if remaining_tokens == 0 {
52            debug!("Token budget exhausted after commits, clearing contents");
53            Self::clear_contents(context);
54            return Ok(());
55        }
56
57        // Step 3: Process file contents (lowest priority)
58        self.optimize_contents(context, remaining_tokens)?;
59
60        debug!("Final token count: {}", self.max_tokens - remaining_tokens);
61
62        Ok(())
63    }
64
65    // Optimize diffs and return remaining tokens
66    fn optimize_diffs(
67        &self,
68        context: &mut CommitContext,
69        mut remaining: usize,
70    ) -> Result<usize, TokenError> {
71        for file in &mut context.staged_files {
72            let diff_tokens = self.count_tokens(&file.diff);
73
74            if diff_tokens > remaining {
75                debug!(
76                    "Truncating diff for {} from {} to {} tokens",
77                    file.path, diff_tokens, remaining
78                );
79                file.diff = self.truncate_string(&file.diff, remaining)?;
80                return Ok(0);
81            }
82
83            remaining = remaining.saturating_sub(diff_tokens);
84        }
85        Ok(remaining)
86    }
87
88    // Optimize commits and return remaining tokens
89    fn optimize_commits(
90        &self,
91        context: &mut CommitContext,
92        mut remaining: usize,
93    ) -> Result<usize, TokenError> {
94        for commit in &mut context.recent_commits {
95            let commit_tokens = self.count_tokens(&commit.message);
96
97            if commit_tokens > remaining {
98                debug!(
99                    "Truncating commit message from {} to {} tokens",
100                    commit_tokens, remaining
101                );
102                commit.message = self.truncate_string(&commit.message, remaining)?;
103                return Ok(0);
104            }
105
106            remaining = remaining.saturating_sub(commit_tokens);
107        }
108        Ok(remaining)
109    }
110
111    // Optimize file contents and return remaining tokens
112    fn optimize_contents(
113        &self,
114        context: &mut CommitContext,
115        mut remaining: usize,
116    ) -> Result<usize, TokenError> {
117        for file in &mut context.staged_files {
118            if let Some(content) = &mut file.content {
119                let content_tokens = self.count_tokens(content);
120
121                if content_tokens > remaining {
122                    debug!(
123                        "Truncating file content for {} from {} to {} tokens",
124                        file.path, content_tokens, remaining
125                    );
126                    *content = self.truncate_string(content, remaining)?;
127                    return Ok(0);
128                }
129
130                remaining = remaining.saturating_sub(content_tokens);
131            }
132        }
133        Ok(remaining)
134    }
135
136    pub fn truncate_string(&self, s: &str, max_tokens: usize) -> Result<String, TokenError> {
137        let tokens = self.encoder.encode_ordinary(s);
138
139        if tokens.len() <= max_tokens {
140            return Ok(s.to_string());
141        }
142
143        if max_tokens == 0 {
144            return Ok(String::from("…"));
145        }
146
147        // Reserve space for ellipsis
148        let truncation_limit = max_tokens.saturating_sub(1);
149        let ellipsis_token = self
150            .encoder
151            .encode_ordinary("…")
152            .first()
153            .copied()
154            .ok_or_else(|| TokenError::EncodingFailed("Failed to encode ellipsis".to_string()))?;
155
156        let mut truncated_tokens = Vec::with_capacity(truncation_limit + 1);
157        truncated_tokens.extend_from_slice(&tokens[..truncation_limit]);
158        truncated_tokens.push(ellipsis_token);
159
160        self.encoder
161            .decode(truncated_tokens)
162            .map_err(|e| TokenError::DecodingFailed(e.to_string()))
163    }
164
165    #[inline]
166    fn clear_commits_and_contents(context: &mut CommitContext) {
167        Self::clear_commits(context);
168        Self::clear_contents(context);
169    }
170
171    #[inline]
172    fn clear_commits(context: &mut CommitContext) {
173        context
174            .recent_commits
175            .iter_mut()
176            .for_each(|c| c.message.clear());
177    }
178
179    #[inline]
180    fn clear_contents(context: &mut CommitContext) {
181        context
182            .staged_files
183            .iter_mut()
184            .for_each(|f| f.content = None);
185    }
186
187    #[inline]
188    pub fn count_tokens(&self, s: &str) -> usize {
189        self.encoder.encode_ordinary(s).len()
190    }
191}