use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::Result;
use async_trait::async_trait;
use futures::StreamExt;
use futures::stream::BoxStream;
use brainwires_core::message::{ChatResponse, Message, MessageContent, StreamChunk, Usage};
use brainwires_core::provider::{ChatOptions, Provider};
use brainwires_core::tool::Tool;
use crate::error::ResilienceError;
#[derive(Debug, Clone, Default)]
pub struct BudgetConfig {
pub max_tokens: Option<u64>,
pub max_usd_cents: Option<u64>,
pub max_rounds: Option<u64>,
}
#[derive(Clone, Debug)]
pub struct BudgetGuard {
cfg: BudgetConfig,
state: Arc<BudgetState>,
}
#[derive(Debug, Default)]
struct BudgetState {
tokens: AtomicU64,
usd_cents: AtomicU64,
rounds: AtomicU64,
}
impl BudgetGuard {
pub fn new(cfg: BudgetConfig) -> Self {
Self {
cfg,
state: Arc::new(BudgetState::default()),
}
}
pub fn config(&self) -> &BudgetConfig {
&self.cfg
}
pub fn tokens_consumed(&self) -> u64 {
self.state.tokens.load(Ordering::Relaxed)
}
pub fn usd_cents_consumed(&self) -> u64 {
self.state.usd_cents.load(Ordering::Relaxed)
}
pub fn rounds_consumed(&self) -> u64 {
self.state.rounds.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.state.tokens.store(0, Ordering::Relaxed);
self.state.usd_cents.store(0, Ordering::Relaxed);
self.state.rounds.store(0, Ordering::Relaxed);
}
pub fn check(&self) -> Result<(), ResilienceError> {
if let Some(limit) = self.cfg.max_tokens {
let consumed = self.tokens_consumed();
if consumed >= limit {
return Err(ResilienceError::BudgetExceeded {
kind: "tokens",
consumed,
limit,
});
}
}
if let Some(limit) = self.cfg.max_usd_cents {
let consumed = self.usd_cents_consumed();
if consumed >= limit {
return Err(ResilienceError::BudgetExceeded {
kind: "usd_cents",
consumed,
limit,
});
}
}
if let Some(limit) = self.cfg.max_rounds {
let consumed = self.rounds_consumed();
if consumed >= limit {
return Err(ResilienceError::BudgetExceeded {
kind: "rounds",
consumed,
limit,
});
}
}
Ok(())
}
pub fn check_and_tick(&self) -> Result<(), ResilienceError> {
self.check()?;
self.state.rounds.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn record_usage(&self, usage: &Usage) {
self.state
.tokens
.fetch_add(usage.total_tokens as u64, Ordering::Relaxed);
}
pub fn record_cost_cents(&self, cents: u64) {
self.state.usd_cents.fetch_add(cents, Ordering::Relaxed);
}
}
pub struct BudgetProvider<P: Provider + ?Sized> {
inner: Arc<P>,
guard: BudgetGuard,
}
impl<P: Provider + ?Sized> BudgetProvider<P> {
pub fn new(inner: Arc<P>, guard: BudgetGuard) -> Self {
Self { inner, guard }
}
pub fn guard(&self) -> &BudgetGuard {
&self.guard
}
pub fn inner(&self) -> &Arc<P> {
&self.inner
}
}
fn approx_input_tokens(messages: &[Message]) -> u64 {
let mut chars: usize = 0;
for m in messages {
match &m.content {
MessageContent::Text(t) => chars += t.len(),
MessageContent::Blocks(blocks) => {
for b in blocks {
chars += approx_block_len(b);
}
}
}
}
(chars as u64) / 4
}
fn approx_block_len(b: &brainwires_core::ContentBlock) -> usize {
use brainwires_core::ContentBlock::*;
match b {
Text { text } => text.len(),
ToolUse { input, .. } => input.to_string().len(),
ToolResult { content, .. } => content.len(),
Image { .. } => 512, }
}
#[async_trait]
impl<P: Provider + ?Sized + 'static> Provider for BudgetProvider<P> {
fn name(&self) -> &str {
self.inner.name()
}
fn max_output_tokens(&self) -> Option<u32> {
self.inner.max_output_tokens()
}
async fn chat(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
options: &ChatOptions,
) -> Result<ChatResponse> {
self.guard.check()?;
if let Some(limit) = self.guard.cfg.max_tokens {
let projected = self.guard.tokens_consumed() + approx_input_tokens(messages);
if projected > limit {
return Err(ResilienceError::BudgetExceeded {
kind: "tokens",
consumed: projected,
limit,
}
.into());
}
}
self.guard.state.rounds.fetch_add(1, Ordering::Relaxed);
let resp = self.inner.chat(messages, tools, options).await?;
self.guard.record_usage(&resp.usage);
Ok(resp)
}
fn stream_chat<'a>(
&'a self,
messages: &'a [Message],
tools: Option<&'a [Tool]>,
options: &'a ChatOptions,
) -> BoxStream<'a, Result<StreamChunk>> {
let guard = self.guard.clone();
if let Err(e) = guard.check() {
let err_stream = futures::stream::once(async move { Err(anyhow::Error::from(e)) });
return Box::pin(err_stream);
}
if let Some(limit) = guard.cfg.max_tokens {
let projected = guard.tokens_consumed() + approx_input_tokens(messages);
if projected > limit {
let err = ResilienceError::BudgetExceeded {
kind: "tokens",
consumed: projected,
limit,
};
let err_stream =
futures::stream::once(async move { Err(anyhow::Error::from(err)) });
return Box::pin(err_stream);
}
}
guard.state.rounds.fetch_add(1, Ordering::Relaxed);
let upstream = self.inner.stream_chat(messages, tools, options);
let mapped = upstream.map(move |chunk| {
if let Ok(StreamChunk::Usage(ref u)) = chunk {
guard.record_usage(u);
}
chunk
});
Box::pin(mapped)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn guard_tracks_tokens() {
let g = BudgetGuard::new(BudgetConfig {
max_tokens: Some(100),
..Default::default()
});
g.record_usage(&Usage::new(40, 40));
assert_eq!(g.tokens_consumed(), 80);
g.check().expect("under budget");
g.record_usage(&Usage::new(30, 0));
assert_eq!(g.tokens_consumed(), 110);
let err = g.check().unwrap_err();
assert!(matches!(
err,
ResilienceError::BudgetExceeded { kind: "tokens", .. }
));
}
#[test]
fn guard_tracks_rounds() {
let g = BudgetGuard::new(BudgetConfig {
max_rounds: Some(2),
..Default::default()
});
g.check_and_tick().unwrap();
g.check_and_tick().unwrap();
let err = g.check_and_tick().unwrap_err();
assert!(matches!(
err,
ResilienceError::BudgetExceeded { kind: "rounds", .. }
));
}
#[test]
fn guard_reset_zeroes_everything() {
let g = BudgetGuard::new(BudgetConfig {
max_tokens: Some(100),
max_rounds: Some(5),
..Default::default()
});
g.record_usage(&Usage::new(5, 5));
g.check_and_tick().unwrap();
g.reset();
assert_eq!(g.tokens_consumed(), 0);
assert_eq!(g.rounds_consumed(), 0);
}
#[test]
fn approx_tokens_text_and_blocks() {
let msgs = vec![
Message::user("abcd".repeat(40)), ];
let n = approx_input_tokens(&msgs);
assert_eq!(n, 40);
}
}