1#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
9pub struct TokenUsage {
10 pub prompt_tokens: Option<u64>,
12 pub completion_tokens: Option<u64>,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TokenSource {
26 Estimated,
28 Provided,
31 Definite,
33}
34
35impl TokenSource {
36 pub fn weaker(self, other: Self) -> Self {
41 match (self, other) {
42 (Self::Estimated, _) | (_, Self::Estimated) => Self::Estimated,
43 (Self::Provided, _) | (_, Self::Provided) => Self::Provided,
44 _ => Self::Definite,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
55pub struct TokenCount {
56 pub tokens: u64,
57 pub source: TokenSource,
58}
59
60impl TokenCount {
61 pub(crate) fn new(source: TokenSource) -> Self {
63 Self { tokens: 0, source }
64 }
65
66 pub(crate) fn accumulate(&mut self, tokens: u64, source: TokenSource) {
68 self.tokens += tokens;
69 self.source = self.source.weaker(source);
70 }
71
72 pub(crate) fn to_json(&self) -> serde_json::Value {
73 serde_json::json!({
74 "tokens": self.tokens,
75 "source": self.source,
76 })
77 }
78}
79
80pub(crate) fn estimate_tokens(text: &str) -> u64 {
91 let mut ascii_chars: u64 = 0;
92 let mut non_ascii_chars: u64 = 0;
93 for ch in text.chars() {
94 if ch.is_ascii() {
95 ascii_chars += 1;
96 } else {
97 non_ascii_chars += 1;
98 }
99 }
100 let ascii_tokens = ascii_chars.div_ceil(4);
102 let non_ascii_tokens = (non_ascii_chars * 2).div_ceil(3);
103 ascii_tokens + non_ascii_tokens
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
110
111 #[test]
112 fn estimate_tokens_empty() {
113 assert_eq!(estimate_tokens(""), 0);
114 }
115
116 #[test]
117 fn estimate_tokens_ascii() {
118 assert_eq!(estimate_tokens("hello world"), 3);
120 }
121
122 #[test]
123 fn token_source_weaker_estimated_wins() {
124 assert_eq!(
125 TokenSource::Estimated.weaker(TokenSource::Definite),
126 TokenSource::Estimated
127 );
128 assert_eq!(
129 TokenSource::Definite.weaker(TokenSource::Estimated),
130 TokenSource::Estimated
131 );
132 }
133
134 #[test]
135 fn token_source_weaker_provided_over_definite() {
136 assert_eq!(
137 TokenSource::Provided.weaker(TokenSource::Definite),
138 TokenSource::Provided
139 );
140 }
141
142 #[test]
143 fn token_source_weaker_same_returns_same() {
144 assert_eq!(
145 TokenSource::Definite.weaker(TokenSource::Definite),
146 TokenSource::Definite
147 );
148 assert_eq!(
149 TokenSource::Estimated.weaker(TokenSource::Estimated),
150 TokenSource::Estimated
151 );
152 }
153
154 #[test]
155 fn token_count_accumulate_degrades_source() {
156 let mut tc = TokenCount::new(TokenSource::Definite);
157 tc.accumulate(10, TokenSource::Definite);
158 assert_eq!(tc.source, TokenSource::Definite);
159
160 tc.accumulate(5, TokenSource::Provided);
161 assert_eq!(tc.tokens, 15);
162 assert_eq!(tc.source, TokenSource::Provided);
163
164 tc.accumulate(3, TokenSource::Estimated);
165 assert_eq!(tc.tokens, 18);
166 assert_eq!(tc.source, TokenSource::Estimated);
167 }
168
169 #[test]
170 fn token_count_to_json_format() {
171 let tc = TokenCount {
172 tokens: 42,
173 source: TokenSource::Provided,
174 };
175 let json = tc.to_json();
176 assert_eq!(json["tokens"], 42);
177 assert_eq!(json["source"], "provided");
178 }
179
180 #[test]
181 fn token_source_serde_roundtrip() {
182 let source = TokenSource::Estimated;
183 let json = serde_json::to_string(&source).unwrap();
184 assert_eq!(json, r#""estimated""#);
185 let restored: TokenSource = serde_json::from_str(&json).unwrap();
186 assert_eq!(restored, source);
187 }
188
189 #[test]
190 fn estimate_tokens_cjk() {
191 assert_eq!(estimate_tokens("あいう"), 2);
193 }
194
195 #[test]
196 fn estimate_tokens_mixed() {
197 assert_eq!(estimate_tokens("hello あ"), 3);
200 }
201
202 #[test]
203 fn token_estimation_in_stats() {
204 let metrics = ExecutionMetrics::new();
205 let observer = metrics.create_observer();
206
207 let queries = vec![LlmQuery {
208 id: QueryId::single(),
209 prompt: "What is 2+2?".into(), system: Some("Expert".into()), max_tokens: 50,
212 grounded: false,
213 underspecified: false,
214 }];
215 observer.on_paused(&queries);
216 observer.on_response_fed(&QueryId::single(), "4", None); observer.on_resumed();
218 observer.on_completed(&serde_json::json!(null));
219
220 let json = metrics.to_json();
221 let auto = &json["auto"];
222 assert_eq!(auto["prompt_tokens"]["tokens"], 5); assert_eq!(auto["prompt_tokens"]["source"], "estimated");
224 assert_eq!(auto["response_tokens"]["tokens"], 1);
225 assert_eq!(auto["response_tokens"]["source"], "estimated");
226 assert_eq!(auto["total_tokens"]["tokens"], 6);
227 assert_eq!(auto["total_tokens"]["source"], "estimated");
228 }
229
230 #[test]
231 fn token_estimation_accumulates_across_rounds() {
232 let metrics = ExecutionMetrics::new();
233 let observer = metrics.create_observer();
234
235 let q = vec![LlmQuery {
236 id: QueryId::single(),
237 prompt: "test".into(), system: None,
239 max_tokens: 10,
240 grounded: false,
241 underspecified: false,
242 }];
243
244 for _ in 0..3 {
246 observer.on_paused(&q);
247 observer.on_response_fed(&QueryId::single(), "reply here", None); observer.on_resumed();
249 }
250 observer.on_completed(&serde_json::json!(null));
251
252 let json = metrics.to_json();
253 let auto = &json["auto"];
254 assert_eq!(auto["prompt_tokens"]["tokens"], 3); assert_eq!(auto["prompt_tokens"]["source"], "estimated");
256 assert_eq!(auto["response_tokens"]["tokens"], 9); assert_eq!(auto["response_tokens"]["source"], "estimated");
258 }
259}