use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenEstimate {
pub tokens: usize,
pub characters: usize,
pub words: usize,
pub method: EstimationMethod,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum EstimationMethod {
CharacterRatio,
WordBased,
TikToken,
}
pub fn estimate_tokens_characters(text: &str) -> TokenEstimate {
let characters = text.len();
let words = text.split_whitespace().count();
let ratio = if text.contains("```") {
5.5
} else if words > 0 {
let avg_word_len = characters as f64 / words as f64;
if avg_word_len > 8.0 {
5.0
} else if avg_word_len < 3.0 {
3.5
} else {
4.0
}
} else {
4.0
};
let tokens = (characters as f64 / ratio).ceil() as usize;
TokenEstimate {
tokens,
characters,
words,
method: EstimationMethod::CharacterRatio,
}
}
pub fn estimate_tokens_words(text: &str) -> TokenEstimate {
let words = text.split_whitespace().count();
let characters = text.len();
let tokens = (words as f64 / 1.3).ceil() as usize;
TokenEstimate {
tokens,
characters,
words,
method: EstimationMethod::WordBased,
}
}
pub fn estimate_tokens(text: &str) -> TokenEstimate {
let char_estimate = estimate_tokens_characters(text);
let word_estimate = estimate_tokens_words(text);
let tokens = (char_estimate.tokens + word_estimate.tokens) / 2;
TokenEstimate {
tokens,
characters: char_estimate.characters,
words: char_estimate.words,
method: EstimationMethod::CharacterRatio,
}
}
pub fn estimate_messages<T: MessageContent>(messages: &[T]) -> usize {
messages
.iter()
.map(|m| {
let content = m.content();
let role_overhead = 4;
estimate_tokens(content).tokens + role_overhead
})
.sum()
}
pub fn estimate_conversation(conversation: &str) -> TokenEstimate {
let turns = conversation
.matches("User:")
.count()
.max(conversation.matches("Assistant:").count());
let turn_overhead = turns * 10;
let base = estimate_tokens(conversation);
TokenEstimate {
tokens: base.tokens + turn_overhead,
characters: base.characters,
words: base.words,
method: base.method,
}
}
pub fn estimate_tool_definitions(tools: &[ToolDefinition]) -> usize {
tools
.iter()
.map(|t| {
let name_tokens = estimate_tokens(&t.name).tokens;
let desc_tokens = t
.description
.as_ref()
.map(|d| estimate_tokens(d).tokens)
.unwrap_or(0);
let params_tokens = estimate_tokens(&t.input_schema).tokens;
name_tokens + desc_tokens + params_tokens + 20 })
.sum()
}
pub trait MessageContent {
fn content(&self) -> &str;
}
impl MessageContent for String {
fn content(&self) -> &str {
self.as_str()
}
}
impl MessageContent for &str {
fn content(&self) -> &str {
self
}
}
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl MessageContent for ChatMessage {
fn content(&self) -> &str {
&self.content
}
}
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub description: Option<String>,
pub input_schema: String,
}
pub fn calculate_padding(
input_tokens: usize,
max_tokens: usize,
context_limit: usize,
) -> usize {
let available_for_input = context_limit.saturating_sub(max_tokens);
if input_tokens < available_for_input {
available_for_input.saturating_sub(input_tokens)
} else {
0
}
}
pub fn fits_in_context(
content_tokens: usize,
max_tokens: usize,
context_limit: usize,
) -> bool {
content_tokens + max_tokens <= context_limit
}
pub mod encoding {
pub const CHARS_PER_TOKEN_EN: f64 = 4.0;
pub const CHARS_PER_TOKEN_CODE: f64 = 5.5;
pub const CHARS_PER_TOKEN_CJK: f64 = 2.0;
pub fn is_code(text: &str) -> bool {
let code_indicators = ["```", "function", "class ", "def ", "const ", "let ", "var ", "import "];
code_indicators.iter().any(|i| text.contains(i))
}
pub fn is_cjk(text: &str) -> bool {
text.chars().any(|c| {
(c >= '\u{4E00}' && c <= '\u{9FFF}') || (c >= '\u{3040}' && c <= '\u{309F}') || (c >= '\u{30A0}' && c <= '\u{30FF}') || (c >= '\u{AC00}' && c <= '\u{D7AF}') })
}
pub fn chars_per_token(text: &str) -> f64 {
if is_code(text) {
super::encoding::CHARS_PER_TOKEN_CODE
} else if is_cjk(text) {
super::encoding::CHARS_PER_TOKEN_CJK
} else {
super::encoding::CHARS_PER_TOKEN_EN
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_tokens_characters() {
let result = estimate_tokens_characters("Hello, world!");
assert!(result.tokens >= 3);
assert_eq!(result.characters, 13);
}
#[test]
fn test_estimate_tokens_words() {
let result = estimate_tokens_words("Hello world this is a test");
assert!(result.tokens > 0);
assert_eq!(result.words, 6);
}
#[test]
fn test_estimate_tokens() {
let result = estimate_tokens("The quick brown fox jumps over the lazy dog");
assert!(result.tokens > 0);
}
#[test]
fn test_estimate_conversation() {
let conv = "User: Hello\nAssistant: Hi there!\nUser: How are you?";
let result = estimate_conversation(conv);
assert!(result.tokens > 0);
}
#[test]
fn test_estimate_tool_definitions() {
let tools = vec![
ToolDefinition {
name: "Read".to_string(),
description: Some("Read a file".to_string()),
input_schema: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#.to_string(),
},
];
let tokens = estimate_tool_definitions(&tools);
assert!(tokens > 0);
}
#[test]
fn test_calculate_padding() {
assert_eq!(calculate_padding(1000, 500, 2000), 500);
assert_eq!(calculate_padding(1500, 500, 2000), 0);
}
#[test]
fn test_fits_in_context() {
assert!(fits_in_context(1000, 500, 2000));
assert!(!fits_in_context(1600, 500, 2000));
}
#[test]
fn test_encoding_chars_per_token() {
assert_eq!(encoding::chars_per_token("Hello world"), encoding::CHARS_PER_TOKEN_EN);
assert_eq!(encoding::chars_per_token("function test() {}"), encoding::CHARS_PER_TOKEN_CODE);
}
#[test]
fn test_is_code() {
assert!(encoding::is_code("function foo() { return 1; }"));
assert!(!encoding::is_code("Hello world"));
}
#[test]
fn test_is_cjk() {
assert!(encoding::is_cjk("你好世界"));
assert!(!encoding::is_cjk("Hello world"));
}
#[test]
fn test_message_content_trait() {
let msg = ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
};
assert_eq!(msg.content(), "Hello");
}
}