use std::path::Path;
use tracing::{debug, warn};
use crate::error::TokenOptError;
use crate::estimator::{ConversationTokenEstimate, MESSAGE_OVERHEAD_TOKENS, TokenEstimator};
use crate::types::{ChatMessage, Conversation, ToolDefinition};
pub struct HfTokenEstimator {
tokenizer: tokenizers::Tokenizer,
model_name: String,
}
impl std::fmt::Debug for HfTokenEstimator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HfTokenEstimator")
.field("model_name", &self.model_name)
.finish_non_exhaustive()
}
}
impl HfTokenEstimator {
pub fn from_file(path: &Path) -> Result<Self, TokenOptError> {
let tokenizer = tokenizers::Tokenizer::from_file(path).map_err(|e| {
TokenOptError::Configuration(format!(
"Failed to load tokenizer from {}: {e}",
path.display()
))
})?;
let model_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("local")
.to_string();
debug!(path = %path.display(), "Loaded HuggingFace tokenizer from file");
Ok(Self {
tokenizer,
model_name,
})
}
pub fn from_pretrained(model_name: &str) -> Result<Self, TokenOptError> {
let tokenizer = tokenizers::Tokenizer::from_pretrained(model_name, None).map_err(|e| {
TokenOptError::Configuration(format!(
"Failed to load tokenizer for '{model_name}': {e}"
))
})?;
debug!(model = model_name, "Loaded HuggingFace tokenizer from Hub");
Ok(Self {
tokenizer,
model_name: model_name.to_string(),
})
}
#[must_use]
pub fn count_tokens(&self, text: &str) -> u32 {
if text.is_empty() {
return 0;
}
match self.tokenizer.encode(text, false) {
Ok(encoding) => {
#[allow(clippy::cast_possible_truncation)]
let count = encoding.get_ids().len() as u32;
count.max(1)
},
Err(e) => {
warn!(
error = %e,
model = self.model_name,
"HF tokenizer encoding failed, falling back to heuristic"
);
TokenEstimator::estimate_tokens(text)
},
}
}
#[must_use]
pub fn count_message_tokens(&self, message: &ChatMessage) -> u32 {
self.count_tokens(&message.content) + MESSAGE_OVERHEAD_TOKENS
}
#[must_use]
pub fn count_messages_tokens(&self, messages: &[ChatMessage]) -> u32 {
messages.iter().map(|m| self.count_message_tokens(m)).sum()
}
#[must_use]
pub fn count_conversation_tokens(
&self,
conversation: &Conversation,
) -> ConversationTokenEstimate {
let system_prompt = conversation
.system_prompt
.as_deref()
.map_or(0, |p| self.count_tokens(p));
let summary = conversation
.summary
.as_deref()
.map_or(0, |s| self.count_tokens(s));
let history = self.count_messages_tokens(&conversation.messages);
ConversationTokenEstimate {
system_prompt,
summary,
history,
total: system_prompt + summary + history,
}
}
#[must_use]
pub fn count_tool_definition_tokens(&self, tool: &ToolDefinition) -> u32 {
let name_tokens = self.count_tokens(&tool.name);
let desc_tokens = self.count_tokens(&tool.description);
let param_tokens: u32 = tool
.parameters
.properties
.values()
.map(|p| self.count_tokens(&p.param_type) + self.count_tokens(&p.description))
.sum();
name_tokens + desc_tokens + param_tokens + 8
}
#[must_use]
pub fn count_tool_definitions_tokens(&self, tools: &[ToolDefinition]) -> u32 {
tools
.iter()
.map(|t| self.count_tool_definition_tokens(t))
.sum()
}
#[must_use]
pub fn model_name(&self) -> &str {
&self.model_name
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_text_returns_zero() {
assert_eq!(TokenEstimator::estimate_tokens(""), 0);
}
#[test]
fn heuristic_fallback_on_missing_file() {
let result = HfTokenEstimator::from_file(Path::new("/nonexistent/tokenizer.json"));
assert!(result.is_err());
}
}