grapsus_proxy/inference/
providers.rs1use http::HeaderMap;
9use serde_json::Value;
10use tracing::trace;
11use grapsus_config::{InferenceProvider, TokenEstimation};
12
13use super::tiktoken::tiktoken_manager;
14
15pub trait InferenceProviderAdapter: Send + Sync {
17 fn name(&self) -> &'static str;
19
20 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64>;
22
23 fn tokens_from_body(&self, body: &[u8]) -> Option<u64>;
25
26 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64;
28
29 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String>;
31}
32
33pub fn create_provider(provider: &InferenceProvider) -> Box<dyn InferenceProviderAdapter> {
35 match provider {
36 InferenceProvider::OpenAi => Box::new(OpenAiProvider),
37 InferenceProvider::Anthropic => Box::new(AnthropicProvider),
38 InferenceProvider::Generic => Box::new(GenericProvider),
39 }
40}
41
42struct OpenAiProvider;
47
48impl InferenceProviderAdapter for OpenAiProvider {
49 fn name(&self) -> &'static str {
50 "openai"
51 }
52
53 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
54 if let Some(value) = headers.get("x-ratelimit-used-tokens") {
59 if let Ok(s) = value.to_str() {
60 if let Ok(n) = s.parse::<u64>() {
61 trace!(
62 tokens = n,
63 "Got token count from OpenAI x-ratelimit-used-tokens"
64 );
65 return Some(n);
66 }
67 }
68 }
69
70 let limit = headers
72 .get("x-ratelimit-limit-tokens")
73 .and_then(|v| v.to_str().ok())
74 .and_then(|s| s.parse::<u64>().ok());
75
76 let remaining = headers
77 .get("x-ratelimit-remaining-tokens")
78 .and_then(|v| v.to_str().ok())
79 .and_then(|s| s.parse::<u64>().ok());
80
81 if let (Some(l), Some(r)) = (limit, remaining) {
82 let used = l.saturating_sub(r);
83 trace!(
84 limit = l,
85 remaining = r,
86 used = used,
87 "Calculated token usage from OpenAI headers"
88 );
89 return Some(used);
90 }
91
92 None
93 }
94
95 fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
96 let json: Value = serde_json::from_slice(body).ok()?;
99 let total = json.get("usage")?.get("total_tokens")?.as_u64();
100 if let Some(t) = total {
101 trace!(tokens = t, "Got token count from OpenAI response body");
102 }
103 total
104 }
105
106 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
107 estimate_tokens(body, method)
108 }
109
110 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
111 if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
113 return Some(model.to_string());
114 }
115
116 let json: Value = serde_json::from_slice(body).ok()?;
118 json.get("model")?.as_str().map(|s| s.to_string())
119 }
120}
121
122struct AnthropicProvider;
127
128impl InferenceProviderAdapter for AnthropicProvider {
129 fn name(&self) -> &'static str {
130 "anthropic"
131 }
132
133 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
134 let limit = headers
139 .get("anthropic-ratelimit-tokens-limit")
140 .and_then(|v| v.to_str().ok())
141 .and_then(|s| s.parse::<u64>().ok());
142
143 let remaining = headers
144 .get("anthropic-ratelimit-tokens-remaining")
145 .and_then(|v| v.to_str().ok())
146 .and_then(|s| s.parse::<u64>().ok());
147
148 if let (Some(l), Some(r)) = (limit, remaining) {
149 let used = l.saturating_sub(r);
150 trace!(
151 limit = l,
152 remaining = r,
153 used = used,
154 "Calculated token usage from Anthropic headers"
155 );
156 return Some(used);
157 }
158
159 None
160 }
161
162 fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
163 let json: Value = serde_json::from_slice(body).ok()?;
166 let usage = json.get("usage")?;
167
168 let input = usage.get("input_tokens")?.as_u64().unwrap_or(0);
169 let output = usage.get("output_tokens")?.as_u64().unwrap_or(0);
170 let total = input + output;
171
172 trace!(
173 input = input,
174 output = output,
175 total = total,
176 "Got token count from Anthropic response body"
177 );
178 Some(total)
179 }
180
181 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
182 estimate_tokens(body, method)
183 }
184
185 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
186 if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
188 return Some(model.to_string());
189 }
190
191 let json: Value = serde_json::from_slice(body).ok()?;
193 json.get("model")?.as_str().map(|s| s.to_string())
194 }
195}
196
197struct GenericProvider;
202
203impl InferenceProviderAdapter for GenericProvider {
204 fn name(&self) -> &'static str {
205 "generic"
206 }
207
208 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
209 let candidates = ["x-tokens-used", "x-token-count", "x-total-tokens"];
211
212 for header in candidates {
213 if let Some(value) = headers.get(header) {
214 if let Ok(s) = value.to_str() {
215 if let Ok(n) = s.parse::<u64>() {
216 trace!(
217 header = header,
218 tokens = n,
219 "Got token count from generic header"
220 );
221 return Some(n);
222 }
223 }
224 }
225 }
226
227 None
228 }
229
230 fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
231 let json: Value = serde_json::from_slice(body).ok()?;
233
234 if let Some(total) = json
236 .get("usage")
237 .and_then(|u| u.get("total_tokens"))
238 .and_then(|t| t.as_u64())
239 {
240 return Some(total);
241 }
242
243 if let Some(usage) = json.get("usage") {
245 let input = usage
246 .get("input_tokens")
247 .and_then(|t| t.as_u64())
248 .unwrap_or(0);
249 let output = usage
250 .get("output_tokens")
251 .and_then(|t| t.as_u64())
252 .unwrap_or(0);
253 if input > 0 || output > 0 {
254 return Some(input + output);
255 }
256 }
257
258 None
259 }
260
261 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
262 estimate_tokens(body, method)
263 }
264
265 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
266 let candidates = ["x-model", "x-model-id", "model"];
268 for header in candidates {
269 if let Some(model) = headers.get(header).and_then(|v| v.to_str().ok()) {
270 return Some(model.to_string());
271 }
272 }
273
274 let json: Value = serde_json::from_slice(body).ok()?;
276 json.get("model")?.as_str().map(|s| s.to_string())
277 }
278}
279
280fn estimate_tokens(body: &[u8], method: TokenEstimation) -> u64 {
286 estimate_tokens_with_model(body, method, None)
287}
288
289fn estimate_tokens_with_model(body: &[u8], method: TokenEstimation, model: Option<&str>) -> u64 {
291 match method {
292 TokenEstimation::Chars => {
293 let char_count = String::from_utf8_lossy(body).chars().count();
295 (char_count / 4).max(1) as u64
296 }
297 TokenEstimation::Words => {
298 let text = String::from_utf8_lossy(body);
300 let word_count = text.split_whitespace().count();
301 ((word_count as f64 * 1.3).ceil() as u64).max(1)
302 }
303 TokenEstimation::Tiktoken => estimate_tokens_tiktoken(body, model),
304 }
305}
306
307fn estimate_tokens_tiktoken(body: &[u8], model: Option<&str>) -> u64 {
314 let manager = tiktoken_manager();
315
316 let tokens = manager.count_chat_request(body, model);
319
320 trace!(
321 token_count = tokens,
322 model = ?model,
323 tiktoken_available = manager.is_available(),
324 "Tiktoken token count"
325 );
326
327 tokens
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_openai_body_parsing() {
336 let body =
337 br#"{"usage": {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}}"#;
338 let provider = OpenAiProvider;
339 assert_eq!(provider.tokens_from_body(body), Some(150));
340 }
341
342 #[test]
343 fn test_anthropic_body_parsing() {
344 let body = br#"{"usage": {"input_tokens": 100, "output_tokens": 50}}"#;
345 let provider = AnthropicProvider;
346 assert_eq!(provider.tokens_from_body(body), Some(150));
347 }
348
349 #[test]
350 fn test_token_estimation_chars() {
351 let body = b"Hello world, this is a test message for token counting!";
352 let estimate = estimate_tokens(body, TokenEstimation::Chars);
353 assert!(estimate > 0 && estimate < 100);
355 }
356
357 #[test]
358 fn test_model_extraction() {
359 let body = br#"{"model": "gpt-4", "messages": []}"#;
360 let provider = OpenAiProvider;
361 let headers = HeaderMap::new();
362 assert_eq!(
363 provider.extract_model(&headers, body),
364 Some("gpt-4".to_string())
365 );
366 }
367
368 #[test]
369 fn test_token_estimation_tiktoken() {
370 let body = b"Hello world, this is a test message for token counting!";
371 let estimate = estimate_tokens(body, TokenEstimation::Tiktoken);
372 assert!(estimate > 0 && estimate < 100);
374 }
375
376 #[test]
377 #[cfg(feature = "tiktoken")]
378 fn test_tiktoken_accurate_count() {
379 let body = b"Hello world";
381 let estimate = estimate_tokens_tiktoken(body, Some("gpt-4"));
382 assert_eq!(estimate, 2);
383 }
384
385 #[test]
386 fn test_tiktoken_chat_request() {
387 let body = br#"{
388 "model": "gpt-4",
389 "messages": [
390 {"role": "user", "content": "Hello!"}
391 ]
392 }"#;
393 let estimate = estimate_tokens_tiktoken(body, None);
394 assert!(estimate > 0);
396 }
397}