Skip to main content

algocline_core/
tokens.rs

1// ─── Token tracking ─────────────────────────────────────────
2
3/// How a token count was obtained.
4///
5/// When a session mixes sources (e.g. some calls estimated, some provided),
6/// the aggregate source degrades to the weakest (least precise) variant
7/// via [`TokenSource::weaker`].
8#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum TokenSource {
11    /// Character-based heuristic (ASCII ~4c/t, CJK ~1.5c/t). ±30% accuracy.
12    Estimated,
13    /// Reported by the host (e.g. MCP Sampling `usage` metadata).
14    /// Accuracy depends on the host's tokenizer.
15    Provided,
16    /// Exact count from a known tokenizer (e.g. local BPE).
17    Definite,
18}
19
20impl TokenSource {
21    /// Return the weaker (less precise) of two sources.
22    ///
23    /// Used when accumulating across multiple LLM calls in a session.
24    /// If any call is `Estimated`, the aggregate is `Estimated`.
25    pub fn weaker(self, other: Self) -> Self {
26        match (self, other) {
27            (Self::Estimated, _) | (_, Self::Estimated) => Self::Estimated,
28            (Self::Provided, _) | (_, Self::Provided) => Self::Provided,
29            _ => Self::Definite,
30        }
31    }
32}
33
34/// Accumulated token count with provenance.
35///
36/// Tracks both the total token count and the weakest [`TokenSource`]
37/// across all accumulated calls. This lets consumers (e.g. `alc_eval_compare`)
38/// know whether a comparison is between precise or estimated values.
39#[derive(Debug, Clone)]
40pub struct TokenCount {
41    pub tokens: u64,
42    pub source: TokenSource,
43}
44
45impl TokenCount {
46    /// New zero-count with the given source.
47    pub(crate) fn new(source: TokenSource) -> Self {
48        Self { tokens: 0, source }
49    }
50
51    /// Add tokens, degrading source to the weaker of the two.
52    pub(crate) fn accumulate(&mut self, tokens: u64, source: TokenSource) {
53        self.tokens += tokens;
54        self.source = self.source.weaker(source);
55    }
56
57    pub(crate) fn to_json(&self) -> serde_json::Value {
58        serde_json::json!({
59            "tokens": self.tokens,
60            "source": self.source,
61        })
62    }
63}
64
65/// Estimate token count from a string using a character-based heuristic.
66///
67/// For mixed-language text (English + CJK), we use a blended approach:
68/// - ASCII characters: ~4 chars per token (GPT/Claude typical)
69/// - Non-ASCII characters (CJK, etc.): ~1.5 chars per token
70///
71/// **Accuracy**: This is an order-of-magnitude estimate. Actual token counts
72/// depend on the model's tokenizer (BPE). Expect ±30% deviation for typical
73/// English text, potentially more for code or heavily structured text.
74/// Intended for cost trend analysis (eval comparison), not billing.
75pub(crate) fn estimate_tokens(text: &str) -> u64 {
76    let mut ascii_chars: u64 = 0;
77    let mut non_ascii_chars: u64 = 0;
78    for ch in text.chars() {
79        if ch.is_ascii() {
80            ascii_chars += 1;
81        } else {
82            non_ascii_chars += 1;
83        }
84    }
85    // ASCII: ~4 chars/token, Non-ASCII: ~1.5 chars/token
86    let ascii_tokens = ascii_chars.div_ceil(4);
87    let non_ascii_tokens = (non_ascii_chars * 2).div_ceil(3);
88    ascii_tokens + non_ascii_tokens
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
95
96    #[test]
97    fn estimate_tokens_empty() {
98        assert_eq!(estimate_tokens(""), 0);
99    }
100
101    #[test]
102    fn estimate_tokens_ascii() {
103        // "hello world" = 11 ASCII chars → ceil(11/4) = 3
104        assert_eq!(estimate_tokens("hello world"), 3);
105    }
106
107    #[test]
108    fn token_source_weaker_estimated_wins() {
109        assert_eq!(
110            TokenSource::Estimated.weaker(TokenSource::Definite),
111            TokenSource::Estimated
112        );
113        assert_eq!(
114            TokenSource::Definite.weaker(TokenSource::Estimated),
115            TokenSource::Estimated
116        );
117    }
118
119    #[test]
120    fn token_source_weaker_provided_over_definite() {
121        assert_eq!(
122            TokenSource::Provided.weaker(TokenSource::Definite),
123            TokenSource::Provided
124        );
125    }
126
127    #[test]
128    fn token_source_weaker_same_returns_same() {
129        assert_eq!(
130            TokenSource::Definite.weaker(TokenSource::Definite),
131            TokenSource::Definite
132        );
133        assert_eq!(
134            TokenSource::Estimated.weaker(TokenSource::Estimated),
135            TokenSource::Estimated
136        );
137    }
138
139    #[test]
140    fn token_count_accumulate_degrades_source() {
141        let mut tc = TokenCount::new(TokenSource::Definite);
142        tc.accumulate(10, TokenSource::Definite);
143        assert_eq!(tc.source, TokenSource::Definite);
144
145        tc.accumulate(5, TokenSource::Provided);
146        assert_eq!(tc.tokens, 15);
147        assert_eq!(tc.source, TokenSource::Provided);
148
149        tc.accumulate(3, TokenSource::Estimated);
150        assert_eq!(tc.tokens, 18);
151        assert_eq!(tc.source, TokenSource::Estimated);
152    }
153
154    #[test]
155    fn token_count_to_json_format() {
156        let tc = TokenCount {
157            tokens: 42,
158            source: TokenSource::Provided,
159        };
160        let json = tc.to_json();
161        assert_eq!(json["tokens"], 42);
162        assert_eq!(json["source"], "provided");
163    }
164
165    #[test]
166    fn token_source_serde_roundtrip() {
167        let source = TokenSource::Estimated;
168        let json = serde_json::to_string(&source).unwrap();
169        assert_eq!(json, r#""estimated""#);
170        let restored: TokenSource = serde_json::from_str(&json).unwrap();
171        assert_eq!(restored, source);
172    }
173
174    #[test]
175    fn estimate_tokens_cjk() {
176        // "あいう" = 3 non-ASCII chars → ceil(3/1.5) = ceil(6/3) = 2
177        assert_eq!(estimate_tokens("あいう"), 2);
178    }
179
180    #[test]
181    fn estimate_tokens_mixed() {
182        // "hello あ" = 6 ASCII + 1 non-ASCII
183        // ASCII: ceil(6/4) = 2, CJK: ceil(1/1.5) = ceil(2/3) = 1
184        assert_eq!(estimate_tokens("hello あ"), 3);
185    }
186
187    #[test]
188    fn token_estimation_in_stats() {
189        let metrics = ExecutionMetrics::new();
190        let observer = metrics.create_observer();
191
192        let queries = vec![LlmQuery {
193            id: QueryId::single(),
194            prompt: "What is 2+2?".into(), // 12 ASCII → ceil(12/4) = 3
195            system: Some("Expert".into()), // 6 ASCII → ceil(6/4) = 2
196            max_tokens: 50,
197            grounded: false,
198            underspecified: false,
199        }];
200        observer.on_paused(&queries);
201        observer.on_response_fed(&QueryId::single(), "4"); // 1 ASCII → ceil(1/4) = 1
202        observer.on_resumed();
203        observer.on_completed(&serde_json::json!(null));
204
205        let json = metrics.to_json();
206        let auto = &json["auto"];
207        assert_eq!(auto["prompt_tokens"]["tokens"], 5); // 3 + 2
208        assert_eq!(auto["prompt_tokens"]["source"], "estimated");
209        assert_eq!(auto["response_tokens"]["tokens"], 1);
210        assert_eq!(auto["response_tokens"]["source"], "estimated");
211        assert_eq!(auto["total_tokens"]["tokens"], 6);
212        assert_eq!(auto["total_tokens"]["source"], "estimated");
213    }
214
215    #[test]
216    fn token_estimation_accumulates_across_rounds() {
217        let metrics = ExecutionMetrics::new();
218        let observer = metrics.create_observer();
219
220        let q = vec![LlmQuery {
221            id: QueryId::single(),
222            prompt: "test".into(), // 4 ASCII → ceil(4/4) = 1
223            system: None,
224            max_tokens: 10,
225            grounded: false,
226            underspecified: false,
227        }];
228
229        // 3 rounds
230        for _ in 0..3 {
231            observer.on_paused(&q);
232            observer.on_response_fed(&QueryId::single(), "reply here"); // 10 → ceil(10/4) = 3
233            observer.on_resumed();
234        }
235        observer.on_completed(&serde_json::json!(null));
236
237        let json = metrics.to_json();
238        let auto = &json["auto"];
239        assert_eq!(auto["prompt_tokens"]["tokens"], 3); // 1 * 3
240        assert_eq!(auto["prompt_tokens"]["source"], "estimated");
241        assert_eq!(auto["response_tokens"]["tokens"], 9); // 3 * 3
242        assert_eq!(auto["response_tokens"]["source"], "estimated");
243    }
244}