1use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use llmkit_core::{
9 pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
10 LlmError, LlmProvider, LlmResult,
11};
12
13use crate::layer::LlmLayer;
14
15#[derive(Clone, Default)]
17pub struct SessionCost(Arc<AtomicU64>);
18
19impl SessionCost {
20 fn add(&self, usd: f64) {
21 let micros = (usd * 1_000_000.0).round() as u64;
22 self.0.fetch_add(micros, Ordering::Relaxed);
23 }
24
25 pub fn total_usd(&self) -> f64 {
27 self.0.load(Ordering::Relaxed) as f64 / 1_000_000.0
28 }
29}
30
31#[derive(Clone)]
33pub struct CostTrackingLayer {
34 budget_usd: Option<f64>,
35 cost: SessionCost,
36}
37
38impl CostTrackingLayer {
39 pub fn new() -> Self {
41 Self { budget_usd: None, cost: SessionCost::default() }
42 }
43
44 pub fn with_budget(budget_usd: f64) -> Self {
46 Self { budget_usd: Some(budget_usd), cost: SessionCost::default() }
47 }
48
49 pub fn session_cost(&self) -> SessionCost {
51 self.cost.clone()
52 }
53}
54
55impl Default for CostTrackingLayer {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl LlmLayer for CostTrackingLayer {
62 type Provider = CostTracking;
63 fn layer(self, inner: Arc<dyn LlmProvider>) -> CostTracking {
64 CostTracking { inner, budget_usd: self.budget_usd, cost: self.cost }
65 }
66}
67
68pub struct CostTracking {
70 inner: Arc<dyn LlmProvider>,
71 budget_usd: Option<f64>,
72 cost: SessionCost,
73}
74
75impl CostTracking {
76 pub fn session_cost(&self) -> SessionCost {
78 self.cost.clone()
79 }
80
81 fn check_budget(&self) -> LlmResult<()> {
82 if let Some(budget) = self.budget_usd {
83 let spent = self.cost.total_usd();
84 if spent >= budget {
85 return Err(LlmError::BudgetExceeded(format!(
86 "session cost ${spent:.4} reached budget ${budget:.2}"
87 )));
88 }
89 }
90 Ok(())
91 }
92}
93
94#[async_trait]
95impl LlmProvider for CostTracking {
96 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
97 self.check_budget()?;
98 let mut resp = self.inner.chat(req).await?;
99
100 let cost = resp
101 .cost
102 .or_else(|| pricing::pricing_for(&resp.model).map(|p| p.cost_for(resp.usage)));
103 if let Some(c) = cost {
104 self.cost.add(c.total_usd());
105 resp.cost = Some(c);
106 tracing::info!(
107 provider = resp.provider,
108 model = %resp.model,
109 request_usd = c.total_usd(),
110 session_usd = self.cost.total_usd(),
111 "cost tracked"
112 );
113 }
114 Ok(resp)
115 }
116
117 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
118 self.check_budget()?;
119 self.inner.chat_stream(req).await
122 }
123
124 async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
125 self.check_budget()?;
126 let resp = self.inner.embed(req).await?;
127 if let Some(p) = pricing::pricing_for(&resp.model) {
128 self.cost.add(p.cost_for(resp.usage).total_usd());
129 }
130 Ok(resp)
131 }
132
133 fn name(&self) -> &'static str {
134 self.inner.name()
135 }
136
137 fn model(&self) -> &str {
138 self.inner.model()
139 }
140
141 fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
142 self.inner.estimate_cost(req)
143 }
144}