Skip to main content

llmkit_tower/
cost.rs

1//! Cost-tracking layer: per-request cost plus a cumulative session total, with
2//! an optional budget cap.
3
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use llmkit_core::{
9    pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
10    LlmError, LlmProvider, LlmResult,
11};
12
13use crate::layer::LlmLayer;
14
15/// Shared, cloneable handle to the running session cost (USD), in micro-dollars.
16#[derive(Clone, Default)]
17pub struct SessionCost(Arc<AtomicU64>);
18
19impl SessionCost {
20    fn add(&self, usd: f64) {
21        let micros = (usd * 1_000_000.0).round() as u64;
22        self.0.fetch_add(micros, Ordering::Relaxed);
23    }
24
25    /// Total cost accumulated so far, in USD.
26    pub fn total_usd(&self) -> f64 {
27        self.0.load(Ordering::Relaxed) as f64 / 1_000_000.0
28    }
29}
30
31/// Tracks per-request and cumulative cost; optionally enforces a budget.
32#[derive(Clone)]
33pub struct CostTrackingLayer {
34    budget_usd: Option<f64>,
35    cost: SessionCost,
36}
37
38impl CostTrackingLayer {
39    /// Track cost with no spending cap.
40    pub fn new() -> Self {
41        Self { budget_usd: None, cost: SessionCost::default() }
42    }
43
44    /// Track cost and refuse requests once `budget_usd` is reached.
45    pub fn with_budget(budget_usd: f64) -> Self {
46        Self { budget_usd: Some(budget_usd), cost: SessionCost::default() }
47    }
48
49    /// Handle to read the running session cost.
50    pub fn session_cost(&self) -> SessionCost {
51        self.cost.clone()
52    }
53}
54
55impl Default for CostTrackingLayer {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl LlmLayer for CostTrackingLayer {
62    type Provider = CostTracking;
63    fn layer(self, inner: Arc<dyn LlmProvider>) -> CostTracking {
64        CostTracking { inner, budget_usd: self.budget_usd, cost: self.cost }
65    }
66}
67
68/// Provider produced by [`CostTrackingLayer`].
69pub struct CostTracking {
70    inner: Arc<dyn LlmProvider>,
71    budget_usd: Option<f64>,
72    cost: SessionCost,
73}
74
75impl CostTracking {
76    /// Handle to read the running session cost.
77    pub fn session_cost(&self) -> SessionCost {
78        self.cost.clone()
79    }
80
81    fn check_budget(&self) -> LlmResult<()> {
82        if let Some(budget) = self.budget_usd {
83            let spent = self.cost.total_usd();
84            if spent >= budget {
85                return Err(LlmError::BudgetExceeded(format!(
86                    "session cost ${spent:.4} reached budget ${budget:.2}"
87                )));
88            }
89        }
90        Ok(())
91    }
92}
93
94#[async_trait]
95impl LlmProvider for CostTracking {
96    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
97        self.check_budget()?;
98        let mut resp = self.inner.chat(req).await?;
99
100        let cost = resp
101            .cost
102            .or_else(|| pricing::pricing_for(&resp.model).map(|p| p.cost_for(resp.usage)));
103        if let Some(c) = cost {
104            self.cost.add(c.total_usd());
105            resp.cost = Some(c);
106            tracing::info!(
107                provider = resp.provider,
108                model = %resp.model,
109                request_usd = c.total_usd(),
110                session_usd = self.cost.total_usd(),
111                "cost tracked"
112            );
113        }
114        Ok(resp)
115    }
116
117    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
118        self.check_budget()?;
119        // Streaming usage arrives in the terminal Done event; we cannot bill it
120        // here without consuming the stream, so the budget check is pre-flight.
121        self.inner.chat_stream(req).await
122    }
123
124    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
125        self.check_budget()?;
126        let resp = self.inner.embed(req).await?;
127        if let Some(p) = pricing::pricing_for(&resp.model) {
128            self.cost.add(p.cost_for(resp.usage).total_usd());
129        }
130        Ok(resp)
131    }
132
133    fn name(&self) -> &'static str {
134        self.inner.name()
135    }
136
137    fn model(&self) -> &str {
138        self.inner.model()
139    }
140
141    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
142        self.inner.estimate_cost(req)
143    }
144}