1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize)]
5pub struct RawTokenUsage {
6 pub input_tokens: i64,
7 pub output_tokens: i64,
8 #[serde(default)]
9 pub cache_read_input_tokens: Option<i64>,
10 #[serde(default)]
11 pub cache_creation_input_tokens: Option<i64>,
12}
13
14#[derive(Debug, Clone)]
16pub struct ParsedUsage {
17 pub agent_id: String,
18 pub session_id: Option<i64>,
19 pub input_tokens: i64,
20 pub output_tokens: i64,
21 pub cache_read_tokens: Option<i64>,
22 pub cache_creation_tokens: Option<i64>,
23 pub model: String,
24 pub cost_estimate: Option<f64>,
25}
26
27struct ModelPricing {
29 input_per_mtok: f64,
30 output_per_mtok: f64,
31 cache_read_per_mtok: f64,
32 cache_creation_per_mtok: f64,
33}
34
35pub fn estimate_cost(
37 model: &str,
38 input_tokens: i64,
39 output_tokens: i64,
40 cache_read_tokens: Option<i64>,
41 cache_creation_tokens: Option<i64>,
42) -> Option<f64> {
43 let pricing = model_pricing(model)?;
44
45 let input_cost = input_tokens as f64 * pricing.input_per_mtok / 1_000_000.0;
46 let output_cost = output_tokens as f64 * pricing.output_per_mtok / 1_000_000.0;
47 let cache_read_cost =
48 cache_read_tokens.unwrap_or(0) as f64 * pricing.cache_read_per_mtok / 1_000_000.0;
49 let cache_create_cost =
50 cache_creation_tokens.unwrap_or(0) as f64 * pricing.cache_creation_per_mtok / 1_000_000.0;
51
52 Some(input_cost + output_cost + cache_read_cost + cache_create_cost)
53}
54
55fn model_pricing(model: &str) -> Option<ModelPricing> {
56 let m = model.to_lowercase();
57 if m.contains("opus") {
58 Some(ModelPricing {
59 input_per_mtok: 15.0,
60 output_per_mtok: 75.0,
61 cache_read_per_mtok: 1.5,
62 cache_creation_per_mtok: 18.75,
63 })
64 } else if m.contains("sonnet") {
65 Some(ModelPricing {
66 input_per_mtok: 3.0,
67 output_per_mtok: 15.0,
68 cache_read_per_mtok: 0.3,
69 cache_creation_per_mtok: 3.75,
70 })
71 } else if m.contains("haiku") {
72 Some(ModelPricing {
73 input_per_mtok: 0.80,
74 output_per_mtok: 4.0,
75 cache_read_per_mtok: 0.08,
76 cache_creation_per_mtok: 1.0,
77 })
78 } else {
79 None
80 }
81}
82
83pub fn parse_api_usage(
85 raw: &RawTokenUsage,
86 model: &str,
87 agent_id: &str,
88 session_id: Option<i64>,
89) -> ParsedUsage {
90 let cost = estimate_cost(
91 model,
92 raw.input_tokens,
93 raw.output_tokens,
94 raw.cache_read_input_tokens,
95 raw.cache_creation_input_tokens,
96 );
97
98 ParsedUsage {
99 agent_id: agent_id.to_string(),
100 session_id,
101 input_tokens: raw.input_tokens,
102 output_tokens: raw.output_tokens,
103 cache_read_tokens: raw.cache_read_input_tokens,
104 cache_creation_tokens: raw.cache_creation_input_tokens,
105 model: model.to_string(),
106 cost_estimate: cost,
107 }
108}
109
110#[derive(Debug, Clone, serde::Serialize)]
112pub struct UsageSummaryRow {
113 pub agent_id: String,
114 pub model: String,
115 pub request_count: i64,
116 pub total_input_tokens: i64,
117 pub total_output_tokens: i64,
118 pub total_cache_read_tokens: i64,
119 pub total_cache_creation_tokens: i64,
120 pub total_cost: f64,
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_estimate_cost_opus() {
129 let cost = estimate_cost("claude-opus-4-6", 1_000_000, 1_000_000, None, None).unwrap();
130 assert!((cost - 90.0).abs() < 0.01);
132 }
133
134 #[test]
135 fn test_estimate_cost_sonnet() {
136 let cost = estimate_cost("claude-sonnet-4-6", 1_000_000, 1_000_000, None, None).unwrap();
137 assert!((cost - 18.0).abs() < 0.01);
139 }
140
141 #[test]
142 fn test_estimate_cost_haiku() {
143 let cost = estimate_cost("claude-haiku-4-5", 1_000_000, 1_000_000, None, None).unwrap();
144 assert!((cost - 4.80).abs() < 0.01);
146 }
147
148 #[test]
149 fn test_estimate_cost_with_cache() {
150 let cost = estimate_cost(
151 "claude-opus-4-6",
152 500_000,
153 200_000,
154 Some(300_000),
155 Some(100_000),
156 )
157 .unwrap();
158 let expected = 500_000.0 * 15.0 / 1_000_000.0
159 + 200_000.0 * 75.0 / 1_000_000.0
160 + 300_000.0 * 1.5 / 1_000_000.0
161 + 100_000.0 * 18.75 / 1_000_000.0;
162 assert!((cost - expected).abs() < 0.01);
163 }
164
165 #[test]
166 fn test_estimate_cost_unknown_model() {
167 assert!(estimate_cost("gpt-4", 1000, 1000, None, None).is_none());
168 }
169
170 #[test]
171 fn test_parse_api_usage() {
172 let raw = RawTokenUsage {
173 input_tokens: 1000,
174 output_tokens: 500,
175 cache_read_input_tokens: Some(200),
176 cache_creation_input_tokens: None,
177 };
178 let parsed = parse_api_usage(&raw, "claude-sonnet-4-6", "worker-1", Some(42));
179 assert_eq!(parsed.agent_id, "worker-1");
180 assert_eq!(parsed.session_id, Some(42));
181 assert_eq!(parsed.input_tokens, 1000);
182 assert_eq!(parsed.output_tokens, 500);
183 assert_eq!(parsed.cache_read_tokens, Some(200));
184 assert!(parsed.cost_estimate.is_some());
185 }
186
187 #[test]
188 fn test_raw_token_usage_deserialize() {
189 let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
190 let raw: RawTokenUsage = serde_json::from_str(json).unwrap();
191 assert_eq!(raw.input_tokens, 100);
192 assert_eq!(raw.output_tokens, 50);
193 assert!(raw.cache_read_input_tokens.is_none());
194 }
195
196 #[test]
197 fn test_raw_token_usage_with_cache_fields() {
198 let json = r#"{
199 "input_tokens": 100,
200 "output_tokens": 50,
201 "cache_read_input_tokens": 30,
202 "cache_creation_input_tokens": 10
203 }"#;
204 let raw: RawTokenUsage = serde_json::from_str(json).unwrap();
205 assert_eq!(raw.cache_read_input_tokens, Some(30));
206 assert_eq!(raw.cache_creation_input_tokens, Some(10));
207 }
208}