use async_trait::async_trait;
use crate::prompt_block::PromptBlock;
use crate::retry::LlmError;
use crate::types::ChatMessage;
pub mod anthropic_api;
pub mod cascading;
pub mod tiktoken_fallback;
pub use anthropic_api::AnthropicTokenCounter;
pub use cascading::CascadingTokenCounter;
pub use tiktoken_fallback::TiktokenCounter;
#[async_trait]
pub trait TokenCounter: Send + Sync {
async fn count_blocks(&self, blocks: &[PromptBlock]) -> Result<u32, LlmError>;
async fn count_messages(
&self,
model: &str,
messages: &[ChatMessage],
) -> Result<u32, LlmError>;
fn is_exact(&self) -> bool;
fn backend(&self) -> &'static str;
}
pub fn build(
backend: &str,
provider: &str,
base_url: &str,
api_key: &str,
cache_capacity: u32,
) -> std::sync::Arc<dyn TokenCounter> {
use std::sync::Arc;
let cascade = |primary: Arc<dyn TokenCounter>| -> Arc<dyn TokenCounter> {
Arc::new(CascadingTokenCounter::new(primary))
};
match backend {
"anthropic_api" => cascade(Arc::new(AnthropicTokenCounter::new(
base_url,
api_key,
cache_capacity,
))),
"tiktoken" => Arc::new(TiktokenCounter::new()),
"auto" => {
if provider == "anthropic" && !api_key.trim().is_empty() {
cascade(Arc::new(AnthropicTokenCounter::new(
base_url,
api_key,
cache_capacity,
)))
} else {
Arc::new(TiktokenCounter::new())
}
}
other => {
tracing::warn!(
backend = other,
"unknown token_counter backend — falling back to tiktoken"
);
Arc::new(TiktokenCounter::new())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_picks_anthropic_when_auto_and_keyed() {
let c = build("auto", "anthropic", "https://api.anthropic.com", "sk-x", 16);
assert_eq!(c.backend(), "cascading");
assert!(c.is_exact());
}
#[test]
fn build_falls_back_to_tiktoken_when_no_key() {
let c = build("auto", "anthropic", "", "", 16);
assert_eq!(c.backend(), "tiktoken");
assert!(!c.is_exact());
}
#[test]
fn build_explicit_tiktoken_overrides_provider() {
let c = build("tiktoken", "anthropic", "", "sk-x", 16);
assert_eq!(c.backend(), "tiktoken");
}
#[test]
fn build_unknown_backend_falls_back_to_tiktoken() {
let c = build("nonsense", "openai", "", "", 16);
assert_eq!(c.backend(), "tiktoken");
}
#[test]
fn build_explicit_anthropic_works_for_other_providers() {
let c = build("anthropic_api", "openai", "https://api.anthropic.com", "k", 16);
assert_eq!(c.backend(), "cascading");
}
}