Skip to main content

opendev_runtime/
cost_tracker.rs

1//! Session-level cost tracking for LLM API usage.
2//!
3//! Uses ModelInfo pricing ($ per million tokens) to compute cost from
4//! the usage dict returned by each LLM API call.
5//!
6//! Ported from `opendev/core/runtime/cost_tracker.py`.
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::debug;
11
12/// Token usage from a single LLM call.
13///
14/// Maps to the usage dict returned by OpenAI/Anthropic APIs.
15#[derive(Debug, Clone, Default)]
16pub struct TokenUsage {
17    pub prompt_tokens: u64,
18    pub completion_tokens: u64,
19    /// Anthropic prompt-caching: tokens read from cache.
20    pub cache_read_input_tokens: u64,
21    /// Anthropic prompt-caching: tokens written to cache.
22    pub cache_creation_input_tokens: u64,
23}
24
25impl TokenUsage {
26    /// Parse from a serde_json::Value (the `usage` field in API responses).
27    pub fn from_json(value: &serde_json::Value) -> Self {
28        Self {
29            prompt_tokens: value
30                .get("prompt_tokens")
31                .and_then(|v| v.as_u64())
32                .unwrap_or(0),
33            completion_tokens: value
34                .get("completion_tokens")
35                .and_then(|v| v.as_u64())
36                .unwrap_or(0),
37            cache_read_input_tokens: value
38                .get("cache_read_input_tokens")
39                .and_then(|v| v.as_u64())
40                .unwrap_or(0),
41            cache_creation_input_tokens: value
42                .get("cache_creation_input_tokens")
43                .and_then(|v| v.as_u64())
44                .unwrap_or(0),
45        }
46    }
47}
48
49/// Pricing info needed for cost computation.
50///
51/// Prices are in USD per 1 million tokens.
52#[derive(Debug, Clone)]
53pub struct PricingInfo {
54    pub input_price_per_million: f64,
55    pub output_price_per_million: f64,
56}
57
58/// Tracks cumulative token usage and estimated cost for a session.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CostTracker {
61    pub total_input_tokens: u64,
62    pub total_output_tokens: u64,
63    pub total_cost_usd: f64,
64    pub call_count: u64,
65    /// Optional session cost budget in USD. When set, the agent loop should
66    /// check [`is_over_budget`](CostTracker::is_over_budget) and pause when
67    /// the budget is exhausted.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub budget_usd: Option<f64>,
70}
71
72/// Anthropic charges higher rates for prompts over 200K tokens.
73const OVER_200K_THRESHOLD: u64 = 200_000;
74const OVER_200K_MULTIPLIER: f64 = 1.5;
75/// Cache read tokens are typically 10% of input price.
76const CACHE_READ_DISCOUNT: f64 = 0.1;
77
78impl CostTracker {
79    /// Create a new empty tracker.
80    pub fn new() -> Self {
81        Self {
82            total_input_tokens: 0,
83            total_output_tokens: 0,
84            total_cost_usd: 0.0,
85            call_count: 0,
86            budget_usd: None,
87        }
88    }
89
90    /// Set a cost budget in USD for the session.
91    ///
92    /// Once [`total_cost_usd`](CostTracker::total_cost_usd) reaches or
93    /// exceeds this value, [`is_over_budget`](CostTracker::is_over_budget)
94    /// returns `true` and the agent loop should pause.
95    pub fn set_budget(&mut self, usd: f64) {
96        self.budget_usd = Some(usd);
97    }
98
99    /// Check whether the session has exceeded its cost budget.
100    ///
101    /// Returns `false` when no budget has been set.
102    pub fn is_over_budget(&self) -> bool {
103        match self.budget_usd {
104            Some(budget) => self.total_cost_usd >= budget,
105            None => false,
106        }
107    }
108
109    /// Return the remaining budget in USD, or `None` if no budget is set.
110    pub fn remaining_budget(&self) -> Option<f64> {
111        self.budget_usd
112            .map(|budget| (budget - self.total_cost_usd).max(0.0))
113    }
114
115    /// Record token usage from a single LLM call.
116    ///
117    /// Returns the incremental cost for this call in USD.
118    pub fn record_usage(&mut self, usage: &TokenUsage, pricing: Option<&PricingInfo>) -> f64 {
119        self.total_input_tokens += usage.prompt_tokens;
120        self.total_output_tokens += usage.completion_tokens;
121        self.call_count += 1;
122
123        let incremental_cost = if let Some(p) = pricing {
124            if p.input_price_per_million > 0.0 || p.output_price_per_million > 0.0 {
125                self.compute_cost(usage, p)
126            } else {
127                0.0
128            }
129        } else {
130            0.0
131        };
132
133        self.total_cost_usd += incremental_cost;
134
135        debug!(
136            call = self.call_count,
137            input = usage.prompt_tokens,
138            output = usage.completion_tokens,
139            cost_delta = format!("${:.6}", incremental_cost),
140            cost_total = format!("${:.6}", self.total_cost_usd),
141            "cost_tracker: recorded usage"
142        );
143
144        incremental_cost
145    }
146
147    fn compute_cost(&self, usage: &TokenUsage, pricing: &PricingInfo) -> f64 {
148        // Handle tiered pricing for inputs over 200K tokens
149        let input_cost = if usage.prompt_tokens > OVER_200K_THRESHOLD {
150            let base = (OVER_200K_THRESHOLD as f64 / 1_000_000.0) * pricing.input_price_per_million;
151            let over = ((usage.prompt_tokens - OVER_200K_THRESHOLD) as f64 / 1_000_000.0)
152                * (pricing.input_price_per_million * OVER_200K_MULTIPLIER);
153            base + over
154        } else {
155            (usage.prompt_tokens as f64 / 1_000_000.0) * pricing.input_price_per_million
156        };
157
158        // Cache read tokens at 10% of input price
159        let cache_cost = if usage.cache_read_input_tokens > 0 {
160            (usage.cache_read_input_tokens as f64 / 1_000_000.0)
161                * (pricing.input_price_per_million * CACHE_READ_DISCOUNT)
162        } else {
163            0.0
164        };
165
166        let output_cost =
167            (usage.completion_tokens as f64 / 1_000_000.0) * pricing.output_price_per_million;
168
169        input_cost + output_cost + cache_cost
170    }
171
172    /// Format the total cost for display.
173    pub fn format_cost(&self) -> String {
174        if self.total_cost_usd < 0.01 {
175            format!("${:.4}", self.total_cost_usd)
176        } else {
177            format!("${:.2}", self.total_cost_usd)
178        }
179    }
180
181    /// Export cost data for session metadata persistence.
182    pub fn to_metadata(&self) -> HashMap<String, serde_json::Value> {
183        let mut map = HashMap::new();
184        map.insert(
185            "total_cost_usd".into(),
186            serde_json::json!(round_f64(self.total_cost_usd, 6)),
187        );
188        map.insert(
189            "total_input_tokens".into(),
190            serde_json::json!(self.total_input_tokens),
191        );
192        map.insert(
193            "total_output_tokens".into(),
194            serde_json::json!(self.total_output_tokens),
195        );
196        map.insert("api_call_count".into(), serde_json::json!(self.call_count));
197        if let Some(budget) = self.budget_usd {
198            map.insert("budget_usd".into(), serde_json::json!(round_f64(budget, 6)));
199        }
200        map
201    }
202
203    /// Restore cost state from session metadata (for `--continue` sessions).
204    pub fn restore_from_metadata(&mut self, metadata: &serde_json::Value) {
205        let cost_data = match metadata.get("cost_tracking") {
206            Some(v) => v,
207            None => return,
208        };
209
210        self.total_cost_usd = cost_data
211            .get("total_cost_usd")
212            .and_then(|v| v.as_f64())
213            .unwrap_or(0.0);
214        self.total_input_tokens = cost_data
215            .get("total_input_tokens")
216            .and_then(|v| v.as_u64())
217            .unwrap_or(0);
218        self.total_output_tokens = cost_data
219            .get("total_output_tokens")
220            .and_then(|v| v.as_u64())
221            .unwrap_or(0);
222        self.call_count = cost_data
223            .get("api_call_count")
224            .and_then(|v| v.as_u64())
225            .unwrap_or(0);
226        self.budget_usd = cost_data.get("budget_usd").and_then(|v| v.as_f64());
227
228        debug!(
229            cost = format!("${:.6}", self.total_cost_usd),
230            calls = self.call_count,
231            "cost_tracker: restored from metadata"
232        );
233    }
234}
235
236impl Default for CostTracker {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242/// Round an f64 to N decimal places.
243fn round_f64(value: f64, decimals: u32) -> f64 {
244    let factor = 10f64.powi(decimals as i32);
245    (value * factor).round() / factor
246}
247
248#[cfg(test)]
249#[path = "cost_tracker_tests.rs"]
250mod tests;