opendev_runtime/
cost_tracker.rs1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::debug;
11
12#[derive(Debug, Clone, Default)]
16pub struct TokenUsage {
17 pub prompt_tokens: u64,
18 pub completion_tokens: u64,
19 pub cache_read_input_tokens: u64,
21 pub cache_creation_input_tokens: u64,
23}
24
25impl TokenUsage {
26 pub fn from_json(value: &serde_json::Value) -> Self {
28 Self {
29 prompt_tokens: value
30 .get("prompt_tokens")
31 .and_then(|v| v.as_u64())
32 .unwrap_or(0),
33 completion_tokens: value
34 .get("completion_tokens")
35 .and_then(|v| v.as_u64())
36 .unwrap_or(0),
37 cache_read_input_tokens: value
38 .get("cache_read_input_tokens")
39 .and_then(|v| v.as_u64())
40 .unwrap_or(0),
41 cache_creation_input_tokens: value
42 .get("cache_creation_input_tokens")
43 .and_then(|v| v.as_u64())
44 .unwrap_or(0),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
53pub struct PricingInfo {
54 pub input_price_per_million: f64,
55 pub output_price_per_million: f64,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CostTracker {
61 pub total_input_tokens: u64,
62 pub total_output_tokens: u64,
63 pub total_cost_usd: f64,
64 pub call_count: u64,
65 #[serde(skip_serializing_if = "Option::is_none")]
69 pub budget_usd: Option<f64>,
70}
71
72const OVER_200K_THRESHOLD: u64 = 200_000;
74const OVER_200K_MULTIPLIER: f64 = 1.5;
75const CACHE_READ_DISCOUNT: f64 = 0.1;
77
78impl CostTracker {
79 pub fn new() -> Self {
81 Self {
82 total_input_tokens: 0,
83 total_output_tokens: 0,
84 total_cost_usd: 0.0,
85 call_count: 0,
86 budget_usd: None,
87 }
88 }
89
90 pub fn set_budget(&mut self, usd: f64) {
96 self.budget_usd = Some(usd);
97 }
98
99 pub fn is_over_budget(&self) -> bool {
103 match self.budget_usd {
104 Some(budget) => self.total_cost_usd >= budget,
105 None => false,
106 }
107 }
108
109 pub fn remaining_budget(&self) -> Option<f64> {
111 self.budget_usd
112 .map(|budget| (budget - self.total_cost_usd).max(0.0))
113 }
114
115 pub fn record_usage(&mut self, usage: &TokenUsage, pricing: Option<&PricingInfo>) -> f64 {
119 self.total_input_tokens += usage.prompt_tokens;
120 self.total_output_tokens += usage.completion_tokens;
121 self.call_count += 1;
122
123 let incremental_cost = if let Some(p) = pricing {
124 if p.input_price_per_million > 0.0 || p.output_price_per_million > 0.0 {
125 self.compute_cost(usage, p)
126 } else {
127 0.0
128 }
129 } else {
130 0.0
131 };
132
133 self.total_cost_usd += incremental_cost;
134
135 debug!(
136 call = self.call_count,
137 input = usage.prompt_tokens,
138 output = usage.completion_tokens,
139 cost_delta = format!("${:.6}", incremental_cost),
140 cost_total = format!("${:.6}", self.total_cost_usd),
141 "cost_tracker: recorded usage"
142 );
143
144 incremental_cost
145 }
146
147 fn compute_cost(&self, usage: &TokenUsage, pricing: &PricingInfo) -> f64 {
148 let input_cost = if usage.prompt_tokens > OVER_200K_THRESHOLD {
150 let base = (OVER_200K_THRESHOLD as f64 / 1_000_000.0) * pricing.input_price_per_million;
151 let over = ((usage.prompt_tokens - OVER_200K_THRESHOLD) as f64 / 1_000_000.0)
152 * (pricing.input_price_per_million * OVER_200K_MULTIPLIER);
153 base + over
154 } else {
155 (usage.prompt_tokens as f64 / 1_000_000.0) * pricing.input_price_per_million
156 };
157
158 let cache_cost = if usage.cache_read_input_tokens > 0 {
160 (usage.cache_read_input_tokens as f64 / 1_000_000.0)
161 * (pricing.input_price_per_million * CACHE_READ_DISCOUNT)
162 } else {
163 0.0
164 };
165
166 let output_cost =
167 (usage.completion_tokens as f64 / 1_000_000.0) * pricing.output_price_per_million;
168
169 input_cost + output_cost + cache_cost
170 }
171
172 pub fn format_cost(&self) -> String {
174 if self.total_cost_usd < 0.01 {
175 format!("${:.4}", self.total_cost_usd)
176 } else {
177 format!("${:.2}", self.total_cost_usd)
178 }
179 }
180
181 pub fn to_metadata(&self) -> HashMap<String, serde_json::Value> {
183 let mut map = HashMap::new();
184 map.insert(
185 "total_cost_usd".into(),
186 serde_json::json!(round_f64(self.total_cost_usd, 6)),
187 );
188 map.insert(
189 "total_input_tokens".into(),
190 serde_json::json!(self.total_input_tokens),
191 );
192 map.insert(
193 "total_output_tokens".into(),
194 serde_json::json!(self.total_output_tokens),
195 );
196 map.insert("api_call_count".into(), serde_json::json!(self.call_count));
197 if let Some(budget) = self.budget_usd {
198 map.insert("budget_usd".into(), serde_json::json!(round_f64(budget, 6)));
199 }
200 map
201 }
202
203 pub fn restore_from_metadata(&mut self, metadata: &serde_json::Value) {
205 let cost_data = match metadata.get("cost_tracking") {
206 Some(v) => v,
207 None => return,
208 };
209
210 self.total_cost_usd = cost_data
211 .get("total_cost_usd")
212 .and_then(|v| v.as_f64())
213 .unwrap_or(0.0);
214 self.total_input_tokens = cost_data
215 .get("total_input_tokens")
216 .and_then(|v| v.as_u64())
217 .unwrap_or(0);
218 self.total_output_tokens = cost_data
219 .get("total_output_tokens")
220 .and_then(|v| v.as_u64())
221 .unwrap_or(0);
222 self.call_count = cost_data
223 .get("api_call_count")
224 .and_then(|v| v.as_u64())
225 .unwrap_or(0);
226 self.budget_usd = cost_data.get("budget_usd").and_then(|v| v.as_f64());
227
228 debug!(
229 cost = format!("${:.6}", self.total_cost_usd),
230 calls = self.call_count,
231 "cost_tracker: restored from metadata"
232 );
233 }
234}
235
236impl Default for CostTracker {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242fn round_f64(value: f64, decimals: u32) -> f64 {
244 let factor = 10f64.powi(decimals as i32);
245 (value * factor).round() / factor
246}
247
248#[cfg(test)]
249#[path = "cost_tracker_tests.rs"]
250mod tests;