reasonkit-core 0.1.8

The Reasoning Engine — Auditable Reasoning for Production AI | Rust-Native | Turn Prompts into Protocols
//! Token Counting and Management
//!
//! This module provides accurate BPE tokenization using tiktoken-rs
//! for managing context windows and tracking token usage.
//!
//! # Features
//! - Accurate token counting for OpenAI models
//! - Context window management
//! - Token budget tracking
//! - Text truncation to fit within limits
//!
//! Enable with: `cargo build --features tokenization`

use anyhow::Result;
use serde::{Deserialize, Serialize};

// Re-export tiktoken for direct access
pub use tiktoken_rs;

/// Token encoding types for different models
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenEncoding {
    /// GPT-4 and GPT-3.5-turbo (cl100k_base)
    Cl100kBase,
    /// GPT-3 models (p50k_base)
    P50kBase,
    /// Codex models (p50k_edit)
    P50kEdit,
    /// GPT-2 (r50k_base / gpt2)
    R50kBase,
    /// Claude models (estimated, uses cl100k_base as approximation)
    ClaudeApprox,
}

impl TokenEncoding {
    /// Get the tiktoken encoding name
    pub fn encoding_name(&self) -> &'static str {
        match self {
            TokenEncoding::Cl100kBase | TokenEncoding::ClaudeApprox => "cl100k_base",
            TokenEncoding::P50kBase => "p50k_base",
            TokenEncoding::P50kEdit => "p50k_edit",
            TokenEncoding::R50kBase => "r50k_base",
        }
    }

    /// Get encoding for a model name
    pub fn for_model(model: &str) -> Self {
        let model_lower = model.to_lowercase();
        if model_lower.contains("gpt-4") || model_lower.contains("gpt-3.5") {
            TokenEncoding::Cl100kBase
        } else if model_lower.contains("claude") {
            TokenEncoding::ClaudeApprox
        } else if model_lower.contains("davinci") || model_lower.contains("curie") {
            TokenEncoding::P50kBase
        } else if model_lower.contains("codex") {
            TokenEncoding::P50kEdit
        } else {
            TokenEncoding::Cl100kBase // Default to latest
        }
    }
}

/// Token counter for accurate token measurement
#[derive(Debug, Clone)]
pub struct TokenCounter {
    encoding: TokenEncoding,
}

impl TokenCounter {
    /// Create a new token counter with the specified encoding
    pub fn new(encoding: TokenEncoding) -> Self {
        Self { encoding }
    }

    /// Create a token counter for a specific model
    pub fn for_model(model: &str) -> Self {
        Self::new(TokenEncoding::for_model(model))
    }

    /// Count tokens in a string
    pub fn count(&self, text: &str) -> Result<usize> {
        let bpe = tiktoken_rs::get_bpe_from_model(self.encoding.encoding_name())?;
        Ok(bpe.encode_ordinary(text).len())
    }

    /// Count tokens in a chat message format
    pub fn count_chat_message(&self, role: &str, content: &str) -> Result<usize> {
        // Chat format adds overhead: <|im_start|>role\ncontent<|im_end|>
        // Approximately 4 tokens of overhead per message
        let content_tokens = self.count(content)?;
        let role_tokens = self.count(role)?;
        Ok(content_tokens + role_tokens + 4)
    }

    /// Get the encoding type
    pub fn encoding(&self) -> TokenEncoding {
        self.encoding
    }
}

impl Default for TokenCounter {
    fn default() -> Self {
        Self::new(TokenEncoding::Cl100kBase)
    }
}

/// Context window manager for tracking token budgets
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextWindow {
    /// Maximum tokens in the context window
    pub max_tokens: usize,
    /// Reserved tokens for the response
    pub reserved_for_response: usize,
    /// Current token count
    pub current_tokens: usize,
}

impl ContextWindow {
    /// Create a new context window with the given limits
    pub fn new(max_tokens: usize, reserved_for_response: usize) -> Self {
        Self {
            max_tokens,
            reserved_for_response,
            current_tokens: 0,
        }
    }

    /// Create a context window for a specific model
    pub fn for_model(model: &str) -> Self {
        let (max_tokens, reserved) = match model.to_lowercase().as_str() {
            m if m.contains("gpt-4-turbo") || m.contains("gpt-4o") => (128_000, 4_096),
            m if m.contains("gpt-4-32k") => (32_768, 4_096),
            m if m.contains("gpt-4") => (8_192, 2_048),
            m if m.contains("gpt-3.5-turbo-16k") => (16_384, 4_096),
            m if m.contains("gpt-3.5") => (4_096, 1_024),
            m if m.contains("claude-3-opus") => (200_000, 4_096),
            m if m.contains("claude-3-sonnet") => (200_000, 4_096),
            m if m.contains("claude-3-haiku") => (200_000, 4_096),
            m if m.contains("claude-2") => (100_000, 4_096),
            _ => (8_192, 2_048), // Conservative default
        };
        Self::new(max_tokens, reserved)
    }

    /// Available tokens for input (after reserving response tokens)
    pub fn available(&self) -> usize {
        self.max_tokens
            .saturating_sub(self.reserved_for_response)
            .saturating_sub(self.current_tokens)
    }

