use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenCounter {
avg_chars_per_token: f32,
}
impl Default for TokenCounter {
fn default() -> Self {
Self {
avg_chars_per_token: 4.0, }
}
}
impl TokenCounter {
pub fn new() -> Self {
Self::default()
}
pub fn with_avg_chars_per_token(mut self, avg: f32) -> Self {
self.avg_chars_per_token = avg;
self
}
pub fn estimate(&self, text: &str) -> u32 {
let char_count = text.chars().count() as f32;
(char_count / self.avg_chars_per_token) as u32
}
pub fn estimate_message(&self, content: &str, role: &str) -> u32 {
let base_tokens = self.estimate(content);
let role_overhead = match role {
"system" => 4,
"user" => 3,
"assistant" => 4,
_ => 3,
};
base_tokens + role_overhead
}
pub fn count_messages(&self, messages: &[(String, String)]) -> u32 {
messages
.iter()
.map(|(role, content)| self.estimate_message(content, role))
.sum()
}
pub fn estimate_file(&self, content: &str, file_type: &str) -> u32 {
let multiplier = match file_type {
"rust" | "python" | "javascript" | "typescript" => 1.2,
"markdown" | "text" => 1.0,
"json" | "yaml" => 1.5,
_ => 1.0,
};
(self.estimate(content) as f32 * multiplier) as u32
}
}
pub struct ContextWindowManager {
context_window: u32,
current_tokens: u32,
counter: TokenCounter,
}
impl ContextWindowManager {
pub fn new(context_window: u32) -> Self {
Self {
context_window,
current_tokens: 0,
counter: TokenCounter::new(),
}
}
pub fn context_window(&self) -> u32 {
self.context_window
}
pub fn current_tokens(&self) -> u32 {
self.current_tokens
}
pub fn remaining_tokens(&self) -> u32 {
self.context_window.saturating_sub(self.current_tokens)
}
pub fn usage_percent(&self) -> f32 {
(self.current_tokens as f32 / self.context_window as f32) * 100.0
}
pub fn add_message(&mut self, content: &str, role: &str) {
let tokens = self.counter.estimate_message(content, role);
self.current_tokens = self.current_tokens.saturating_add(tokens);
}
pub fn remove_message(&mut self, content: &str, role: &str) {
let tokens = self.counter.estimate_message(content, role);
self.current_tokens = self.current_tokens.saturating_sub(tokens);
}
pub fn is_near_limit(&self, buffer: u32) -> bool {
self.remaining_tokens() <= buffer
}
pub fn reset(&mut self) {
self.current_tokens = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate() {
let counter = TokenCounter::new();
let english = "Hello, world! This is a test message.";
let estimated = counter.estimate(english);
assert!(estimated > 0);
let chinese = "你好,世界!这是一个测试消息。";
let estimated_cn = counter.estimate(chinese);
assert!(estimated_cn > 0);
}
#[test]
fn test_context_window_manager() {
let mut manager = ContextWindowManager::new(200_000);
assert_eq!(manager.context_window(), 200_000);
assert_eq!(manager.current_tokens(), 0);
assert_eq!(manager.remaining_tokens(), 200_000);
manager.add_message("Hello, world!", "user");
assert!(manager.current_tokens() > 0);
manager.reset();
assert_eq!(manager.current_tokens(), 0);
}
#[test]
fn test_is_near_limit() {
let mut manager = ContextWindowManager::new(200_000);
manager.current_tokens = 190_000;
assert!(manager.is_near_limit(20_000));
assert!(!manager.is_near_limit(5_000));
}
}