claude_agent/tokens/
tracker.rs1use super::{ContextWindow, PricingTier, TokenBudget, WindowStatus};
2use crate::models::ModelSpec;
3
4#[derive(Debug, Clone)]
5pub enum PreflightResult {
6 Ok {
7 estimated_tokens: u64,
8 remaining: u64,
9 tier: PricingTier,
10 },
11 Warning {
12 estimated_tokens: u64,
13 utilization: f64,
14 tier: PricingTier,
15 },
16 Exceeded {
17 estimated_tokens: u64,
18 limit: u64,
19 overage: u64,
20 },
21}
22
23impl PreflightResult {
24 pub fn should_proceed(&self) -> bool {
25 !matches!(self, Self::Exceeded { .. })
26 }
27
28 pub fn estimated_tokens(&self) -> u64 {
29 match self {
30 Self::Ok {
31 estimated_tokens, ..
32 }
33 | Self::Warning {
34 estimated_tokens, ..
35 }
36 | Self::Exceeded {
37 estimated_tokens, ..
38 } => *estimated_tokens,
39 }
40 }
41
42 pub fn tier(&self) -> Option<PricingTier> {
43 match self {
44 Self::Ok { tier, .. } | Self::Warning { tier, .. } => Some(*tier),
45 Self::Exceeded { .. } => None,
46 }
47 }
48}
49
50#[derive(Debug)]
51pub struct TokenTracker {
52 context_window: ContextWindow,
53 cumulative: TokenBudget,
54 last_turn: TokenBudget,
55 model_spec: ModelSpec,
56}
57
58impl TokenTracker {
59 pub fn new(model_spec: ModelSpec, extended_context: bool) -> Self {
60 Self {
61 context_window: ContextWindow::new(&model_spec, extended_context),
62 cumulative: TokenBudget::default(),
63 last_turn: TokenBudget::default(),
64 model_spec,
65 }
66 }
67
68 pub fn with_thresholds(mut self, warning: f64, critical: f64) -> Self {
69 self.context_window = self.context_window.with_thresholds(warning, critical);
70 self
71 }
72
73 pub fn check(&self, estimated_tokens: u64) -> PreflightResult {
74 let new_usage = self.context_window.usage() + estimated_tokens;
75 let limit = self.context_window.limit();
76
77 if new_usage > limit {
78 return PreflightResult::Exceeded {
79 estimated_tokens,
80 limit,
81 overage: new_usage - limit,
82 };
83 }
84
85 let utilization = if limit == 0 {
86 0.0
87 } else {
88 new_usage as f64 / limit as f64
89 };
90 let tier = PricingTier::for_context(new_usage);
91
92 if utilization >= self.context_window.warning_threshold() {
93 PreflightResult::Warning {
94 estimated_tokens,
95 utilization,
96 tier,
97 }
98 } else {
99 PreflightResult::Ok {
100 estimated_tokens,
101 remaining: limit - new_usage,
102 tier,
103 }
104 }
105 }
106
107 pub fn record(&mut self, usage: &crate::types::Usage) {
108 let budget = TokenBudget::from(usage);
109 self.last_turn = budget;
110 self.cumulative.add(&budget);
111 self.context_window.update(budget.context_usage());
112 }
113
114 pub fn status(&self) -> WindowStatus {
115 self.context_window.status()
116 }
117
118 pub fn context_window(&self) -> &ContextWindow {
119 &self.context_window
120 }
121
122 pub fn cumulative(&self) -> &TokenBudget {
123 &self.cumulative
124 }
125
126 pub fn last_turn(&self) -> &TokenBudget {
127 &self.last_turn
128 }
129
130 pub fn pricing_tier(&self) -> PricingTier {
131 PricingTier::for_context(self.context_window.usage())
132 }
133
134 pub fn total_cost(&self) -> f64 {
135 self.model_spec.pricing.calculate(
136 self.cumulative.input_tokens,
137 self.cumulative.output_tokens,
138 self.cumulative.cache_read_tokens,
139 self.cumulative.cache_write_tokens,
140 )
141 }
142
143 pub fn reset(&mut self, new_context_usage: u64) {
144 self.context_window.reset(new_context_usage);
145 }
146
147 pub fn model(&self) -> &ModelSpec {
148 &self.model_spec
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::models::read_registry;
156
157 #[test]
158 fn test_preflight_ok() {
159 let spec = read_registry().resolve("sonnet").unwrap().clone();
160 let tracker = TokenTracker::new(spec, false);
161
162 let result = tracker.check(50_000);
163 assert!(result.should_proceed());
164 assert!(matches!(result, PreflightResult::Ok { .. }));
165 }
166
167 #[test]
168 fn test_preflight_warning() {
169 let spec = read_registry().resolve("sonnet").unwrap().clone();
170 let tracker = TokenTracker::new(spec, false);
171
172 let result = tracker.check(180_000);
173 assert!(result.should_proceed());
174 assert!(matches!(result, PreflightResult::Warning { .. }));
175 }
176
177 #[test]
178 fn test_preflight_exceeded() {
179 let spec = read_registry().resolve("sonnet").unwrap().clone();
180 let tracker = TokenTracker::new(spec, false);
181
182 let result = tracker.check(250_000);
183 assert!(!result.should_proceed());
184 assert!(matches!(result, PreflightResult::Exceeded { .. }));
185 }
186
187 #[test]
188 fn test_extended_context_not_exceeded() {
189 let spec = read_registry().resolve("sonnet").unwrap().clone();
190 let tracker = TokenTracker::new(spec, true);
191
192 let result = tracker.check(500_000);
193 assert!(result.should_proceed());
194 }
195}