1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum TokenSource {
11 Estimated,
13 Provided,
16 Definite,
18}
19
20impl TokenSource {
21 pub fn weaker(self, other: Self) -> Self {
26 match (self, other) {
27 (Self::Estimated, _) | (_, Self::Estimated) => Self::Estimated,
28 (Self::Provided, _) | (_, Self::Provided) => Self::Provided,
29 _ => Self::Definite,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
40pub struct TokenCount {
41 pub tokens: u64,
42 pub source: TokenSource,
43}
44
45impl TokenCount {
46 pub(crate) fn new(source: TokenSource) -> Self {
48 Self { tokens: 0, source }
49 }
50
51 pub(crate) fn accumulate(&mut self, tokens: u64, source: TokenSource) {
53 self.tokens += tokens;
54 self.source = self.source.weaker(source);
55 }
56
57 pub(crate) fn to_json(&self) -> serde_json::Value {
58 serde_json::json!({
59 "tokens": self.tokens,
60 "source": self.source,
61 })
62 }
63}
64
65pub(crate) fn estimate_tokens(text: &str) -> u64 {
76 let mut ascii_chars: u64 = 0;
77 let mut non_ascii_chars: u64 = 0;
78 for ch in text.chars() {
79 if ch.is_ascii() {
80 ascii_chars += 1;
81 } else {
82 non_ascii_chars += 1;
83 }
84 }
85 let ascii_tokens = ascii_chars.div_ceil(4);
87 let non_ascii_tokens = (non_ascii_chars * 2).div_ceil(3);
88 ascii_tokens + non_ascii_tokens
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
95
96 #[test]
97 fn estimate_tokens_empty() {
98 assert_eq!(estimate_tokens(""), 0);
99 }
100
101 #[test]
102 fn estimate_tokens_ascii() {
103 assert_eq!(estimate_tokens("hello world"), 3);
105 }
106
107 #[test]
108 fn token_source_weaker_estimated_wins() {
109 assert_eq!(
110 TokenSource::Estimated.weaker(TokenSource::Definite),
111 TokenSource::Estimated
112 );
113 assert_eq!(
114 TokenSource::Definite.weaker(TokenSource::Estimated),
115 TokenSource::Estimated
116 );
117 }
118
119 #[test]
120 fn token_source_weaker_provided_over_definite() {
121 assert_eq!(
122 TokenSource::Provided.weaker(TokenSource::Definite),
123 TokenSource::Provided
124 );
125 }
126
127 #[test]
128 fn token_source_weaker_same_returns_same() {
129 assert_eq!(
130 TokenSource::Definite.weaker(TokenSource::Definite),
131 TokenSource::Definite
132 );
133 assert_eq!(
134 TokenSource::Estimated.weaker(TokenSource::Estimated),
135 TokenSource::Estimated
136 );
137 }
138
139 #[test]
140 fn token_count_accumulate_degrades_source() {
141 let mut tc = TokenCount::new(TokenSource::Definite);
142 tc.accumulate(10, TokenSource::Definite);
143 assert_eq!(tc.source, TokenSource::Definite);
144
145 tc.accumulate(5, TokenSource::Provided);
146 assert_eq!(tc.tokens, 15);
147 assert_eq!(tc.source, TokenSource::Provided);
148
149 tc.accumulate(3, TokenSource::Estimated);
150 assert_eq!(tc.tokens, 18);
151 assert_eq!(tc.source, TokenSource::Estimated);
152 }
153
154 #[test]
155 fn token_count_to_json_format() {
156 let tc = TokenCount {
157 tokens: 42,
158 source: TokenSource::Provided,
159 };
160 let json = tc.to_json();
161 assert_eq!(json["tokens"], 42);
162 assert_eq!(json["source"], "provided");
163 }
164
165 #[test]
166 fn token_source_serde_roundtrip() {
167 let source = TokenSource::Estimated;
168 let json = serde_json::to_string(&source).unwrap();
169 assert_eq!(json, r#""estimated""#);
170 let restored: TokenSource = serde_json::from_str(&json).unwrap();
171 assert_eq!(restored, source);
172 }
173
174 #[test]
175 fn estimate_tokens_cjk() {
176 assert_eq!(estimate_tokens("あいう"), 2);
178 }
179
180 #[test]
181 fn estimate_tokens_mixed() {
182 assert_eq!(estimate_tokens("hello あ"), 3);
185 }
186
187 #[test]
188 fn token_estimation_in_stats() {
189 let metrics = ExecutionMetrics::new();
190 let observer = metrics.create_observer();
191
192 let queries = vec![LlmQuery {
193 id: QueryId::single(),
194 prompt: "What is 2+2?".into(), system: Some("Expert".into()), max_tokens: 50,
197 grounded: false,
198 underspecified: false,
199 }];
200 observer.on_paused(&queries);
201 observer.on_response_fed(&QueryId::single(), "4"); observer.on_resumed();
203 observer.on_completed(&serde_json::json!(null));
204
205 let json = metrics.to_json();
206 let auto = &json["auto"];
207 assert_eq!(auto["prompt_tokens"]["tokens"], 5); assert_eq!(auto["prompt_tokens"]["source"], "estimated");
209 assert_eq!(auto["response_tokens"]["tokens"], 1);
210 assert_eq!(auto["response_tokens"]["source"], "estimated");
211 assert_eq!(auto["total_tokens"]["tokens"], 6);
212 assert_eq!(auto["total_tokens"]["source"], "estimated");
213 }
214
215 #[test]
216 fn token_estimation_accumulates_across_rounds() {
217 let metrics = ExecutionMetrics::new();
218 let observer = metrics.create_observer();
219
220 let q = vec![LlmQuery {
221 id: QueryId::single(),
222 prompt: "test".into(), system: None,
224 max_tokens: 10,
225 grounded: false,
226 underspecified: false,
227 }];
228
229 for _ in 0..3 {
231 observer.on_paused(&q);
232 observer.on_response_fed(&QueryId::single(), "reply here"); observer.on_resumed();
234 }
235 observer.on_completed(&serde_json::json!(null));
236
237 let json = metrics.to_json();
238 let auto = &json["auto"];
239 assert_eq!(auto["prompt_tokens"]["tokens"], 3); assert_eq!(auto["prompt_tokens"]["source"], "estimated");
241 assert_eq!(auto["response_tokens"]["tokens"], 9); assert_eq!(auto["response_tokens"]["source"], "estimated");
243 }
244}