Skip to main content

grapsus_proxy/inference/
providers.rs

1//! Inference provider adapters for token extraction
2//!
3//! Each provider has specific headers and body formats for token information:
4//! - OpenAI: `x-ratelimit-remaining-tokens` header, `usage.total_tokens` in body
5//! - Anthropic: `anthropic-ratelimit-tokens-remaining` header, `usage.input_tokens + output_tokens`
6//! - Generic: `x-tokens-used` header, estimation fallback
7
8use http::HeaderMap;
9use serde_json::Value;
10use tracing::trace;
11use grapsus_config::{InferenceProvider, TokenEstimation};
12
13use super::tiktoken::tiktoken_manager;
14
15/// Trait for provider-specific token extraction and estimation
16pub trait InferenceProviderAdapter: Send + Sync {
17    /// Provider name for logging/metrics
18    fn name(&self) -> &'static str;
19
20    /// Extract token count from response headers (primary method)
21    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64>;
22
23    /// Extract token count from response body (fallback method)
24    fn tokens_from_body(&self, body: &[u8]) -> Option<u64>;
25
26    /// Estimate tokens from request body using the specified method
27    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64;
28
29    /// Extract model name from request (header or body)
30    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String>;
31}
32
33/// Create a provider adapter based on provider type
34pub fn create_provider(provider: &InferenceProvider) -> Box<dyn InferenceProviderAdapter> {
35    match provider {
36        InferenceProvider::OpenAi => Box::new(OpenAiProvider),
37        InferenceProvider::Anthropic => Box::new(AnthropicProvider),
38        InferenceProvider::Generic => Box::new(GenericProvider),
39    }
40}
41
42// ============================================================================
43// OpenAI Provider
44// ============================================================================
45
46struct OpenAiProvider;
47
48impl InferenceProviderAdapter for OpenAiProvider {
49    fn name(&self) -> &'static str {
50        "openai"
51    }
52
53    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
54        // OpenAI uses several headers:
55        // - x-ratelimit-remaining-tokens
56        // - x-ratelimit-limit-tokens
57        // - x-ratelimit-used-tokens (what we want)
58        if let Some(value) = headers.get("x-ratelimit-used-tokens") {
59            if let Ok(s) = value.to_str() {
60                if let Ok(n) = s.parse::<u64>() {
61                    trace!(
62                        tokens = n,
63                        "Got token count from OpenAI x-ratelimit-used-tokens"
64                    );
65                    return Some(n);
66                }
67            }
68        }
69
70        // Fallback: calculate from limit - remaining
71        let limit = headers
72            .get("x-ratelimit-limit-tokens")
73            .and_then(|v| v.to_str().ok())
74            .and_then(|s| s.parse::<u64>().ok());
75
76        let remaining = headers
77            .get("x-ratelimit-remaining-tokens")
78            .and_then(|v| v.to_str().ok())
79            .and_then(|s| s.parse::<u64>().ok());
80
81        if let (Some(l), Some(r)) = (limit, remaining) {
82            let used = l.saturating_sub(r);
83            trace!(
84                limit = l,
85                remaining = r,
86                used = used,
87                "Calculated token usage from OpenAI headers"
88            );
89            return Some(used);
90        }
91
92        None
93    }
94
95    fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
96        // OpenAI response format:
97        // { "usage": { "prompt_tokens": N, "completion_tokens": M, "total_tokens": T } }
98        let json: Value = serde_json::from_slice(body).ok()?;
99        let total = json.get("usage")?.get("total_tokens")?.as_u64();
100        if let Some(t) = total {
101            trace!(tokens = t, "Got token count from OpenAI response body");
102        }
103        total
104    }
105
106    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
107        estimate_tokens(body, method)
108    }
109
110    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
111        // Check header first
112        if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
113            return Some(model.to_string());
114        }
115
116        // Extract from body: { "model": "gpt-4" }
117        let json: Value = serde_json::from_slice(body).ok()?;
118        json.get("model")?.as_str().map(|s| s.to_string())
119    }
120}
121
122// ============================================================================
123// Anthropic Provider
124// ============================================================================
125
126struct AnthropicProvider;
127
128impl InferenceProviderAdapter for AnthropicProvider {
129    fn name(&self) -> &'static str {
130        "anthropic"
131    }
132
133    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
134        // Anthropic uses:
135        // - anthropic-ratelimit-tokens-limit
136        // - anthropic-ratelimit-tokens-remaining
137        // - anthropic-ratelimit-tokens-reset
138        let limit = headers
139            .get("anthropic-ratelimit-tokens-limit")
140            .and_then(|v| v.to_str().ok())
141            .and_then(|s| s.parse::<u64>().ok());
142
143        let remaining = headers
144            .get("anthropic-ratelimit-tokens-remaining")
145            .and_then(|v| v.to_str().ok())
146            .and_then(|s| s.parse::<u64>().ok());
147
148        if let (Some(l), Some(r)) = (limit, remaining) {
149            let used = l.saturating_sub(r);
150            trace!(
151                limit = l,
152                remaining = r,
153                used = used,
154                "Calculated token usage from Anthropic headers"
155            );
156            return Some(used);
157        }
158
159        None
160    }
161
162    fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
163        // Anthropic response format:
164        // { "usage": { "input_tokens": N, "output_tokens": M } }
165        let json: Value = serde_json::from_slice(body).ok()?;
166        let usage = json.get("usage")?;
167
168        let input = usage.get("input_tokens")?.as_u64().unwrap_or(0);
169        let output = usage.get("output_tokens")?.as_u64().unwrap_or(0);
170        let total = input + output;
171
172        trace!(
173            input = input,
174            output = output,
175            total = total,
176            "Got token count from Anthropic response body"
177        );
178        Some(total)
179    }
180
181    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
182        estimate_tokens(body, method)
183    }
184
185    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
186        // Check header first
187        if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
188            return Some(model.to_string());
189        }
190
191        // Anthropic puts model in body: { "model": "claude-3-opus-20240229" }
192        let json: Value = serde_json::from_slice(body).ok()?;
193        json.get("model")?.as_str().map(|s| s.to_string())
194    }
195}
196
197// ============================================================================
198// Generic Provider
199// ============================================================================
200
201struct GenericProvider;
202
203impl InferenceProviderAdapter for GenericProvider {
204    fn name(&self) -> &'static str {
205        "generic"
206    }
207
208    fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
209        // Generic provider looks for common headers
210        let candidates = ["x-tokens-used", "x-token-count", "x-total-tokens"];
211
212        for header in candidates {
213            if let Some(value) = headers.get(header) {
214                if let Ok(s) = value.to_str() {
215                    if let Ok(n) = s.parse::<u64>() {
216                        trace!(
217                            header = header,
218                            tokens = n,
219                            "Got token count from generic header"
220                        );
221                        return Some(n);
222                    }
223                }
224            }
225        }
226
227        None
228    }
229
230    fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
231        // Try OpenAI format first (most common)
232        let json: Value = serde_json::from_slice(body).ok()?;
233
234        // Try usage.total_tokens (OpenAI style)
235        if let Some(total) = json
236            .get("usage")
237            .and_then(|u| u.get("total_tokens"))
238            .and_then(|t| t.as_u64())
239        {
240            return Some(total);
241        }
242
243        // Try usage.input_tokens + output_tokens (Anthropic style)
244        if let Some(usage) = json.get("usage") {
245            let input = usage
246                .get("input_tokens")
247                .and_then(|t| t.as_u64())
248                .unwrap_or(0);
249            let output = usage
250                .get("output_tokens")
251                .and_then(|t| t.as_u64())
252                .unwrap_or(0);
253            if input > 0 || output > 0 {
254                return Some(input + output);
255            }
256        }
257
258        None
259    }
260
261    fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
262        estimate_tokens(body, method)
263    }
264
265    fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
266        // Check common headers
267        let candidates = ["x-model", "x-model-id", "model"];
268        for header in candidates {
269            if let Some(model) = headers.get(header).and_then(|v| v.to_str().ok()) {
270                return Some(model.to_string());
271            }
272        }
273
274        // Extract from body
275        let json: Value = serde_json::from_slice(body).ok()?;
276        json.get("model")?.as_str().map(|s| s.to_string())
277    }
278}
279
280// ============================================================================
281// Token Estimation Utilities
282// ============================================================================
283
284/// Estimate tokens from body content using the specified method
285fn estimate_tokens(body: &[u8], method: TokenEstimation) -> u64 {
286    estimate_tokens_with_model(body, method, None)
287}
288
289/// Estimate tokens from body content using the specified method, with optional model hint
290fn estimate_tokens_with_model(body: &[u8], method: TokenEstimation, model: Option<&str>) -> u64 {
291    match method {
292        TokenEstimation::Chars => {
293            // Simple: ~4 characters per token
294            let char_count = String::from_utf8_lossy(body).chars().count();
295            (char_count / 4).max(1) as u64
296        }
297        TokenEstimation::Words => {
298            // ~1.3 tokens per word (English average)
299            let text = String::from_utf8_lossy(body);
300            let word_count = text.split_whitespace().count();
301            ((word_count as f64 * 1.3).ceil() as u64).max(1)
302        }
303        TokenEstimation::Tiktoken => estimate_tokens_tiktoken(body, model),
304    }
305}
306
307/// Estimate tokens using tiktoken with model-specific encoding
308///
309/// Uses the global TiktokenManager which:
310/// - Caches BPE instances for reuse
311/// - Selects the correct encoding based on model name
312/// - Parses chat completion requests to extract just the message content
313fn estimate_tokens_tiktoken(body: &[u8], model: Option<&str>) -> u64 {
314    let manager = tiktoken_manager();
315
316    // Use the chat request parser for accurate counting
317    // This extracts message content and handles JSON structure
318    let tokens = manager.count_chat_request(body, model);
319
320    trace!(
321        token_count = tokens,
322        model = ?model,
323        tiktoken_available = manager.is_available(),
324        "Tiktoken token count"
325    );
326
327    tokens
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_openai_body_parsing() {
336        let body =
337            br#"{"usage": {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}}"#;
338        let provider = OpenAiProvider;
339        assert_eq!(provider.tokens_from_body(body), Some(150));
340    }
341
342    #[test]
343    fn test_anthropic_body_parsing() {
344        let body = br#"{"usage": {"input_tokens": 100, "output_tokens": 50}}"#;
345        let provider = AnthropicProvider;
346        assert_eq!(provider.tokens_from_body(body), Some(150));
347    }
348
349    #[test]
350    fn test_token_estimation_chars() {
351        let body = b"Hello world, this is a test message for token counting!";
352        let estimate = estimate_tokens(body, TokenEstimation::Chars);
353        // 57 chars / 4 = 14 tokens
354        assert!(estimate > 0 && estimate < 100);
355    }
356
357    #[test]
358    fn test_model_extraction() {
359        let body = br#"{"model": "gpt-4", "messages": []}"#;
360        let provider = OpenAiProvider;
361        let headers = HeaderMap::new();
362        assert_eq!(
363            provider.extract_model(&headers, body),
364            Some("gpt-4".to_string())
365        );
366    }
367
368    #[test]
369    fn test_token_estimation_tiktoken() {
370        let body = b"Hello world, this is a test message for token counting!";
371        let estimate = estimate_tokens(body, TokenEstimation::Tiktoken);
372        // Should return a reasonable token count regardless of feature flag
373        assert!(estimate > 0 && estimate < 100);
374    }
375
376    #[test]
377    #[cfg(feature = "tiktoken")]
378    fn test_tiktoken_accurate_count() {
379        // "Hello world" is typically 2 tokens with cl100k_base
380        let body = b"Hello world";
381        let estimate = estimate_tokens_tiktoken(body, Some("gpt-4"));
382        assert_eq!(estimate, 2);
383    }
384
385    #[test]
386    fn test_tiktoken_chat_request() {
387        let body = br#"{
388            "model": "gpt-4",
389            "messages": [
390                {"role": "user", "content": "Hello!"}
391            ]
392        }"#;
393        let estimate = estimate_tokens_tiktoken(body, None);
394        // Should count message content plus overhead
395        assert!(estimate > 0);
396    }
397}