Skip to main content

algocline_core/
tokens.rs

1// ─── Token usage (host-provided) ─────────────────────────────
2
3/// Token usage reported by the host LLM alongside its response.
4///
5/// When the host includes usage data in `alc_continue`, these counts
6/// replace the character-based estimates for that specific response,
7/// upgrading `TokenSource` from `Estimated` to `Provided`.
8#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
9pub struct TokenUsage {
10    /// Prompt tokens consumed by this LLM call, as reported by the host.
11    pub prompt_tokens: Option<u64>,
12    /// Completion (response) tokens produced by this LLM call.
13    pub completion_tokens: Option<u64>,
14}
15
16// ─── Token tracking ─────────────────────────────────────────
17
18/// How a token count was obtained.
19///
20/// When a session mixes sources (e.g. some calls estimated, some provided),
21/// the aggregate source degrades to the weakest (least precise) variant
22/// via [`TokenSource::weaker`].
23#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TokenSource {
26    /// Character-based heuristic (ASCII ~4c/t, CJK ~1.5c/t). ±30% accuracy.
27    Estimated,
28    /// Reported by the host (e.g. MCP Sampling `usage` metadata).
29    /// Accuracy depends on the host's tokenizer.
30    Provided,
31    /// Exact count from a known tokenizer (e.g. local BPE).
32    Definite,
33}
34
35impl TokenSource {
36    /// Return the weaker (less precise) of two sources.
37    ///
38    /// Used when accumulating across multiple LLM calls in a session.
39    /// If any call is `Estimated`, the aggregate is `Estimated`.
40    pub fn weaker(self, other: Self) -> Self {
41        match (self, other) {
42            (Self::Estimated, _) | (_, Self::Estimated) => Self::Estimated,
43            (Self::Provided, _) | (_, Self::Provided) => Self::Provided,
44            _ => Self::Definite,
45        }
46    }
47}
48
49/// Accumulated token count with provenance.
50///
51/// Tracks both the total token count and the weakest [`TokenSource`]
52/// across all accumulated calls. This lets consumers (e.g. `alc_eval_compare`)
53/// know whether a comparison is between precise or estimated values.
54#[derive(Debug, Clone)]
55pub struct TokenCount {
56    pub tokens: u64,
57    pub source: TokenSource,
58}
59
60impl TokenCount {
61    /// New zero-count with the given source.
62    pub(crate) fn new(source: TokenSource) -> Self {
63        Self { tokens: 0, source }
64    }
65
66    /// Add tokens, degrading source to the weaker of the two.
67    pub(crate) fn accumulate(&mut self, tokens: u64, source: TokenSource) {
68        self.tokens += tokens;
69        self.source = self.source.weaker(source);
70    }
71
72    pub(crate) fn to_json(&self) -> serde_json::Value {
73        serde_json::json!({
74            "tokens": self.tokens,
75            "source": self.source,
76        })
77    }
78}
79
80/// Estimate token count from a string using a character-based heuristic.
81///
82/// For mixed-language text (English + CJK), we use a blended approach:
83/// - ASCII characters: ~4 chars per token (GPT/Claude typical)
84/// - Non-ASCII characters (CJK, etc.): ~1.5 chars per token
85///
86/// **Accuracy**: This is an order-of-magnitude estimate. Actual token counts
87/// depend on the model's tokenizer (BPE). Expect ±30% deviation for typical
88/// English text, potentially more for code or heavily structured text.
89/// Intended for cost trend analysis (eval comparison), not billing.
90pub(crate) fn estimate_tokens(text: &str) -> u64 {
91    let mut ascii_chars: u64 = 0;
92    let mut non_ascii_chars: u64 = 0;
93    for ch in text.chars() {
94        if ch.is_ascii() {
95            ascii_chars += 1;
96        } else {
97            non_ascii_chars += 1;
98        }
99    }
100    // ASCII: ~4 chars/token, Non-ASCII: ~1.5 chars/token
101    let ascii_tokens = ascii_chars.div_ceil(4);
102    let non_ascii_tokens = (non_ascii_chars * 2).div_ceil(3);
103    ascii_tokens + non_ascii_tokens
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
110
111    #[test]
112    fn estimate_tokens_empty() {
113        assert_eq!(estimate_tokens(""), 0);
114    }
115
116    #[test]
117    fn estimate_tokens_ascii() {
118        // "hello world" = 11 ASCII chars → ceil(11/4) = 3
119        assert_eq!(estimate_tokens("hello world"), 3);
120    }
121
122    #[test]
123    fn token_source_weaker_estimated_wins() {
124        assert_eq!(
125            TokenSource::Estimated.weaker(TokenSource::Definite),
126            TokenSource::Estimated
127        );
128        assert_eq!(
129            TokenSource::Definite.weaker(TokenSource::Estimated),
130            TokenSource::Estimated
131        );
132    }
133
134    #[test]
135    fn token_source_weaker_provided_over_definite() {
136        assert_eq!(
137            TokenSource::Provided.weaker(TokenSource::Definite),
138            TokenSource::Provided
139        );
140    }
141
142    #[test]
143    fn token_source_weaker_same_returns_same() {
144        assert_eq!(
145            TokenSource::Definite.weaker(TokenSource::Definite),
146            TokenSource::Definite
147        );
148        assert_eq!(
149            TokenSource::Estimated.weaker(TokenSource::Estimated),
150            TokenSource::Estimated
151        );
152    }
153
154    #[test]
155    fn token_count_accumulate_degrades_source() {
156        let mut tc = TokenCount::new(TokenSource::Definite);
157        tc.accumulate(10, TokenSource::Definite);
158        assert_eq!(tc.source, TokenSource::Definite);
159
160        tc.accumulate(5, TokenSource::Provided);
161        assert_eq!(tc.tokens, 15);
162        assert_eq!(tc.source, TokenSource::Provided);
163
164        tc.accumulate(3, TokenSource::Estimated);
165        assert_eq!(tc.tokens, 18);
166        assert_eq!(tc.source, TokenSource::Estimated);
167    }
168
169    #[test]
170    fn token_count_to_json_format() {
171        let tc = TokenCount {
172            tokens: 42,
173            source: TokenSource::Provided,
174        };
175        let json = tc.to_json();
176        assert_eq!(json["tokens"], 42);
177        assert_eq!(json["source"], "provided");
178    }
179
180    #[test]
181    fn token_source_serde_roundtrip() {
182        let source = TokenSource::Estimated;
183        let json = serde_json::to_string(&source).unwrap();
184        assert_eq!(json, r#""estimated""#);
185        let restored: TokenSource = serde_json::from_str(&json).unwrap();
186        assert_eq!(restored, source);
187    }
188
189    #[test]
190    fn estimate_tokens_cjk() {
191        // "あいう" = 3 non-ASCII chars → ceil(3/1.5) = ceil(6/3) = 2
192        assert_eq!(estimate_tokens("あいう"), 2);
193    }
194
195    #[test]
196    fn estimate_tokens_mixed() {
197        // "hello あ" = 6 ASCII + 1 non-ASCII
198        // ASCII: ceil(6/4) = 2, CJK: ceil(1/1.5) = ceil(2/3) = 1
199        assert_eq!(estimate_tokens("hello あ"), 3);
200    }
201
202    #[test]
203    fn token_estimation_in_stats() {
204        let metrics = ExecutionMetrics::new();
205        let observer = metrics.create_observer();
206
207        let queries = vec![LlmQuery {
208            id: QueryId::single(),
209            prompt: "What is 2+2?".into(), // 12 ASCII → ceil(12/4) = 3
210            system: Some("Expert".into()), // 6 ASCII → ceil(6/4) = 2
211            max_tokens: 50,
212            grounded: false,
213            underspecified: false,
214        }];
215        observer.on_paused(&queries);
216        observer.on_response_fed(&QueryId::single(), "4", None); // 1 ASCII → ceil(1/4) = 1
217        observer.on_resumed();
218        observer.on_completed(&serde_json::json!(null));
219
220        let json = metrics.to_json();
221        let auto = &json["auto"];
222        assert_eq!(auto["prompt_tokens"]["tokens"], 5); // 3 + 2
223        assert_eq!(auto["prompt_tokens"]["source"], "estimated");
224        assert_eq!(auto["response_tokens"]["tokens"], 1);
225        assert_eq!(auto["response_tokens"]["source"], "estimated");
226        assert_eq!(auto["total_tokens"]["tokens"], 6);
227        assert_eq!(auto["total_tokens"]["source"], "estimated");
228    }
229
230    #[test]
231    fn token_estimation_accumulates_across_rounds() {
232        let metrics = ExecutionMetrics::new();
233        let observer = metrics.create_observer();
234
235        let q = vec![LlmQuery {
236            id: QueryId::single(),
237            prompt: "test".into(), // 4 ASCII → ceil(4/4) = 1
238            system: None,
239            max_tokens: 10,
240            grounded: false,
241            underspecified: false,
242        }];
243
244        // 3 rounds
245        for _ in 0..3 {
246            observer.on_paused(&q);
247            observer.on_response_fed(&QueryId::single(), "reply here", None); // 10 → ceil(10/4) = 3
248            observer.on_resumed();
249        }
250        observer.on_completed(&serde_json::json!(null));
251
252        let json = metrics.to_json();
253        let auto = &json["auto"];
254        assert_eq!(auto["prompt_tokens"]["tokens"], 3); // 1 * 3
255        assert_eq!(auto["prompt_tokens"]["source"], "estimated");
256        assert_eq!(auto["response_tokens"]["tokens"], 9); // 3 * 3
257        assert_eq!(auto["response_tokens"]["source"], "estimated");
258    }
259}