grapsus_proxy/inference/
tokens.rs1use http::HeaderMap;
7use tracing::{debug, trace};
8use grapsus_config::TokenEstimation;
9
10use super::providers::InferenceProviderAdapter;
11
12#[derive(Debug, Clone)]
14pub struct TokenEstimate {
15 pub tokens: u64,
17 pub source: TokenSource,
19 pub model: Option<String>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TokenSource {
26 Header,
28 Body,
30 Estimated,
32}
33
34pub struct TokenCounter {
36 provider: Box<dyn InferenceProviderAdapter>,
37 estimation_method: TokenEstimation,
38}
39
40impl TokenCounter {
41 pub fn new(
43 provider: Box<dyn InferenceProviderAdapter>,
44 estimation_method: TokenEstimation,
45 ) -> Self {
46 Self {
47 provider,
48 estimation_method,
49 }
50 }
51
52 pub fn estimate_request(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
54 let model = self.provider.extract_model(headers, body);
56
57 let tokens = self
59 .provider
60 .estimate_request_tokens(body, self.estimation_method);
61
62 trace!(
63 provider = self.provider.name(),
64 tokens = tokens,
65 model = ?model,
66 method = ?self.estimation_method,
67 "Estimated request tokens"
68 );
69
70 TokenEstimate {
71 tokens,
72 source: TokenSource::Estimated,
73 model,
74 }
75 }
76
77 pub fn tokens_from_response(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
81 if let Some(tokens) = self.provider.tokens_from_headers(headers) {
83 debug!(
84 provider = self.provider.name(),
85 tokens = tokens,
86 source = "header",
87 "Got actual token count from response headers"
88 );
89 return TokenEstimate {
90 tokens,
91 source: TokenSource::Header,
92 model: None,
93 };
94 }
95
96 if let Some(tokens) = self.provider.tokens_from_body(body) {
98 debug!(
99 provider = self.provider.name(),
100 tokens = tokens,
101 source = "body",
102 "Got actual token count from response body"
103 );
104 return TokenEstimate {
105 tokens,
106 source: TokenSource::Body,
107 model: None,
108 };
109 }
110
111 trace!(
113 provider = self.provider.name(),
114 "Could not extract actual token count from response"
115 );
116 TokenEstimate {
117 tokens: 0,
118 source: TokenSource::Estimated,
119 model: None,
120 }
121 }
122
123 pub fn provider_name(&self) -> &'static str {
125 self.provider.name()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::inference::providers::create_provider;
133 use grapsus_config::InferenceProvider;
134
135 #[test]
136 fn test_request_estimation() {
137 let provider = create_provider(&InferenceProvider::OpenAi);
138 let counter = TokenCounter::new(provider, TokenEstimation::Chars);
139
140 let body =
141 br#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello world"}]}"#;
142 let headers = HeaderMap::new();
143
144 let estimate = counter.estimate_request(&headers, body);
145 assert!(estimate.tokens > 0);
146 assert_eq!(estimate.source, TokenSource::Estimated);
147 assert_eq!(estimate.model, Some("gpt-4".to_string()));
148 }
149
150 #[test]
151 fn test_response_parsing() {
152 let provider = create_provider(&InferenceProvider::OpenAi);
153 let counter = TokenCounter::new(provider, TokenEstimation::Chars);
154
155 let body = br#"{"usage": {"total_tokens": 150}}"#;
156 let headers = HeaderMap::new();
157
158 let estimate = counter.tokens_from_response(&headers, body);
159 assert_eq!(estimate.tokens, 150);
160 assert_eq!(estimate.source, TokenSource::Body);
161 }
162}