agtrace_engine/
token_usage.rs

1/// Type-safe token usage tracking to prevent context window calculation bugs.
2///
3/// Raw i32 values are error-prone; these newtypes keep input/output/cache
4/// accounting explicit and make it harder to forget cache reads when
5/// computing context window pressure.
6use serde::{Deserialize, Serialize};
7use std::ops::Add;
8
9/// Total token count (always non-negative)
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Serialize, Deserialize)]
11pub struct TokenCount(u64);
12
13impl TokenCount {
14    pub const fn new(value: u64) -> Self {
15        Self(value)
16    }
17
18    pub const fn zero() -> Self {
19        Self(0)
20    }
21
22    pub const fn as_u64(self) -> u64 {
23        self.0
24    }
25
26    pub fn saturating_add(self, other: Self) -> Self {
27        Self(self.0.saturating_add(other.0))
28    }
29
30    pub fn saturating_sub(self, other: Self) -> Self {
31        Self(self.0.saturating_sub(other.0))
32    }
33}
34
35impl From<u64> for TokenCount {
36    fn from(value: u64) -> Self {
37        Self(value)
38    }
39}
40
41impl Add for TokenCount {
42    type Output = Self;
43    fn add(self, rhs: Self) -> Self {
44        Self(self.0 + rhs.0)
45    }
46}
47
48/// Context window limit (maximum tokens allowed)
49#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
50pub struct ContextLimit(u64);
51
52impl ContextLimit {
53    pub const fn new(value: u64) -> Self {
54        Self(value)
55    }
56
57    pub const fn as_u64(self) -> u64 {
58        self.0
59    }
60
61    /// Calculate usage ratio (0.0 to 1.0+)
62    pub fn usage_ratio(self, used: TokenCount) -> f64 {
63        if self.0 == 0 {
64            0.0
65        } else {
66            used.as_u64() as f64 / self.0 as f64
67        }
68    }
69
70    /// Calculate remaining capacity
71    pub fn remaining(self, used: TokenCount) -> TokenCount {
72        TokenCount::new(self.0.saturating_sub(used.as_u64()))
73    }
74
75    /// Check if usage exceeds limit
76    pub fn is_exceeded(self, used: TokenCount) -> bool {
77        used.as_u64() >= self.0
78    }
79
80    /// Check if usage is in warning zone (>80%)
81    pub fn is_warning_zone(self, used: TokenCount) -> bool {
82        self.usage_ratio(used) > 0.8
83    }
84
85    /// Check if usage is in danger zone (>90%)
86    pub fn is_danger_zone(self, used: TokenCount) -> bool {
87        self.usage_ratio(used) > 0.9
88    }
89}
90
91impl From<u64> for ContextLimit {
92    fn from(value: u64) -> Self {
93        Self(value)
94    }
95}
96
97/// Fresh input tokens (new content, not from cache)
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct FreshInputTokens(pub i32);
100
101/// Tokens used to create new cache entries
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103pub struct CacheCreationTokens(pub i32);
104
105/// Tokens read from existing cache entries (these still consume context)
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
107pub struct CacheReadTokens(pub i32);
108
109/// Output tokens generated by the model
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub struct OutputTokens(pub i32);
112
113impl Add for FreshInputTokens {
114    type Output = Self;
115    fn add(self, rhs: Self) -> Self {
116        Self(self.0 + rhs.0)
117    }
118}
119
120impl Add for CacheCreationTokens {
121    type Output = Self;
122    fn add(self, rhs: Self) -> Self {
123        Self(self.0 + rhs.0)
124    }
125}
126
127impl Add for CacheReadTokens {
128    type Output = Self;
129    fn add(self, rhs: Self) -> Self {
130        Self(self.0 + rhs.0)
131    }
132}
133
134impl Add for OutputTokens {
135    type Output = Self;
136    fn add(self, rhs: Self) -> Self {
137        Self(self.0 + rhs.0)
138    }
139}
140
141/// Complete snapshot of token usage for a single turn.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143pub struct ContextWindowUsage {
144    pub fresh_input: FreshInputTokens,
145    pub cache_creation: CacheCreationTokens,
146    pub cache_read: CacheReadTokens,
147    pub output: OutputTokens,
148}
149
150impl ContextWindowUsage {
151    pub fn new(
152        fresh_input: FreshInputTokens,
153        cache_creation: CacheCreationTokens,
154        cache_read: CacheReadTokens,
155        output: OutputTokens,
156    ) -> Self {
157        Self {
158            fresh_input,
159            cache_creation,
160            cache_read,
161            output,
162        }
163    }
164
165    pub fn from_raw(fresh_input: i32, cache_creation: i32, cache_read: i32, output: i32) -> Self {
166        Self {
167            fresh_input: FreshInputTokens(fresh_input),
168            cache_creation: CacheCreationTokens(cache_creation),
169            cache_read: CacheReadTokens(cache_read),
170            output: OutputTokens(output),
171        }
172    }
173
174    /// Input-side tokens (fresh + cache creation + cache read)
175    pub fn input_tokens(&self) -> i32 {
176        self.fresh_input.0 + self.cache_creation.0 + self.cache_read.0
177    }
178
179    /// Output-side tokens (model generation)
180    pub fn output_tokens(&self) -> i32 {
181        self.output.0
182    }
183
184    /// Context window tokens consumed this turn (legacy i32 version)
185    pub fn context_window_tokens(&self) -> i32 {
186        self.input_tokens() + self.output_tokens()
187    }
188
189    /// Context window tokens as type-safe TokenCount
190    pub fn total_tokens(&self) -> TokenCount {
191        let total = self.context_window_tokens().max(0);
192        TokenCount::new(total as u64)
193    }
194
195    pub fn is_empty(&self) -> bool {
196        self.context_window_tokens() == 0
197    }
198}
199
200impl Add for ContextWindowUsage {
201    type Output = Self;
202    fn add(self, rhs: Self) -> Self {
203        Self {
204            fresh_input: self.fresh_input + rhs.fresh_input,
205            cache_creation: self.cache_creation + rhs.cache_creation,
206            cache_read: self.cache_read + rhs.cache_read,
207            output: self.output + rhs.output,
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_context_window_usage_calculation() {
218        let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
219
220        assert_eq!(usage.input_tokens(), 600);
221        assert_eq!(usage.output_tokens(), 50);
222        assert_eq!(usage.context_window_tokens(), 650);
223    }
224
225    #[test]
226    fn test_cache_read_always_included() {
227        let usage = ContextWindowUsage::from_raw(10, 20, 5000, 30);
228        assert_eq!(usage.context_window_tokens(), 5060);
229    }
230
231    #[test]
232    fn test_add_usage() {
233        let usage1 = ContextWindowUsage::from_raw(100, 200, 300, 50);
234        let usage2 = ContextWindowUsage::from_raw(10, 20, 30, 5);
235
236        let total = usage1 + usage2;
237
238        assert_eq!(total.fresh_input.0, 110);
239        assert_eq!(total.cache_creation.0, 220);
240        assert_eq!(total.cache_read.0, 330);
241        assert_eq!(total.output.0, 55);
242        assert_eq!(total.context_window_tokens(), 715);
243    }
244
245    #[test]
246    fn test_default_is_empty() {
247        let usage = ContextWindowUsage::default();
248        assert!(usage.is_empty());
249        assert_eq!(usage.context_window_tokens(), 0);
250    }
251
252    #[test]
253    fn test_token_count_operations() {
254        let count1 = TokenCount::new(100);
255        let count2 = TokenCount::new(50);
256
257        assert_eq!(count1 + count2, TokenCount::new(150));
258        assert_eq!(count1.saturating_sub(count2), TokenCount::new(50));
259        assert_eq!(count2.saturating_sub(count1), TokenCount::zero());
260        assert_eq!(count1.as_u64(), 100);
261    }
262
263    #[test]
264    fn test_context_limit_usage_ratio() {
265        let limit = ContextLimit::new(200_000);
266        let used = TokenCount::new(100_000);
267
268        assert_eq!(limit.usage_ratio(used), 0.5);
269        assert!(!limit.is_exceeded(used));
270        assert!(!limit.is_warning_zone(used));
271        assert!(!limit.is_danger_zone(used));
272    }
273
274    #[test]
275    fn test_context_limit_warning_zones() {
276        let limit = ContextLimit::new(100_000);
277
278        let used_normal = TokenCount::new(50_000);
279        assert!(!limit.is_warning_zone(used_normal));
280        assert!(!limit.is_danger_zone(used_normal));
281
282        let used_warning = TokenCount::new(85_000);
283        assert!(limit.is_warning_zone(used_warning));
284        assert!(!limit.is_danger_zone(used_warning));
285
286        let used_danger = TokenCount::new(95_000);
287        assert!(limit.is_warning_zone(used_danger));
288        assert!(limit.is_danger_zone(used_danger));
289
290        let used_exceeded = TokenCount::new(105_000);
291        assert!(limit.is_exceeded(used_exceeded));
292    }
293
294    #[test]
295    fn test_context_limit_remaining() {
296        let limit = ContextLimit::new(200_000);
297        let used = TokenCount::new(150_000);
298
299        assert_eq!(limit.remaining(used), TokenCount::new(50_000));
300
301        let used_over = TokenCount::new(250_000);
302        assert_eq!(limit.remaining(used_over), TokenCount::zero());
303    }
304
305    #[test]
306    fn test_context_window_usage_total_tokens() {
307        let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
308        assert_eq!(usage.total_tokens(), TokenCount::new(650));
309
310        let usage_negative = ContextWindowUsage::from_raw(-100, 0, 0, 0);
311        assert_eq!(usage_negative.total_tokens(), TokenCount::zero());
312    }
313}