1use crate::time::Instant;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
16use std::time::Duration;
17
18tokio::task_local! {
24 pub static BUDGET_TRACKER: Arc<BudgetTracker>;
25}
26
27#[derive(Clone, Default)]
32pub enum BudgetExceededAction {
33 #[default]
35 Terminate,
36
37 Interrupt,
39
40 Custom(std::sync::Arc<dyn Fn(BudgetUsage) -> BudgetExceededAction + Send + Sync>),
42}
43
44impl std::fmt::Debug for BudgetExceededAction {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::Terminate => write!(f, "Terminate"),
48 Self::Interrupt => write!(f, "Interrupt"),
49 Self::Custom(_) => write!(f, "Custom(<fn>)"),
50 }
51 }
52}
53
54#[derive(Clone, Default)]
58pub struct BudgetConfig {
59 pub max_tokens: Option<u64>,
61
62 pub max_cost_usd: Option<f64>,
64
65 pub max_duration: Option<Duration>,
67
68 pub max_steps: Option<usize>,
70
71 pub on_exceeded: BudgetExceededAction,
73}
74
75impl std::fmt::Debug for BudgetConfig {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("BudgetConfig")
78 .field("max_tokens", &self.max_tokens)
79 .field("max_cost_usd", &self.max_cost_usd)
80 .field("max_duration", &self.max_duration)
81 .field("max_steps", &self.max_steps)
82 .field("on_exceeded", &self.on_exceeded)
83 .finish()
84 }
85}
86
87impl BudgetConfig {
88 #[must_use]
90 pub fn new() -> Self {
91 Self {
92 max_tokens: None,
93 max_cost_usd: None,
94 max_duration: None,
95 max_steps: None,
96 on_exceeded: BudgetExceededAction::default(),
97 }
98 }
99
100 #[must_use]
102 pub const fn with_max_tokens(mut self, tokens: u64) -> Self {
103 self.max_tokens = Some(tokens);
104 self
105 }
106
107 #[must_use]
109 pub const fn with_max_cost_usd(mut self, cost: f64) -> Self {
110 self.max_cost_usd = Some(cost);
111 self
112 }
113
114 #[must_use]
116 pub const fn with_max_duration(mut self, duration: Duration) -> Self {
117 self.max_duration = Some(duration);
118 self
119 }
120
121 #[must_use]
123 pub const fn with_max_steps(mut self, steps: usize) -> Self {
124 self.max_steps = Some(steps);
125 self
126 }
127
128 #[must_use]
130 pub const fn has_limits(&self) -> bool {
131 self.max_tokens.is_some()
132 || self.max_cost_usd.is_some()
133 || self.max_duration.is_some()
134 || self.max_steps.is_some()
135 }
136}
137
138pub struct BudgetTracker {
142 tokens_used: AtomicU64,
144
145 cost_usd_micros: AtomicU64,
147
148 start_time: Instant,
150
151 steps_completed: AtomicUsize,
153
154 config: BudgetConfig,
156
157 metrics_collector: Option<std::sync::Arc<dyn crate::observability::MetricsCollector>>,
159}
160
161impl std::fmt::Debug for BudgetTracker {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("BudgetTracker")
164 .field("tokens_used", &self.tokens_used)
165 .field("cost_usd_micros", &self.cost_usd_micros)
166 .field("start_time", &self.start_time)
167 .field("steps_completed", &self.steps_completed)
168 .field("config", &self.config)
169 .field(
170 "metrics_collector",
171 &self.metrics_collector.as_ref().map(|_| "<Arc>"),
172 )
173 .finish()
174 }
175}
176
177impl BudgetTracker {
178 #[must_use]
192 pub fn new(config: BudgetConfig) -> Self {
193 Self {
194 tokens_used: AtomicU64::new(0),
195 cost_usd_micros: AtomicU64::new(0),
196 start_time: Instant::now(),
197 steps_completed: AtomicUsize::new(0),
198 config,
199 metrics_collector: None,
200 }
201 }
202
203 #[must_use]
205 pub fn with_metrics_collector(
206 mut self,
207 collector: Option<std::sync::Arc<dyn crate::observability::MetricsCollector>>,
208 ) -> Self {
209 self.metrics_collector = collector;
210 self
211 }
212
213 pub fn report_tokens(&self, tokens: u64) {
225 self.tokens_used.fetch_add(tokens, Ordering::Relaxed);
226
227 if let Some(ref collector) = self.metrics_collector {
229 collector.inc_counter("juncture.llm.tokens.input", tokens);
230 }
231 }
232
233 pub fn report_output_tokens(&self, tokens: u64) {
238 if let Some(ref collector) = self.metrics_collector {
240 collector.inc_counter("juncture.llm.tokens.output", tokens);
241 }
242 }
243
244 #[allow(
259 clippy::cast_sign_loss,
260 clippy::cast_possible_truncation,
261 reason = "cost values are expected to be positive and within reasonable bounds"
262 )]
263 pub fn report_cost(&self, cost_usd: f64) {
264 let cost_micros = (cost_usd * 1_000_000.0) as u64;
266 self.cost_usd_micros
267 .fetch_add(cost_micros, Ordering::Relaxed);
268
269 if let Some(ref collector) = self.metrics_collector {
271 collector.inc_counter("juncture.llm.cost_usd", cost_micros);
272 }
273 }
274
275 pub fn report_step(&self) {
288 self.steps_completed.fetch_add(1, Ordering::Relaxed);
289 }
290
291 pub fn report_llm_call(&self) {
295 if let Some(ref collector) = self.metrics_collector {
296 collector.inc_counter("juncture.llm.calls", 1);
297 }
298 }
299
300 #[allow(
304 clippy::cast_precision_loss,
305 reason = "milliseconds as f64 is sufficient for histogram metrics"
306 )]
307 pub fn report_llm_duration(&self, duration_ms: u64) {
308 if let Some(ref collector) = self.metrics_collector {
309 collector.record_histogram("juncture.llm.duration_ms", duration_ms as f64);
310 }
311 }
312
313 pub fn report_tool_call(&self) {
317 if let Some(ref collector) = self.metrics_collector {
318 collector.inc_counter("juncture.tool.calls", 1);
319 }
320 }
321
322 pub fn report_tool_error(&self) {
326 if let Some(ref collector) = self.metrics_collector {
327 collector.inc_counter("juncture.tool.errors", 1);
328 }
329 }
330
331 #[allow(
335 clippy::cast_precision_loss,
336 reason = "milliseconds as f64 is sufficient for histogram metrics"
337 )]
338 pub fn report_tool_duration(&self, duration_ms: u64) {
339 if let Some(ref collector) = self.metrics_collector {
340 collector.record_histogram("juncture.tool.duration_ms", duration_ms as f64);
341 }
342 }
343
344 pub fn report_usage(&self, tokens: u64, cost_usd: f64) {
357 self.report_tokens(tokens);
358 self.report_cost(cost_usd);
359 }
360
361 pub fn report_model_call(&self, input_tokens: u64, output_tokens: u64) {
377 self.tokens_used
378 .fetch_add(input_tokens + output_tokens, Ordering::Relaxed);
379 }
380
381 #[must_use]
399 pub fn check(&self) -> Option<BudgetExceededReason> {
400 if let Some(max_tokens) = self.config.max_tokens
402 && self.tokens_used.load(Ordering::Relaxed) > max_tokens
403 {
404 return Some(BudgetExceededReason::Tokens {
405 used: self.tokens_used.load(Ordering::Relaxed),
406 limit: max_tokens,
407 });
408 }
409
410 if let Some(max_cost) = self.config.max_cost_usd {
412 #[allow(
413 clippy::cast_precision_loss,
414 reason = "precision loss is acceptable for cost comparison"
415 )]
416 let cost_micros = self.cost_usd_micros.load(Ordering::Relaxed);
417 #[allow(
418 clippy::cast_precision_loss,
419 reason = "precision loss is acceptable for cost comparison"
420 )]
421 let cost_usd = cost_micros as f64 / 1_000_000.0;
422 if cost_usd > max_cost {
423 return Some(BudgetExceededReason::Cost {
424 used: cost_usd,
425 limit: max_cost,
426 });
427 }
428 }
429
430 if let Some(max_duration) = self.config.max_duration
432 && self.start_time.elapsed() > max_duration
433 {
434 return Some(BudgetExceededReason::Duration {
435 used: self.start_time.elapsed(),
436 limit: max_duration,
437 });
438 }
439
440 if let Some(max_steps) = self.config.max_steps
442 && self.steps_completed.load(Ordering::Relaxed) > max_steps
443 {
444 return Some(BudgetExceededReason::Steps {
445 used: self.steps_completed.load(Ordering::Relaxed),
446 limit: max_steps,
447 });
448 }
449
450 None
451 }
452
453 #[must_use]
472 pub fn current_usage(&self) -> BudgetUsage {
473 let cost_micros = self.cost_usd_micros.load(Ordering::Relaxed);
474 #[allow(
475 clippy::cast_precision_loss,
476 reason = "precision loss is acceptable for cost display"
477 )]
478 BudgetUsage {
479 tokens_used: self.tokens_used.load(Ordering::Relaxed),
480 cost_usd: cost_micros as f64 / 1_000_000.0,
481 duration: self.start_time.elapsed(),
482 steps_completed: self.steps_completed.load(Ordering::Relaxed),
483 }
484 }
485}
486
487#[derive(Clone, Debug)]
489pub enum BudgetExceededReason {
490 Tokens { used: u64, limit: u64 },
492
493 Cost { used: f64, limit: f64 },
495
496 Duration { used: Duration, limit: Duration },
498
499 Steps { used: usize, limit: usize },
501}
502
503impl std::fmt::Display for BudgetExceededReason {
504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505 match self {
506 Self::Tokens { used, limit } => {
507 write!(f, "Token budget exceeded: {used} > {limit}")
508 }
509 Self::Cost { used, limit } => {
510 write!(f, "Cost budget exceeded: ${used:.6} > ${limit:.6}")
511 }
512 Self::Duration { used, limit } => {
513 write!(f, "Duration budget exceeded: {used:?} > {limit:?}")
514 }
515 Self::Steps { used, limit } => {
516 write!(f, "Step budget exceeded: {used} > {limit}")
517 }
518 }
519 }
520}
521
522#[derive(Clone, Debug)]
524pub struct BudgetUsage {
525 pub tokens_used: u64,
527
528 pub cost_usd: f64,
530
531 pub duration: Duration,
533
534 pub steps_completed: usize,
536}
537
538pub fn try_report_model_call(
564 input_tokens: u64,
565 output_tokens: u64,
566) -> Result<(), BudgetReportError> {
567 BUDGET_TRACKER
568 .try_with(|tracker| {
569 tracker.report_model_call(input_tokens, output_tokens);
570 })
571 .map_err(|_err| BudgetReportError::NoTracker)
572}
573
574pub fn try_report_llm_call() -> Result<(), BudgetReportError> {
583 BUDGET_TRACKER
584 .try_with(|tracker| {
585 tracker.report_llm_call();
586 })
587 .map_err(|_err| BudgetReportError::NoTracker)
588}
589
590pub fn try_report_llm_duration(duration_ms: u64) -> Result<(), BudgetReportError> {
598 BUDGET_TRACKER
599 .try_with(|tracker| {
600 tracker.report_llm_duration(duration_ms);
601 })
602 .map_err(|_err| BudgetReportError::NoTracker)
603}
604
605pub fn try_report_tool_call() -> Result<(), BudgetReportError> {
613 BUDGET_TRACKER
614 .try_with(|tracker| {
615 tracker.report_tool_call();
616 })
617 .map_err(|_err| BudgetReportError::NoTracker)
618}
619
620pub fn try_report_tool_error() -> Result<(), BudgetReportError> {
628 BUDGET_TRACKER
629 .try_with(|tracker| {
630 tracker.report_tool_error();
631 })
632 .map_err(|_err| BudgetReportError::NoTracker)
633}
634
635pub fn try_report_tool_duration(duration_ms: u64) -> Result<(), BudgetReportError> {
643 BUDGET_TRACKER
644 .try_with(|tracker| {
645 tracker.report_tool_duration(duration_ms);
646 })
647 .map_err(|_err| BudgetReportError::NoTracker)
648}
649
650#[derive(Debug, Clone, Copy, PartialEq, Eq)]
655pub enum BudgetReportError {
656 NoTracker,
661}
662
663impl std::fmt::Display for BudgetReportError {
664 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
665 match self {
666 Self::NoTracker => write!(
667 f,
668 "Cannot report budget usage: no budget tracker in current context"
669 ),
670 }
671 }
672}
673
674impl std::error::Error for BudgetReportError {}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[test]
681 fn test_budget_config_no_limits() {
682 let config = BudgetConfig::new();
683 assert!(!config.has_limits());
684 }
685
686 #[test]
687 fn test_budget_config_with_limits() {
688 let config = BudgetConfig::new().with_max_tokens(1000).with_max_steps(10);
689
690 assert!(config.has_limits());
691 }
692
693 #[test]
694 fn test_budget_tracker_tokens() {
695 let config = BudgetConfig::new().with_max_tokens(100);
696 let tracker = BudgetTracker::new(config);
697
698 tracker.report_tokens(50);
699 assert!(tracker.check().is_none());
700
701 tracker.report_tokens(60);
702 assert!(tracker.check().is_some());
703
704 let usage = tracker.current_usage();
705 assert_eq!(usage.tokens_used, 110);
706 }
707
708 #[test]
709 fn test_budget_tracker_cost() {
710 let config = BudgetConfig::new().with_max_cost_usd(0.01);
711 let tracker = BudgetTracker::new(config);
712
713 tracker.report_cost(0.005);
714 assert!(tracker.check().is_none());
715
716 tracker.report_cost(0.006);
717 assert!(tracker.check().is_some());
718
719 let usage = tracker.current_usage();
720 assert!((usage.cost_usd - 0.011).abs() < 0.0001);
721 }
722
723 #[test]
724 fn test_budget_tracker_steps() {
725 let config = BudgetConfig::new().with_max_steps(5);
726 let tracker = BudgetTracker::new(config);
727
728 for _ in 0..5 {
729 tracker.report_step();
730 }
731 assert!(tracker.check().is_none());
732
733 tracker.report_step();
734 assert!(tracker.check().is_some());
735
736 let usage = tracker.current_usage();
737 assert_eq!(usage.steps_completed, 6);
738 }
739
740 #[test]
741 fn test_budget_tracker_model_call() {
742 let tracker = BudgetTracker::new(BudgetConfig::new());
743 assert_eq!(tracker.current_usage().tokens_used, 0);
744
745 tracker.report_model_call(50, 100);
746 assert_eq!(tracker.current_usage().tokens_used, 150);
747
748 tracker.report_model_call(10, 20);
749 assert_eq!(tracker.current_usage().tokens_used, 180);
750 }
751
752 #[test]
753 fn test_budget_tracker_model_call_exceeds_limit() {
754 let config = BudgetConfig::new().with_max_tokens(100);
755 let tracker = BudgetTracker::new(config);
756
757 assert!(tracker.check().is_none());
758 tracker.report_model_call(60, 50);
759 assert!(tracker.check().is_some());
760 }
761
762 #[test]
763 fn test_budget_tracker_duration() {
764 let config = BudgetConfig::new().with_max_duration(Duration::from_millis(100));
765 let tracker = BudgetTracker::new(config);
766
767 assert!(tracker.check().is_none());
768 std::thread::sleep(Duration::from_millis(150));
769 assert!(tracker.check().is_some());
770 }
771
772 #[test]
773 fn test_budget_exceeded_reason_display() {
774 let reason = BudgetExceededReason::Tokens {
775 used: 150,
776 limit: 100,
777 };
778 assert!(reason.to_string().contains("Token budget exceeded"));
779
780 let reason = BudgetExceededReason::Steps { used: 10, limit: 5 };
781 assert!(reason.to_string().contains("Step budget exceeded"));
782 }
783}
784
785