1use regex::Regex;
3use std::sync::LazyLock;
4use std::time::Instant;
5
6static SHORTHAND_START_RE: LazyLock<Regex> =
7 LazyLock::new(|| Regex::new(r"^\s*\+(\d+(?:\.\d+)?)\s*(k|m|b)\b").unwrap());
8
9static SHORTHAND_END_RE: LazyLock<Regex> = LazyLock::new(|| {
10 Regex::new(r"\s\+(\d+(?:\.\d+)?)\s*(k|m|b)\s*[.!?]?\s*$").unwrap()
11});
12
13static VERBOSE_RE: LazyLock<Regex> = LazyLock::new(|| {
14 Regex::new(r"(?i)\b(?:use|spend)\s+(\d+(?:\.\d+)?)\s*(k|m|b)\s*tokens?\b").unwrap()
15});
16
17static VERBOSE_RE_G: LazyLock<Regex> = LazyLock::new(|| {
18 Regex::new(r"(?i)\b(?:use|spend)\s+(\d+(?:\.\d+)?)\s*(k|m|b)\s*tokens?\b").unwrap()
19});
20
21const COMPLETION_THRESHOLD: f64 = 0.9;
23const DIMINISHING_RETURNS_THRESHOLD: u64 = 3;
25const LOW_PRODUCTION_TOKENS: u64 = 500;
27
28fn parse_budget_match(value: &str, suffix: &str) -> f64 {
29 let value: f64 = value.parse().unwrap_or(0.0);
30 let multiplier = match suffix.to_lowercase().as_str() {
31 "k" => 1_000.0,
32 "m" => 1_000_000.0,
33 "b" => 1_000_000_000.0,
34 _ => 1.0,
35 };
36 value * multiplier
37}
38
39pub fn parse_token_budget(text: &str) -> Option<f64> {
40 if let Some(caps) = SHORTHAND_START_RE.captures(text) {
41 let value = caps.get(1).map(|m| m.as_str()).unwrap();
42 let suffix = caps.get(2).map(|m| m.as_str()).unwrap();
43 return Some(parse_budget_match(value, suffix));
44 }
45
46 if let Some(caps) = SHORTHAND_END_RE.captures(text) {
47 let value = caps.get(1).map(|m| m.as_str()).unwrap();
48 let suffix = caps.get(2).map(|m| m.as_str()).unwrap();
49 return Some(parse_budget_match(value, suffix));
50 }
51
52 if let Some(caps) = VERBOSE_RE.captures(text) {
53 let value = caps.get(1).map(|m| m.as_str()).unwrap();
54 let suffix = caps.get(2).map(|m| m.as_str()).unwrap();
55 return Some(parse_budget_match(value, suffix));
56 }
57
58 None
59}
60
61#[derive(Debug)]
62pub struct BudgetPosition {
63 pub start: usize,
64 pub end: usize,
65}
66
67pub fn find_token_budget_positions(text: &str) -> Vec<BudgetPosition> {
68 let mut positions = Vec::new();
69
70 if let Some(m) = SHORTHAND_START_RE.find(text) {
71 let offset = m.start() + m.as_str().len() - m.as_str().trim_start().len();
72 positions.push(BudgetPosition {
73 start: offset,
74 end: m.end(),
75 });
76 }
77
78 if let Some(m) = SHORTHAND_END_RE.find(text) {
79 let end_start = m.start() + 1;
80 let already_covered = positions
81 .iter()
82 .any(|p| end_start >= p.start && end_start < p.end);
83 if !already_covered {
84 positions.push(BudgetPosition {
85 start: end_start,
86 end: m.end(),
87 });
88 }
89 }
90
91 for m in VERBOSE_RE_G.find_iter(text) {
92 positions.push(BudgetPosition {
93 start: m.start(),
94 end: m.end(),
95 });
96 }
97
98 positions
99}
100
101pub fn get_budget_continuation_message(pct: f64, turn_tokens: u64, budget: f64) -> String {
102 format!(
103 "Stopped at {pct}% of token target ({turn_tokens} / {budget}). Keep working \u{2014} do not summarize."
104 )
105}
106
107#[derive(Debug)]
110pub struct BudgetTracker {
111 pub continuation_count: u64,
113 pub last_delta_tokens: u64,
115 pub last_global_turn_tokens: u64,
117 pub started_at: Instant,
119}
120
121impl BudgetTracker {
122 pub fn new() -> Self {
123 Self {
124 continuation_count: 0,
125 last_delta_tokens: 0,
126 last_global_turn_tokens: 0,
127 started_at: Instant::now(),
128 }
129 }
130}
131
132impl Default for BudgetTracker {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct TokenBudgetCompletion {
141 pub pct: f64,
143 pub tokens: u64,
145 pub budget: f64,
147 pub continuation_count: u64,
149 pub duration_ms: u64,
151 pub diminishing_returns: bool,
153}
154
155#[derive(Debug)]
157pub enum TokenBudgetDecision {
158 Continue { nudge_message: String },
160 Stop { completion: Option<TokenBudgetCompletion> },
162}
163
164pub fn check_token_budget(
166 tracker: &mut BudgetTracker,
167 _agent_id: Option<&str>,
168 budget: Option<f64>,
169 turn_tokens: u64,
170) -> TokenBudgetDecision {
171 let budget = match budget {
172 Some(b) if b > 0.0 => b,
173 _ => return TokenBudgetDecision::Stop { completion: None },
174 };
175
176 if _agent_id.is_some() {
177 return TokenBudgetDecision::Stop { completion: None };
178 }
179
180 let current_delta = if turn_tokens >= tracker.last_global_turn_tokens {
181 turn_tokens - tracker.last_global_turn_tokens
182 } else {
183 turn_tokens
184 };
185
186 let diminishing_returns = tracker.continuation_count >= DIMINISHING_RETURNS_THRESHOLD
187 && current_delta < LOW_PRODUCTION_TOKENS
188 && tracker.last_delta_tokens < LOW_PRODUCTION_TOKENS;
189
190 let pct = if budget > 0.0 {
191 (turn_tokens as f64 / budget)
192 } else {
193 1.0
194 };
195
196 if pct < COMPLETION_THRESHOLD && !diminishing_returns {
197 tracker.continuation_count += 1;
198 tracker.last_delta_tokens = current_delta;
199 tracker.last_global_turn_tokens = turn_tokens;
200 return TokenBudgetDecision::Continue {
201 nudge_message: get_budget_continuation_message((pct * 100.0) as u64 as f64, turn_tokens, budget),
202 };
203 }
204
205 let completion = if tracker.continuation_count > 0 || diminishing_returns {
206 Some(TokenBudgetCompletion {
207 pct,
208 tokens: turn_tokens,
209 budget,
210 continuation_count: tracker.continuation_count,
211 duration_ms: tracker.started_at.elapsed().as_millis() as u64,
212 diminishing_returns,
213 })
214 } else {
215 None
216 };
217
218 TokenBudgetDecision::Stop { completion }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_parse_token_budget_shorthand_start() {
227 assert_eq!(parse_token_budget("+500k"), Some(500_000.0));
228 assert_eq!(parse_token_budget("+2m"), Some(2_000_000.0));
229 assert_eq!(parse_token_budget("+1.5b"), Some(1_500_000_000.0));
230 }
231
232 #[test]
233 fn test_parse_token_budget_shorthand_end() {
234 assert_eq!(parse_token_budget("I want +500k."), Some(500_000.0));
235 }
236
237 #[test]
238 fn test_parse_token_budget_verbose() {
239 assert_eq!(parse_token_budget("use 2M tokens"), Some(2_000_000.0));
240 assert_eq!(parse_token_budget("spend 500k tokens"), Some(500_000.0));
241 }
242
243 #[test]
244 fn test_parse_token_budget_none() {
245 assert!(parse_token_budget("hello world").is_none());
246 }
247
248 #[test]
249 fn test_find_positions() {
250 let positions = find_token_budget_positions("+500k");
251 assert!(!positions.is_empty());
252 }
253
254 #[test]
255 fn test_budget_continuation_message() {
256 let msg = get_budget_continuation_message(80.0, 160000, 200000.0);
257 assert!(msg.contains("80%"));
258 assert!(msg.contains("160000"));
259 assert!(msg.contains("200000"));
260 }
261
262 #[test]
263 fn test_budget_tracker_new() {
264 let t = BudgetTracker::new();
265 assert_eq!(t.continuation_count, 0);
266 assert_eq!(t.last_delta_tokens, 0);
267 }
268
269 #[test]
270 fn test_check_no_budget() {
271 let mut t = BudgetTracker::new();
272 let d = check_token_budget(&mut t, None, None, 100);
273 assert!(matches!(d, TokenBudgetDecision::Stop { completion: None }));
274
275 let d2 = check_token_budget(&mut t, None, Some(0.0), 100);
276 assert!(matches!(d2, TokenBudgetDecision::Stop { completion: None }));
277 }
278
279 #[test]
280 fn test_check_subagent_skips_budget() {
281 let mut t = BudgetTracker::new();
282 let d = check_token_budget(&mut t, Some("sub1"), Some(5_000.0), 100);
283 assert!(matches!(d, TokenBudgetDecision::Stop { completion: None }));
284 }
285
286 #[test]
287 fn test_check_continue_under_threshold() {
288 let mut t = BudgetTracker::new();
289 let d = check_token_budget(&mut t, None, Some(5_000.0), 100);
290 match d {
291 TokenBudgetDecision::Continue { nudge_message } => {
292 assert!(nudge_message.contains("Keep working"));
293 }
294 other => panic!("Expected Continue, got {:?}", other),
295 }
296 assert_eq!(t.continuation_count, 1);
297 }
298
299 #[test]
300 fn test_check_stop_over_threshold() {
301 let mut t = BudgetTracker::new();
302 let d = check_token_budget(&mut t, None, Some(5_000.0), 5_000);
303 match d {
304 TokenBudgetDecision::Stop { completion } => {
305 assert!(completion.is_none());
306 }
307 other => panic!("Expected Stop, got {:?}", other),
308 }
309 }
310
311 #[test]
312 fn test_check_continuation_then_stop() {
313 let mut t = BudgetTracker::new();
314 let d = check_token_budget(&mut t, None, Some(5_000.0), 100);
315 assert!(matches!(d, TokenBudgetDecision::Continue { .. }));
316
317 let d = check_token_budget(&mut t, None, Some(5_000.0), 4_800);
318 match d {
319 TokenBudgetDecision::Stop { completion } => {
320 let c = completion.expect("should have completion");
321 assert!(c.pct >= 0.9);
322 assert_eq!(c.continuation_count, 1);
323 assert!(!c.diminishing_returns);
324 }
325 other => panic!("Expected Stop, got {:?}", other),
326 }
327 }
328
329 #[test]
330 fn test_check_diminishing_returns() {
331 let mut t = BudgetTracker::new();
332 for _ in 0..3 {
333 let tokens = t.last_global_turn_tokens + 100;
334 let d = check_token_budget(&mut t, None, Some(10_000.0), tokens);
335 assert!(matches!(d, TokenBudgetDecision::Continue { .. }));
336 }
337 let tokens = t.last_global_turn_tokens + 100;
338 let d = check_token_budget(&mut t, None, Some(10_000.0), tokens);
339 match d {
340 TokenBudgetDecision::Stop { completion } => {
341 let c = completion.expect("should have completion");
342 assert!(c.diminishing_returns);
343 }
344 TokenBudgetDecision::Continue { .. } => {
345 panic!("Expected Stop due to diminishing returns");
346 }
347 }
348 }
349}