ai_tokenopt 0.5.9

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! HuggingFace tokenizer-based token estimation.
//!
//! Provides exact token counts using a real tokenizer vocabulary,
//! replacing the heuristic character-ratio estimation with precise
//! tokenization. Falls back to heuristic on any encoding error.
//!
//! # Usage
//!
//! ```no_run
//! use ai_tokenopt::estimator_hf::HfTokenEstimator;
//!
//! // Load from HuggingFace Hub (cached automatically)
//! let estimator = HfTokenEstimator::from_pretrained("meta-llama/Llama-3.2-3B")
//!     .expect("failed to load tokenizer");
//!
//! let tokens = estimator.count_tokens("Hello, world!");
//! assert!(tokens > 0);
//! ```

use std::path::Path;

use tracing::{debug, warn};

use crate::error::TokenOptError;
use crate::estimator::{ConversationTokenEstimate, MESSAGE_OVERHEAD_TOKENS, TokenEstimator};
use crate::types::{ChatMessage, Conversation, ToolDefinition};

/// Exact token estimator backed by a HuggingFace tokenizer.
///
/// Wraps [`tokenizers::Tokenizer`] and provides the same estimation surface as
/// [`TokenEstimator`] but with exact token counts from a real vocabulary.
///
/// On encoding failure, transparently falls back to the heuristic estimator.
pub struct HfTokenEstimator {
    tokenizer: tokenizers::Tokenizer,
    model_name: String,
}

impl std::fmt::Debug for HfTokenEstimator {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("HfTokenEstimator")
            .field("model_name", &self.model_name)
            .finish_non_exhaustive()
    }
}

impl HfTokenEstimator {
    /// Load a tokenizer from a local `tokenizer.json` file.
    ///
    /// # Errors
    ///
    /// Returns [`TokenOptError`] if the file cannot be read or parsed.
    pub fn from_file(path: &Path) -> Result<Self, TokenOptError> {
        let tokenizer = tokenizers::Tokenizer::from_file(path).map_err(|e| {
            TokenOptError::Configuration(format!(
                "Failed to load tokenizer from {}: {e}",
                path.display()
            ))
        })?;

        let model_name = path
            .file_stem()
            .and_then(|s| s.to_str())
            .unwrap_or("local")
            .to_string();

        debug!(path = %path.display(), "Loaded HuggingFace tokenizer from file");

        Ok(Self {
            tokenizer,
            model_name,
        })
    }

    /// Download and cache a tokenizer from the HuggingFace Hub.
    ///
    /// Uses the HuggingFace Hub's built-in caching mechanism
    /// (`~/.cache/huggingface/hub/`). Subsequent calls with the same
    /// model name load from cache without network access.
    ///
    /// # Errors
    ///
    /// Returns [`TokenOptError`] if the download or parsing fails.
    pub fn from_pretrained(model_name: &str) -> Result<Self, TokenOptError> {
        let tokenizer = tokenizers::Tokenizer::from_pretrained(model_name, None).map_err(|e| {
            TokenOptError::Configuration(format!(
                "Failed to load tokenizer for '{model_name}': {e}"
            ))
        })?;

        debug!(model = model_name, "Loaded HuggingFace tokenizer from Hub");

        Ok(Self {
            tokenizer,
            model_name: model_name.to_string(),
        })
    }

    /// Count exact tokens in a text string.
    ///
    /// Falls back to heuristic estimation on encoding error.
    #[must_use]
    pub fn count_tokens(&self, text: &str) -> u32 {
        if text.is_empty() {
            return 0;
        }

        match self.tokenizer.encode(text, false) {
            Ok(encoding) => {
                #[allow(clippy::cast_possible_truncation)]
                let count = encoding.get_ids().len() as u32;
                count.max(1)
            },
            Err(e) => {
                warn!(
                    error = %e,
                    model = self.model_name,
                    "HF tokenizer encoding failed, falling back to heuristic"
                );
                TokenEstimator::estimate_tokens(text)
            },
        }
    }

    /// Estimate tokens for a single message (content + role overhead).
    #[must_use]
    pub fn count_message_tokens(&self, message: &ChatMessage) -> u32 {
        self.count_tokens(&message.content) + MESSAGE_OVERHEAD_TOKENS
    }

    /// Estimate tokens for a slice of messages.
    #[must_use]
    pub fn count_messages_tokens(&self, messages: &[ChatMessage]) -> u32 {
        messages.iter().map(|m| self.count_message_tokens(m)).sum()
    }

    /// Produce a detailed token estimate breakdown for a conversation.
    #[must_use]
    pub fn count_conversation_tokens(
        &self,
        conversation: &Conversation,
    ) -> ConversationTokenEstimate {
        let system_prompt = conversation
            .system_prompt
            .as_deref()
            .map_or(0, |p| self.count_tokens(p));

        let summary = conversation
            .summary
            .as_deref()
            .map_or(0, |s| self.count_tokens(s));

        let history = self.count_messages_tokens(&conversation.messages);

        ConversationTokenEstimate {
            system_prompt,
            summary,
            history,
            total: system_prompt + summary + history,
        }
    }

    /// Estimate tokens for a tool definition.
    #[must_use]
    pub fn count_tool_definition_tokens(&self, tool: &ToolDefinition) -> u32 {
        let name_tokens = self.count_tokens(&tool.name);
        let desc_tokens = self.count_tokens(&tool.description);

        let param_tokens: u32 = tool
            .parameters
            .properties
            .values()
            .map(|p| self.count_tokens(&p.param_type) + self.count_tokens(&p.description))
            .sum();

        // Schema overhead (JSON structure, brackets, keys)
        name_tokens + desc_tokens + param_tokens + 8
    }

    /// Estimate tokens for a slice of tool definitions.
    #[must_use]
    pub fn count_tool_definitions_tokens(&self, tools: &[ToolDefinition]) -> u32 {
        tools
            .iter()
            .map(|t| self.count_tool_definition_tokens(t))
            .sum()
    }

    /// Return the model name this estimator was loaded for.
    #[must_use]
    pub fn model_name(&self) -> &str {
        &self.model_name
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn empty_text_returns_zero() {
        // Use from_file with a non-existent path to test the error path,
        // then test count_tokens via heuristic fallback.
        // For unit tests without a real tokenizer, we verify the API surface.
        assert_eq!(TokenEstimator::estimate_tokens(""), 0);
    }

    #[test]
    fn heuristic_fallback_on_missing_file() {
        let result = HfTokenEstimator::from_file(Path::new("/nonexistent/tokenizer.json"));
        assert!(result.is_err());
    }
}