ai_tokenopt 0.5.9

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! Token estimation using a character-based heuristic
//!
//! Estimates token counts without requiring a model-specific tokenizer.
//! Uses the widely-accepted `chars ÷ 4` heuristic (~85% accurate for most
//! BPE tokenizers on English text) with unicode awareness for non-ASCII content.

use crate::types::{ChatMessage, Conversation, MessageRole, ToolDefinition};

use crate::estimator_language::detect_language_class;

/// Heuristic token estimator using character-based ratios.
///
/// This is intentionally conservative — overestimation is safer than
/// underestimation when managing context window budgets.
#[derive(Debug, Clone, Copy)]
pub struct TokenEstimator;

/// Token estimate breakdown for a conversation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConversationTokenEstimate {
    /// Estimated tokens in the system prompt
    pub system_prompt: u32,
    /// Estimated tokens in the rolling summary
    pub summary: u32,
    /// Estimated tokens across all conversation messages
    pub history: u32,
    /// Total estimated tokens (system_prompt + summary + history)
    pub total: u32,
}

/// Overhead tokens per message for role markers and separators.
///
/// Most chat formats add ~4 tokens per message for structure:
/// `<|start|>role\n`, content, `<|end|>`.
pub const MESSAGE_OVERHEAD_TOKENS: u32 = 4;

impl TokenEstimator {
    /// Estimate the number of tokens in a text string.
    ///
    /// When the `v2` feature is enabled, uses per-language-class ratios
    /// (Latin 4.0, CJK 1.5, Cyrillic 2.5, Arabic 2.0, Mixed 3.0).
    /// Otherwise, uses the original binary ASCII/Unicode heuristic.
    #[must_use]
    pub fn estimate_tokens(text: &str) -> u32 {
        if text.is_empty() {
            return 0;
        }

        #[allow(clippy::cast_precision_loss)]
        let total_chars = text.len() as f64;

        let chars_per_token = detect_language_class(text).chars_per_token();

        // Ceiling division — always round up to be conservative
        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
        let tokens = (total_chars / chars_per_token).ceil() as u32;
        tokens.max(1)
    }

    /// Estimate tokens for a single chat message including role overhead.
    #[must_use]
    pub fn estimate_message(message: &ChatMessage) -> u32 {
        Self::estimate_tokens(&message.content) + MESSAGE_OVERHEAD_TOKENS
    }

    /// Estimate total tokens for a sequence of messages.
    #[must_use]
    pub fn estimate_messages(messages: &[ChatMessage]) -> u32 {
        messages.iter().map(Self::estimate_message).sum()
    }

    /// Produce a detailed token estimate breakdown for a conversation.
    #[must_use]
    pub fn estimate_conversation(conversation: &Conversation) -> ConversationTokenEstimate {
        let system_prompt = conversation
            .system_prompt
            .as_deref()
            .map_or(0, Self::estimate_tokens);

        let summary = conversation
            .summary
            .as_deref()
            .map_or(0, Self::estimate_tokens);

        let history = Self::estimate_messages(&conversation.messages);

        ConversationTokenEstimate {
            system_prompt,
            summary,
            history,
            total: system_prompt + summary + history,
        }
    }

    /// Estimate tokens for a tool definition (name + description + parameters schema).
    #[must_use]
    pub fn estimate_tool_definition(tool: &ToolDefinition) -> u32 {
        let name_tokens = Self::estimate_tokens(&tool.name);
        let desc_tokens = Self::estimate_tokens(&tool.description);

        // Parameter schema: count each property's type + description
        let param_tokens: u32 = tool
            .parameters
            .properties
            .values()
            .map(|p| Self::estimate_tokens(&p.description) + 2) // +2 for type+key
            .sum();

        // Overhead for JSON schema structure
        let schema_overhead = 8;

        name_tokens + desc_tokens + param_tokens + schema_overhead
    }

    /// Estimate tokens for a slice of tool definitions.
    #[must_use]
    pub fn estimate_tool_definitions(tools: &[ToolDefinition]) -> u32 {
        tools.iter().map(Self::estimate_tool_definition).sum()
    }
}

/// Estimate tokens for a role marker string.
#[must_use]
pub fn role_token_cost(role: MessageRole) -> u32 {
    match role {
        MessageRole::System | MessageRole::User | MessageRole::Assistant => 2,
        MessageRole::Tool => 3, // tool role + call_id reference
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn empty_string_is_zero_tokens() {
        assert_eq!(TokenEstimator::estimate_tokens(""), 0);
    }

    #[test]
    fn short_ascii_text() {
        // "Hello" = 5 chars → ceil(5/4) = 2
        assert_eq!(TokenEstimator::estimate_tokens("Hello"), 2);
    }

    #[test]
    fn longer_ascii_text() {
        let text = "The quick brown fox jumps over the lazy dog";
        // 43 chars → ceil(43/4) = 11
        assert_eq!(TokenEstimator::estimate_tokens(text), 11);
    }

    #[test]
    fn german_text_uses_language_class_ratio() {
        let text = "Ünüöäßüöäßüöäßüöäß ÖÜÄ";
        let estimate = TokenEstimator::estimate_tokens(text);
        assert!(estimate > 0);

        // Latin script detected → chars/4.0
        #[allow(
            clippy::cast_precision_loss,
            clippy::cast_possible_truncation,
            clippy::cast_sign_loss
        )]
        let expected = (text.len() as f64 / 4.0).ceil() as u32;
        assert_eq!(estimate, expected);
    }

    #[test]
    fn message_includes_overhead() {
        let msg = ChatMessage::user("Hello");
        let content_tokens = TokenEstimator::estimate_tokens("Hello");
        assert_eq!(
            TokenEstimator::estimate_message(&msg),
            content_tokens + MESSAGE_OVERHEAD_TOKENS
        );
    }

    #[test]
    fn conversation_estimate_breakdown() {
        let mut conv = Conversation::with_system_prompt("You are helpful.");
        conv.summary = Some("Previously discussed weather.".to_string());
        conv.add_user_message("What's the weather?");
        conv.add_assistant_message("It's sunny today.");

        let est = TokenEstimator::estimate_conversation(&conv);
        assert!(est.system_prompt > 0);
        assert!(est.summary > 0);
        assert!(est.history > 0);
        assert_eq!(est.total, est.system_prompt + est.summary + est.history);
    }

    #[test]
    fn empty_conversation_estimate() {
        let conv = Conversation::new();
        let est = TokenEstimator::estimate_conversation(&conv);
        assert_eq!(est.total, 0);
    }

    #[test]
    fn single_char_is_at_least_one_token() {
        assert_eq!(TokenEstimator::estimate_tokens("a"), 1);
    }

    #[test]
    fn cjk_text_uses_language_class_ratio() {
        let text = "你好世界这是测试文本";
        let estimate = TokenEstimator::estimate_tokens(text);
        // CJK: chars/1.5 — each CJK char is ~3 bytes in UTF-8
        #[allow(
            clippy::cast_precision_loss,
            clippy::cast_possible_truncation,
            clippy::cast_sign_loss
        )]
        let expected = (text.len() as f64 / 1.5).ceil() as u32;
        assert_eq!(estimate, expected);
    }

    #[test]
    fn cyrillic_text_uses_language_class_ratio() {
        let text = "Привет мир как дела";
        let estimate = TokenEstimator::estimate_tokens(text);
        #[allow(
            clippy::cast_precision_loss,
            clippy::cast_possible_truncation,
            clippy::cast_sign_loss
        )]
        let expected = (text.len() as f64 / 2.5).ceil() as u32;
        assert_eq!(estimate, expected);
    }
}