Skip to main content

grapsus_proxy/inference/
tokens.rs

1//! Token counting and estimation for inference requests
2//!
3//! Provides utilities for counting tokens from responses and estimating
4//! tokens from requests for rate limiting purposes.
5
6use http::HeaderMap;
7use tracing::{debug, trace};
8use grapsus_config::TokenEstimation;
9
10use super::providers::InferenceProviderAdapter;
11
12/// Token count estimate with metadata
13#[derive(Debug, Clone)]
14pub struct TokenEstimate {
15    /// Estimated token count
16    pub tokens: u64,
17    /// Source of the estimate
18    pub source: TokenSource,
19    /// Model name if known
20    pub model: Option<String>,
21}
22
23/// Source of token count information
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TokenSource {
26    /// From response headers (most accurate)
27    Header,
28    /// From response body JSON (accurate)
29    Body,
30    /// Estimated from content (approximate)
31    Estimated,
32}
33
34/// Token counter for a specific provider
35pub struct TokenCounter {
36    provider: Box<dyn InferenceProviderAdapter>,
37    estimation_method: TokenEstimation,
38}
39
40impl TokenCounter {
41    /// Create a new token counter for the given provider
42    pub fn new(
43        provider: Box<dyn InferenceProviderAdapter>,
44        estimation_method: TokenEstimation,
45    ) -> Self {
46        Self {
47            provider,
48            estimation_method,
49        }
50    }
51
52    /// Estimate tokens for an incoming request (before processing)
53    pub fn estimate_request(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
54        // Try to extract model from request
55        let model = self.provider.extract_model(headers, body);
56
57        // Estimate based on body content
58        let tokens = self
59            .provider
60            .estimate_request_tokens(body, self.estimation_method);
61
62        trace!(
63            provider = self.provider.name(),
64            tokens = tokens,
65            model = ?model,
66            method = ?self.estimation_method,
67            "Estimated request tokens"
68        );
69
70        TokenEstimate {
71            tokens,
72            source: TokenSource::Estimated,
73            model,
74        }
75    }
76
77    /// Get actual tokens from response (after processing)
78    ///
79    /// Uses headers first (preferred), then falls back to body parsing.
80    pub fn tokens_from_response(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
81        // Try headers first (most accurate, no body parsing needed)
82        if let Some(tokens) = self.provider.tokens_from_headers(headers) {
83            debug!(
84                provider = self.provider.name(),
85                tokens = tokens,
86                source = "header",
87                "Got actual token count from response headers"
88            );
89            return TokenEstimate {
90                tokens,
91                source: TokenSource::Header,
92                model: None,
93            };
94        }
95
96        // Fall back to body parsing
97        if let Some(tokens) = self.provider.tokens_from_body(body) {
98            debug!(
99                provider = self.provider.name(),
100                tokens = tokens,
101                source = "body",
102                "Got actual token count from response body"
103            );
104            return TokenEstimate {
105                tokens,
106                source: TokenSource::Body,
107                model: None,
108            };
109        }
110
111        // If we can't get actual tokens, return 0 (estimation already done on request)
112        trace!(
113            provider = self.provider.name(),
114            "Could not extract actual token count from response"
115        );
116        TokenEstimate {
117            tokens: 0,
118            source: TokenSource::Estimated,
119            model: None,
120        }
121    }
122
123    /// Get the provider name
124    pub fn provider_name(&self) -> &'static str {
125        self.provider.name()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::inference::providers::create_provider;
133    use grapsus_config::InferenceProvider;
134
135    #[test]
136    fn test_request_estimation() {
137        let provider = create_provider(&InferenceProvider::OpenAi);
138        let counter = TokenCounter::new(provider, TokenEstimation::Chars);
139
140        let body =
141            br#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello world"}]}"#;
142        let headers = HeaderMap::new();
143
144        let estimate = counter.estimate_request(&headers, body);
145        assert!(estimate.tokens > 0);
146        assert_eq!(estimate.source, TokenSource::Estimated);
147        assert_eq!(estimate.model, Some("gpt-4".to_string()));
148    }
149
150    #[test]
151    fn test_response_parsing() {
152        let provider = create_provider(&InferenceProvider::OpenAi);
153        let counter = TokenCounter::new(provider, TokenEstimation::Chars);
154
155        let body = br#"{"usage": {"total_tokens": 150}}"#;
156        let headers = HeaderMap::new();
157
158        let estimate = counter.tokens_from_response(&headers, body);
159        assert_eq!(estimate.tokens, 150);
160        assert_eq!(estimate.source, TokenSource::Body);
161    }
162}