1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ModelCosts {
10 pub input_tokens: f64,
12 pub output_tokens: f64,
14 pub prompt_cache_write_tokens: f64,
16 pub prompt_cache_read_tokens: f64,
18 pub web_search_requests: f64,
20}
21
22impl ModelCosts {
23 pub fn input_cost(&self, tokens: u32) -> f64 {
25 (tokens as f64 / 1_000_000.0) * self.input_tokens
26 }
27
28 pub fn output_cost(&self, tokens: u32) -> f64 {
30 (tokens as f64 / 1_000_000.0) * self.output_tokens
31 }
32
33 pub fn cache_write_cost(&self, tokens: u32) -> f64 {
35 (tokens as f64 / 1_000_000.0) * self.prompt_cache_write_tokens
36 }
37
38 pub fn cache_read_cost(&self, tokens: u32) -> f64 {
40 (tokens as f64 / 1_000_000.0) * self.prompt_cache_read_tokens
41 }
42
43 pub fn total_cost(&self, usage: &TokenUsage) -> f64 {
45 self.input_cost(usage.input_tokens)
46 + self.output_cost(usage.output_tokens)
47 + self.cache_write_cost(usage.prompt_cache_write_tokens)
48 + self.cache_read_cost(usage.prompt_cache_read_tokens)
49 }
50}
51
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct TokenUsage {
55 pub input_tokens: u32,
56 pub output_tokens: u32,
57 #[serde(rename = "promptCacheWriteTokens")]
58 pub prompt_cache_write_tokens: u32,
59 #[serde(rename = "promptCacheReadTokens")]
60 pub prompt_cache_read_tokens: u32,
61}
62
63impl TokenUsage {
64 pub fn total(&self) -> u32 {
66 self.input_tokens
67 + self.output_tokens
68 + self.prompt_cache_write_tokens
69 + self.prompt_cache_read_tokens
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ModelInfo {
76 pub id: String,
78 pub name: String,
80 pub description: String,
82 pub context_window: u32,
84}
85
86pub const COST_TIER_3_15: ModelCosts = ModelCosts {
90 input_tokens: 3.0,
91 output_tokens: 15.0,
92 prompt_cache_write_tokens: 3.75,
93 prompt_cache_read_tokens: 0.3,
94 web_search_requests: 0.01,
95};
96
97pub const COST_TIER_15_75: ModelCosts = ModelCosts {
99 input_tokens: 15.0,
100 output_tokens: 75.0,
101 prompt_cache_write_tokens: 18.75,
102 prompt_cache_read_tokens: 1.5,
103 web_search_requests: 0.01,
104};
105
106pub const COST_TIER_5_25: ModelCosts = ModelCosts {
108 input_tokens: 5.0,
109 output_tokens: 25.0,
110 prompt_cache_write_tokens: 6.25,
111 prompt_cache_read_tokens: 0.5,
112 web_search_requests: 0.01,
113};
114
115pub const COST_TIER_30_150: ModelCosts = ModelCosts {
117 input_tokens: 30.0,
118 output_tokens: 150.0,
119 prompt_cache_write_tokens: 37.5,
120 prompt_cache_read_tokens: 3.0,
121 web_search_requests: 0.01,
122};
123
124pub const COST_HAIKU_35: ModelCosts = ModelCosts {
126 input_tokens: 0.8,
127 output_tokens: 4.0,
128 prompt_cache_write_tokens: 1.0,
129 prompt_cache_read_tokens: 0.08,
130 web_search_requests: 0.01,
131};
132
133pub const COST_HAIKU_45: ModelCosts = ModelCosts {
135 input_tokens: 1.0,
136 output_tokens: 5.0,
137 prompt_cache_write_tokens: 1.25,
138 prompt_cache_read_tokens: 0.1,
139 web_search_requests: 0.01,
140};
141
142pub const COST_DEFAULT: ModelCosts = COST_TIER_5_25;
144
145pub struct ModelCostRegistry {
147 costs: std::collections::HashMap<String, ModelCosts>,
148}
149
150impl ModelCostRegistry {
151 pub fn new() -> Self {
152 let mut costs = std::collections::HashMap::new();
153
154 costs.insert("claude-opus-4-6".to_string(), COST_TIER_5_25);
156 costs.insert("claude-opus-4-5".to_string(), COST_TIER_5_25);
157 costs.insert("claude-opus-4-1".to_string(), COST_TIER_15_75);
158 costs.insert("claude-opus-4".to_string(), COST_TIER_15_75);
159 costs.insert("claude-sonnet-4-6".to_string(), COST_TIER_3_15);
160 costs.insert("claude-sonnet-4-5".to_string(), COST_TIER_3_15);
161 costs.insert("claude-sonnet-4".to_string(), COST_TIER_3_15);
162 costs.insert("claude-sonnet-3-5".to_string(), COST_TIER_3_15);
163 costs.insert("claude-haiku-4-5".to_string(), COST_HAIKU_45);
164 costs.insert("claude-haiku-3-5".to_string(), COST_HAIKU_35);
165
166 costs.insert("MiniMaxAI/MiniMax-M2.5".to_string(), COST_TIER_3_15);
168 costs.insert("MiniMaxAI/MiniMax-M2".to_string(), COST_TIER_3_15);
169
170 costs.insert("gpt-4o".to_string(), COST_TIER_5_25);
172 costs.insert("gpt-4o-mini".to_string(), COST_HAIKU_35);
173 costs.insert("gpt-4-turbo".to_string(), COST_TIER_10_30);
174 costs.insert("gpt-4".to_string(), COST_TIER_30_60);
175
176 Self { costs }
177 }
178
179 pub fn get(&self, model: &str) -> &ModelCosts {
181 if let Some(cost) = self.costs.get(model) {
183 return cost;
184 }
185
186 for (key, cost) in &self.costs {
188 if model.starts_with(key) || key.starts_with(model) {
189 return cost;
190 }
191 }
192
193 &COST_DEFAULT
194 }
195
196 pub fn register(&mut self, model: &str, costs: ModelCosts) {
198 self.costs.insert(model.to_string(), costs);
199 }
200}
201
202impl Default for ModelCostRegistry {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208pub const COST_TIER_30_60: ModelCosts = ModelCosts {
210 input_tokens: 30.0,
211 output_tokens: 60.0,
212 prompt_cache_write_tokens: 30.0,
213 prompt_cache_read_tokens: 10.0,
214 web_search_requests: 0.01,
215};
216
217pub const COST_TIER_10_30: ModelCosts = ModelCosts {
219 input_tokens: 10.0,
220 output_tokens: 30.0,
221 prompt_cache_write_tokens: 10.0,
222 prompt_cache_read_tokens: 3.0,
223 web_search_requests: 0.01,
224};
225
226pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
228 let registry = ModelCostRegistry::new();
229 let costs = registry.get(model);
230 costs.total_cost(usage)
231}
232
233pub fn calculate_cost_for_tokens(
235 model: &str,
236 input_tokens: u32,
237 output_tokens: u32,
238 cache_read_input_tokens: u32,
239 cache_creation_input_tokens: u32,
240) -> f64 {
241 let registry = ModelCostRegistry::new();
242 let costs = registry.get(model);
243 costs.input_cost(input_tokens)
244 + costs.output_cost(output_tokens)
245 + costs.cache_read_cost(cache_read_input_tokens)
246 + costs.cache_write_cost(cache_creation_input_tokens)
247}
248
249pub fn get_available_models() -> Vec<ModelInfo> {
251 vec![
252 ModelInfo {
253 id: "claude-opus-4-6".to_string(),
254 name: "Opus".to_string(),
255 description: "Most capable for complex work".to_string(),
256 context_window: 200_000,
257 },
258 ModelInfo {
259 id: "claude-sonnet-4-6".to_string(),
260 name: "Sonnet".to_string(),
261 description: "Best for everyday tasks".to_string(),
262 context_window: 200_000,
263 },
264 ModelInfo {
265 id: "claude-sonnet-4-6-20250520".to_string(),
266 name: "Sonnet 4.6".to_string(),
267 description: "Latest Sonnet model".to_string(),
268 context_window: 200_000,
269 },
270 ModelInfo {
271 id: "claude-haiku-4-5".to_string(),
272 name: "Haiku".to_string(),
273 description: "Fastest for quick answers".to_string(),
274 context_window: 200_000,
275 },
276 ModelInfo {
277 id: "claude-opus-4-5".to_string(),
278 name: "Opus 4.5".to_string(),
279 description: "Previous Opus version".to_string(),
280 context_window: 200_000,
281 },
282 ModelInfo {
283 id: "claude-sonnet-4-5".to_string(),
284 name: "Sonnet 4.5".to_string(),
285 description: "Previous Sonnet version".to_string(),
286 context_window: 200_000,
287 },
288 ModelInfo {
289 id: "MiniMaxAI/MiniMax-M2.5".to_string(),
290 name: "MiniMax M2.5".to_string(),
291 description: "Fast and capable (default)".to_string(),
292 context_window: 1_000_000,
293 },
294 ]
295}
296
297pub fn format_cost(cost: f64) -> String {
299 if cost < 0.01 {
300 format!("${:.4}", cost)
301 } else if cost < 1.0 {
302 format!("${:.2}", cost)
303 } else {
304 format!("${:.4}", cost)
305 }
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct CostSummary {
311 pub input_cost: f64,
312 pub output_cost: f64,
313 pub cache_write_cost: f64,
314 pub cache_read_cost: f64,
315 pub total_cost: f64,
316}
317
318impl CostSummary {
319 pub fn from_usage(model: &str, usage: &TokenUsage) -> Self {
320 let registry = ModelCostRegistry::new();
321 let costs = registry.get(model);
322
323 Self {
324 input_cost: costs.input_cost(usage.input_tokens),
325 output_cost: costs.output_cost(usage.output_tokens),
326 cache_write_cost: costs.cache_write_cost(usage.prompt_cache_write_tokens),
327 cache_read_cost: costs.cache_read_cost(usage.prompt_cache_read_tokens),
328 total_cost: costs.total_cost(usage),
329 }
330 }
331}
332
333use crate::utils::config::{
334 ModelUsage as ConfigModelUsage, get_current_project_config, save_current_project_config,
335};
336
337#[derive(Debug, Clone, Default)]
339pub struct StoredCostState {
340 pub total_cost_usd: f64,
341 pub total_api_duration: u64,
342 pub total_api_duration_without_retries: u64,
343 pub total_tool_duration: u64,
344 pub total_lines_added: u32,
345 pub total_lines_removed: u32,
346 pub last_duration: Option<u64>,
347 pub model_usage: Option<std::collections::HashMap<String, ConfigModelUsage>>,
348}
349
350pub fn get_stored_session_costs(session_id: &str) -> Option<StoredCostState> {
354 let project_config = get_current_project_config();
355
356 if project_config.last_session_id.as_deref() != Some(session_id) {
358 return None;
359 }
360
361 Some(StoredCostState {
362 total_cost_usd: project_config.last_cost.unwrap_or(0.0),
363 total_api_duration: project_config.last_api_duration.unwrap_or(0),
364 total_api_duration_without_retries: project_config
365 .last_api_duration_without_retries
366 .unwrap_or(0),
367 total_tool_duration: project_config.last_tool_duration.unwrap_or(0),
368 total_lines_added: project_config.last_lines_added.unwrap_or(0),
369 total_lines_removed: project_config.last_lines_removed.unwrap_or(0),
370 last_duration: project_config.last_duration,
371 model_usage: project_config.last_model_usage,
372 })
373}
374
375pub fn restore_cost_state_for_session(session_id: &str) -> bool {
379 let stored = get_stored_session_costs(session_id);
380 let Some(stored) = stored else {
381 return false;
382 };
383
384 update_global_cost_state(|state| {
385 state.total_cost_usd = stored.total_cost_usd;
386 state.total_api_duration = stored.total_api_duration;
387 state.total_api_duration_without_retries = stored.total_api_duration_without_retries;
388 state.total_tool_duration = stored.total_tool_duration;
389 state.total_lines_added = stored.total_lines_added;
390 state.total_lines_removed = stored.total_lines_removed;
391 state.last_duration = stored.last_duration;
392 state.model_usage = stored
393 .model_usage
394 .map(|mu| {
395 mu.into_iter()
396 .map(|(k, v)| {
397 (
398 k,
399 ModelUsageInfo {
400 input_tokens: v.input_tokens,
401 output_tokens: v.output_tokens,
402 cache_read_input_tokens: v.cache_read_input_tokens,
403 cache_creation_input_tokens: v.cache_creation_input_tokens,
404 web_search_requests: v.web_search_requests,
405 cost_usd: v.cost_usd,
406 context_window: 0,
407 max_output_tokens: 0,
408 },
409 )
410 })
411 .collect()
412 })
413 .unwrap_or_default();
414 state.session_id = session_id.to_string();
415 });
416
417 true
418}
419
420pub fn save_current_session_costs() {
423 let cost_state = get_global_cost_state();
424
425 let model_usage_map: Option<std::collections::HashMap<String, ConfigModelUsage>> =
426 if cost_state.model_usage.is_empty() {
427 None
428 } else {
429 let mut map = std::collections::HashMap::new();
430 for (model, usage) in &cost_state.model_usage {
431 map.insert(
432 model.clone(),
433 ConfigModelUsage {
434 input_tokens: usage.input_tokens,
435 output_tokens: usage.output_tokens,
436 cache_read_input_tokens: usage.cache_read_input_tokens,
437 cache_creation_input_tokens: usage.cache_creation_input_tokens,
438 web_search_requests: usage.web_search_requests,
439 cost_usd: usage.cost_usd,
440 },
441 );
442 }
443 Some(map)
444 };
445
446 let mut config = get_current_project_config();
447 config.last_cost = Some(cost_state.total_cost_usd);
448 config.last_api_duration = Some(cost_state.total_api_duration);
449 config.last_api_duration_without_retries = Some(cost_state.total_api_duration_without_retries);
450 config.last_tool_duration = Some(cost_state.total_tool_duration);
451 config.last_duration = cost_state.last_duration;
452 config.last_lines_added = Some(cost_state.total_lines_added);
453 config.last_lines_removed = Some(cost_state.total_lines_removed);
454 config.last_total_input_tokens = Some(cost_state.total_input_tokens);
455 config.last_total_output_tokens = Some(cost_state.total_output_tokens);
456 config.last_total_cache_creation_input_tokens =
457 Some(cost_state.total_cache_creation_input_tokens);
458 config.last_total_cache_read_input_tokens = Some(cost_state.total_cache_read_input_tokens);
459 config.last_total_web_search_requests = Some(cost_state.total_web_search_requests);
460 config.last_model_usage = model_usage_map;
461 config.last_session_id = Some(cost_state.session_id.clone());
462
463 let _ = save_current_project_config(config);
464}
465
466fn format_cost_for_display(cost: f64, max_decimal_places: usize) -> String {
468 if cost > 0.5 {
469 format!("${:.2}", (cost * 100.0).round() / 100.0)
470 } else {
471 format!("${:.width$}", cost, width = max_decimal_places + 2)
472 }
473}
474
475fn format_number(n: u32) -> String {
477 let s = n.to_string();
478 let mut result = String::new();
479 let len = s.len();
480 for (i, c) in s.chars().enumerate() {
481 if i > 0 && (len - i) % 3 == 0 {
482 result.push(',');
483 }
484 result.push(c);
485 }
486 result
487}
488
489#[derive(Debug, Clone, Default)]
491pub struct ModelUsageInfo {
492 pub input_tokens: u32,
493 pub output_tokens: u32,
494 pub cache_read_input_tokens: u32,
495 pub cache_creation_input_tokens: u32,
496 pub web_search_requests: u32,
497 pub cost_usd: f64,
498 pub context_window: u32,
499 pub max_output_tokens: u32,
500}
501
502fn get_canonical_name(model: &str) -> String {
504 if model.contains("opus") {
506 "Opus".to_string()
507 } else if model.contains("sonnet") {
508 "Sonnet".to_string()
509 } else if model.contains("haiku") {
510 "Haiku".to_string()
511 } else if model.contains("MiniMax") {
512 "MiniMax".to_string()
513 } else if model.contains("gpt") {
514 "GPT".to_string()
515 } else {
516 model.to_string()
517 }
518}
519
520pub fn format_model_usage() -> String {
522 let cost_state = get_global_cost_state();
523
524 if cost_state.model_usage.is_empty() {
525 return "Usage: 0 input, 0 output, 0 cache read, 0 cache write".to_string();
526 }
527
528 let mut usage_by_short_name: std::collections::HashMap<String, ModelUsageInfo> =
530 std::collections::HashMap::new();
531 for (model, usage) in &cost_state.model_usage {
532 let short_name = get_canonical_name(model);
533 let entry = usage_by_short_name
534 .entry(short_name)
535 .or_insert_with(|| ModelUsageInfo::default());
536 entry.input_tokens += usage.input_tokens;
537 entry.output_tokens += usage.output_tokens;
538 entry.cache_read_input_tokens += usage.cache_read_input_tokens;
539 entry.cache_creation_input_tokens += usage.cache_creation_input_tokens;
540 entry.web_search_requests += usage.web_search_requests;
541 entry.cost_usd += usage.cost_usd;
542 }
543
544 let mut result = "Usage by model:".to_string();
545 for (short_name, usage) in &usage_by_short_name {
546 let usage_string = format!(
547 " {} input, {} output, {} cache read, {} cache write{}{} (${})",
548 format_number(usage.input_tokens),
549 format_number(usage.output_tokens),
550 format_number(usage.cache_read_input_tokens),
551 format_number(usage.cache_creation_input_tokens),
552 if usage.web_search_requests > 0 {
553 format!(", {} web search", format_number(usage.web_search_requests))
554 } else {
555 String::new()
556 },
557 if cost_state.has_unknown_model_cost {
558 " (costs may be inaccurate due to usage of unknown models)".to_string()
559 } else {
560 String::new()
561 },
562 format_cost_for_display(usage.cost_usd, 4)
563 );
564 result.push('\n');
565 let padded_name = format!("{:<21}", format!("{}:", short_name));
567 result.push_str(&padded_name);
568 result.push_str(&usage_string.replace(" ", " "));
569 }
570 result
571}
572
573fn format_duration(ms: u64) -> String {
575 let seconds = ms / 1000;
576 let minutes = seconds / 60;
577 let hours = minutes / 60;
578
579 if hours > 0 {
580 format!("{}h {}m {}s", hours, minutes % 60, seconds % 60)
581 } else if minutes > 0 {
582 format!("{}m {}s", minutes, seconds % 60)
583 } else if seconds > 0 {
584 format!("{}s", seconds)
585 } else {
586 format!("{}ms", ms)
587 }
588}
589
590pub fn format_total_cost() -> String {
592 let cost_state = get_global_cost_state();
593
594 let cost_display = format!("Total cost: ${:.4}", cost_state.total_cost_usd);
595
596 let model_usage_display = format_model_usage();
597
598 format!(
599 "Total cost: {}\nTotal duration (API): {}\nTotal duration (wall): {}\nTotal code changes: {} {} added, {} {}\n{}",
600 cost_display,
601 format_duration(cost_state.total_api_duration),
602 format_duration(cost_state.last_duration.unwrap_or(0)),
603 cost_state.total_lines_added,
604 if cost_state.total_lines_added == 1 {
605 "line"
606 } else {
607 "lines"
608 },
609 cost_state.total_lines_removed,
610 if cost_state.total_lines_removed == 1 {
611 "line"
612 } else {
613 "lines"
614 },
615 model_usage_display
616 )
617}
618
619#[derive(Debug, Clone, Default)]
621pub struct GlobalCostState {
622 pub total_cost_usd: f64,
623 pub total_api_duration: u64,
624 pub total_api_duration_without_retries: u64,
625 pub total_tool_duration: u64,
626 pub total_lines_added: u32,
627 pub total_lines_removed: u32,
628 pub last_duration: Option<u64>,
629 pub total_input_tokens: u32,
630 pub total_output_tokens: u32,
631 pub total_cache_creation_input_tokens: u32,
632 pub total_cache_read_input_tokens: u32,
633 pub total_web_search_requests: u32,
634 pub model_usage: std::collections::HashMap<String, ModelUsageInfo>,
635 pub has_unknown_model_cost: bool,
636 pub session_id: String,
637 pub turn_tool_duration_ms: u64,
639 pub turn_tool_count: u32,
640 pub output_tokens_at_turn_start: u64,
642 pub current_turn_token_budget: Option<u64>,
643 pub budget_continuation_count: u32,
644}
645
646static GLOBAL_COST_STATE: once_cell::sync::Lazy<std::sync::Mutex<GlobalCostState>> =
648 once_cell::sync::Lazy::new(|| std::sync::Mutex::new(GlobalCostState::default()));
649
650pub fn init_cost_state(session_id: &str) {
652 let mut state = GLOBAL_COST_STATE.lock().unwrap();
653 *state = GlobalCostState {
654 session_id: session_id.to_string(),
655 ..Default::default()
656 };
657}
658
659fn get_global_cost_state() -> GlobalCostState {
661 GLOBAL_COST_STATE.lock().unwrap().clone()
662}
663
664pub fn update_global_cost_state<F: FnOnce(&mut GlobalCostState)>(f: F) {
666 let mut state = GLOBAL_COST_STATE.lock().unwrap();
667 f(&mut state);
668}
669
670pub fn add_to_total_model_usage(
672 cost: f64,
673 input_tokens: u32,
674 output_tokens: u32,
675 cache_read_input_tokens: u32,
676 cache_creation_input_tokens: u32,
677 web_search_requests: u32,
678 model: &str,
679) -> ModelUsageInfo {
680 update_global_cost_state(|cost_state| {
681 let model_usage = cost_state
682 .model_usage
683 .entry(model.to_string())
684 .or_insert_with(|| ModelUsageInfo {
685 input_tokens: 0,
686 output_tokens: 0,
687 cache_read_input_tokens: 0,
688 cache_creation_input_tokens: 0,
689 web_search_requests: 0,
690 cost_usd: 0.0,
691 context_window: 0,
692 max_output_tokens: 0,
693 });
694
695 model_usage.input_tokens += input_tokens;
696 model_usage.output_tokens += output_tokens;
697 model_usage.cache_read_input_tokens += cache_read_input_tokens;
698 model_usage.cache_creation_input_tokens += cache_creation_input_tokens;
699 model_usage.web_search_requests += web_search_requests;
700 model_usage.cost_usd += cost;
701
702 cost_state.total_cost_usd += cost;
703 cost_state.total_input_tokens += input_tokens;
704 cost_state.total_output_tokens += output_tokens;
705 cost_state.total_cache_creation_input_tokens += cache_creation_input_tokens;
706 cost_state.total_cache_read_input_tokens += cache_read_input_tokens;
707 cost_state.total_web_search_requests += web_search_requests;
708 });
709
710 get_global_cost_state()
711 .model_usage
712 .get(model)
713 .cloned()
714 .unwrap_or_default()
715}
716
717pub fn add_to_total_session_cost(
719 cost: f64,
720 input_tokens: u32,
721 output_tokens: u32,
722 cache_read_input_tokens: u32,
723 cache_creation_input_tokens: u32,
724 web_search_requests: u32,
725 model: &str,
726) -> f64 {
727 add_to_total_model_usage(
728 cost,
729 input_tokens,
730 output_tokens,
731 cache_read_input_tokens,
732 cache_creation_input_tokens,
733 web_search_requests,
734 model,
735 );
736
737 cost
738}
739
740pub fn reset_turn_metrics() {
742 update_global_cost_state(|state| {
743 state.turn_tool_duration_ms = 0;
744 state.turn_tool_count = 0;
745 state.output_tokens_at_turn_start = state.total_output_tokens as u64;
746 });
747}
748
749pub fn record_turn_tool_duration(duration_ms: u64) {
751 update_global_cost_state(|state| {
752 state.turn_tool_duration_ms += duration_ms;
753 state.turn_tool_count += 1;
754 });
755}
756
757pub fn get_turn_metrics() -> (u64, u32) {
759 let state = get_global_cost_state();
760 (state.turn_tool_duration_ms, state.turn_tool_count)
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766
767 #[test]
768 fn test_model_costs_input() {
769 let costs = COST_TIER_3_15;
770 assert_eq!(costs.input_cost(1_000_000), 3.0);
771 assert_eq!(costs.input_cost(500_000), 1.5);
772 }
773
774 #[test]
775 fn test_model_costs_output() {
776 let costs = COST_TIER_3_15;
777 assert_eq!(costs.output_cost(1_000_000), 15.0);
778 }
779
780 #[test]
781 fn test_token_usage_total() {
782 let usage = TokenUsage {
783 input_tokens: 100,
784 output_tokens: 50,
785 prompt_cache_write_tokens: 25,
786 prompt_cache_read_tokens: 75,
787 };
788 assert_eq!(usage.total(), 250);
789 }
790
791 #[test]
792 fn test_model_cost_registry() {
793 let registry = ModelCostRegistry::new();
794
795 let costs = registry.get("claude-sonnet-4-6");
796 assert_eq!(costs.input_tokens, 3.0);
797
798 let costs = registry.get("claude-haiku-4-5");
799 assert_eq!(costs.input_tokens, 1.0);
800 }
801
802 #[test]
803 fn test_model_cost_registry_unknown() {
804 let registry = ModelCostRegistry::new();
805 let costs = registry.get("unknown-model");
806 assert_eq!(costs.input_tokens, COST_DEFAULT.input_tokens);
807 }
808
809 #[test]
810 fn test_calculate_cost() {
811 let usage = TokenUsage {
812 input_tokens: 1_000_000,
813 output_tokens: 500_000,
814 prompt_cache_write_tokens: 0,
815 prompt_cache_read_tokens: 0,
816 };
817
818 let cost = calculate_cost("claude-sonnet-4-6", &usage);
819 assert!((cost - 10.5).abs() < 0.01);
821 }
822
823 #[test]
824 fn test_format_cost() {
825 assert_eq!(format_cost(0.001), "$0.0010");
826 assert_eq!(format_cost(0.5), "$0.50");
827 assert_eq!(format_cost(1.5), "$1.5000");
828 }
829
830 #[test]
831 fn test_cost_summary() {
832 let usage = TokenUsage {
833 input_tokens: 1_000_000,
834 output_tokens: 500_000,
835 prompt_cache_write_tokens: 100_000,
836 prompt_cache_read_tokens: 200_000,
837 };
838
839 let summary = CostSummary::from_usage("claude-sonnet-4-6", &usage);
840
841 assert!((summary.input_cost - 3.0).abs() < 0.01);
843 assert!((summary.output_cost - 7.5).abs() < 0.01);
845 assert!((summary.cache_write_cost - 0.375).abs() < 0.01);
847 assert!((summary.cache_read_cost - 0.06).abs() < 0.01);
849 }
850}