use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use llmkit_core::{
pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
LlmError, LlmProvider, LlmResult,
};
use crate::layer::LlmLayer;
#[derive(Clone, Default)]
pub struct SessionCost(Arc<AtomicU64>);
impl SessionCost {
fn add(&self, usd: f64) {
let micros = (usd * 1_000_000.0).round() as u64;
self.0.fetch_add(micros, Ordering::Relaxed);
}
pub fn total_usd(&self) -> f64 {
self.0.load(Ordering::Relaxed) as f64 / 1_000_000.0
}
}
#[derive(Clone)]
pub struct CostTrackingLayer {
budget_usd: Option<f64>,
cost: SessionCost,
}
impl CostTrackingLayer {
pub fn new() -> Self {
Self { budget_usd: None, cost: SessionCost::default() }
}
pub fn with_budget(budget_usd: f64) -> Self {
Self { budget_usd: Some(budget_usd), cost: SessionCost::default() }
}
pub fn session_cost(&self) -> SessionCost {
self.cost.clone()
}
}
impl Default for CostTrackingLayer {
fn default() -> Self {
Self::new()
}
}
impl LlmLayer for CostTrackingLayer {
type Provider = CostTracking;
fn layer(self, inner: Arc<dyn LlmProvider>) -> CostTracking {
CostTracking { inner, budget_usd: self.budget_usd, cost: self.cost }
}
}
pub struct CostTracking {
inner: Arc<dyn LlmProvider>,
budget_usd: Option<f64>,
cost: SessionCost,
}
impl CostTracking {
pub fn session_cost(&self) -> SessionCost {
self.cost.clone()
}
fn check_budget(&self) -> LlmResult<()> {
if let Some(budget) = self.budget_usd {
let spent = self.cost.total_usd();
if spent >= budget {
return Err(LlmError::BudgetExceeded(format!(
"session cost ${spent:.4} reached budget ${budget:.2}"
)));
}
}
Ok(())
}
}
#[async_trait]
impl LlmProvider for CostTracking {
async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
self.check_budget()?;
let mut resp = self.inner.chat(req).await?;
let cost = resp
.cost
.or_else(|| pricing::pricing_for(&resp.model).map(|p| p.cost_for(resp.usage)));
if let Some(c) = cost {
self.cost.add(c.total_usd());
resp.cost = Some(c);
tracing::info!(
provider = resp.provider,
model = %resp.model,
request_usd = c.total_usd(),
session_usd = self.cost.total_usd(),
"cost tracked"
);
}
Ok(resp)
}
async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
self.check_budget()?;
self.inner.chat_stream(req).await
}
async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
self.check_budget()?;
let resp = self.inner.embed(req).await?;
if let Some(p) = pricing::pricing_for(&resp.model) {
self.cost.add(p.cost_for(resp.usage).total_usd());
}
Ok(resp)
}
fn name(&self) -> &'static str {
self.inner.name()
}
fn model(&self) -> &str {
self.inner.model()
}
fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
self.inner.estimate_cost(req)
}
}