use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use dashmap::DashMap;
use tower::{Layer, Service};
use super::types::{LlmRequest, LlmResponse};
use crate::client::BoxFuture;
use crate::cost;
use crate::error::{LiterLlmError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Enforcement {
Hard,
Soft,
}
#[derive(Debug, Clone)]
pub struct BudgetConfig {
pub global_limit: Option<f64>,
pub model_limits: HashMap<String, f64>,
pub enforcement: Enforcement,
}
impl Default for BudgetConfig {
fn default() -> Self {
Self {
global_limit: None,
model_limits: HashMap::new(),
enforcement: Enforcement::Hard,
}
}
}
#[derive(Debug)]
pub struct BudgetState {
global_spend: AtomicU64,
model_spend: DashMap<String, AtomicU64>,
}
impl BudgetState {
#[must_use]
pub fn new() -> Self {
Self {
global_spend: AtomicU64::new(0),
model_spend: DashMap::new(),
}
}
#[must_use]
pub fn global_spend(&self) -> f64 {
microcents_to_usd(self.global_spend.load(Ordering::Relaxed))
}
#[must_use]
pub fn model_spend(&self, model: &str) -> f64 {
self.model_spend
.get(model)
.map(|v| microcents_to_usd(v.load(Ordering::Relaxed)))
.unwrap_or(0.0)
}
pub fn reset(&self) {
self.global_spend.store(0, Ordering::Relaxed);
self.model_spend.clear();
}
fn record(&self, model: &str, usd: f64) {
let mc = usd_to_microcents(usd);
self.global_spend.fetch_add(mc, Ordering::Relaxed);
self.model_spend
.entry(model.to_owned())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(mc, Ordering::Relaxed);
}
}
impl Default for BudgetState {
fn default() -> Self {
Self::new()
}
}
fn usd_to_microcents(usd: f64) -> u64 {
if usd <= 0.0 {
return 0;
}
(usd * 1_000_000.0).round() as u64
}
fn microcents_to_usd(mc: u64) -> f64 {
mc as f64 / 1_000_000.0
}
pub struct BudgetLayer {
config: BudgetConfig,
state: Arc<BudgetState>,
}
impl BudgetLayer {
#[must_use]
pub fn new(config: BudgetConfig, state: Arc<BudgetState>) -> Self {
Self { config, state }
}
}
impl<S> Layer<S> for BudgetLayer {
type Service = BudgetService<S>;
fn layer(&self, inner: S) -> Self::Service {
BudgetService {
inner,
config: self.config.clone(),
state: Arc::clone(&self.state),
}
}
}
pub struct BudgetService<S> {
inner: S,
config: BudgetConfig,
state: Arc<BudgetState>,
}
impl<S: Clone> Clone for BudgetService<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
config: self.config.clone(),
state: Arc::clone(&self.state),
}
}
}
impl<S> Service<LlmRequest> for BudgetService<S>
where
S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
S::Future: Send + 'static,
{
type Response = LlmResponse;
type Error = LiterLlmError;
type Future = BoxFuture<'static, Result<LlmResponse>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: LlmRequest) -> Self::Future {
let model = req.model().unwrap_or("unknown").to_owned();
let config = self.config.clone();
let state = Arc::clone(&self.state);
if config.enforcement == Enforcement::Hard
&& let Some(err) = check_budget(&config, &state, &model)
{
return Box::pin(async move { Err(err) });
}
let fut = self.inner.call(req);
Box::pin(async move {
let resp = fut.await?;
if let Some(usage) = resp.usage()
&& let Some(usd) = cost::completion_cost(&model, usage.prompt_tokens, usage.completion_tokens)
{
state.record(&model, usd);
if config.enforcement == Enforcement::Soft {
emit_soft_warnings(&config, &state, &model);
}
}
Ok(resp)
})
}
}
fn check_budget(config: &BudgetConfig, state: &BudgetState, model: &str) -> Option<LiterLlmError> {
if let Some(limit) = config.global_limit
&& state.global_spend() >= limit
{
return Some(LiterLlmError::BudgetExceeded {
message: format!(
"global budget exceeded: spent ${:.6}, limit ${:.6}",
state.global_spend(),
limit,
),
model: None,
});
}
if let Some(&limit) = config.model_limits.get(model)
&& state.model_spend(model) >= limit
{
return Some(LiterLlmError::BudgetExceeded {
message: format!(
"model {model} budget exceeded: spent ${:.6}, limit ${:.6}",
state.model_spend(model),
limit,
),
model: Some(model.to_owned()),
});
}
None
}
fn emit_soft_warnings(config: &BudgetConfig, state: &BudgetState, model: &str) {
if let Some(limit) = config.global_limit
&& state.global_spend() >= limit
{
tracing::warn!(
spend = state.global_spend(),
limit,
"global budget exceeded (soft enforcement)"
);
}
if let Some(&limit) = config.model_limits.get(model)
&& state.model_spend(model) >= limit
{
tracing::warn!(
model,
spend = state.model_spend(model),
limit,
"model budget exceeded (soft enforcement)"
);
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use tower::{Layer as _, Service as _};
use super::*;
use crate::tower::service::LlmService;
use crate::tower::tests_common::{MockClient, chat_req};
use crate::tower::types::LlmRequest;
fn build_service(config: BudgetConfig, state: Arc<BudgetState>) -> BudgetService<LlmService<MockClient>> {
let layer = BudgetLayer::new(config, state);
let inner = LlmService::new(MockClient::ok());
layer.layer(inner)
}
#[tokio::test]
async fn hard_enforcement_rejects_when_global_limit_exceeded() {
let state = Arc::new(BudgetState::new());
state.global_spend.store(usd_to_microcents(10.0), Ordering::Relaxed);
let config = BudgetConfig {
global_limit: Some(5.0),
enforcement: Enforcement::Hard,
..Default::default()
};
let mut svc = build_service(config, state);
let err = svc
.call(LlmRequest::Chat(chat_req("gpt-4")))
.await
.expect_err("should reject over-budget request");
assert!(matches!(err, LiterLlmError::BudgetExceeded { .. }));
}
#[tokio::test]
async fn hard_enforcement_rejects_when_model_limit_exceeded() {
let state = Arc::new(BudgetState::new());
state
.model_spend
.entry("gpt-4".to_owned())
.or_insert_with(|| AtomicU64::new(0))
.store(usd_to_microcents(2.0), Ordering::Relaxed);
let mut limits = HashMap::new();
limits.insert("gpt-4".into(), 1.0);
let config = BudgetConfig {
global_limit: None,
model_limits: limits,
enforcement: Enforcement::Hard,
};
let mut svc = build_service(config, state);
let err = svc
.call(LlmRequest::Chat(chat_req("gpt-4")))
.await
.expect_err("should reject over-budget model request");
match &err {
LiterLlmError::BudgetExceeded { model, .. } => {
assert_eq!(model.as_deref(), Some("gpt-4"));
}
other => panic!("expected BudgetExceeded, got {other:?}"),
}
}
#[tokio::test]
async fn hard_enforcement_allows_requests_under_limit() {
let state = Arc::new(BudgetState::new());
let config = BudgetConfig {
global_limit: Some(100.0),
enforcement: Enforcement::Hard,
..Default::default()
};
let mut svc = build_service(config, state);
let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
assert!(resp.is_ok(), "request under budget should succeed");
}
#[tokio::test]
async fn soft_enforcement_allows_requests_over_global_limit() {
let state = Arc::new(BudgetState::new());
state.global_spend.store(usd_to_microcents(100.0), Ordering::Relaxed);
let config = BudgetConfig {
global_limit: Some(5.0),
enforcement: Enforcement::Soft,
..Default::default()
};
let mut svc = build_service(config, state);
let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
assert!(resp.is_ok(), "soft mode should never reject");
}
#[tokio::test]
async fn soft_enforcement_allows_requests_over_model_limit() {
let state = Arc::new(BudgetState::new());
state
.model_spend
.entry("gpt-4".to_owned())
.or_insert_with(|| AtomicU64::new(0))
.store(usd_to_microcents(10.0), Ordering::Relaxed);
let mut limits = HashMap::new();
limits.insert("gpt-4".into(), 1.0);
let config = BudgetConfig {
global_limit: None,
model_limits: limits,
enforcement: Enforcement::Soft,
};
let mut svc = build_service(config, state);
let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
assert!(resp.is_ok(), "soft mode should never reject");
}
#[tokio::test]
async fn accumulates_cost_after_response() {
let state = Arc::new(BudgetState::new());
let config = BudgetConfig {
global_limit: Some(100.0),
enforcement: Enforcement::Hard,
..Default::default()
};
let mut svc = build_service(config, Arc::clone(&state));
svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
assert!(state.global_spend() > 0.0, "global spend should be recorded");
assert!(state.model_spend("gpt-4") > 0.0, "model spend should be recorded");
}
#[tokio::test]
async fn per_model_limits_are_independent() {
let state = Arc::new(BudgetState::new());
state
.model_spend
.entry("gpt-4".to_owned())
.or_insert_with(|| AtomicU64::new(0))
.store(usd_to_microcents(5.0), Ordering::Relaxed);
let mut limits = HashMap::new();
limits.insert("gpt-4".into(), 1.0);
let config = BudgetConfig {
global_limit: None,
model_limits: limits,
enforcement: Enforcement::Hard,
};
let mut svc = build_service(config, state);
let err = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
assert!(err.is_err(), "gpt-4 should be rejected");
let ok = svc.call(LlmRequest::Chat(chat_req("gpt-3.5-turbo"))).await;
assert!(ok.is_ok(), "gpt-3.5-turbo should not be limited");
}
#[tokio::test]
async fn reset_clears_all_counters() {
let state = Arc::new(BudgetState::new());
state.global_spend.store(usd_to_microcents(50.0), Ordering::Relaxed);
state
.model_spend
.entry("gpt-4".to_owned())
.or_insert_with(|| AtomicU64::new(0))
.store(usd_to_microcents(25.0), Ordering::Relaxed);
assert!(state.global_spend() > 0.0);
assert!(state.model_spend("gpt-4") > 0.0);
state.reset();
assert_eq!(state.global_spend(), 0.0, "global spend should be zero after reset");
assert_eq!(
state.model_spend("gpt-4"),
0.0,
"model spend should be zero after reset"
);
}
#[tokio::test]
async fn reset_allows_previously_blocked_requests() {
let state = Arc::new(BudgetState::new());
state.global_spend.store(usd_to_microcents(10.0), Ordering::Relaxed);
let config = BudgetConfig {
global_limit: Some(5.0),
enforcement: Enforcement::Hard,
..Default::default()
};
let mut svc = build_service(config, Arc::clone(&state));
let err = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
assert!(err.is_err());
state.reset();
let ok = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
assert!(ok.is_ok(), "should succeed after reset");
}
#[tokio::test]
async fn unlimited_config_allows_all_requests() {
let state = Arc::new(BudgetState::new());
let config = BudgetConfig::default();
let mut svc = build_service(config, state);
for _ in 0..20 {
assert!(svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.is_ok());
}
}
#[tokio::test]
async fn propagates_inner_service_errors() {
let state = Arc::new(BudgetState::new());
let config = BudgetConfig {
global_limit: Some(100.0),
enforcement: Enforcement::Hard,
..Default::default()
};
let layer = BudgetLayer::new(config, state);
let inner = LlmService::new(MockClient::failing_timeout());
let mut svc = layer.layer(inner);
let err = svc
.call(LlmRequest::Chat(chat_req("gpt-4")))
.await
.expect_err("should propagate inner error");
assert!(matches!(err, LiterLlmError::Timeout));
}
}