use std::sync::{Arc, Mutex};
use crate::metrics::SessionStatus;
#[derive(Debug, Clone, Default)]
pub struct Budget {
pub max_llm_calls: Option<u64>,
pub max_elapsed_ms: Option<u64>,
pub max_tokens: Option<u64>,
}
impl Budget {
pub fn from_ctx(ctx: &serde_json::Value) -> Option<Self> {
let obj = ctx.as_object()?.get("budget")?.as_object()?;
let max_llm_calls = obj.get("max_llm_calls").and_then(|v| v.as_u64());
let max_elapsed_ms = obj.get("max_elapsed_ms").and_then(|v| v.as_u64());
let max_tokens = obj.get("max_tokens").and_then(|v| v.as_u64());
if max_llm_calls.is_none() && max_elapsed_ms.is_none() && max_tokens.is_none() {
return None;
}
Some(Self {
max_llm_calls,
max_elapsed_ms,
max_tokens,
})
}
pub fn check(&self, llm_calls: u64, elapsed_ms: u64, total_tokens: u64) -> Result<(), String> {
if let Some(max) = self.max_llm_calls {
if llm_calls >= max {
return Err(format!(
"budget_exceeded: max_llm_calls ({max}) reached ({llm_calls} used)"
));
}
}
if let Some(max_ms) = self.max_elapsed_ms {
if elapsed_ms >= max_ms {
return Err(format!(
"budget_exceeded: max_elapsed_ms ({max_ms}ms) reached ({elapsed_ms}ms elapsed)"
));
}
}
if let Some(max) = self.max_tokens {
if total_tokens >= max {
return Err(format!(
"budget_exceeded: max_tokens ({max}) reached ({total_tokens} used)"
));
}
}
Ok(())
}
pub fn remaining_json(
&self,
llm_calls: u64,
elapsed_ms: u64,
total_tokens: u64,
) -> serde_json::Value {
serde_json::json!({
"llm_calls": self.max_llm_calls.map(|max| max.saturating_sub(llm_calls)),
"elapsed_ms": self.max_elapsed_ms.map(|max| max.saturating_sub(elapsed_ms)),
"tokens": self.max_tokens.map(|max| max.saturating_sub(total_tokens)),
})
}
pub fn to_json(&self) -> serde_json::Value {
let mut map = serde_json::Map::new();
if let Some(max) = self.max_llm_calls {
map.insert("max_llm_calls".into(), max.into());
}
if let Some(max) = self.max_elapsed_ms {
map.insert("max_elapsed_ms".into(), max.into());
}
if let Some(max) = self.max_tokens {
map.insert("max_tokens".into(), max.into());
}
serde_json::Value::Object(map)
}
}
#[derive(Clone)]
pub struct BudgetHandle {
auto: Arc<Mutex<SessionStatus>>,
}
impl BudgetHandle {
pub(crate) fn new(auto: Arc<Mutex<SessionStatus>>) -> Self {
Self { auto }
}
pub fn check(&self) -> Result<(), String> {
let m = self
.auto
.lock()
.map_err(|_| "budget check: mutex poisoned".to_string())?;
m.check_budget()
}
pub fn remaining(&self) -> serde_json::Value {
let m = match self.auto.lock() {
Ok(m) => m,
Err(_) => return serde_json::Value::Null,
};
m.budget_remaining()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
#[test]
fn budget_from_ctx_none_when_missing() {
let ctx = serde_json::json!({"task": "test"});
assert!(Budget::from_ctx(&ctx).is_none());
}
#[test]
fn budget_from_ctx_none_when_empty() {
let ctx = serde_json::json!({"budget": {}});
assert!(Budget::from_ctx(&ctx).is_none());
}
#[test]
fn budget_from_ctx_extracts_llm_calls() {
let ctx = serde_json::json!({"budget": {"max_llm_calls": 10}});
let budget = Budget::from_ctx(&ctx).expect("should parse");
assert_eq!(budget.max_llm_calls, Some(10));
assert_eq!(budget.max_elapsed_ms, None);
}
#[test]
fn budget_from_ctx_extracts_elapsed_ms() {
let ctx = serde_json::json!({"budget": {"max_elapsed_ms": 5000}});
let budget = Budget::from_ctx(&ctx).expect("should parse");
assert_eq!(budget.max_llm_calls, None);
assert_eq!(budget.max_elapsed_ms, Some(5000));
}
#[test]
fn budget_from_ctx_extracts_both() {
let ctx = serde_json::json!({"budget": {"max_llm_calls": 5, "max_elapsed_ms": 30000}});
let budget = Budget::from_ctx(&ctx).expect("should parse");
assert_eq!(budget.max_llm_calls, Some(5));
assert_eq!(budget.max_elapsed_ms, Some(30000));
}
#[test]
fn budget_check_passes_when_within_limits() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: Some(5),
max_elapsed_ms: None,
max_tokens: None,
});
let handle = metrics.budget_handle();
assert!(handle.check().is_ok());
}
#[test]
fn budget_check_fails_when_llm_calls_exceeded() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: Some(2),
max_elapsed_ms: None,
max_tokens: None,
});
let observer = metrics.create_observer();
let handle = metrics.budget_handle();
let q = vec![LlmQuery {
id: QueryId::single(),
prompt: "p".into(),
system: None,
max_tokens: 10,
grounded: false,
underspecified: false,
}];
observer.on_paused(&q);
observer.on_response_fed(&QueryId::single(), "r", None);
observer.on_resumed();
observer.on_paused(&q);
let result = handle.check();
assert!(result.is_err());
assert!(result.unwrap_err().contains("budget_exceeded"));
}
#[test]
fn budget_check_fails_when_tokens_exceeded() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: None,
max_elapsed_ms: None,
max_tokens: Some(10),
});
let observer = metrics.create_observer();
let handle = metrics.budget_handle();
let q = vec![LlmQuery {
id: QueryId::single(),
prompt: "abcdefghijklmnopqrstuvwxyz0123456789abcd".into(),
system: None,
max_tokens: 100,
grounded: false,
underspecified: false,
}];
observer.on_paused(&q);
observer.on_response_fed(&QueryId::single(), "r", None);
observer.on_resumed();
let result = handle.check();
assert!(result.is_err());
assert!(result.unwrap_err().contains("max_tokens"));
}
#[test]
fn budget_check_passes_when_tokens_within_limit() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: None,
max_elapsed_ms: None,
max_tokens: Some(1000),
});
let observer = metrics.create_observer();
let handle = metrics.budget_handle();
let q = vec![LlmQuery {
id: QueryId::single(),
prompt: "short".into(),
system: None,
max_tokens: 100,
grounded: false,
underspecified: false,
}];
observer.on_paused(&q);
observer.on_response_fed(&QueryId::single(), "reply", None);
observer.on_resumed();
assert!(handle.check().is_ok());
}
#[test]
fn budget_remaining_tracks_tokens() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: None,
max_elapsed_ms: None,
max_tokens: Some(100),
});
let observer = metrics.create_observer();
let handle = metrics.budget_handle();
let q = vec![LlmQuery {
id: QueryId::single(),
prompt: "test".into(),
system: None,
max_tokens: 10,
grounded: false,
underspecified: false,
}];
observer.on_paused(&q);
observer.on_response_fed(&QueryId::single(), "r", None);
observer.on_resumed();
let remaining = handle.remaining();
assert_eq!(remaining["tokens"], 98);
}
#[test]
fn budget_from_ctx_extracts_max_tokens() {
let ctx = serde_json::json!({"budget": {"max_tokens": 5000}});
let budget = Budget::from_ctx(&ctx).expect("should parse");
assert_eq!(budget.max_llm_calls, None);
assert_eq!(budget.max_elapsed_ms, None);
assert_eq!(budget.max_tokens, Some(5000));
}
#[test]
fn budget_remaining_null_when_no_budget() {
let metrics = ExecutionMetrics::new();
let handle = metrics.budget_handle();
assert!(handle.remaining().is_null());
}
#[test]
fn budget_remaining_tracks_llm_calls() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: Some(5),
max_elapsed_ms: None,
max_tokens: None,
});
let observer = metrics.create_observer();
let handle = metrics.budget_handle();
let q = vec![LlmQuery {
id: QueryId::single(),
prompt: "p".into(),
system: None,
max_tokens: 10,
grounded: false,
underspecified: false,
}];
observer.on_paused(&q);
let remaining = handle.remaining();
assert_eq!(remaining["llm_calls"], 4); }
#[test]
fn budget_in_stats_json() {
let metrics = ExecutionMetrics::new();
metrics.set_budget(Budget {
max_llm_calls: Some(10),
max_elapsed_ms: Some(60000),
max_tokens: None,
});
let observer = metrics.create_observer();
observer.on_completed(&serde_json::json!(null));
let json = metrics.to_json();
let budget = &json["auto"]["budget"];
assert_eq!(budget["max_llm_calls"], 10);
assert_eq!(budget["max_elapsed_ms"], 60000);
}
#[test]
fn no_budget_in_stats_json_when_not_set() {
let metrics = ExecutionMetrics::new();
let observer = metrics.create_observer();
observer.on_completed(&serde_json::json!(null));
let json = metrics.to_json();
assert!(json["auto"].get("budget").is_none());
}
}