use crate::time::Instant;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
tokio::task_local! {
pub static BUDGET_TRACKER: Arc<BudgetTracker>;
}
#[derive(Clone, Default)]
pub enum BudgetExceededAction {
#[default]
Terminate,
Interrupt,
Custom(std::sync::Arc<dyn Fn(BudgetUsage) -> BudgetExceededAction + Send + Sync>),
}
impl std::fmt::Debug for BudgetExceededAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Terminate => write!(f, "Terminate"),
Self::Interrupt => write!(f, "Interrupt"),
Self::Custom(_) => write!(f, "Custom(<fn>)"),
}
}
}
#[derive(Clone, Default)]
pub struct BudgetConfig {
pub max_tokens: Option<u64>,
pub max_cost_usd: Option<f64>,
pub max_duration: Option<Duration>,
pub max_steps: Option<usize>,
pub on_exceeded: BudgetExceededAction,
}
impl std::fmt::Debug for BudgetConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BudgetConfig")
.field("max_tokens", &self.max_tokens)
.field("max_cost_usd", &self.max_cost_usd)
.field("max_duration", &self.max_duration)
.field("max_steps", &self.max_steps)
.field("on_exceeded", &self.on_exceeded)
.finish()
}
}
impl BudgetConfig {
#[must_use]
pub fn new() -> Self {
Self {
max_tokens: None,
max_cost_usd: None,
max_duration: None,
max_steps: None,
on_exceeded: BudgetExceededAction::default(),
}
}
#[must_use]
pub const fn with_max_tokens(mut self, tokens: u64) -> Self {
self.max_tokens = Some(tokens);
self
}
#[must_use]
pub const fn with_max_cost_usd(mut self, cost: f64) -> Self {
self.max_cost_usd = Some(cost);
self
}
#[must_use]
pub const fn with_max_duration(mut self, duration: Duration) -> Self {
self.max_duration = Some(duration);
self
}
#[must_use]
pub const fn with_max_steps(mut self, steps: usize) -> Self {
self.max_steps = Some(steps);
self
}
#[must_use]
pub const fn has_limits(&self) -> bool {
self.max_tokens.is_some()
|| self.max_cost_usd.is_some()
|| self.max_duration.is_some()
|| self.max_steps.is_some()
}
}
pub struct BudgetTracker {
tokens_used: AtomicU64,
cost_usd_micros: AtomicU64,
start_time: Instant,
steps_completed: AtomicUsize,
config: BudgetConfig,
metrics_collector: Option<std::sync::Arc<dyn crate::observability::MetricsCollector>>,
}
impl std::fmt::Debug for BudgetTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BudgetTracker")
.field("tokens_used", &self.tokens_used)
.field("cost_usd_micros", &self.cost_usd_micros)
.field("start_time", &self.start_time)
.field("steps_completed", &self.steps_completed)
.field("config", &self.config)
.field(
"metrics_collector",
&self.metrics_collector.as_ref().map(|_| "<Arc>"),
)
.finish()
}
}
impl BudgetTracker {
#[must_use]
pub fn new(config: BudgetConfig) -> Self {
Self {
tokens_used: AtomicU64::new(0),
cost_usd_micros: AtomicU64::new(0),
start_time: Instant::now(),
steps_completed: AtomicUsize::new(0),
config,
metrics_collector: None,
}
}
#[must_use]
pub fn with_metrics_collector(
mut self,
collector: Option<std::sync::Arc<dyn crate::observability::MetricsCollector>>,
) -> Self {
self.metrics_collector = collector;
self
}
pub fn report_tokens(&self, tokens: u64) {
self.tokens_used.fetch_add(tokens, Ordering::Relaxed);
if let Some(ref collector) = self.metrics_collector {
collector.inc_counter("juncture.llm.tokens.input", tokens);
}
}
pub fn report_output_tokens(&self, tokens: u64) {
if let Some(ref collector) = self.metrics_collector {
collector.inc_counter("juncture.llm.tokens.output", tokens);
}
}
#[allow(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
reason = "cost values are expected to be positive and within reasonable bounds"
)]
pub fn report_cost(&self, cost_usd: f64) {
let cost_micros = (cost_usd * 1_000_000.0) as u64;
self.cost_usd_micros
.fetch_add(cost_micros, Ordering::Relaxed);
if let Some(ref collector) = self.metrics_collector {
collector.inc_counter("juncture.llm.cost_usd", cost_micros);
}
}
pub fn report_step(&self) {
self.steps_completed.fetch_add(1, Ordering::Relaxed);
}
pub fn report_llm_call(&self) {
if let Some(ref collector) = self.metrics_collector {
collector.inc_counter("juncture.llm.calls", 1);
}
}
#[allow(
clippy::cast_precision_loss,
reason = "milliseconds as f64 is sufficient for histogram metrics"
)]
pub fn report_llm_duration(&self, duration_ms: u64) {
if let Some(ref collector) = self.metrics_collector {
collector.record_histogram("juncture.llm.duration_ms", duration_ms as f64);
}
}
pub fn report_tool_call(&self) {
if let Some(ref collector) = self.metrics_collector {
collector.inc_counter("juncture.tool.calls", 1);
}
}
pub fn report_tool_error(&self) {
if let Some(ref collector) = self.metrics_collector {
collector.inc_counter("juncture.tool.errors", 1);
}
}
#[allow(
clippy::cast_precision_loss,
reason = "milliseconds as f64 is sufficient for histogram metrics"
)]
pub fn report_tool_duration(&self, duration_ms: u64) {
if let Some(ref collector) = self.metrics_collector {
collector.record_histogram("juncture.tool.duration_ms", duration_ms as f64);
}
}
pub fn report_usage(&self, tokens: u64, cost_usd: f64) {
self.report_tokens(tokens);
self.report_cost(cost_usd);
}
pub fn report_model_call(&self, input_tokens: u64, output_tokens: u64) {
self.tokens_used
.fetch_add(input_tokens + output_tokens, Ordering::Relaxed);
}
#[must_use]
pub fn check(&self) -> Option<BudgetExceededReason> {
if let Some(max_tokens) = self.config.max_tokens
&& self.tokens_used.load(Ordering::Relaxed) > max_tokens
{
return Some(BudgetExceededReason::Tokens {
used: self.tokens_used.load(Ordering::Relaxed),
limit: max_tokens,
});
}
if let Some(max_cost) = self.config.max_cost_usd {
#[allow(
clippy::cast_precision_loss,
reason = "precision loss is acceptable for cost comparison"
)]
let cost_micros = self.cost_usd_micros.load(Ordering::Relaxed);
#[allow(
clippy::cast_precision_loss,
reason = "precision loss is acceptable for cost comparison"
)]
let cost_usd = cost_micros as f64 / 1_000_000.0;
if cost_usd > max_cost {
return Some(BudgetExceededReason::Cost {
used: cost_usd,
limit: max_cost,
});
}
}
if let Some(max_duration) = self.config.max_duration
&& self.start_time.elapsed() > max_duration
{
return Some(BudgetExceededReason::Duration {
used: self.start_time.elapsed(),
limit: max_duration,
});
}
if let Some(max_steps) = self.config.max_steps
&& self.steps_completed.load(Ordering::Relaxed) > max_steps
{
return Some(BudgetExceededReason::Steps {
used: self.steps_completed.load(Ordering::Relaxed),
limit: max_steps,
});
}
None
}
#[must_use]
pub fn current_usage(&self) -> BudgetUsage {
let cost_micros = self.cost_usd_micros.load(Ordering::Relaxed);
#[allow(
clippy::cast_precision_loss,
reason = "precision loss is acceptable for cost display"
)]
BudgetUsage {
tokens_used: self.tokens_used.load(Ordering::Relaxed),
cost_usd: cost_micros as f64 / 1_000_000.0,
duration: self.start_time.elapsed(),
steps_completed: self.steps_completed.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Debug)]
pub enum BudgetExceededReason {
Tokens { used: u64, limit: u64 },
Cost { used: f64, limit: f64 },
Duration { used: Duration, limit: Duration },
Steps { used: usize, limit: usize },
}
impl std::fmt::Display for BudgetExceededReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokens { used, limit } => {
write!(f, "Token budget exceeded: {used} > {limit}")
}
Self::Cost { used, limit } => {
write!(f, "Cost budget exceeded: ${used:.6} > ${limit:.6}")
}
Self::Duration { used, limit } => {
write!(f, "Duration budget exceeded: {used:?} > {limit:?}")
}
Self::Steps { used, limit } => {
write!(f, "Step budget exceeded: {used} > {limit}")
}
}
}
}
#[derive(Clone, Debug)]
pub struct BudgetUsage {
pub tokens_used: u64,
pub cost_usd: f64,
pub duration: Duration,
pub steps_completed: usize,
}
pub fn try_report_model_call(
input_tokens: u64,
output_tokens: u64,
) -> Result<(), BudgetReportError> {
BUDGET_TRACKER
.try_with(|tracker| {
tracker.report_model_call(input_tokens, output_tokens);
})
.map_err(|_err| BudgetReportError::NoTracker)
}
pub fn try_report_llm_call() -> Result<(), BudgetReportError> {
BUDGET_TRACKER
.try_with(|tracker| {
tracker.report_llm_call();
})
.map_err(|_err| BudgetReportError::NoTracker)
}
pub fn try_report_llm_duration(duration_ms: u64) -> Result<(), BudgetReportError> {
BUDGET_TRACKER
.try_with(|tracker| {
tracker.report_llm_duration(duration_ms);
})
.map_err(|_err| BudgetReportError::NoTracker)
}
pub fn try_report_tool_call() -> Result<(), BudgetReportError> {
BUDGET_TRACKER
.try_with(|tracker| {
tracker.report_tool_call();
})
.map_err(|_err| BudgetReportError::NoTracker)
}
pub fn try_report_tool_error() -> Result<(), BudgetReportError> {
BUDGET_TRACKER
.try_with(|tracker| {
tracker.report_tool_error();
})
.map_err(|_err| BudgetReportError::NoTracker)
}
pub fn try_report_tool_duration(duration_ms: u64) -> Result<(), BudgetReportError> {
BUDGET_TRACKER
.try_with(|tracker| {
tracker.report_tool_duration(duration_ms);
})
.map_err(|_err| BudgetReportError::NoTracker)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BudgetReportError {
NoTracker,
}
impl std::fmt::Display for BudgetReportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoTracker => write!(
f,
"Cannot report budget usage: no budget tracker in current context"
),
}
}
}
impl std::error::Error for BudgetReportError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_config_no_limits() {
let config = BudgetConfig::new();
assert!(!config.has_limits());
}
#[test]
fn test_budget_config_with_limits() {
let config = BudgetConfig::new().with_max_tokens(1000).with_max_steps(10);
assert!(config.has_limits());
}
#[test]
fn test_budget_tracker_tokens() {
let config = BudgetConfig::new().with_max_tokens(100);
let tracker = BudgetTracker::new(config);
tracker.report_tokens(50);
assert!(tracker.check().is_none());
tracker.report_tokens(60);
assert!(tracker.check().is_some());
let usage = tracker.current_usage();
assert_eq!(usage.tokens_used, 110);
}
#[test]
fn test_budget_tracker_cost() {
let config = BudgetConfig::new().with_max_cost_usd(0.01);
let tracker = BudgetTracker::new(config);
tracker.report_cost(0.005);
assert!(tracker.check().is_none());
tracker.report_cost(0.006);
assert!(tracker.check().is_some());
let usage = tracker.current_usage();
assert!((usage.cost_usd - 0.011).abs() < 0.0001);
}
#[test]
fn test_budget_tracker_steps() {
let config = BudgetConfig::new().with_max_steps(5);
let tracker = BudgetTracker::new(config);
for _ in 0..5 {
tracker.report_step();
}
assert!(tracker.check().is_none());
tracker.report_step();
assert!(tracker.check().is_some());
let usage = tracker.current_usage();
assert_eq!(usage.steps_completed, 6);
}
#[test]
fn test_budget_tracker_model_call() {
let tracker = BudgetTracker::new(BudgetConfig::new());
assert_eq!(tracker.current_usage().tokens_used, 0);
tracker.report_model_call(50, 100);
assert_eq!(tracker.current_usage().tokens_used, 150);
tracker.report_model_call(10, 20);
assert_eq!(tracker.current_usage().tokens_used, 180);
}
#[test]
fn test_budget_tracker_model_call_exceeds_limit() {
let config = BudgetConfig::new().with_max_tokens(100);
let tracker = BudgetTracker::new(config);
assert!(tracker.check().is_none());
tracker.report_model_call(60, 50);
assert!(tracker.check().is_some());
}
#[test]
fn test_budget_tracker_duration() {
let config = BudgetConfig::new().with_max_duration(Duration::from_millis(100));
let tracker = BudgetTracker::new(config);
assert!(tracker.check().is_none());
std::thread::sleep(Duration::from_millis(150));
assert!(tracker.check().is_some());
}
#[test]
fn test_budget_exceeded_reason_display() {
let reason = BudgetExceededReason::Tokens {
used: 150,
limit: 100,
};
assert!(reason.to_string().contains("Token budget exceeded"));
let reason = BudgetExceededReason::Steps { used: 10, limit: 5 };
assert!(reason.to_string().contains("Step budget exceeded"));
}
}