use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AiModel {
Gpt4,
Gpt4Turbo,
Gpt35,
Claude,
ClaudeSonnet,
Gemini,
GeminiFlash,
Generic,
}
impl AiModel {
pub fn parse(s: &str) -> Result<Self, String> {
match s.to_lowercase().as_str() {
"gpt4" | "gpt-4" => Ok(Self::Gpt4),
"gpt4-turbo" | "gpt-4-turbo" | "gpt4turbo" => Ok(Self::Gpt4Turbo),
"gpt35" | "gpt-3.5" | "gpt-3.5-turbo" => Ok(Self::Gpt35),
"claude" | "claude-4" | "claude-3" => Ok(Self::Claude),
"claude-sonnet" | "claude-4-sonnet" | "claude-3.5" | "claude-3.5-sonnet"
| "claude35" | "claude35sonnet" => Ok(Self::ClaudeSonnet),
"gemini" | "gemini-1.5" | "gemini-1.5-pro" | "gemini-pro" => Ok(Self::Gemini),
"gemini-flash" | "gemini-2.0-flash" | "gemini-2" => Ok(Self::GeminiFlash),
"generic" => Ok(Self::Generic),
_ => Err(format!("Unknown AI model: {s}")),
}
}
pub fn all() -> Vec<Self> {
vec![
Self::Gpt4,
Self::Gpt4Turbo,
Self::Gpt35,
Self::Claude,
Self::ClaudeSonnet,
Self::Gemini,
Self::GeminiFlash,
Self::Generic,
]
}
pub const fn as_str(&self) -> &'static str {
match self {
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo",
Self::Gpt35 => "gpt-3.5",
Self::Claude => "claude",
Self::ClaudeSonnet => "claude-sonnet",
Self::Gemini => "gemini-1.5-pro",
Self::GeminiFlash => "gemini-2.0-flash",
Self::Generic => "generic",
}
}
pub const fn context_window(&self) -> usize {
match self {
Self::Gpt4 => 128_000, Self::Gpt4Turbo => 128_000, Self::Gpt35 => 16_384, Self::Claude => 200_000, Self::ClaudeSonnet => 200_000, Self::Gemini => 1_000_000, Self::GeminiFlash => 1_000_000, Self::Generic => 8_192, }
}
fn tokens_per_word(self) -> f64 {
match self {
Self::Gpt4 | Self::Gpt4Turbo | Self::Gpt35 => 1.3, Self::Claude | Self::ClaudeSonnet => 1.2, Self::Gemini | Self::GeminiFlash => 1.3, Self::Generic => 1.5, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenCount {
pub tokens: usize,
pub words: usize,
pub characters: usize,
pub model: AiModel,
pub fits_in_context: bool,
pub context_usage_percent: f64,
}
pub struct TokenCounter {
model: AiModel,
}
impl TokenCounter {
pub const fn new(model: AiModel) -> Self {
Self { model }
}
pub const fn model(&self) -> AiModel {
self.model
}
pub fn count_tokens(&self, text: &str) -> TokenCount {
let characters = text.chars().count();
let (words, tokens) = if characters > 100_000 {
self.estimate_large_file_tokens(text, characters)
} else {
let words = self.count_words(text);
let tokens = self.estimate_tokens(text, words);
(words, tokens)
};
let context_window = self.model.context_window();
TokenCount {
tokens,
words,
characters,
model: self.model,
fits_in_context: tokens <= context_window,
context_usage_percent: (tokens as f64 / context_window as f64) * 100.0,
}
}
fn estimate_large_file_tokens(&self, text: &str, total_chars: usize) -> (usize, usize) {
let sample_size = 10_000; let num_samples = 5;
let mut total_words = 0;
let mut total_tokens = 0;
let mut samples_taken = 0;
let chunk_size = total_chars / num_samples;
for i in 0..num_samples {
let start = i * chunk_size;
let end = if i == num_samples - 1 {
total_chars
} else {
(start + sample_size).min(total_chars)
};
if start >= total_chars {
break;
}
let sample = text
.chars()
.skip(start)
.take(end - start)
.collect::<String>();
if !sample.is_empty() {
let words = self.count_words(&sample);
let tokens = self.estimate_tokens(&sample, words);
total_words += words;
total_tokens += tokens;
samples_taken += 1;
}
}
if samples_taken > 0 {
let sample_chars: usize = samples_taken * sample_size;
let scale_factor = total_chars as f64 / sample_chars as f64;
let estimated_words = (total_words as f64 * scale_factor) as usize;
let estimated_tokens = (total_tokens as f64 * scale_factor) as usize;
(estimated_words, estimated_tokens)
} else {
let words = self.count_words(text);
let tokens = self.estimate_tokens(text, words);
(words, tokens)
}
}
pub fn fits_with_prompt(&self, content: &str, prompt_tokens: usize) -> bool {
let content_tokens = self.estimate_tokens(content, self.count_words(content));
let total_tokens = content_tokens + prompt_tokens;
total_tokens <= self.model.context_window()
}
pub fn max_content_length(&self, prompt_tokens: usize) -> usize {
let available_tokens = self.model.context_window().saturating_sub(prompt_tokens);
let chars_per_token = 1.0 / self.model.tokens_per_word() * 4.5; (available_tokens as f64 * chars_per_token) as usize
}
pub fn truncate_to_fit(&self, content: &str, prompt_tokens: usize) -> (String, bool) {
let content_tokens = self.estimate_tokens(content, self.count_words(content));
let available_tokens = self.model.context_window().saturating_sub(prompt_tokens);
if content_tokens <= available_tokens {
return (content.to_string(), false);
}
let ratio = available_tokens as f64 / content_tokens as f64;
let target_chars = (content.chars().count() as f64 * ratio) as usize;
let truncated = self.smart_truncate(content, target_chars);
(truncated, true)
}
fn smart_truncate(&self, content: &str, target_chars: usize) -> String {
if content.chars().count() <= target_chars {
return content.to_string();
}
let lines: Vec<&str> = content.lines().collect();
let mut result = String::new();
let mut char_count = 0;
for line in lines {
let line_chars = line.chars().count() + 1; if char_count + line_chars > target_chars {
let remaining = target_chars.saturating_sub(char_count);
if remaining > 20 {
let truncated_line = self.truncate_at_word_boundary(line, remaining);
if !truncated_line.is_empty() {
result.push_str(&truncated_line);
result.push_str("...");
}
}
break;
}
result.push_str(line);
result.push('\n');
char_count += line_chars;
}
result.trim_end().to_string()
}
fn truncate_at_word_boundary(&self, text: &str, max_chars: usize) -> String {
if text.chars().count() <= max_chars {
return text.to_string();
}
let chars: Vec<char> = text.chars().collect();
let mut end = max_chars.min(chars.len());
while end > 0 && end < chars.len() {
let ch = chars[end - 1];
if ch.is_whitespace() || ch == ',' || ch == ';' || ch == '.' {
while end > 0 && chars[end - 1].is_whitespace() {
end -= 1;
}
break;
}
end -= 1;
}
if end == 0 {
end = max_chars.min(chars.len());
}
chars[..end].iter().collect()
}
fn count_words(&self, text: &str) -> usize {
text.split_whitespace()
.flat_map(|word| {
word.split(&['.', ',', ';', ':', '!', '?', '(', ')', '[', ']', '{', '}'])
})
.filter(|word| !word.is_empty())
.count()
}
fn estimate_tokens(&self, text: &str, words: usize) -> usize {
match self.model {
AiModel::Gpt4 | AiModel::Gpt4Turbo | AiModel::Gpt35 => {
self.estimate_gpt_tokens(text, words)
}
AiModel::Claude | AiModel::ClaudeSonnet => self.estimate_claude_tokens(text, words),
AiModel::Gemini | AiModel::GeminiFlash => self.estimate_gemini_tokens(text, words),
AiModel::Generic => self.estimate_generic_tokens(words),
}
}
fn estimate_gpt_tokens(&self, text: &str, words: usize) -> usize {
let base_tokens = (words as f64 * self.model.tokens_per_word()) as usize;
let code_penalty = if self.looks_like_code(text) { 1.2 } else { 1.0 };
let newlines = text.matches('\n').count();
let special_tokens = newlines / 2;
((base_tokens as f64 * code_penalty) as usize) + special_tokens
}
fn estimate_claude_tokens(&self, text: &str, words: usize) -> usize {
let base_tokens = (words as f64 * self.model.tokens_per_word()) as usize;
let code_bonus = if self.looks_like_code(text) { 0.9 } else { 1.0 };
(base_tokens as f64 * code_bonus) as usize
}
fn estimate_gemini_tokens(&self, text: &str, words: usize) -> usize {
let base_tokens = (words as f64 * self.model.tokens_per_word()) as usize;
let code_penalty = if self.looks_like_code(text) {
1.15
} else {
1.0
};
let newlines = text.matches('\n').count();
let special_tokens = newlines / 3;
((base_tokens as f64 * code_penalty) as usize) + special_tokens
}
fn estimate_generic_tokens(&self, words: usize) -> usize {
(words as f64 * self.model.tokens_per_word()) as usize
}
fn looks_like_code(&self, text: &str) -> bool {
let code_indicators = [
"fn ",
"function ",
"class ",
"import ",
"def ",
"#include",
"pub ",
"private ",
"public ",
"const ",
"let ",
"var ",
"struct ",
"impl ",
"trait ",
"interface ",
"extends ",
"{",
"}",
"(",
")",
"[",
"]",
";",
"=>",
"->",
"::",
"::",
];
let total_chars = text.len();
if total_chars == 0 {
return false;
}
let code_char_count: usize = code_indicators
.iter()
.map(|indicator| text.matches(indicator).count() * indicator.len())
.sum();
(code_char_count as f64 / total_chars as f64) > 0.1
}
}
pub fn get_token_counter_for_profile(profile: &str) -> TokenCounter {
let model = match profile.to_lowercase().as_str() {
"claude" => AiModel::Claude,
"copilot" | "chatgpt" => AiModel::Gpt4,
"gemini" => AiModel::Gemini,
"assistant" => AiModel::Generic,
_ => AiModel::Generic,
};
TokenCounter::new(model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ai_model_parsing() {
assert_eq!(AiModel::parse("gpt-4").unwrap(), AiModel::Gpt4);
assert_eq!(AiModel::parse("claude").unwrap(), AiModel::Claude);
assert!(AiModel::parse("unknown").is_err());
}
#[test]
fn test_context_windows() {
assert_eq!(AiModel::Gpt4.context_window(), 128_000);
assert_eq!(AiModel::Claude.context_window(), 200_000);
}
#[test]
fn test_token_counting() {
let counter = TokenCounter::new(AiModel::Gpt4);
let text = "Hello world this is a test";
let count = counter.count_tokens(text);
assert!(count.tokens > 0);
assert!(count.words > 0);
assert!(count.characters > 0);
assert_eq!(count.model, AiModel::Gpt4);
}
#[test]
fn test_code_detection() {
let counter = TokenCounter::new(AiModel::Gpt4);
let code_text = "fn main() { println!('Hello, world!'); }";
assert!(counter.looks_like_code(code_text));
let natural_text = "This is just regular English text without any code.";
assert!(!counter.looks_like_code(natural_text));
}
#[test]
fn test_fits_with_prompt() {
let counter = TokenCounter::new(AiModel::Generic);
let small_text = "Hello world";
assert!(counter.fits_with_prompt(small_text, 100));
assert!(!counter.fits_with_prompt(&"word ".repeat(10000), 1000));
}
#[test]
fn test_max_content_length() {
let counter = TokenCounter::new(AiModel::Gpt35);
let max_length = counter.max_content_length(1000);
assert!(max_length > 0);
assert!(max_length < counter.model.context_window() * 10); }
#[test]
fn test_truncate_to_fit() {
let counter = TokenCounter::new(AiModel::Generic);
let small_text = "Hello world";
let (result, truncated) = counter.truncate_to_fit(small_text, 100);
assert_eq!(result, small_text);
assert!(!truncated);
let large_text = "word ".repeat(10000);
let (result, truncated) = counter.truncate_to_fit(&large_text, 1000);
assert!(truncated);
assert!(result.len() < large_text.len());
}
#[test]
fn test_smart_truncate() {
let counter = TokenCounter::new(AiModel::Gpt4);
let text = "Line one\nLine two\nLine three\nLine four";
let result = counter.smart_truncate(text, 20);
assert!(result.lines().all(|line| text.contains(line)));
}
#[test]
fn test_truncate_at_word_boundary() {
let counter = TokenCounter::new(AiModel::Gpt4);
let text = "This is a long sentence with many words";
let result = counter.truncate_at_word_boundary(text, 15);
assert!(result.len() <= 15);
assert!(!result.ends_with(' '));
if result.len() < text.len() {
let last_char = result.chars().last().unwrap_or(' ');
assert!(last_char.is_alphabetic() || last_char.is_numeric());
}
}
#[test]
fn test_large_file_optimization() {
let counter = TokenCounter::new(AiModel::Gpt4);
let base_text = "This is a test sentence with several words. ";
let large_text = base_text.repeat(3000);
assert!(large_text.len() > 100_000);
let count = counter.count_tokens(&large_text);
assert!(count.tokens > 0);
assert!(count.words > 0);
assert!(count.characters > 100_000);
assert_eq!(count.model, AiModel::Gpt4);
let expected_words = large_text.split_whitespace().count();
let word_ratio = count.words as f64 / expected_words as f64;
assert!(word_ratio > 0.5 && word_ratio < 2.0,
"Word count estimation should be within reasonable range. Expected: {}, Got: {}, Ratio: {:.2}",
expected_words, count.words, word_ratio);
}
#[test]
fn test_new_ai_models() {
let gpt4_turbo = TokenCounter::new(AiModel::Gpt4Turbo);
assert_eq!(gpt4_turbo.model(), AiModel::Gpt4Turbo);
assert_eq!(gpt4_turbo.model().context_window(), 128_000);
assert_eq!(gpt4_turbo.model().as_str(), "gpt-4-turbo");
let claude_sonnet = TokenCounter::new(AiModel::ClaudeSonnet);
assert_eq!(claude_sonnet.model(), AiModel::ClaudeSonnet);
assert_eq!(claude_sonnet.model().context_window(), 200_000);
assert_eq!(claude_sonnet.model().as_str(), "claude-sonnet");
let gemini = TokenCounter::new(AiModel::Gemini);
assert_eq!(gemini.model(), AiModel::Gemini);
assert_eq!(gemini.model().context_window(), 1_000_000);
assert_eq!(gemini.model().as_str(), "gemini-1.5-pro");
let gemini_flash = TokenCounter::new(AiModel::GeminiFlash);
assert_eq!(gemini_flash.model(), AiModel::GeminiFlash);
assert_eq!(gemini_flash.model().context_window(), 1_000_000);
assert_eq!(gemini_flash.model().as_str(), "gemini-2.0-flash");
assert_eq!(AiModel::parse("gpt-4-turbo").unwrap(), AiModel::Gpt4Turbo);
assert_eq!(
AiModel::parse("claude-3.5-sonnet").unwrap(),
AiModel::ClaudeSonnet
);
assert_eq!(
AiModel::parse("claude-sonnet").unwrap(),
AiModel::ClaudeSonnet
);
assert_eq!(AiModel::parse("gemini").unwrap(), AiModel::Gemini);
assert_eq!(AiModel::parse("gemini-1.5-pro").unwrap(), AiModel::Gemini);
assert_eq!(
AiModel::parse("gemini-flash").unwrap(),
AiModel::GeminiFlash
);
assert_eq!(
AiModel::parse("gemini-2.0-flash").unwrap(),
AiModel::GeminiFlash
);
}
#[test]
fn test_new_models_token_estimation() {
let test_text = "fn main() {\n println!(\"Hello, world!\");\n}";
let gpt4_turbo = TokenCounter::new(AiModel::Gpt4Turbo);
let gpt4_turbo_count = gpt4_turbo.count_tokens(test_text);
assert!(gpt4_turbo_count.tokens > 0);
assert_eq!(gpt4_turbo_count.model, AiModel::Gpt4Turbo);
let claude_sonnet = TokenCounter::new(AiModel::ClaudeSonnet);
let claude_sonnet_count = claude_sonnet.count_tokens(test_text);
assert!(claude_sonnet_count.tokens > 0);
assert_eq!(claude_sonnet_count.model, AiModel::ClaudeSonnet);
let gemini = TokenCounter::new(AiModel::Gemini);
let gemini_count = gemini.count_tokens(test_text);
assert!(gemini_count.tokens > 0);
assert_eq!(gemini_count.model, AiModel::Gemini);
let gemini_flash = TokenCounter::new(AiModel::GeminiFlash);
let gemini_flash_count = gemini_flash.count_tokens(test_text);
assert!(gemini_flash_count.tokens > 0);
assert_eq!(gemini_flash_count.model, AiModel::GeminiFlash);
let gpt4 = TokenCounter::new(AiModel::Gpt4);
let gpt4_count = gpt4.count_tokens(test_text);
let claude = TokenCounter::new(AiModel::Claude);
let claude_count = claude.count_tokens(test_text);
assert_eq!(gpt4_turbo_count.tokens, gpt4_count.tokens);
assert_eq!(claude_sonnet_count.tokens, claude_count.tokens);
assert_eq!(gemini_flash_count.tokens, gemini_count.tokens);
}
#[test]
fn test_new_models_context_fitting() {
let test_text = "word ".repeat(1000);
let gpt4_turbo = TokenCounter::new(AiModel::Gpt4Turbo);
assert!(gpt4_turbo.fits_with_prompt(&test_text, 500));
let max_length = gpt4_turbo.max_content_length(1000);
assert!(max_length > 0);
assert!(max_length < 128_000 * 10);
let claude_sonnet = TokenCounter::new(AiModel::ClaudeSonnet);
assert!(claude_sonnet.fits_with_prompt(&test_text, 500));
let max_length_claude = claude_sonnet.max_content_length(1000);
assert!(max_length_claude > 0);
assert!(max_length_claude < 200_000 * 10);
assert!(max_length_claude > max_length);
let gemini = TokenCounter::new(AiModel::Gemini);
assert!(gemini.fits_with_prompt(&test_text, 500));
let max_length_gemini = gemini.max_content_length(1000);
assert!(max_length_gemini > 0);
assert!(max_length_gemini < 1_000_000 * 10);
assert!(max_length_gemini > max_length_claude);
}
}