use crate::context::CommitContext;
use tiktoken_rs::cl100k_base;
pub struct TokenOptimizer {
encoder: tiktoken_rs::CoreBPE,
max_tokens: usize,
}
impl TokenOptimizer {
pub fn new(max_tokens: usize) -> Self {
TokenOptimizer {
encoder: cl100k_base().unwrap(),
max_tokens,
}
}
pub fn optimize_context(&self, context: &mut CommitContext) {
let mut remaining_tokens = self.max_tokens;
for file in &mut context.staged_files {
let diff_tokens = self.count_tokens(&file.diff);
if diff_tokens > remaining_tokens {
file.diff = self.truncate_string(&file.diff, remaining_tokens);
remaining_tokens = 0;
} else {
remaining_tokens = remaining_tokens.saturating_sub(diff_tokens);
}
if remaining_tokens == 0 {
self.clear_commits_and_contents(context);
return;
}
}
for commit in &mut context.recent_commits {
let commit_tokens = self.count_tokens(&commit.message);
if commit_tokens > remaining_tokens {
commit.message = self.truncate_string(&commit.message, remaining_tokens);
remaining_tokens = 0;
} else {
remaining_tokens = remaining_tokens.saturating_sub(commit_tokens);
}
if remaining_tokens == 0 {
self.clear_contents(context);
return;
}
}
for file in &mut context.staged_files {
if let Some(content) = &mut file.content {
let content_tokens = self.count_tokens(content);
if content_tokens > remaining_tokens {
*content = self.truncate_string(content, remaining_tokens);
remaining_tokens = 0;
} else {
remaining_tokens = remaining_tokens.saturating_sub(content_tokens);
}
if remaining_tokens == 0 {
return; }
}
}
}
pub fn truncate_string(&self, s: &str, max_tokens: usize) -> String {
let tokens = self.encoder.encode_ordinary(s);
if tokens.len() <= max_tokens {
return s.to_string();
}
let truncation_limit = max_tokens.saturating_sub(1); let mut truncated_tokens = tokens[..truncation_limit].to_vec();
truncated_tokens.push(self.encoder.encode_ordinary("…")[0]);
let truncated_string = self.encoder.decode(truncated_tokens).unwrap();
truncated_string
}
fn clear_commits_and_contents(&self, context: &mut CommitContext) {
self.clear_commits(context);
self.clear_contents(context);
}
fn clear_commits(&self, context: &mut CommitContext) {
for commit in &mut context.recent_commits {
commit.message.clear();
}
}
fn clear_contents(&self, context: &mut CommitContext) {
for file in &mut context.staged_files {
file.content = None;
}
}
pub fn count_tokens(&self, s: &str) -> usize {
let tokens = self.encoder.encode_ordinary(s);
tokens.len()
}
}