sentinel_proxy/inference/
tokens.rs

1//! Token counting and estimation for inference requests
2//!
3//! Provides utilities for counting tokens from responses and estimating
4//! tokens from requests for rate limiting purposes.
5
6use http::HeaderMap;
7use sentinel_config::TokenEstimation;
8use tracing::{debug, trace};
9
10use super::providers::InferenceProviderAdapter;
11
12/// Token count estimate with metadata
13#[derive(Debug, Clone)]
14pub struct TokenEstimate {
15    /// Estimated token count
16    pub tokens: u64,
17    /// Source of the estimate
18    pub source: TokenSource,
19    /// Model name if known
20    pub model: Option<String>,
21}
22
23/// Source of token count information
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TokenSource {
26    /// From response headers (most accurate)
27    Header,
28    /// From response body JSON (accurate)
29    Body,
30    /// Estimated from content (approximate)
31    Estimated,
32}
33
34/// Token counter for a specific provider
35pub struct TokenCounter {
36    provider: Box<dyn InferenceProviderAdapter>,
37    estimation_method: TokenEstimation,
38}
39
40impl TokenCounter {
41    /// Create a new token counter for the given provider
42    pub fn new(provider: Box<dyn InferenceProviderAdapter>, estimation_method: TokenEstimation) -> Self {
43        Self {
44            provider,
45            estimation_method,
46        }
47    }
48
49    /// Estimate tokens for an incoming request (before processing)
50    pub fn estimate_request(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
51        // Try to extract model from request
52        let model = self.provider.extract_model(headers, body);
53
54        // Estimate based on body content
55        let tokens = self.provider.estimate_request_tokens(body, self.estimation_method);
56
57        trace!(
58            provider = self.provider.name(),
59            tokens = tokens,
60            model = ?model,
61            method = ?self.estimation_method,
62            "Estimated request tokens"
63        );
64
65        TokenEstimate {
66            tokens,
67            source: TokenSource::Estimated,
68            model,
69        }
70    }
71
72    /// Get actual tokens from response (after processing)
73    ///
74    /// Uses headers first (preferred), then falls back to body parsing.
75    pub fn tokens_from_response(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
76        // Try headers first (most accurate, no body parsing needed)
77        if let Some(tokens) = self.provider.tokens_from_headers(headers) {
78            debug!(
79                provider = self.provider.name(),
80                tokens = tokens,
81                source = "header",
82                "Got actual token count from response headers"
83            );
84            return TokenEstimate {
85                tokens,
86                source: TokenSource::Header,
87                model: None,
88            };
89        }
90
91        // Fall back to body parsing
92        if let Some(tokens) = self.provider.tokens_from_body(body) {
93            debug!(
94                provider = self.provider.name(),
95                tokens = tokens,
96                source = "body",
97                "Got actual token count from response body"
98            );
99            return TokenEstimate {
100                tokens,
101                source: TokenSource::Body,
102                model: None,
103            };
104        }
105
106        // If we can't get actual tokens, return 0 (estimation already done on request)
107        trace!(
108            provider = self.provider.name(),
109            "Could not extract actual token count from response"
110        );
111        TokenEstimate {
112            tokens: 0,
113            source: TokenSource::Estimated,
114            model: None,
115        }
116    }
117
118    /// Get the provider name
119    pub fn provider_name(&self) -> &'static str {
120        self.provider.name()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::inference::providers::create_provider;
128    use sentinel_config::InferenceProvider;
129
130    #[test]
131    fn test_request_estimation() {
132        let provider = create_provider(&InferenceProvider::OpenAi);
133        let counter = TokenCounter::new(provider, TokenEstimation::Chars);
134
135        let body = br#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello world"}]}"#;
136        let headers = HeaderMap::new();
137
138        let estimate = counter.estimate_request(&headers, body);
139        assert!(estimate.tokens > 0);
140        assert_eq!(estimate.source, TokenSource::Estimated);
141        assert_eq!(estimate.model, Some("gpt-4".to_string()));
142    }
143
144    #[test]
145    fn test_response_parsing() {
146        let provider = create_provider(&InferenceProvider::OpenAi);
147        let counter = TokenCounter::new(provider, TokenEstimation::Chars);
148
149        let body = br#"{"usage": {"total_tokens": 150}}"#;
150        let headers = HeaderMap::new();
151
152        let estimate = counter.tokens_from_response(&headers, body);
153        assert_eq!(estimate.tokens, 150);
154        assert_eq!(estimate.source, TokenSource::Body);
155    }
156}