1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4#[derive(Debug, Clone, Default)]
5pub struct CostState {
6 pub total_cents: u64,
7 pub api_call_count: u64,
8 pub input_tokens: u64,
9 pub output_tokens: u64,
10}
11
12impl CostState {
13 pub fn new() -> Self {
14 Self::default()
15 }
16
17 pub fn add_api_cost(&mut self, input_tokens: u64, output_tokens: u64, cost_per_million: f64) {
18 self.api_call_count += 1;
19 self.input_tokens += input_tokens;
20 self.output_tokens += output_tokens;
21 let input_cost = (input_tokens as f64 / 1_000_000.0) * cost_per_million;
22 let output_cost = (output_tokens as f64 / 1_000_000.0) * cost_per_million;
23 let total_cost_cents = ((input_cost + output_cost) * 100.0).round() as u64;
24 self.total_cents += total_cost_cents;
25 }
26}
27
28#[derive(Debug, Default, Clone)]
29pub struct CostAccumulator {
30 state: Arc<AtomicU64>,
31}
32
33impl CostAccumulator {
34 pub fn new() -> Self {
35 Self {
36 state: Arc::new(AtomicU64::new(0)),
37 }
38 }
39
40 pub fn add(&self, cents: u64) {
41 self.state.fetch_add(cents, Ordering::Relaxed);
42 }
43
44 pub fn total(&self) -> u64 {
45 self.state.load(Ordering::Relaxed)
46 }
47
48 pub fn total_dollars(&self) -> f64 {
49 self.total() as f64 / 100.0
50 }
51
52 pub fn reset(&self) {
53 self.state.store(0, Ordering::Relaxed);
54 }
55}
56
57pub fn format_cost(cents: u64) -> String {
58 let dollars = cents as f64 / 100.0;
59 if dollars >= 1.0 {
60 format!("${:.2}", dollars)
61 } else {
62 format!("{}ยข", cents)
63 }
64}
65
66pub fn format_total_cost(
67 input_tokens: u64,
68 output_tokens: u64,
69 api_calls: u64,
70 total_cents: u64,
71) -> String {
72 let total_tokens = input_tokens + output_tokens;
73 let dollars = total_cents as f64 / 100.0;
74 format!(
75 "Session Costs:\n API Calls: {}\n Input Tokens: {}\n Output Tokens: {}\n Total Tokens: {}\n Total Cost: ${:.4}",
76 api_calls, input_tokens, output_tokens, total_tokens, dollars
77 )
78}
79
80pub fn calculate_cost(
81 input_tokens: u64,
82 output_tokens: u64,
83 input_cost_per_million: f64,
84 output_cost_per_million: f64,
85) -> u64 {
86 let input_cost = (input_tokens as f64 / 1_000_000.0) * input_cost_per_million;
87 let output_cost = (output_tokens as f64 / 1_000_000.0) * output_cost_per_million;
88 ((input_cost + output_cost) * 100.0).round() as u64
89}
90
91pub mod pricing {
92 pub const OPUS_INPUT: f64 = 15.0;
93 pub const OPUS_OUTPUT: f64 = 75.0;
94 pub const SONNET_INPUT: f64 = 3.0;
95 pub const SONNET_OUTPUT: f64 = 15.0;
96 pub const HAIKU_INPUT: f64 = 0.8;
97 pub const HAIKU_OUTPUT: f64 = 4.0;
98
99 pub fn get_pricing(model_id: &str) -> Option<(f64, f64)> {
100 let model = model_id.to_lowercase();
101 if model.contains("opus") {
102 Some((OPUS_INPUT, OPUS_OUTPUT))
103 } else if model.contains("sonnet") {
104 Some((SONNET_INPUT, SONNET_OUTPUT))
105 } else if model.contains("haiku") {
106 Some((HAIKU_INPUT, HAIKU_OUTPUT))
107 } else {
108 None
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
114pub struct CostSummary {
115 pub total_cents: u64,
116 pub input_tokens: u64,
117 pub output_tokens: u64,
118 pub api_calls: u64,
119}
120
121impl CostSummary {
122 pub fn format(&self) -> String {
123 format_total_cost(
124 self.input_tokens,
125 self.output_tokens,
126 self.api_calls,
127 self.total_cents,
128 )
129 }
130}
131
132pub fn has_console_billing_access() -> bool {
133 true
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_cost_state() {
142 let mut state = CostState::new();
143 state.add_api_cost(1000, 500, 3.0);
144 assert_eq!(state.api_call_count, 1);
145 }
146
147 #[test]
148 fn test_format_cost() {
149 assert_eq!(format_cost(150), "$1.50");
150 }
151}