agtrace_engine/
token_usage.rs

1/// Type-safe token usage tracking to prevent context window calculation bugs.
2///
3/// Raw i32 values are error-prone; these newtypes keep input/output/cache
4/// accounting explicit and make it harder to forget cache reads when
5/// computing context window pressure.
6use serde::{Deserialize, Serialize};
7use std::ops::Add;
8
9/// Total token count (always non-negative)
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Serialize, Deserialize)]
11pub struct TokenCount(u64);
12
13impl TokenCount {
14    pub const fn new(value: u64) -> Self {
15        Self(value)
16    }
17
18    pub const fn zero() -> Self {
19        Self(0)
20    }
21
22    pub const fn as_u64(self) -> u64 {
23        self.0
24    }
25
26    pub fn saturating_add(self, other: Self) -> Self {
27        Self(self.0.saturating_add(other.0))
28    }
29
30    pub fn saturating_sub(self, other: Self) -> Self {
31        Self(self.0.saturating_sub(other.0))
32    }
33}
34
35impl From<u64> for TokenCount {
36    fn from(value: u64) -> Self {
37        Self(value)
38    }
39}
40
41impl Add for TokenCount {
42    type Output = Self;
43    fn add(self, rhs: Self) -> Self {
44        Self(self.0 + rhs.0)
45    }
46}
47
48/// Context window limit (maximum tokens allowed)
49#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
50pub struct ContextLimit(u64);
51
52impl ContextLimit {
53    pub const fn new(value: u64) -> Self {
54        Self(value)
55    }
56
57    pub const fn as_u64(self) -> u64 {
58        self.0
59    }
60
61    /// Calculate usage ratio (0.0 to 1.0+)
62    pub fn usage_ratio(self, used: TokenCount) -> f64 {
63        if self.0 == 0 {
64            0.0
65        } else {
66            used.as_u64() as f64 / self.0 as f64
67        }
68    }
69
70    /// Calculate remaining capacity
71    pub fn remaining(self, used: TokenCount) -> TokenCount {
72        TokenCount::new(self.0.saturating_sub(used.as_u64()))
73    }
74
75    /// Check if usage exceeds limit
76    pub fn is_exceeded(self, used: TokenCount) -> bool {
77        used.as_u64() >= self.0
78    }
79
80    /// Check if usage is in warning zone (>80%)
81    pub fn is_warning_zone(self, used: TokenCount) -> bool {
82        self.usage_ratio(used) > 0.8
83    }
84
85    /// Check if usage is in danger zone (>90%)
86    pub fn is_danger_zone(self, used: TokenCount) -> bool {
87        self.usage_ratio(used) > 0.9
88    }
89}
90
91impl From<u64> for ContextLimit {
92    fn from(value: u64) -> Self {
93        Self(value)
94    }
95}
96
97/// Fresh input tokens (new content, not from cache)
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct FreshInputTokens(pub i32);
100
101/// Tokens used to create new cache entries
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103pub struct CacheCreationTokens(pub i32);
104
105/// Tokens read from existing cache entries (these still consume context)
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
107pub struct CacheReadTokens(pub i32);
108
109/// Output tokens generated by the model
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub struct OutputTokens(pub i32);
112
113impl Add for FreshInputTokens {
114    type Output = Self;
115    fn add(self, rhs: Self) -> Self {
116        Self(self.0 + rhs.0)
117    }
118}
119
120impl Add for CacheCreationTokens {
121    type Output = Self;
122    fn add(self, rhs: Self) -> Self {
123        Self(self.0 + rhs.0)
124    }
125}
126
127impl Add for CacheReadTokens {
128    type Output = Self;
129    fn add(self, rhs: Self) -> Self {
130        Self(self.0 + rhs.0)
131    }
132}
133
134impl Add for OutputTokens {
135    type Output = Self;
136    fn add(self, rhs: Self) -> Self {
137        Self(self.0 + rhs.0)
138    }
139}
140
141/// Complete snapshot of token usage for a single turn.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143pub struct ContextWindowUsage {
144    pub fresh_input: FreshInputTokens,
145    pub cache_creation: CacheCreationTokens,
146    pub cache_read: CacheReadTokens,
147    pub output: OutputTokens,
148}
149
150impl ContextWindowUsage {
151    pub fn new(
152        fresh_input: FreshInputTokens,
153        cache_creation: CacheCreationTokens,
154        cache_read: CacheReadTokens,
155        output: OutputTokens,
156    ) -> Self {
157        Self {
158            fresh_input,
159            cache_creation,
160            cache_read,
161            output,
162        }
163    }
164
165    pub fn from_raw(fresh_input: i32, cache_creation: i32, cache_read: i32, output: i32) -> Self {
166        Self {
167            fresh_input: FreshInputTokens(fresh_input),
168            cache_creation: CacheCreationTokens(cache_creation),
169            cache_read: CacheReadTokens(cache_read),
170            output: OutputTokens(output),
171        }
172    }
173
174    /// Input-side tokens (fresh + cache creation + cache read)
175    pub fn input_tokens(&self) -> i32 {
176        self.fresh_input.0 + self.cache_creation.0 + self.cache_read.0
177    }
178
179    /// Output-side tokens (model generation)
180    pub fn output_tokens(&self) -> i32 {
181        self.output.0
182    }
183
184    /// Context window tokens as type-safe TokenCount
185    pub fn total_tokens(&self) -> TokenCount {
186        let total = (self.input_tokens() + self.output_tokens()).max(0);
187        TokenCount::new(total as u64)
188    }
189
190    pub fn is_empty(&self) -> bool {
191        self.total_tokens() == TokenCount::zero()
192    }
193}
194
195impl Add for ContextWindowUsage {
196    type Output = Self;
197    fn add(self, rhs: Self) -> Self {
198        Self {
199            fresh_input: self.fresh_input + rhs.fresh_input,
200            cache_creation: self.cache_creation + rhs.cache_creation,
201            cache_read: self.cache_read + rhs.cache_read,
202            output: self.output + rhs.output,
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_context_window_usage_calculation() {
213        let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
214
215        assert_eq!(usage.input_tokens(), 600);
216        assert_eq!(usage.output_tokens(), 50);
217        assert_eq!(usage.total_tokens(), TokenCount::new(650));
218    }
219
220    #[test]
221    fn test_cache_read_always_included() {
222        let usage = ContextWindowUsage::from_raw(10, 20, 5000, 30);
223        assert_eq!(usage.total_tokens(), TokenCount::new(5060));
224    }
225
226    #[test]
227    fn test_add_usage() {
228        let usage1 = ContextWindowUsage::from_raw(100, 200, 300, 50);
229        let usage2 = ContextWindowUsage::from_raw(10, 20, 30, 5);
230
231        let total = usage1 + usage2;
232
233        assert_eq!(total.fresh_input.0, 110);
234        assert_eq!(total.cache_creation.0, 220);
235        assert_eq!(total.cache_read.0, 330);
236        assert_eq!(total.output.0, 55);
237        assert_eq!(total.total_tokens(), TokenCount::new(715));
238    }
239
240    #[test]
241    fn test_default_is_empty() {
242        let usage = ContextWindowUsage::default();
243        assert!(usage.is_empty());
244        assert_eq!(usage.total_tokens(), TokenCount::zero());
245    }
246
247    #[test]
248    fn test_token_count_operations() {
249        let count1 = TokenCount::new(100);
250        let count2 = TokenCount::new(50);
251
252        assert_eq!(count1 + count2, TokenCount::new(150));
253        assert_eq!(count1.saturating_sub(count2), TokenCount::new(50));
254        assert_eq!(count2.saturating_sub(count1), TokenCount::zero());
255        assert_eq!(count1.as_u64(), 100);
256    }
257
258    #[test]
259    fn test_context_limit_usage_ratio() {
260        let limit = ContextLimit::new(200_000);
261        let used = TokenCount::new(100_000);
262
263        assert_eq!(limit.usage_ratio(used), 0.5);
264        assert!(!limit.is_exceeded(used));
265        assert!(!limit.is_warning_zone(used));
266        assert!(!limit.is_danger_zone(used));
267    }
268
269    #[test]
270    fn test_context_limit_warning_zones() {
271        let limit = ContextLimit::new(100_000);
272
273        let used_normal = TokenCount::new(50_000);
274        assert!(!limit.is_warning_zone(used_normal));
275        assert!(!limit.is_danger_zone(used_normal));
276
277        let used_warning = TokenCount::new(85_000);
278        assert!(limit.is_warning_zone(used_warning));
279        assert!(!limit.is_danger_zone(used_warning));
280
281        let used_danger = TokenCount::new(95_000);
282        assert!(limit.is_warning_zone(used_danger));
283        assert!(limit.is_danger_zone(used_danger));
284
285        let used_exceeded = TokenCount::new(105_000);
286        assert!(limit.is_exceeded(used_exceeded));
287    }
288
289    #[test]
290    fn test_context_limit_remaining() {
291        let limit = ContextLimit::new(200_000);
292        let used = TokenCount::new(150_000);
293
294        assert_eq!(limit.remaining(used), TokenCount::new(50_000));
295
296        let used_over = TokenCount::new(250_000);
297        assert_eq!(limit.remaining(used_over), TokenCount::zero());
298    }
299
300    #[test]
301    fn test_context_window_usage_total_tokens() {
302        let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
303        assert_eq!(usage.total_tokens(), TokenCount::new(650));
304
305        let usage_negative = ContextWindowUsage::from_raw(-100, 0, 0, 0);
306        assert_eq!(usage_negative.total_tokens(), TokenCount::zero());
307    }
308}