Skip to main content

ai_agent/
token_budget.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/query/tokenBudget.ts
2use regex::Regex;
3use std::sync::LazyLock;
4use std::time::Instant;
5
6static SHORTHAND_START_RE: LazyLock<Regex> =
7    LazyLock::new(|| Regex::new(r"^\s*\+(\d+(?:\.\d+)?)\s*(k|m|b)\b").unwrap());
8
9static SHORTHAND_END_RE: LazyLock<Regex> = LazyLock::new(|| {
10    Regex::new(r"\s\+(\d+(?:\.\d+)?)\s*(k|m|b)\s*[.!?]?\s*$").unwrap()
11});
12
13static VERBOSE_RE: LazyLock<Regex> = LazyLock::new(|| {
14    Regex::new(r"(?i)\b(?:use|spend)\s+(\d+(?:\.\d+)?)\s*(k|m|b)\s*tokens?\b").unwrap()
15});
16
17static VERBOSE_RE_G: LazyLock<Regex> = LazyLock::new(|| {
18    Regex::new(r"(?i)\b(?:use|spend)\s+(\d+(?:\.\d+)?)\s*(k|m|b)\s*tokens?\b").unwrap()
19});
20
21/// Threshold at which we stop continuing (90% of budget)
22const COMPLETION_THRESHOLD: f64 = 0.9;
23/// After this many continuations with low token production, trigger diminishing returns
24const DIMINISHING_RETURNS_THRESHOLD: u64 = 3;
25/// Tokens below which we consider a continuation "low production"
26const LOW_PRODUCTION_TOKENS: u64 = 500;
27
28fn parse_budget_match(value: &str, suffix: &str) -> f64 {
29    let value: f64 = value.parse().unwrap_or(0.0);
30    let multiplier = match suffix.to_lowercase().as_str() {
31        "k" => 1_000.0,
32        "m" => 1_000_000.0,
33        "b" => 1_000_000_000.0,
34        _ => 1.0,
35    };
36    value * multiplier
37}
38
39pub fn parse_token_budget(text: &str) -> Option<f64> {
40    if let Some(caps) = SHORTHAND_START_RE.captures(text) {
41        let value = caps.get(1).map(|m| m.as_str()).unwrap();
42        let suffix = caps.get(2).map(|m| m.as_str()).unwrap();
43        return Some(parse_budget_match(value, suffix));
44    }
45
46    if let Some(caps) = SHORTHAND_END_RE.captures(text) {
47        let value = caps.get(1).map(|m| m.as_str()).unwrap();
48        let suffix = caps.get(2).map(|m| m.as_str()).unwrap();
49        return Some(parse_budget_match(value, suffix));
50    }
51
52    if let Some(caps) = VERBOSE_RE.captures(text) {
53        let value = caps.get(1).map(|m| m.as_str()).unwrap();
54        let suffix = caps.get(2).map(|m| m.as_str()).unwrap();
55        return Some(parse_budget_match(value, suffix));
56    }
57
58    None
59}
60
61#[derive(Debug)]
62pub struct BudgetPosition {
63    pub start: usize,
64    pub end: usize,
65}
66
67pub fn find_token_budget_positions(text: &str) -> Vec<BudgetPosition> {
68    let mut positions = Vec::new();
69
70    if let Some(m) = SHORTHAND_START_RE.find(text) {
71        let offset = m.start() + m.as_str().len() - m.as_str().trim_start().len();
72        positions.push(BudgetPosition {
73            start: offset,
74            end: m.end(),
75        });
76    }
77
78    if let Some(m) = SHORTHAND_END_RE.find(text) {
79        let end_start = m.start() + 1;
80        let already_covered = positions
81            .iter()
82            .any(|p| end_start >= p.start && end_start < p.end);
83        if !already_covered {
84            positions.push(BudgetPosition {
85                start: end_start,
86                end: m.end(),
87            });
88        }
89    }
90
91    for m in VERBOSE_RE_G.find_iter(text) {
92        positions.push(BudgetPosition {
93            start: m.start(),
94            end: m.end(),
95        });
96    }
97
98    positions
99}
100
101pub fn get_budget_continuation_message(pct: f64, turn_tokens: u64, budget: f64) -> String {
102    format!(
103        "Stopped at {pct}% of token target ({turn_tokens} / {budget}). Keep working \u{2014} do not summarize."
104    )
105}
106
107/// Tracker state for a single query loop's token budget.
108/// One tracker per query loop iteration, created when TOKEN_BUDGET is active.
109#[derive(Debug)]
110pub struct BudgetTracker {
111    /// How many times we've continued the loop with a nudge message
112    pub continuation_count: u64,
113    /// Tokens produced since the last check (delta)
114    pub last_delta_tokens: u64,
115    /// Global turn tokens at last check
116    pub last_global_turn_tokens: u64,
117    /// When the tracker was created
118    pub started_at: Instant,
119}
120
121impl BudgetTracker {
122    pub fn new() -> Self {
123        Self {
124            continuation_count: 0,
125            last_delta_tokens: 0,
126            last_global_turn_tokens: 0,
127            started_at: Instant::now(),
128        }
129    }
130}
131
132impl Default for BudgetTracker {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Completion event emitted when token budget causes the loop to stop.
139#[derive(Debug, Clone)]
140pub struct TokenBudgetCompletion {
141    /// Percentage of budget consumed (0.0-1.0+)
142    pub pct: f64,
143    /// Total turn tokens consumed
144    pub tokens: u64,
145    /// Budget target
146    pub budget: f64,
147    /// How many continuations were performed
148    pub continuation_count: u64,
149    /// Duration in milliseconds since tracker started
150    pub duration_ms: u64,
151    /// Whether diminishing returns triggered the stop
152    pub diminishing_returns: bool,
153}
154
155/// Decision from `check_token_budget()`.
156#[derive(Debug)]
157pub enum TokenBudgetDecision {
158    /// Continue the loop with a nudge message injected.
159    Continue { nudge_message: String },
160    /// Stop the loop. Optionally emits a completion event for telemetry.
161    Stop { completion: Option<TokenBudgetCompletion> },
162}
163
164/// Check whether the token budget allows the query loop to continue.
165pub fn check_token_budget(
166    tracker: &mut BudgetTracker,
167    _agent_id: Option<&str>,
168    budget: Option<f64>,
169    turn_tokens: u64,
170) -> TokenBudgetDecision {
171    let budget = match budget {
172        Some(b) if b > 0.0 => b,
173        _ => return TokenBudgetDecision::Stop { completion: None },
174    };
175
176    if _agent_id.is_some() {
177        return TokenBudgetDecision::Stop { completion: None };
178    }
179
180    let current_delta = if turn_tokens >= tracker.last_global_turn_tokens {
181        turn_tokens - tracker.last_global_turn_tokens
182    } else {
183        turn_tokens
184    };
185
186    let diminishing_returns = tracker.continuation_count >= DIMINISHING_RETURNS_THRESHOLD
187        && current_delta < LOW_PRODUCTION_TOKENS
188        && tracker.last_delta_tokens < LOW_PRODUCTION_TOKENS;
189
190    let pct = if budget > 0.0 {
191        (turn_tokens as f64 / budget)
192    } else {
193        1.0
194    };
195
196    if pct < COMPLETION_THRESHOLD && !diminishing_returns {
197        tracker.continuation_count += 1;
198        tracker.last_delta_tokens = current_delta;
199        tracker.last_global_turn_tokens = turn_tokens;
200        return TokenBudgetDecision::Continue {
201            nudge_message: get_budget_continuation_message((pct * 100.0) as u64 as f64, turn_tokens, budget),
202        };
203    }
204
205    let completion = if tracker.continuation_count > 0 || diminishing_returns {
206        Some(TokenBudgetCompletion {
207            pct,
208            tokens: turn_tokens,
209            budget,
210            continuation_count: tracker.continuation_count,
211            duration_ms: tracker.started_at.elapsed().as_millis() as u64,
212            diminishing_returns,
213        })
214    } else {
215        None
216    };
217
218    TokenBudgetDecision::Stop { completion }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_parse_token_budget_shorthand_start() {
227        assert_eq!(parse_token_budget("+500k"), Some(500_000.0));
228        assert_eq!(parse_token_budget("+2m"), Some(2_000_000.0));
229        assert_eq!(parse_token_budget("+1.5b"), Some(1_500_000_000.0));
230    }
231
232    #[test]
233    fn test_parse_token_budget_shorthand_end() {
234        assert_eq!(parse_token_budget("I want +500k."), Some(500_000.0));
235    }
236
237    #[test]
238    fn test_parse_token_budget_verbose() {
239        assert_eq!(parse_token_budget("use 2M tokens"), Some(2_000_000.0));
240        assert_eq!(parse_token_budget("spend 500k tokens"), Some(500_000.0));
241    }
242
243    #[test]
244    fn test_parse_token_budget_none() {
245        assert!(parse_token_budget("hello world").is_none());
246    }
247
248    #[test]
249    fn test_find_positions() {
250        let positions = find_token_budget_positions("+500k");
251        assert!(!positions.is_empty());
252    }
253
254    #[test]
255    fn test_budget_continuation_message() {
256        let msg = get_budget_continuation_message(80.0, 160000, 200000.0);
257        assert!(msg.contains("80%"));
258        assert!(msg.contains("160000"));
259        assert!(msg.contains("200000"));
260    }
261
262    #[test]
263    fn test_budget_tracker_new() {
264        let t = BudgetTracker::new();
265        assert_eq!(t.continuation_count, 0);
266        assert_eq!(t.last_delta_tokens, 0);
267    }
268
269    #[test]
270    fn test_check_no_budget() {
271        let mut t = BudgetTracker::new();
272        let d = check_token_budget(&mut t, None, None, 100);
273        assert!(matches!(d, TokenBudgetDecision::Stop { completion: None }));
274
275        let d2 = check_token_budget(&mut t, None, Some(0.0), 100);
276        assert!(matches!(d2, TokenBudgetDecision::Stop { completion: None }));
277    }
278
279    #[test]
280    fn test_check_subagent_skips_budget() {
281        let mut t = BudgetTracker::new();
282        let d = check_token_budget(&mut t, Some("sub1"), Some(5_000.0), 100);
283        assert!(matches!(d, TokenBudgetDecision::Stop { completion: None }));
284    }
285
286    #[test]
287    fn test_check_continue_under_threshold() {
288        let mut t = BudgetTracker::new();
289        let d = check_token_budget(&mut t, None, Some(5_000.0), 100);
290        match d {
291            TokenBudgetDecision::Continue { nudge_message } => {
292                assert!(nudge_message.contains("Keep working"));
293            }
294            other => panic!("Expected Continue, got {:?}", other),
295        }
296        assert_eq!(t.continuation_count, 1);
297    }
298
299    #[test]
300    fn test_check_stop_over_threshold() {
301        let mut t = BudgetTracker::new();
302        let d = check_token_budget(&mut t, None, Some(5_000.0), 5_000);
303        match d {
304            TokenBudgetDecision::Stop { completion } => {
305                assert!(completion.is_none());
306            }
307            other => panic!("Expected Stop, got {:?}", other),
308        }
309    }
310
311    #[test]
312    fn test_check_continuation_then_stop() {
313        let mut t = BudgetTracker::new();
314        let d = check_token_budget(&mut t, None, Some(5_000.0), 100);
315        assert!(matches!(d, TokenBudgetDecision::Continue { .. }));
316
317        let d = check_token_budget(&mut t, None, Some(5_000.0), 4_800);
318        match d {
319            TokenBudgetDecision::Stop { completion } => {
320                let c = completion.expect("should have completion");
321                assert!(c.pct >= 0.9);
322                assert_eq!(c.continuation_count, 1);
323                assert!(!c.diminishing_returns);
324            }
325            other => panic!("Expected Stop, got {:?}", other),
326        }
327    }
328
329    #[test]
330    fn test_check_diminishing_returns() {
331        let mut t = BudgetTracker::new();
332        for _ in 0..3 {
333            let tokens = t.last_global_turn_tokens + 100;
334            let d = check_token_budget(&mut t, None, Some(10_000.0), tokens);
335            assert!(matches!(d, TokenBudgetDecision::Continue { .. }));
336        }
337        let tokens = t.last_global_turn_tokens + 100;
338        let d = check_token_budget(&mut t, None, Some(10_000.0), tokens);
339        match d {
340            TokenBudgetDecision::Stop { completion } => {
341                let c = completion.expect("should have completion");
342                assert!(c.diminishing_returns);
343            }
344            TokenBudgetDecision::Continue { .. } => {
345                panic!("Expected Stop due to diminishing returns");
346            }
347        }
348    }
349}