    /// Add tokens to the current count
    pub fn add(&mut self, tokens: usize) {
        self.current_tokens = self.current_tokens.saturating_add(tokens);
    }

    /// Reset the current token count
    pub fn reset(&mut self) {
        self.current_tokens = 0;
    }

    /// Check if more content can fit
    pub fn can_fit(&self, tokens: usize) -> bool {
        self.available() >= tokens
    }

    /// Usage percentage
    pub fn usage_percent(&self) -> f64 {
        let usable = self.max_tokens.saturating_sub(self.reserved_for_response);
        if usable == 0 {
            return 100.0;
        }
        (self.current_tokens as f64 / usable as f64) * 100.0
    }
}

/// Text truncator for fitting content within token limits
pub struct TextTruncator {
    counter: TokenCounter,
}

impl TextTruncator {
    /// Create a new truncator with the specified encoding
    pub fn new(encoding: TokenEncoding) -> Self {
        Self {
            counter: TokenCounter::new(encoding),
        }
    }

    /// Truncate text to fit within the specified token limit
    pub fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
        let tokens = self.counter.count(text)?;
        if tokens <= max_tokens {
            return Ok(text.to_string());
        }

        // Binary search for the right length
        let mut low = 0;
        let mut high = text.len();
        let mut best = 0;

        while low < high {
            let mid = (low + high) / 2;
            let truncated = &text[..mid];
            let count = self.counter.count(truncated)?;

            if count <= max_tokens {
                best = mid;
                low = mid + 1;
            } else {
                high = mid;
            }
        }

        // Ensure we don't cut in the middle of a UTF-8 character
        let mut result = text[..best].to_string();
        while !result.is_empty() && !text.is_char_boundary(result.len()) {
            result.pop();
        }

        Ok(result)
    }

    /// Truncate text and add an ellipsis indicator
    pub fn truncate_with_ellipsis(&self, text: &str, max_tokens: usize) -> Result<String> {
        let ellipsis = "...";
        let ellipsis_tokens = self.counter.count(ellipsis)?;

        if max_tokens <= ellipsis_tokens {
            return Ok(ellipsis.to_string());
        }

        let truncated = self.truncate(text, max_tokens - ellipsis_tokens)?;
        if truncated.len() < text.len() {
            Ok(format!("{}{}", truncated, ellipsis))
        } else {
            Ok(text.to_string())
        }
    }
}

impl Default for TextTruncator {
    fn default() -> Self {
        Self::new(TokenEncoding::Cl100kBase)
    }
}

/// Token budget tracker for multi-turn conversations
#[derive(Debug, Clone)]
pub struct TokenBudget {
    window: ContextWindow,
    counter: TokenCounter,
    messages: Vec<(String, usize)>, // (content, tokens)
}

impl TokenBudget {
    /// Create a new token budget for a model
    pub fn for_model(model: &str) -> Self {
        Self {
            window: ContextWindow::for_model(model),
            counter: TokenCounter::for_model(model),
            messages: Vec::new(),
        }
    }

    /// Add a message to the budget
    pub fn add_message(&mut self, content: &str) -> Result<bool> {
        let tokens = self.counter.count(content)?;
        if !self.window.can_fit(tokens) {
            return Ok(false);
        }
        self.window.add(tokens);
        self.messages.push((content.to_string(), tokens));
        Ok(true)
    }

    /// Remove oldest messages until there's room
    pub fn make_room(&mut self, needed_tokens: usize) {
        while !self.window.can_fit(needed_tokens) && !self.messages.is_empty() {
            let (_, tokens) = self.messages.remove(0);
            self.window.current_tokens = self.window.current_tokens.saturating_sub(tokens);
        }
    }

    /// Get current usage statistics
    pub fn stats(&self) -> BudgetStats {
        BudgetStats {
            current_tokens: self.window.current_tokens,
            max_tokens: self.window.max_tokens,
            available_tokens: self.window.available(),
            usage_percent: self.window.usage_percent(),
            message_count: self.messages.len(),
        }
    }
}

/// Statistics about token budget usage
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetStats {
    pub current_tokens: usize,
    pub max_tokens: usize,
    pub available_tokens: usize,
    pub usage_percent: f64,
    pub message_count: usize,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_encoding_for_model() {
        assert_eq!(TokenEncoding::for_model("gpt-4"), TokenEncoding::Cl100kBase);
        assert_eq!(
            TokenEncoding::for_model("claude-3-opus"),
            TokenEncoding::ClaudeApprox
        );
    }

    #[test]
    fn test_context_window() {
        let mut window = ContextWindow::new(8192, 2048);
        assert_eq!(window.available(), 6144);

        window.add(1000);
        assert_eq!(window.available(), 5144);
        assert!(window.can_fit(5000));
        assert!(!window.can_fit(6000));
    }

    #[test]
    fn test_context_window_for_model() {
        let window = ContextWindow::for_model("gpt-4-turbo");
        assert_eq!(window.max_tokens, 128_000);

        let window = ContextWindow::for_model("claude-3-opus");
        assert_eq!(window.max_tokens, 200_000);
    }
}