sentinel_proxy/inference/
providers.rs

1//! Inference provider adapters for token extraction
2//!
3//! Each provider has specific headers and body formats for token information:
4//! - OpenAI: `x-ratelimit-remaining-tokens` header, `usage.total_tokens` in body
5//! - Anthropic: `anthropic-ratelimit-tokens-remaining` header, `usage.input_tokens + output_tokens`
6//! - Generic: `x-tokens-used` header, estimation fallback
7
8use http::HeaderMap;
9use serde_json::Value;
10use sentinel_config::{InferenceProvider, TokenEstimation};
11use tracing::trace;
12
13use super::tiktoken::tiktoken_manager;
14
15/// Trait for provider-specific token extraction and estimation
16pub trait InferenceProviderAdapter: Send + Sync {
17    /// Provider name for logging/metrics
18    fn name(&self) -> &'static str;
19
20    /// Extract token count from response headers (primary method)
21    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64>;
22
23    /// Extract token count from response body (fallback method)
24    fn tokens_from_body(&self, body: &[u8]) -> Option<u64>;
25
26    /// Estimate tokens from request body using the specified method
27    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64;
28
29    /// Extract model name from request (header or body)
30    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String>;
31}
32
33/// Create a provider adapter based on provider type
34pub fn create_provider(provider: &InferenceProvider) -> Box<dyn InferenceProviderAdapter> {
35    match provider {
36        InferenceProvider::OpenAi => Box::new(OpenAiProvider),
37        InferenceProvider::Anthropic => Box::new(AnthropicProvider),
38        InferenceProvider::Generic => Box::new(GenericProvider),
39    }
40}
41
42// ============================================================================
43// OpenAI Provider
44// ============================================================================
45
46struct OpenAiProvider;
47
48impl InferenceProviderAdapter for OpenAiProvider {
49    fn name(&self) -> &'static str {
50        "openai"
51    }
52
53    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
54        // OpenAI uses several headers:
55        // - x-ratelimit-remaining-tokens
56        // - x-ratelimit-limit-tokens
57        // - x-ratelimit-used-tokens (what we want)
58        if let Some(value) = headers.get("x-ratelimit-used-tokens") {
59            if let Ok(s) = value.to_str() {
60                if let Ok(n) = s.parse::<u64>() {
61                    trace!(tokens = n, "Got token count from OpenAI x-ratelimit-used-tokens");
62                    return Some(n);
63                }
64            }
65        }
66
67        // Fallback: calculate from limit - remaining
68        let limit = headers
69            .get("x-ratelimit-limit-tokens")
70            .and_then(|v| v.to_str().ok())
71            .and_then(|s| s.parse::<u64>().ok());
72
73        let remaining = headers
74            .get("x-ratelimit-remaining-tokens")
75            .and_then(|v| v.to_str().ok())
76            .and_then(|s| s.parse::<u64>().ok());
77
78        if let (Some(l), Some(r)) = (limit, remaining) {
79            let used = l.saturating_sub(r);
80            trace!(limit = l, remaining = r, used = used, "Calculated token usage from OpenAI headers");
81            return Some(used);
82        }
83
84        None
85    }
86
87    fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
88        // OpenAI response format:
89        // { "usage": { "prompt_tokens": N, "completion_tokens": M, "total_tokens": T } }
90        let json: Value = serde_json::from_slice(body).ok()?;
91        let total = json.get("usage")?.get("total_tokens")?.as_u64();
92        if let Some(t) = total {
93            trace!(tokens = t, "Got token count from OpenAI response body");
94        }
95        total
96    }
97
98    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
99        estimate_tokens(body, method)
100    }
101
102    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
103        // Check header first
104        if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
105            return Some(model.to_string());
106        }
107
108        // Extract from body: { "model": "gpt-4" }
109        let json: Value = serde_json::from_slice(body).ok()?;
110        json.get("model")?.as_str().map(|s| s.to_string())
111    }
112}
113
114// ============================================================================
115// Anthropic Provider
116// ============================================================================
117
118struct AnthropicProvider;
119
120impl InferenceProviderAdapter for AnthropicProvider {
121    fn name(&self) -> &'static str {
122        "anthropic"
123    }
124
125    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
126        // Anthropic uses:
127        // - anthropic-ratelimit-tokens-limit
128        // - anthropic-ratelimit-tokens-remaining
129        // - anthropic-ratelimit-tokens-reset
130        let limit = headers
131            .get("anthropic-ratelimit-tokens-limit")
132            .and_then(|v| v.to_str().ok())
133            .and_then(|s| s.parse::<u64>().ok());
134
135        let remaining = headers
136            .get("anthropic-ratelimit-tokens-remaining")
137            .and_then(|v| v.to_str().ok())
138            .and_then(|s| s.parse::<u64>().ok());
139
140        if let (Some(l), Some(r)) = (limit, remaining) {
141            let used = l.saturating_sub(r);
142            trace!(limit = l, remaining = r, used = used, "Calculated token usage from Anthropic headers");
143            return Some(used);
144        }
145
146        None
147    }
148
149    fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
150        // Anthropic response format:
151        // { "usage": { "input_tokens": N, "output_tokens": M } }
152        let json: Value = serde_json::from_slice(body).ok()?;
153        let usage = json.get("usage")?;
154
155        let input = usage.get("input_tokens")?.as_u64().unwrap_or(0);
156        let output = usage.get("output_tokens")?.as_u64().unwrap_or(0);
157        let total = input + output;
158
159        trace!(input = input, output = output, total = total, "Got token count from Anthropic response body");
160        Some(total)
161    }
162
163    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
164        estimate_tokens(body, method)
165    }
166
167    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
168        // Check header first
169        if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
170            return Some(model.to_string());
171        }
172
173        // Anthropic puts model in body: { "model": "claude-3-opus-20240229" }
174        let json: Value = serde_json::from_slice(body).ok()?;
175        json.get("model")?.as_str().map(|s| s.to_string())
176    }
177}
178
179// ============================================================================
180// Generic Provider
181// ============================================================================
182
183struct GenericProvider;
184
185impl InferenceProviderAdapter for GenericProvider {
186    fn name(&self) -> &'static str {
187        "generic"
188    }
189
190    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
191        // Generic provider looks for common headers
192        let candidates = [
193            "x-tokens-used",
194            "x-token-count",
195            "x-total-tokens",
196        ];
197
198        for header in candidates {
199            if let Some(value) = headers.get(header) {
200                if let Ok(s) = value.to_str() {
201                    if let Ok(n) = s.parse::<u64>() {
202                        trace!(header = header, tokens = n, "Got token count from generic header");
203                        return Some(n);
204                    }
205                }
206            }
207        }
208
209        None
210    }
211
212    fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
213        // Try OpenAI format first (most common)
214        let json: Value = serde_json::from_slice(body).ok()?;
215
216        // Try usage.total_tokens (OpenAI style)
217        if let Some(total) = json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|t| t.as_u64()) {
218            return Some(total);
219        }
220
221        // Try usage.input_tokens + output_tokens (Anthropic style)
222        if let Some(usage) = json.get("usage") {
223            let input = usage.get("input_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
224            let output = usage.get("output_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
225            if input > 0 || output > 0 {
226                return Some(input + output);
227            }
228        }
229
230        None
231    }
232
233    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
234        estimate_tokens(body, method)
235    }
236
237    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
238        // Check common headers
239        let candidates = ["x-model", "x-model-id", "model"];
240        for header in candidates {
241            if let Some(model) = headers.get(header).and_then(|v| v.to_str().ok()) {
242                return Some(model.to_string());
243            }
244        }
245
246        // Extract from body
247        let json: Value = serde_json::from_slice(body).ok()?;
248        json.get("model")?.as_str().map(|s| s.to_string())
249    }
250}
251
252// ============================================================================
253// Token Estimation Utilities
254// ============================================================================
255
256/// Estimate tokens from body content using the specified method
257fn estimate_tokens(body: &[u8], method: TokenEstimation) -> u64 {
258    estimate_tokens_with_model(body, method, None)
259}
260
261/// Estimate tokens from body content using the specified method, with optional model hint
262fn estimate_tokens_with_model(body: &[u8], method: TokenEstimation, model: Option<&str>) -> u64 {
263    match method {
264        TokenEstimation::Chars => {
265            // Simple: ~4 characters per token
266            let char_count = String::from_utf8_lossy(body).chars().count();
267            (char_count / 4).max(1) as u64
268        }
269        TokenEstimation::Words => {
270            // ~1.3 tokens per word (English average)
271            let text = String::from_utf8_lossy(body);
272            let word_count = text.split_whitespace().count();
273            ((word_count as f64 * 1.3).ceil() as u64).max(1)
274        }
275        TokenEstimation::Tiktoken => {
276            estimate_tokens_tiktoken(body, model)
277        }
278    }
279}
280
281/// Estimate tokens using tiktoken with model-specific encoding
282///
283/// Uses the global TiktokenManager which:
284/// - Caches BPE instances for reuse
285/// - Selects the correct encoding based on model name
286/// - Parses chat completion requests to extract just the message content
287fn estimate_tokens_tiktoken(body: &[u8], model: Option<&str>) -> u64 {
288    let manager = tiktoken_manager();
289
290    // Use the chat request parser for accurate counting
291    // This extracts message content and handles JSON structure
292    let tokens = manager.count_chat_request(body, model);
293
294    trace!(
295        token_count = tokens,
296        model = ?model,
297        tiktoken_available = manager.is_available(),
298        "Tiktoken token count"
299    );
300
301    tokens
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_openai_body_parsing() {
310        let body = br#"{"usage": {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}}"#;
311        let provider = OpenAiProvider;
312        assert_eq!(provider.tokens_from_body(body), Some(150));
313    }
314
315    #[test]
316    fn test_anthropic_body_parsing() {
317        let body = br#"{"usage": {"input_tokens": 100, "output_tokens": 50}}"#;
318        let provider = AnthropicProvider;
319        assert_eq!(provider.tokens_from_body(body), Some(150));
320    }
321
322    #[test]
323    fn test_token_estimation_chars() {
324        let body = b"Hello world, this is a test message for token counting!";
325        let estimate = estimate_tokens(body, TokenEstimation::Chars);
326        // 57 chars / 4 = 14 tokens
327        assert!(estimate > 0 && estimate < 100);
328    }
329
330    #[test]
331    fn test_model_extraction() {
332        let body = br#"{"model": "gpt-4", "messages": []}"#;
333        let provider = OpenAiProvider;
334        let headers = HeaderMap::new();
335        assert_eq!(provider.extract_model(&headers, body), Some("gpt-4".to_string()));
336    }
337
338    #[test]
339    fn test_token_estimation_tiktoken() {
340        let body = b"Hello world, this is a test message for token counting!";
341        let estimate = estimate_tokens(body, TokenEstimation::Tiktoken);
342        // Should return a reasonable token count regardless of feature flag
343        assert!(estimate > 0 && estimate < 100);
344    }
345
346    #[test]
347    #[cfg(feature = "tiktoken")]
348    fn test_tiktoken_accurate_count() {
349        // "Hello world" is typically 2 tokens with cl100k_base
350        let body = b"Hello world";
351        let estimate = estimate_tokens_tiktoken(body, Some("gpt-4"));
352        assert_eq!(estimate, 2);
353    }
354
355    #[test]
356    fn test_tiktoken_chat_request() {
357        let body = br#"{
358            "model": "gpt-4",
359            "messages": [
360                {"role": "user", "content": "Hello!"}
361            ]
362        }"#;
363        let estimate = estimate_tokens_tiktoken(body, None);
364        // Should count message content plus overhead
365        assert!(estimate > 0);
366    }
367}