Skip to main content

ai_agent/cost/
mod.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4#[derive(Debug, Clone, Default)]
5pub struct CostState {
6    pub total_cents: u64,
7    pub api_call_count: u64,
8    pub input_tokens: u64,
9    pub output_tokens: u64,
10}
11
12impl CostState {
13    pub fn new() -> Self {
14        Self::default()
15    }
16
17    pub fn add_api_cost(&mut self, input_tokens: u64, output_tokens: u64, cost_per_million: f64) {
18        self.api_call_count += 1;
19        self.input_tokens += input_tokens;
20        self.output_tokens += output_tokens;
21        let input_cost = (input_tokens as f64 / 1_000_000.0) * cost_per_million;
22        let output_cost = (output_tokens as f64 / 1_000_000.0) * cost_per_million;
23        let total_cost_cents = ((input_cost + output_cost) * 100.0).round() as u64;
24        self.total_cents += total_cost_cents;
25    }
26}
27
28#[derive(Debug, Default, Clone)]
29pub struct CostAccumulator {
30    state: Arc<AtomicU64>,
31}
32
33impl CostAccumulator {
34    pub fn new() -> Self {
35        Self {
36            state: Arc::new(AtomicU64::new(0)),
37        }
38    }
39
40    pub fn add(&self, cents: u64) {
41        self.state.fetch_add(cents, Ordering::Relaxed);
42    }
43
44    pub fn total(&self) -> u64 {
45        self.state.load(Ordering::Relaxed)
46    }
47
48    pub fn total_dollars(&self) -> f64 {
49        self.total() as f64 / 100.0
50    }
51
52    pub fn reset(&self) {
53        self.state.store(0, Ordering::Relaxed);
54    }
55}
56
57pub fn format_cost(cents: u64) -> String {
58    let dollars = cents as f64 / 100.0;
59    if dollars >= 1.0 {
60        format!("${:.2}", dollars)
61    } else {
62        format!("{}ยข", cents)
63    }
64}
65
66pub fn format_total_cost(
67    input_tokens: u64,
68    output_tokens: u64,
69    api_calls: u64,
70    total_cents: u64,
71) -> String {
72    let total_tokens = input_tokens + output_tokens;
73    let dollars = total_cents as f64 / 100.0;
74    format!(
75        "Session Costs:\n  API Calls: {}\n  Input Tokens: {}\n  Output Tokens: {}\n  Total Tokens: {}\n  Total Cost: ${:.4}",
76        api_calls, input_tokens, output_tokens, total_tokens, dollars
77    )
78}
79
80pub fn calculate_cost(
81    input_tokens: u64,
82    output_tokens: u64,
83    input_cost_per_million: f64,
84    output_cost_per_million: f64,
85) -> u64 {
86    let input_cost = (input_tokens as f64 / 1_000_000.0) * input_cost_per_million;
87    let output_cost = (output_tokens as f64 / 1_000_000.0) * output_cost_per_million;
88    ((input_cost + output_cost) * 100.0).round() as u64
89}
90
91pub mod pricing {
92    pub const OPUS_INPUT: f64 = 15.0;
93    pub const OPUS_OUTPUT: f64 = 75.0;
94    pub const SONNET_INPUT: f64 = 3.0;
95    pub const SONNET_OUTPUT: f64 = 15.0;
96    pub const HAIKU_INPUT: f64 = 0.8;
97    pub const HAIKU_OUTPUT: f64 = 4.0;
98
99    pub fn get_pricing(model_id: &str) -> Option<(f64, f64)> {
100        let model = model_id.to_lowercase();
101        if model.contains("opus") {
102            Some((OPUS_INPUT, OPUS_OUTPUT))
103        } else if model.contains("sonnet") {
104            Some((SONNET_INPUT, SONNET_OUTPUT))
105        } else if model.contains("haiku") {
106            Some((HAIKU_INPUT, HAIKU_OUTPUT))
107        } else {
108            None
109        }
110    }
111}
112
113#[derive(Debug, Clone)]
114pub struct CostSummary {
115    pub total_cents: u64,
116    pub input_tokens: u64,
117    pub output_tokens: u64,
118    pub api_calls: u64,
119}
120
121impl CostSummary {
122    pub fn format(&self) -> String {
123        format_total_cost(
124            self.input_tokens,
125            self.output_tokens,
126            self.api_calls,
127            self.total_cents,
128        )
129    }
130}
131
132pub fn has_console_billing_access() -> bool {
133    true
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_cost_state() {
142        let mut state = CostState::new();
143        state.add_api_cost(1000, 500, 3.0);
144        assert_eq!(state.api_call_count, 1);
145    }
146
147    #[test]
148    fn test_format_cost() {
149        assert_eq!(format_cost(150), "$1.50");
150    }
151}