reasonkit/thinktool/
budget.rs

1//! Budget Configuration for Adaptive Compute Time
2//!
3//! Enables --budget flag for controlling reasoning execution within constraints.
4//!
5//! ## Budget Types
6//! - Time budget: Maximum wall-clock time (e.g., "10s", "2m")
7//! - Token budget: Maximum tokens to consume
8//! - Cost budget: Maximum USD to spend (e.g., "$0.10")
9//!
10//! ## Adaptive Behavior
11//! When budget is constrained, the executor will:
12//! 1. Skip optional steps if time is running low
13//! 2. Reduce max_tokens for remaining steps
14//! 3. Use faster/cheaper model tiers when approaching limit
15//! 4. Terminate early if budget is exhausted
16
17use serde::{Deserialize, Serialize};
18use std::time::{Duration, Instant};
19
20/// Budget configuration for protocol execution
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BudgetConfig {
23    /// Maximum execution time
24    pub time_limit: Option<Duration>,
25
26    /// Maximum tokens to consume
27    pub token_limit: Option<u32>,
28
29    /// Maximum cost in USD
30    pub cost_limit: Option<f64>,
31
32    /// Strategy when approaching budget limits
33    #[serde(default)]
34    pub strategy: BudgetStrategy,
35
36    /// Percentage of budget at which to start adapting (0.0-1.0)
37    #[serde(default = "default_adapt_threshold")]
38    pub adapt_threshold: f64,
39}
40
41fn default_adapt_threshold() -> f64 {
42    0.7 // Start adapting at 70% budget usage
43}
44
45/// Strategy for handling budget constraints
46#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum BudgetStrategy {
49    /// Strict: Fail if budget would be exceeded
50    Strict,
51    /// Adaptive: Reduce quality/scope to stay within budget (default)
52    #[default]
53    Adaptive,
54    /// BestEffort: Try to complete as much as possible, may exceed budget
55    BestEffort,
56}
57
58impl Default for BudgetConfig {
59    fn default() -> Self {
60        Self {
61            time_limit: None,
62            token_limit: None,
63            cost_limit: None,
64            strategy: BudgetStrategy::default(),
65            adapt_threshold: default_adapt_threshold(),
66        }
67    }
68}
69
70impl BudgetConfig {
71    /// Create an unlimited budget (no constraints)
72    pub fn unlimited() -> Self {
73        Self::default()
74    }
75
76    /// Create a time-limited budget
77    pub fn with_time(duration: Duration) -> Self {
78        Self {
79            time_limit: Some(duration),
80            ..Default::default()
81        }
82    }
83
84    /// Create a token-limited budget
85    pub fn with_tokens(limit: u32) -> Self {
86        Self {
87            token_limit: Some(limit),
88            ..Default::default()
89        }
90    }
91
92    /// Create a cost-limited budget
93    pub fn with_cost(usd: f64) -> Self {
94        Self {
95            cost_limit: Some(usd),
96            ..Default::default()
97        }
98    }
99
100    /// Set budget strategy
101    pub fn with_strategy(mut self, strategy: BudgetStrategy) -> Self {
102        self.strategy = strategy;
103        self
104    }
105
106    /// Check if any limits are set
107    pub fn is_constrained(&self) -> bool {
108        self.time_limit.is_some() || self.token_limit.is_some() || self.cost_limit.is_some()
109    }
110
111    /// Parse a budget string like "10s", "1000t", "$0.50"
112    pub fn parse(budget_str: &str) -> Result<Self, BudgetParseError> {
113        let budget_str = budget_str.trim();
114
115        if budget_str.is_empty() {
116            return Err(BudgetParseError::Empty);
117        }
118
119        // Cost budget: $X.XX
120        if let Some(cost) = budget_str.strip_prefix('$') {
121            let usd: f64 = cost.parse().map_err(|_| BudgetParseError::InvalidCost)?;
122            return Ok(Self::with_cost(usd));
123        }
124
125        // Token budget: XXXt or XXXtokens
126        if budget_str.ends_with('t') || budget_str.ends_with("tokens") {
127            let num_str = budget_str.trim_end_matches("tokens").trim_end_matches('t');
128            let tokens: u32 = num_str
129                .parse()
130                .map_err(|_| BudgetParseError::InvalidTokens)?;
131            return Ok(Self::with_tokens(tokens));
132        }
133
134        // Time budget: Xs, Xm, Xh
135        if let Some(secs) = budget_str.strip_suffix('s') {
136            let seconds: u64 = secs.parse().map_err(|_| BudgetParseError::InvalidTime)?;
137            return Ok(Self::with_time(Duration::from_secs(seconds)));
138        }
139
140        if let Some(mins) = budget_str.strip_suffix('m') {
141            let minutes: u64 = mins.parse().map_err(|_| BudgetParseError::InvalidTime)?;
142            return Ok(Self::with_time(Duration::from_secs(minutes * 60)));
143        }
144
145        if let Some(hours) = budget_str.strip_suffix('h') {
146            let hours_val: u64 = hours.parse().map_err(|_| BudgetParseError::InvalidTime)?;
147            return Ok(Self::with_time(Duration::from_secs(hours_val * 3600)));
148        }
149
150        Err(BudgetParseError::UnknownFormat(budget_str.to_string()))
151    }
152}
153
154/// Error parsing budget string
155#[derive(Debug, Clone)]
156pub enum BudgetParseError {
157    /// Budget string is empty
158    Empty,
159    /// Invalid time format
160    InvalidTime,
161    /// Invalid token count
162    InvalidTokens,
163    /// Invalid cost value
164    InvalidCost,
165    /// Unknown format string
166    UnknownFormat(String),
167}
168
169impl std::fmt::Display for BudgetParseError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        match self {
172            BudgetParseError::Empty => write!(f, "Budget string is empty"),
173            BudgetParseError::InvalidTime => write!(f, "Invalid time format (use Xs, Xm, or Xh)"),
174            BudgetParseError::InvalidTokens => write!(f, "Invalid token count (use Xt or Xtokens)"),
175            BudgetParseError::InvalidCost => write!(f, "Invalid cost format (use $X.XX)"),
176            BudgetParseError::UnknownFormat(s) => write!(f, "Unknown budget format: {}", s),
177        }
178    }
179}
180
181impl std::error::Error for BudgetParseError {}
182
183/// Runtime budget tracker
184#[derive(Debug, Clone)]
185pub struct BudgetTracker {
186    /// Configuration
187    config: BudgetConfig,
188
189    /// When execution started
190    start_time: Instant,
191
192    /// Tokens consumed so far
193    tokens_used: u32,
194
195    /// Cost incurred so far (USD)
196    cost_incurred: f64,
197
198    /// Steps completed
199    steps_completed: usize,
200
201    /// Steps skipped due to budget
202    steps_skipped: usize,
203}
204
205impl BudgetTracker {
206    /// Create a new budget tracker
207    pub fn new(config: BudgetConfig) -> Self {
208        Self {
209            config,
210            start_time: Instant::now(),
211            tokens_used: 0,
212            cost_incurred: 0.0,
213            steps_completed: 0,
214            steps_skipped: 0,
215        }
216    }
217
218    /// Record token and cost usage
219    pub fn record_usage(&mut self, tokens: u32, cost: f64) {
220        self.tokens_used += tokens;
221        self.cost_incurred += cost;
222        self.steps_completed += 1;
223    }
224
225    /// Record a skipped step
226    pub fn record_skip(&mut self) {
227        self.steps_skipped += 1;
228    }
229
230    /// Get elapsed time
231    pub fn elapsed(&self) -> Duration {
232        self.start_time.elapsed()
233    }
234
235    /// Get time remaining (if time limit set)
236    pub fn time_remaining(&self) -> Option<Duration> {
237        self.config
238            .time_limit
239            .map(|limit| limit.saturating_sub(self.elapsed()))
240    }
241
242    /// Get tokens remaining (if token limit set)
243    pub fn tokens_remaining(&self) -> Option<u32> {
244        self.config
245            .token_limit
246            .map(|limit| limit.saturating_sub(self.tokens_used))
247    }
248
249    /// Get cost remaining (if cost limit set)
250    pub fn cost_remaining(&self) -> Option<f64> {
251        self.config
252            .cost_limit
253            .map(|limit| (limit - self.cost_incurred).max(0.0))
254    }
255
256    /// Check if budget is exhausted
257    pub fn is_exhausted(&self) -> bool {
258        if let Some(remaining) = self.time_remaining() {
259            if remaining.is_zero() {
260                return true;
261            }
262        }
263        if let Some(remaining) = self.tokens_remaining() {
264            if remaining == 0 {
265                return true;
266            }
267        }
268        if let Some(remaining) = self.cost_remaining() {
269            if remaining <= 0.0 {
270                return true;
271            }
272        }
273        false
274    }
275
276    /// Calculate budget usage ratio (0.0-1.0)
277    pub fn usage_ratio(&self) -> f64 {
278        let mut max_ratio = 0.0f64;
279
280        if let Some(limit) = self.config.time_limit {
281            let ratio = self.elapsed().as_secs_f64() / limit.as_secs_f64();
282            max_ratio = max_ratio.max(ratio);
283        }
284
285        if let Some(limit) = self.config.token_limit {
286            let ratio = self.tokens_used as f64 / limit as f64;
287            max_ratio = max_ratio.max(ratio);
288        }
289
290        if let Some(limit) = self.config.cost_limit {
291            let ratio = self.cost_incurred / limit;
292            max_ratio = max_ratio.max(ratio);
293        }
294
295        max_ratio.min(1.0)
296    }
297
298    /// Check if we should start adapting (approaching limit)
299    pub fn should_adapt(&self) -> bool {
300        self.usage_ratio() >= self.config.adapt_threshold
301    }
302
303    /// Get adaptive max_tokens based on remaining budget
304    pub fn adaptive_max_tokens(&self, requested: u32) -> u32 {
305        if let Some(remaining) = self.tokens_remaining() {
306            // Reserve some tokens for remaining steps
307            let reserve = remaining / 4;
308            return requested.min(remaining - reserve);
309        }
310        requested
311    }
312
313    /// Check if step should be skipped (non-essential and low budget)
314    pub fn should_skip_step(&self, is_essential: bool) -> bool {
315        if is_essential {
316            return false;
317        }
318
319        match self.config.strategy {
320            BudgetStrategy::Strict => self.is_exhausted(),
321            BudgetStrategy::Adaptive => self.usage_ratio() > 0.9,
322            BudgetStrategy::BestEffort => false,
323        }
324    }
325
326    /// Get budget summary
327    pub fn summary(&self) -> BudgetSummary {
328        BudgetSummary {
329            elapsed: self.elapsed(),
330            tokens_used: self.tokens_used,
331            cost_incurred: self.cost_incurred,
332            steps_completed: self.steps_completed,
333            steps_skipped: self.steps_skipped,
334            usage_ratio: self.usage_ratio(),
335            exhausted: self.is_exhausted(),
336        }
337    }
338}
339
340/// Summary of budget usage
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct BudgetSummary {
343    /// Time elapsed
344    #[serde(with = "duration_serde")]
345    pub elapsed: Duration,
346
347    /// Tokens consumed
348    pub tokens_used: u32,
349
350    /// Cost in USD
351    pub cost_incurred: f64,
352
353    /// Steps completed
354    pub steps_completed: usize,
355
356    /// Steps skipped due to budget
357    pub steps_skipped: usize,
358
359    /// Usage ratio (0.0-1.0)
360    pub usage_ratio: f64,
361
362    /// Whether budget was exhausted
363    pub exhausted: bool,
364}
365
366mod duration_serde {
367    use serde::{Deserialize, Deserializer, Serialize, Serializer};
368    use std::time::Duration;
369
370    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
371    where
372        S: Serializer,
373    {
374        duration.as_millis().serialize(serializer)
375    }
376
377    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
378    where
379        D: Deserializer<'de>,
380    {
381        let millis = u64::deserialize(deserializer)?;
382        Ok(Duration::from_millis(millis))
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_parse_time_seconds() {
392        let budget = BudgetConfig::parse("30s").unwrap();
393        assert_eq!(budget.time_limit, Some(Duration::from_secs(30)));
394    }
395
396    #[test]
397    fn test_parse_time_minutes() {
398        let budget = BudgetConfig::parse("5m").unwrap();
399        assert_eq!(budget.time_limit, Some(Duration::from_secs(300)));
400    }
401
402    #[test]
403    fn test_parse_tokens() {
404        let budget = BudgetConfig::parse("1000t").unwrap();
405        assert_eq!(budget.token_limit, Some(1000));
406    }
407
408    #[test]
409    fn test_parse_tokens_full() {
410        let budget = BudgetConfig::parse("5000tokens").unwrap();
411        assert_eq!(budget.token_limit, Some(5000));
412    }
413
414    #[test]
415    fn test_parse_cost() {
416        let budget = BudgetConfig::parse("$0.50").unwrap();
417        assert_eq!(budget.cost_limit, Some(0.50));
418    }
419
420    #[test]
421    fn test_budget_tracker_usage() {
422        let config = BudgetConfig::with_tokens(1000);
423        let mut tracker = BudgetTracker::new(config);
424
425        tracker.record_usage(200, 0.01);
426        assert_eq!(tracker.tokens_remaining(), Some(800));
427        assert!(!tracker.is_exhausted());
428
429        tracker.record_usage(800, 0.04);
430        assert_eq!(tracker.tokens_remaining(), Some(0));
431        assert!(tracker.is_exhausted());
432    }
433
434    #[test]
435    fn test_budget_tracker_adapt() {
436        let config = BudgetConfig::with_tokens(1000);
437        let mut tracker = BudgetTracker::new(config);
438
439        tracker.record_usage(600, 0.03);
440        assert!(!tracker.should_adapt()); // 60% < 70% threshold
441
442        tracker.record_usage(150, 0.01);
443        assert!(tracker.should_adapt()); // 75% > 70% threshold
444    }
445
446    #[test]
447    fn test_adaptive_max_tokens() {
448        let config = BudgetConfig::with_tokens(1000);
449        let mut tracker = BudgetTracker::new(config);
450
451        // Initially can request full amount
452        assert_eq!(tracker.adaptive_max_tokens(500), 500);
453
454        // After using some, it limits
455        tracker.record_usage(800, 0.04);
456        assert!(tracker.adaptive_max_tokens(500) < 200);
457    }
458}