enact-context 0.0.2

Context window management and compaction for Enact
Documentation
//! Token Counter
//!
//! Token counting using tiktoken for accurate context management.

use thiserror::Error;
use tiktoken_rs::{cl100k_base, CoreBPE};

/// Token counter errors
#[derive(Debug, Error)]
pub enum TokenCounterError {
    #[error("Failed to initialize tokenizer: {0}")]
    InitError(String),
}

/// Token counter using tiktoken
///
/// Uses cl100k_base encoding (GPT-4, GPT-3.5-turbo, text-embedding-ada-002)
pub struct TokenCounter {
    bpe: CoreBPE,
}

impl TokenCounter {
    /// Create a new token counter with cl100k_base encoding
    pub fn new() -> Result<Self, TokenCounterError> {
        let bpe = cl100k_base().map_err(|e| TokenCounterError::InitError(e.to_string()))?;
        Ok(Self { bpe })
    }

    /// Count tokens in a string
    pub fn count(&self, text: &str) -> usize {
        self.bpe.encode_with_special_tokens(text).len()
    }

    /// Count tokens for a chat message (includes role overhead)
    ///
    /// Approximates the token overhead for chat formatting:
    /// - Each message has ~4 tokens overhead (role, separators)
    /// - System messages have additional overhead
    pub fn count_message(&self, role: &str, content: &str) -> usize {
        let content_tokens = self.count(content);
        let role_overhead = match role {
            "system" => 4,
            "user" => 4,
            "assistant" => 4,
            "function" | "tool" => 5,
            _ => 4,
        };
        content_tokens + role_overhead
    }

    /// Count tokens for multiple messages
    pub fn count_messages(&self, messages: &[(String, String)]) -> usize {
        messages
            .iter()
            .map(|(role, content)| self.count_message(role, content))
            .sum::<usize>()
            + 3 // priming tokens
    }

    /// Truncate text to fit within a token limit
    ///
    /// Returns the truncated text and the actual token count.
    pub fn truncate(&self, text: &str, max_tokens: usize) -> (String, usize) {
        let tokens = self.bpe.encode_with_special_tokens(text);
        if tokens.len() <= max_tokens {
            return (text.to_string(), tokens.len());
        }

        let truncated_tokens = &tokens[..max_tokens];
        let truncated_text = self
            .bpe
            .decode(truncated_tokens.to_vec())
            .unwrap_or_else(|_| text[..text.len() / 2].to_string());

        (truncated_text, max_tokens)
    }

    /// Split text into chunks of approximately equal token size
    pub fn chunk(&self, text: &str, chunk_size: usize) -> Vec<String> {
        let tokens = self.bpe.encode_with_special_tokens(text);
        let mut chunks = Vec::new();

        for chunk_tokens in tokens.chunks(chunk_size) {
            if let Ok(chunk_text) = self.bpe.decode(chunk_tokens.to_vec()) {
                chunks.push(chunk_text);
            }
        }

        chunks
    }

    /// Estimate if text will fit within budget
    pub fn will_fit(&self, text: &str, budget: usize) -> bool {
        self.count(text) <= budget
    }
}

impl Default for TokenCounter {
    fn default() -> Self {
        Self::new().expect("Failed to initialize default token counter")
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_count_tokens() {
        let counter = TokenCounter::new().unwrap();

        // Empty string
        assert_eq!(counter.count(""), 0);

        // Simple text
        let count = counter.count("Hello, world!");
        assert!(count > 0);
        assert!(count < 10);
    }

    #[test]
    fn test_count_message() {
        let counter = TokenCounter::new().unwrap();

        let content_tokens = counter.count("Hello");
        let message_tokens = counter.count_message("user", "Hello");

        // Message should have overhead
        assert!(message_tokens > content_tokens);
    }

    #[test]
    fn test_truncate() {
        let counter = TokenCounter::new().unwrap();

        let long_text = "This is a long text that we want to truncate to a smaller size.";
        let (truncated, count) = counter.truncate(long_text, 5);

        assert!(count <= 5);
        assert!(truncated.len() < long_text.len());
    }

    #[test]
    fn test_will_fit() {
        let counter = TokenCounter::new().unwrap();

        assert!(counter.will_fit("Hello", 100));
        assert!(!counter.will_fit("Hello ".repeat(1000).as_str(), 10));
    }
}