use crate::budget::{BudgetDecision, BudgetGuard};
use crate::llm::TokenUsage;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BudgetSnapshot {
pub consumed_tokens: u64,
pub limit_tokens: Option<u64>,
}
impl BudgetSnapshot {
pub fn remaining_tokens(&self) -> Option<u64> {
self.limit_tokens
.map(|limit| limit.saturating_sub(self.consumed_tokens))
}
pub fn is_exhausted(&self) -> bool {
matches!(self.limit_tokens, Some(limit) if self.consumed_tokens >= limit)
}
}
pub struct WorkflowBudget {
inner: Option<Arc<dyn BudgetGuard>>,
consumed_tokens: AtomicU64,
limit_tokens: Option<u64>,
}
impl WorkflowBudget {
pub fn new(limit_tokens: Option<u64>) -> Self {
Self {
inner: None,
consumed_tokens: AtomicU64::new(0),
limit_tokens,
}
}
pub fn with_inner(mut self, inner: Arc<dyn BudgetGuard>) -> Self {
self.inner = Some(inner);
self
}
pub fn consumed_tokens(&self) -> u64 {
self.consumed_tokens.load(Ordering::SeqCst)
}
pub fn snapshot(&self) -> BudgetSnapshot {
BudgetSnapshot {
consumed_tokens: self.consumed_tokens(),
limit_tokens: self.limit_tokens,
}
}
pub fn is_exhausted(&self) -> bool {
self.snapshot().is_exhausted()
}
}
#[async_trait]
impl BudgetGuard for WorkflowBudget {
async fn check_before_llm(
&self,
session_id: &str,
estimated_prompt_tokens: usize,
) -> BudgetDecision {
if self.is_exhausted() {
return BudgetDecision::Deny {
resource: "workflow_tokens".to_string(),
reason: format!(
"workflow token budget exhausted ({} / {} tokens)",
self.consumed_tokens(),
self.limit_tokens.unwrap_or(0)
),
};
}
match &self.inner {
Some(inner) => {
inner
.check_before_llm(session_id, estimated_prompt_tokens)
.await
}
None => BudgetDecision::Allow,
}
}
async fn record_after_llm(&self, session_id: &str, usage: &TokenUsage) {
self.consumed_tokens
.fetch_add(usage.total_tokens as u64, Ordering::SeqCst);
if let Some(inner) = &self.inner {
inner.record_after_llm(session_id, usage).await;
}
}
async fn check_before_tool(&self, session_id: &str, tool_name: &str) -> BudgetDecision {
if self.is_exhausted() {
return BudgetDecision::Deny {
resource: "workflow_tokens".to_string(),
reason: "workflow token budget exhausted".to_string(),
};
}
match &self.inner {
Some(inner) => inner.check_before_tool(session_id, tool_name).await,
None => BudgetDecision::Allow,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
fn usage(total: usize) -> TokenUsage {
TokenUsage {
total_tokens: total,
..Default::default()
}
}
#[tokio::test]
async fn accumulates_and_caps() {
let budget = WorkflowBudget::new(Some(100));
assert!(matches!(
budget.check_before_llm("s", 0).await,
BudgetDecision::Allow
));
budget.record_after_llm("a", &usage(60)).await;
budget.record_after_llm("b", &usage(50)).await; assert_eq!(budget.consumed_tokens(), 110);
assert!(budget.is_exhausted());
match budget.check_before_llm("s", 0).await {
BudgetDecision::Deny { resource, .. } => assert_eq!(resource, "workflow_tokens"),
other => panic!("expected Deny, got {other:?}"),
}
}
#[tokio::test]
async fn uncapped_never_denies() {
let budget = WorkflowBudget::new(None);
budget.record_after_llm("a", &usage(1_000_000)).await;
assert!(!budget.is_exhausted());
assert!(matches!(
budget.check_before_llm("s", 0).await,
BudgetDecision::Allow
));
assert_eq!(budget.snapshot().remaining_tokens(), None);
}
#[tokio::test]
async fn snapshot_reports_remaining() {
let budget = WorkflowBudget::new(Some(100));
budget.record_after_llm("a", &usage(30)).await;
let snap = budget.snapshot();
assert_eq!(snap.consumed_tokens, 30);
assert_eq!(snap.remaining_tokens(), Some(70));
assert!(!snap.is_exhausted());
}
#[tokio::test]
async fn delegates_to_inner_guard() {
#[derive(Default)]
struct Counting {
checks: AtomicUsize,
records: AtomicUsize,
}
#[async_trait]
impl BudgetGuard for Counting {
async fn check_before_llm(&self, _: &str, _: usize) -> BudgetDecision {
self.checks.fetch_add(1, Ordering::SeqCst);
BudgetDecision::Allow
}
async fn record_after_llm(&self, _: &str, _: &TokenUsage) {
self.records.fetch_add(1, Ordering::SeqCst);
}
}
let inner = Arc::new(Counting::default());
let budget = WorkflowBudget::new(Some(1000)).with_inner(inner.clone());
budget.check_before_llm("s", 0).await;
budget.record_after_llm("s", &usage(10)).await;
assert_eq!(inner.checks.load(Ordering::SeqCst), 1);
assert_eq!(inner.records.load(Ordering::SeqCst), 1);
assert_eq!(budget.consumed_tokens(), 10);
}
}