1use serde::{Deserialize, Serialize};
18use std::time::{Duration, Instant};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BudgetConfig {
23 pub time_limit: Option<Duration>,
25
26 pub token_limit: Option<u32>,
28
29 pub cost_limit: Option<f64>,
31
32 #[serde(default)]
34 pub strategy: BudgetStrategy,
35
36 #[serde(default = "default_adapt_threshold")]
38 pub adapt_threshold: f64,
39}
40
41fn default_adapt_threshold() -> f64 {
42 0.7 }
44
45#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum BudgetStrategy {
49 Strict,
51 #[default]
53 Adaptive,
54 BestEffort,
56}
57
58impl Default for BudgetConfig {
59 fn default() -> Self {
60 Self {
61 time_limit: None,
62 token_limit: None,
63 cost_limit: None,
64 strategy: BudgetStrategy::default(),
65 adapt_threshold: default_adapt_threshold(),
66 }
67 }
68}
69
70impl BudgetConfig {
71 pub fn unlimited() -> Self {
73 Self::default()
74 }
75
76 pub fn with_time(duration: Duration) -> Self {
78 Self {
79 time_limit: Some(duration),
80 ..Default::default()
81 }
82 }
83
84 pub fn with_tokens(limit: u32) -> Self {
86 Self {
87 token_limit: Some(limit),
88 ..Default::default()
89 }
90 }
91
92 pub fn with_cost(usd: f64) -> Self {
94 Self {
95 cost_limit: Some(usd),
96 ..Default::default()
97 }
98 }
99
100 pub fn with_strategy(mut self, strategy: BudgetStrategy) -> Self {
102 self.strategy = strategy;
103 self
104 }
105
106 pub fn is_constrained(&self) -> bool {
108 self.time_limit.is_some() || self.token_limit.is_some() || self.cost_limit.is_some()
109 }
110
111 pub fn parse(budget_str: &str) -> Result<Self, BudgetParseError> {
113 let budget_str = budget_str.trim();
114
115 if budget_str.is_empty() {
116 return Err(BudgetParseError::Empty);
117 }
118
119 if let Some(cost) = budget_str.strip_prefix('$') {
121 let usd: f64 = cost.parse().map_err(|_| BudgetParseError::InvalidCost)?;
122 return Ok(Self::with_cost(usd));
123 }
124
125 if budget_str.ends_with('t') || budget_str.ends_with("tokens") {
127 let num_str = budget_str.trim_end_matches("tokens").trim_end_matches('t');
128 let tokens: u32 = num_str
129 .parse()
130 .map_err(|_| BudgetParseError::InvalidTokens)?;
131 return Ok(Self::with_tokens(tokens));
132 }
133
134 if let Some(secs) = budget_str.strip_suffix('s') {
136 let seconds: u64 = secs.parse().map_err(|_| BudgetParseError::InvalidTime)?;
137 return Ok(Self::with_time(Duration::from_secs(seconds)));
138 }
139
140 if let Some(mins) = budget_str.strip_suffix('m') {
141 let minutes: u64 = mins.parse().map_err(|_| BudgetParseError::InvalidTime)?;
142 return Ok(Self::with_time(Duration::from_secs(minutes * 60)));
143 }
144
145 if let Some(hours) = budget_str.strip_suffix('h') {
146 let hours_val: u64 = hours.parse().map_err(|_| BudgetParseError::InvalidTime)?;
147 return Ok(Self::with_time(Duration::from_secs(hours_val * 3600)));
148 }
149
150 Err(BudgetParseError::UnknownFormat(budget_str.to_string()))
151 }
152}
153
154#[derive(Debug, Clone)]
156pub enum BudgetParseError {
157 Empty,
159 InvalidTime,
161 InvalidTokens,
163 InvalidCost,
165 UnknownFormat(String),
167}
168
169impl std::fmt::Display for BudgetParseError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 match self {
172 BudgetParseError::Empty => write!(f, "Budget string is empty"),
173 BudgetParseError::InvalidTime => write!(f, "Invalid time format (use Xs, Xm, or Xh)"),
174 BudgetParseError::InvalidTokens => write!(f, "Invalid token count (use Xt or Xtokens)"),
175 BudgetParseError::InvalidCost => write!(f, "Invalid cost format (use $X.XX)"),
176 BudgetParseError::UnknownFormat(s) => write!(f, "Unknown budget format: {}", s),
177 }
178 }
179}
180
181impl std::error::Error for BudgetParseError {}
182
183#[derive(Debug, Clone)]
185pub struct BudgetTracker {
186 config: BudgetConfig,
188
189 start_time: Instant,
191
192 tokens_used: u32,
194
195 cost_incurred: f64,
197
198 steps_completed: usize,
200
201 steps_skipped: usize,
203}
204
205impl BudgetTracker {
206 pub fn new(config: BudgetConfig) -> Self {
208 Self {
209 config,
210 start_time: Instant::now(),
211 tokens_used: 0,
212 cost_incurred: 0.0,
213 steps_completed: 0,
214 steps_skipped: 0,
215 }
216 }
217
218 pub fn record_usage(&mut self, tokens: u32, cost: f64) {
220 self.tokens_used += tokens;
221 self.cost_incurred += cost;
222 self.steps_completed += 1;
223 }
224
225 pub fn record_skip(&mut self) {
227 self.steps_skipped += 1;
228 }
229
230 pub fn elapsed(&self) -> Duration {
232 self.start_time.elapsed()
233 }
234
235 pub fn time_remaining(&self) -> Option<Duration> {
237 self.config
238 .time_limit
239 .map(|limit| limit.saturating_sub(self.elapsed()))
240 }
241
242 pub fn tokens_remaining(&self) -> Option<u32> {
244 self.config
245 .token_limit
246 .map(|limit| limit.saturating_sub(self.tokens_used))
247 }
248
249 pub fn cost_remaining(&self) -> Option<f64> {
251 self.config
252 .cost_limit
253 .map(|limit| (limit - self.cost_incurred).max(0.0))
254 }
255
256 pub fn is_exhausted(&self) -> bool {
258 if let Some(remaining) = self.time_remaining() {
259 if remaining.is_zero() {
260 return true;
261 }
262 }
263 if let Some(remaining) = self.tokens_remaining() {
264 if remaining == 0 {
265 return true;
266 }
267 }
268 if let Some(remaining) = self.cost_remaining() {
269 if remaining <= 0.0 {
270 return true;
271 }
272 }
273 false
274 }
275
276 pub fn usage_ratio(&self) -> f64 {
278 let mut max_ratio = 0.0f64;
279
280 if let Some(limit) = self.config.time_limit {
281 let ratio = self.elapsed().as_secs_f64() / limit.as_secs_f64();
282 max_ratio = max_ratio.max(ratio);
283 }
284
285 if let Some(limit) = self.config.token_limit {
286 let ratio = self.tokens_used as f64 / limit as f64;
287 max_ratio = max_ratio.max(ratio);
288 }
289
290 if let Some(limit) = self.config.cost_limit {
291 let ratio = self.cost_incurred / limit;
292 max_ratio = max_ratio.max(ratio);
293 }
294
295 max_ratio.min(1.0)
296 }
297
298 pub fn should_adapt(&self) -> bool {
300 self.usage_ratio() >= self.config.adapt_threshold
301 }
302
303 pub fn adaptive_max_tokens(&self, requested: u32) -> u32 {
305 if let Some(remaining) = self.tokens_remaining() {
306 let reserve = remaining / 4;
308 return requested.min(remaining - reserve);
309 }
310 requested
311 }
312
313 pub fn should_skip_step(&self, is_essential: bool) -> bool {
315 if is_essential {
316 return false;
317 }
318
319 match self.config.strategy {
320 BudgetStrategy::Strict => self.is_exhausted(),
321 BudgetStrategy::Adaptive => self.usage_ratio() > 0.9,
322 BudgetStrategy::BestEffort => false,
323 }
324 }
325
326 pub fn summary(&self) -> BudgetSummary {
328 BudgetSummary {
329 elapsed: self.elapsed(),
330 tokens_used: self.tokens_used,
331 cost_incurred: self.cost_incurred,
332 steps_completed: self.steps_completed,
333 steps_skipped: self.steps_skipped,
334 usage_ratio: self.usage_ratio(),
335 exhausted: self.is_exhausted(),
336 }
337 }
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct BudgetSummary {
343 #[serde(with = "duration_serde")]
345 pub elapsed: Duration,
346
347 pub tokens_used: u32,
349
350 pub cost_incurred: f64,
352
353 pub steps_completed: usize,
355
356 pub steps_skipped: usize,
358
359 pub usage_ratio: f64,
361
362 pub exhausted: bool,
364}
365
366mod duration_serde {
367 use serde::{Deserialize, Deserializer, Serialize, Serializer};
368 use std::time::Duration;
369
370 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
371 where
372 S: Serializer,
373 {
374 duration.as_millis().serialize(serializer)
375 }
376
377 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
378 where
379 D: Deserializer<'de>,
380 {
381 let millis = u64::deserialize(deserializer)?;
382 Ok(Duration::from_millis(millis))
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_parse_time_seconds() {
392 let budget = BudgetConfig::parse("30s").unwrap();
393 assert_eq!(budget.time_limit, Some(Duration::from_secs(30)));
394 }
395
396 #[test]
397 fn test_parse_time_minutes() {
398 let budget = BudgetConfig::parse("5m").unwrap();
399 assert_eq!(budget.time_limit, Some(Duration::from_secs(300)));
400 }
401
402 #[test]
403 fn test_parse_tokens() {
404 let budget = BudgetConfig::parse("1000t").unwrap();
405 assert_eq!(budget.token_limit, Some(1000));
406 }
407
408 #[test]
409 fn test_parse_tokens_full() {
410 let budget = BudgetConfig::parse("5000tokens").unwrap();
411 assert_eq!(budget.token_limit, Some(5000));
412 }
413
414 #[test]
415 fn test_parse_cost() {
416 let budget = BudgetConfig::parse("$0.50").unwrap();
417 assert_eq!(budget.cost_limit, Some(0.50));
418 }
419
420 #[test]
421 fn test_budget_tracker_usage() {
422 let config = BudgetConfig::with_tokens(1000);
423 let mut tracker = BudgetTracker::new(config);
424
425 tracker.record_usage(200, 0.01);
426 assert_eq!(tracker.tokens_remaining(), Some(800));
427 assert!(!tracker.is_exhausted());
428
429 tracker.record_usage(800, 0.04);
430 assert_eq!(tracker.tokens_remaining(), Some(0));
431 assert!(tracker.is_exhausted());
432 }
433
434 #[test]
435 fn test_budget_tracker_adapt() {
436 let config = BudgetConfig::with_tokens(1000);
437 let mut tracker = BudgetTracker::new(config);
438
439 tracker.record_usage(600, 0.03);
440 assert!(!tracker.should_adapt()); tracker.record_usage(150, 0.01);
443 assert!(tracker.should_adapt()); }
445
446 #[test]
447 fn test_adaptive_max_tokens() {
448 let config = BudgetConfig::with_tokens(1000);
449 let mut tracker = BudgetTracker::new(config);
450
451 assert_eq!(tracker.adaptive_max_tokens(500), 500);
453
454 tracker.record_usage(800, 0.04);
456 assert!(tracker.adaptive_max_tokens(500) < 200);
457 }
458}