claude_agent/tokens/
tracker.rs

1use super::{ContextWindow, PricingTier, TokenBudget, WindowStatus};
2use crate::models::ModelSpec;
3
4#[derive(Debug, Clone)]
5pub enum PreflightResult {
6    Ok {
7        estimated_tokens: u64,
8        remaining: u64,
9        tier: PricingTier,
10    },
11    Warning {
12        estimated_tokens: u64,
13        utilization: f64,
14        tier: PricingTier,
15    },
16    Exceeded {
17        estimated_tokens: u64,
18        limit: u64,
19        overage: u64,
20    },
21}
22
23impl PreflightResult {
24    pub fn should_proceed(&self) -> bool {
25        !matches!(self, Self::Exceeded { .. })
26    }
27
28    pub fn estimated_tokens(&self) -> u64 {
29        match self {
30            Self::Ok {
31                estimated_tokens, ..
32            }
33            | Self::Warning {
34                estimated_tokens, ..
35            }
36            | Self::Exceeded {
37                estimated_tokens, ..
38            } => *estimated_tokens,
39        }
40    }
41
42    pub fn tier(&self) -> Option<PricingTier> {
43        match self {
44            Self::Ok { tier, .. } | Self::Warning { tier, .. } => Some(*tier),
45            Self::Exceeded { .. } => None,
46        }
47    }
48}
49
50#[derive(Debug)]
51pub struct TokenTracker {
52    context_window: ContextWindow,
53    cumulative: TokenBudget,
54    last_turn: TokenBudget,
55    model_spec: ModelSpec,
56}
57
58impl TokenTracker {
59    pub fn new(model_spec: ModelSpec, extended_context: bool) -> Self {
60        Self {
61            context_window: ContextWindow::new(&model_spec, extended_context),
62            cumulative: TokenBudget::default(),
63            last_turn: TokenBudget::default(),
64            model_spec,
65        }
66    }
67
68    pub fn with_thresholds(mut self, warning: f64, critical: f64) -> Self {
69        self.context_window = self.context_window.with_thresholds(warning, critical);
70        self
71    }
72
73    pub fn check(&self, estimated_tokens: u64) -> PreflightResult {
74        let new_usage = self.context_window.usage() + estimated_tokens;
75        let limit = self.context_window.limit();
76
77        if new_usage > limit {
78            return PreflightResult::Exceeded {
79                estimated_tokens,
80                limit,
81                overage: new_usage - limit,
82            };
83        }
84
85        let utilization = if limit == 0 {
86            0.0
87        } else {
88            new_usage as f64 / limit as f64
89        };
90        let tier = PricingTier::for_context(new_usage);
91
92        if utilization >= self.context_window.warning_threshold() {
93            PreflightResult::Warning {
94                estimated_tokens,
95                utilization,
96                tier,
97            }
98        } else {
99            PreflightResult::Ok {
100                estimated_tokens,
101                remaining: limit - new_usage,
102                tier,
103            }
104        }
105    }
106
107    pub fn record(&mut self, usage: &crate::types::Usage) {
108        let budget = TokenBudget::from(usage);
109        self.last_turn = budget;
110        self.cumulative.add(&budget);
111        self.context_window.update(budget.context_usage());
112    }
113
114    pub fn status(&self) -> WindowStatus {
115        self.context_window.status()
116    }
117
118    pub fn context_window(&self) -> &ContextWindow {
119        &self.context_window
120    }
121
122    pub fn cumulative(&self) -> &TokenBudget {
123        &self.cumulative
124    }
125
126    pub fn last_turn(&self) -> &TokenBudget {
127        &self.last_turn
128    }
129
130    pub fn pricing_tier(&self) -> PricingTier {
131        PricingTier::for_context(self.context_window.usage())
132    }
133
134    pub fn total_cost(&self) -> f64 {
135        self.model_spec.pricing.calculate(
136            self.cumulative.input_tokens,
137            self.cumulative.output_tokens,
138            self.cumulative.cache_read_tokens,
139            self.cumulative.cache_write_tokens,
140        )
141    }
142
143    pub fn reset(&mut self, new_context_usage: u64) {
144        self.context_window.reset(new_context_usage);
145    }
146
147    pub fn model(&self) -> &ModelSpec {
148        &self.model_spec
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::models::read_registry;
156
157    #[test]
158    fn test_preflight_ok() {
159        let spec = read_registry().resolve("sonnet").unwrap().clone();
160        let tracker = TokenTracker::new(spec, false);
161
162        let result = tracker.check(50_000);
163        assert!(result.should_proceed());
164        assert!(matches!(result, PreflightResult::Ok { .. }));
165    }
166
167    #[test]
168    fn test_preflight_warning() {
169        let spec = read_registry().resolve("sonnet").unwrap().clone();
170        let tracker = TokenTracker::new(spec, false);
171
172        let result = tracker.check(180_000);
173        assert!(result.should_proceed());
174        assert!(matches!(result, PreflightResult::Warning { .. }));
175    }
176
177    #[test]
178    fn test_preflight_exceeded() {
179        let spec = read_registry().resolve("sonnet").unwrap().clone();
180        let tracker = TokenTracker::new(spec, false);
181
182        let result = tracker.check(250_000);
183        assert!(!result.should_proceed());
184        assert!(matches!(result, PreflightResult::Exceeded { .. }));
185    }
186
187    #[test]
188    fn test_extended_context_not_exceeded() {
189        let spec = read_registry().resolve("sonnet").unwrap().clone();
190        let tracker = TokenTracker::new(spec, true);
191
192        let result = tracker.check(500_000);
193        assert!(result.should_proceed());
194    }
195}