use http::HeaderMap;
use serde_json::Value;
use tracing::trace;
use zentinel_config::{InferenceProvider, TokenEstimation};
use super::tiktoken::tiktoken_manager;
pub trait InferenceProviderAdapter: Send + Sync {
fn name(&self) -> &'static str;
fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64>;
fn tokens_from_body(&self, body: &[u8]) -> Option<u64>;
fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64;
fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String>;
}
pub fn create_provider(provider: &InferenceProvider) -> Box<dyn InferenceProviderAdapter> {
match provider {
InferenceProvider::OpenAi => Box::new(OpenAiProvider),
InferenceProvider::Anthropic => Box::new(AnthropicProvider),
InferenceProvider::Generic => Box::new(GenericProvider),
}
}
struct OpenAiProvider;
impl InferenceProviderAdapter for OpenAiProvider {
fn name(&self) -> &'static str {
"openai"
}
fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
if let Some(value) = headers.get("x-ratelimit-used-tokens") {
if let Ok(s) = value.to_str() {
if let Ok(n) = s.parse::<u64>() {
trace!(
tokens = n,
"Got token count from OpenAI x-ratelimit-used-tokens"
);
return Some(n);
}
}
}
let limit = headers
.get("x-ratelimit-limit-tokens")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let remaining = headers
.get("x-ratelimit-remaining-tokens")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if let (Some(l), Some(r)) = (limit, remaining) {
let used = l.saturating_sub(r);
trace!(
limit = l,
remaining = r,
used = used,
"Calculated token usage from OpenAI headers"
);
return Some(used);
}
None
}
fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
let json: Value = serde_json::from_slice(body).ok()?;
let total = json.get("usage")?.get("total_tokens")?.as_u64();
if let Some(t) = total {
trace!(tokens = t, "Got token count from OpenAI response body");
}
total
}
fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
estimate_tokens(body, method)
}
fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
return Some(model.to_string());
}
let json: Value = serde_json::from_slice(body).ok()?;
json.get("model")?.as_str().map(|s| s.to_string())
}
}
struct AnthropicProvider;
impl InferenceProviderAdapter for AnthropicProvider {
fn name(&self) -> &'static str {
"anthropic"
}
fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
let limit = headers
.get("anthropic-ratelimit-tokens-limit")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let remaining = headers
.get("anthropic-ratelimit-tokens-remaining")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if let (Some(l), Some(r)) = (limit, remaining) {
let used = l.saturating_sub(r);
trace!(
limit = l,
remaining = r,
used = used,
"Calculated token usage from Anthropic headers"
);
return Some(used);
}
None
}
fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
let json: Value = serde_json::from_slice(body).ok()?;
let usage = json.get("usage")?;
let input = usage.get("input_tokens")?.as_u64().unwrap_or(0);
let output = usage.get("output_tokens")?.as_u64().unwrap_or(0);
let total = input + output;
trace!(
input = input,
output = output,
total = total,
"Got token count from Anthropic response body"
);
Some(total)
}
fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
estimate_tokens(body, method)
}
fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
return Some(model.to_string());
}
let json: Value = serde_json::from_slice(body).ok()?;
json.get("model")?.as_str().map(|s| s.to_string())
}
}
struct GenericProvider;
impl InferenceProviderAdapter for GenericProvider {
fn name(&self) -> &'static str {
"generic"
}
fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
let candidates = ["x-tokens-used", "x-token-count", "x-total-tokens"];
for header in candidates {
if let Some(value) = headers.get(header) {
if let Ok(s) = value.to_str() {
if let Ok(n) = s.parse::<u64>() {
trace!(
header = header,
tokens = n,
"Got token count from generic header"
);
return Some(n);
}
}
}
}
None
}
fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
let json: Value = serde_json::from_slice(body).ok()?;
if let Some(total) = json
.get("usage")
.and_then(|u| u.get("total_tokens"))
.and_then(|t| t.as_u64())
{
return Some(total);
}
if let Some(usage) = json.get("usage") {
let input = usage
.get("input_tokens")
.and_then(|t| t.as_u64())
.unwrap_or(0);
let output = usage
.get("output_tokens")
.and_then(|t| t.as_u64())
.unwrap_or(0);
if input > 0 || output > 0 {
return Some(input + output);
}
}
None
}
fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
estimate_tokens(body, method)
}
fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
let candidates = ["x-model", "x-model-id", "model"];
for header in candidates {
if let Some(model) = headers.get(header).and_then(|v| v.to_str().ok()) {
return Some(model.to_string());
}
}
let json: Value = serde_json::from_slice(body).ok()?;
json.get("model")?.as_str().map(|s| s.to_string())
}
}
fn estimate_tokens(body: &[u8], method: TokenEstimation) -> u64 {
estimate_tokens_with_model(body, method, None)
}
fn estimate_tokens_with_model(body: &[u8], method: TokenEstimation, model: Option<&str>) -> u64 {
match method {
TokenEstimation::Chars => {
let char_count = String::from_utf8_lossy(body).chars().count();
(char_count / 4).max(1) as u64
}
TokenEstimation::Words => {
let text = String::from_utf8_lossy(body);
let word_count = text.split_whitespace().count();
((word_count as f64 * 1.3).ceil() as u64).max(1)
}
TokenEstimation::Tiktoken => estimate_tokens_tiktoken(body, model),
}
}
fn estimate_tokens_tiktoken(body: &[u8], model: Option<&str>) -> u64 {
let manager = tiktoken_manager();
let tokens = manager.count_chat_request(body, model);
trace!(
token_count = tokens,
model = ?model,
tiktoken_available = manager.is_available(),
"Tiktoken token count"
);
tokens
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_body_parsing() {
let body =
br#"{"usage": {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}}"#;
let provider = OpenAiProvider;
assert_eq!(provider.tokens_from_body(body), Some(150));
}
#[test]
fn test_anthropic_body_parsing() {
let body = br#"{"usage": {"input_tokens": 100, "output_tokens": 50}}"#;
let provider = AnthropicProvider;
assert_eq!(provider.tokens_from_body(body), Some(150));
}
#[test]
fn test_token_estimation_chars() {
let body = b"Hello world, this is a test message for token counting!";
let estimate = estimate_tokens(body, TokenEstimation::Chars);
assert!(estimate > 0 && estimate < 100);
}
#[test]
fn test_model_extraction() {
let body = br#"{"model": "gpt-4", "messages": []}"#;
let provider = OpenAiProvider;
let headers = HeaderMap::new();
assert_eq!(
provider.extract_model(&headers, body),
Some("gpt-4".to_string())
);
}
#[test]
fn test_token_estimation_tiktoken() {
let body = b"Hello world, this is a test message for token counting!";
let estimate = estimate_tokens(body, TokenEstimation::Tiktoken);
assert!(estimate > 0 && estimate < 100);
}
#[test]
#[cfg(feature = "tiktoken")]
fn test_tiktoken_accurate_count() {
let body = b"Hello world";
let estimate = estimate_tokens_tiktoken(body, Some("gpt-4"));
assert_eq!(estimate, 2);
}
#[test]
fn test_tiktoken_chat_request() {
let body = br#"{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello!"}
]
}"#;
let estimate = estimate_tokens_tiktoken(body, None);
assert!(estimate > 0);
}
}