1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ModelCosts {
10 pub input_tokens: f64,
12 pub output_tokens: f64,
14 pub prompt_cache_write_tokens: f64,
16 pub prompt_cache_read_tokens: f64,
18 pub web_search_requests: f64,
20}
21
22impl ModelCosts {
23 pub fn input_cost(&self, tokens: u32) -> f64 {
25 (tokens as f64 / 1_000_000.0) * self.input_tokens
26 }
27
28 pub fn output_cost(&self, tokens: u32) -> f64 {
30 (tokens as f64 / 1_000_000.0) * self.output_tokens
31 }
32
33 pub fn cache_write_cost(&self, tokens: u32) -> f64 {
35 (tokens as f64 / 1_000_000.0) * self.prompt_cache_write_tokens
36 }
37
38 pub fn cache_read_cost(&self, tokens: u32) -> f64 {
40 (tokens as f64 / 1_000_000.0) * self.prompt_cache_read_tokens
41 }
42
43 pub fn total_cost(&self, usage: &TokenUsage) -> f64 {
45 self.input_cost(usage.input_tokens)
46 + self.output_cost(usage.output_tokens)
47 + self.cache_write_cost(usage.prompt_cache_write_tokens)
48 + self.cache_read_cost(usage.prompt_cache_read_tokens)
49 }
50}
51
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct TokenUsage {
55 pub input_tokens: u32,
56 pub output_tokens: u32,
57 #[serde(rename = "promptCacheWriteTokens")]
58 pub prompt_cache_write_tokens: u32,
59 #[serde(rename = "promptCacheReadTokens")]
60 pub prompt_cache_read_tokens: u32,
61}
62
63impl TokenUsage {
64 pub fn total(&self) -> u32 {
66 self.input_tokens
67 + self.output_tokens
68 + self.prompt_cache_write_tokens
69 + self.prompt_cache_read_tokens
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ModelInfo {
76 pub id: String,
78 pub name: String,
80 pub description: String,
82 pub context_window: u32,
84}
85
86pub const COST_TIER_3_15: ModelCosts = ModelCosts {
90 input_tokens: 3.0,
91 output_tokens: 15.0,
92 prompt_cache_write_tokens: 3.75,
93 prompt_cache_read_tokens: 0.3,
94 web_search_requests: 0.01,
95};
96
97pub const COST_TIER_15_75: ModelCosts = ModelCosts {
99 input_tokens: 15.0,
100 output_tokens: 75.0,
101 prompt_cache_write_tokens: 18.75,
102 prompt_cache_read_tokens: 1.5,
103 web_search_requests: 0.01,
104};
105
106pub const COST_TIER_5_25: ModelCosts = ModelCosts {
108 input_tokens: 5.0,
109 output_tokens: 25.0,
110 prompt_cache_write_tokens: 6.25,
111 prompt_cache_read_tokens: 0.5,
112 web_search_requests: 0.01,
113};
114
115pub const COST_TIER_30_150: ModelCosts = ModelCosts {
117 input_tokens: 30.0,
118 output_tokens: 150.0,
119 prompt_cache_write_tokens: 37.5,
120 prompt_cache_read_tokens: 3.0,
121 web_search_requests: 0.01,
122};
123
124pub const COST_HAIKU_35: ModelCosts = ModelCosts {
126 input_tokens: 0.8,
127 output_tokens: 4.0,
128 prompt_cache_write_tokens: 1.0,
129 prompt_cache_read_tokens: 0.08,
130 web_search_requests: 0.01,
131};
132
133pub const COST_HAIKU_45: ModelCosts = ModelCosts {
135 input_tokens: 1.0,
136 output_tokens: 5.0,
137 prompt_cache_write_tokens: 1.25,
138 prompt_cache_read_tokens: 0.1,
139 web_search_requests: 0.01,
140};
141
142pub const COST_DEFAULT: ModelCosts = COST_TIER_5_25;
144
145pub struct ModelCostRegistry {
147 costs: std::collections::HashMap<String, ModelCosts>,
148}
149
150impl ModelCostRegistry {
151 pub fn new() -> Self {
152 let mut costs = std::collections::HashMap::new();
153
154 costs.insert("claude-opus-4-6".to_string(), COST_TIER_5_25);
156 costs.insert("claude-opus-4-5".to_string(), COST_TIER_5_25);
157 costs.insert("claude-opus-4-1".to_string(), COST_TIER_15_75);
158 costs.insert("claude-opus-4".to_string(), COST_TIER_15_75);
159 costs.insert("claude-sonnet-4-6".to_string(), COST_TIER_3_15);
160 costs.insert("claude-sonnet-4-5".to_string(), COST_TIER_3_15);
161 costs.insert("claude-sonnet-4".to_string(), COST_TIER_3_15);
162 costs.insert("claude-sonnet-3-5".to_string(), COST_TIER_3_15);
163 costs.insert("claude-haiku-4-5".to_string(), COST_HAIKU_45);
164 costs.insert("claude-haiku-3-5".to_string(), COST_HAIKU_35);
165
166 costs.insert("MiniMaxAI/MiniMax-M2.5".to_string(), COST_TIER_3_15);
168 costs.insert("MiniMaxAI/MiniMax-M2".to_string(), COST_TIER_3_15);
169
170 costs.insert("gpt-4o".to_string(), COST_TIER_5_25);
172 costs.insert("gpt-4o-mini".to_string(), COST_HAIKU_35);
173 costs.insert("gpt-4-turbo".to_string(), COST_TIER_10_30);
174 costs.insert("gpt-4".to_string(), COST_TIER_30_60);
175
176 Self { costs }
177 }
178
179 pub fn get(&self, model: &str) -> &ModelCosts {
181 if let Some(cost) = self.costs.get(model) {
183 return cost;
184 }
185
186 for (key, cost) in &self.costs {
188 if model.starts_with(key) || key.starts_with(model) {
189 return cost;
190 }
191 }
192
193 &COST_DEFAULT
194 }
195
196 pub fn register(&mut self, model: &str, costs: ModelCosts) {
198 self.costs.insert(model.to_string(), costs);
199 }
200}
201
202impl Default for ModelCostRegistry {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208pub const COST_TIER_30_60: ModelCosts = ModelCosts {
210 input_tokens: 30.0,
211 output_tokens: 60.0,
212 prompt_cache_write_tokens: 30.0,
213 prompt_cache_read_tokens: 10.0,
214 web_search_requests: 0.01,
215};
216
217pub const COST_TIER_10_30: ModelCosts = ModelCosts {
219 input_tokens: 10.0,
220 output_tokens: 30.0,
221 prompt_cache_write_tokens: 10.0,
222 prompt_cache_read_tokens: 3.0,
223 web_search_requests: 0.01,
224};
225
226pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
228 let registry = ModelCostRegistry::new();
229 let costs = registry.get(model);
230 costs.total_cost(usage)
231}
232
233pub fn get_available_models() -> Vec<ModelInfo> {
235 vec![
236 ModelInfo {
237 id: "claude-opus-4-6".to_string(),
238 name: "Opus".to_string(),
239 description: "Most capable for complex work".to_string(),
240 context_window: 200_000,
241 },
242 ModelInfo {
243 id: "claude-sonnet-4-6".to_string(),
244 name: "Sonnet".to_string(),
245 description: "Best for everyday tasks".to_string(),
246 context_window: 200_000,
247 },
248 ModelInfo {
249 id: "claude-sonnet-4-6-20250520".to_string(),
250 name: "Sonnet 4.6".to_string(),
251 description: "Latest Sonnet model".to_string(),
252 context_window: 200_000,
253 },
254 ModelInfo {
255 id: "claude-haiku-4-5".to_string(),
256 name: "Haiku".to_string(),
257 description: "Fastest for quick answers".to_string(),
258 context_window: 200_000,
259 },
260 ModelInfo {
261 id: "claude-opus-4-5".to_string(),
262 name: "Opus 4.5".to_string(),
263 description: "Previous Opus version".to_string(),
264 context_window: 200_000,
265 },
266 ModelInfo {
267 id: "claude-sonnet-4-5".to_string(),
268 name: "Sonnet 4.5".to_string(),
269 description: "Previous Sonnet version".to_string(),
270 context_window: 200_000,
271 },
272 ModelInfo {
273 id: "MiniMaxAI/MiniMax-M2.5".to_string(),
274 name: "MiniMax M2.5".to_string(),
275 description: "Fast and capable (default)".to_string(),
276 context_window: 1_000_000,
277 },
278 ]
279}
280
281pub fn format_cost(cost: f64) -> String {
283 if cost < 0.01 {
284 format!("${:.4}", cost)
285 } else if cost < 1.0 {
286 format!("${:.2}", cost)
287 } else {
288 format!("${:.4}", cost)
289 }
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct CostSummary {
295 pub input_cost: f64,
296 pub output_cost: f64,
297 pub cache_write_cost: f64,
298 pub cache_read_cost: f64,
299 pub total_cost: f64,
300}
301
302impl CostSummary {
303 pub fn from_usage(model: &str, usage: &TokenUsage) -> Self {
304 let registry = ModelCostRegistry::new();
305 let costs = registry.get(model);
306
307 Self {
308 input_cost: costs.input_cost(usage.input_tokens),
309 output_cost: costs.output_cost(usage.output_tokens),
310 cache_write_cost: costs.cache_write_cost(usage.prompt_cache_write_tokens),
311 cache_read_cost: costs.cache_read_cost(usage.prompt_cache_read_tokens),
312 total_cost: costs.total_cost(usage),
313 }
314 }
315}
316
317use crate::utils::config::{
318 get_current_project_config, save_current_project_config, ModelUsage as ConfigModelUsage,
319};
320
321#[derive(Debug, Clone, Default)]
323pub struct StoredCostState {
324 pub total_cost_usd: f64,
325 pub total_api_duration: u64,
326 pub total_api_duration_without_retries: u64,
327 pub total_tool_duration: u64,
328 pub total_lines_added: u32,
329 pub total_lines_removed: u32,
330 pub last_duration: Option<u64>,
331 pub model_usage: Option<std::collections::HashMap<String, ConfigModelUsage>>,
332}
333
334pub fn get_stored_session_costs(session_id: &str) -> Option<StoredCostState> {
338 let project_config = get_current_project_config();
339
340 if project_config.last_session_id.as_deref() != Some(session_id) {
342 return None;
343 }
344
345 Some(StoredCostState {
346 total_cost_usd: project_config.last_cost.unwrap_or(0.0),
347 total_api_duration: project_config.last_api_duration.unwrap_or(0),
348 total_api_duration_without_retries: project_config
349 .last_api_duration_without_retries
350 .unwrap_or(0),
351 total_tool_duration: project_config.last_tool_duration.unwrap_or(0),
352 total_lines_added: project_config.last_lines_added.unwrap_or(0),
353 total_lines_removed: project_config.last_lines_removed.unwrap_or(0),
354 last_duration: project_config.last_duration,
355 model_usage: project_config.last_model_usage,
356 })
357}
358
359pub fn restore_cost_state_for_session(_session_id: &str) -> bool {
363 false
366}
367
368pub fn save_current_session_costs() {
371 let cost_state = get_global_cost_state();
372
373 let model_usage_map: Option<std::collections::HashMap<String, ConfigModelUsage>> =
374 if cost_state.model_usage.is_empty() {
375 None
376 } else {
377 let mut map = std::collections::HashMap::new();
378 for (model, usage) in &cost_state.model_usage {
379 map.insert(
380 model.clone(),
381 ConfigModelUsage {
382 input_tokens: usage.input_tokens,
383 output_tokens: usage.output_tokens,
384 cache_read_input_tokens: usage.cache_read_input_tokens,
385 cache_creation_input_tokens: usage.cache_creation_input_tokens,
386 web_search_requests: usage.web_search_requests,
387 cost_usd: usage.cost_usd,
388 },
389 );
390 }
391 Some(map)
392 };
393
394 let mut config = get_current_project_config();
395 config.last_cost = Some(cost_state.total_cost_usd);
396 config.last_api_duration = Some(cost_state.total_api_duration);
397 config.last_api_duration_without_retries = Some(cost_state.total_api_duration_without_retries);
398 config.last_tool_duration = Some(cost_state.total_tool_duration);
399 config.last_duration = cost_state.last_duration;
400 config.last_lines_added = Some(cost_state.total_lines_added);
401 config.last_lines_removed = Some(cost_state.total_lines_removed);
402 config.last_total_input_tokens = Some(cost_state.total_input_tokens);
403 config.last_total_output_tokens = Some(cost_state.total_output_tokens);
404 config.last_total_cache_creation_input_tokens =
405 Some(cost_state.total_cache_creation_input_tokens);
406 config.last_total_cache_read_input_tokens = Some(cost_state.total_cache_read_input_tokens);
407 config.last_total_web_search_requests = Some(cost_state.total_web_search_requests);
408 config.last_model_usage = model_usage_map;
409 config.last_session_id = Some(cost_state.session_id.clone());
410
411 let _ = save_current_project_config(config);
412}
413
414fn format_cost_for_display(cost: f64, max_decimal_places: usize) -> String {
416 if cost > 0.5 {
417 format!("${:.2}", (cost * 100.0).round() / 100.0)
418 } else {
419 format!("${:.width$}", cost, width = max_decimal_places + 2)
420 }
421}
422
423fn format_number(n: u32) -> String {
425 let s = n.to_string();
426 let mut result = String::new();
427 let len = s.len();
428 for (i, c) in s.chars().enumerate() {
429 if i > 0 && (len - i) % 3 == 0 {
430 result.push(',');
431 }
432 result.push(c);
433 }
434 result
435}
436
437#[derive(Debug, Clone, Default)]
439pub struct ModelUsageInfo {
440 pub input_tokens: u32,
441 pub output_tokens: u32,
442 pub cache_read_input_tokens: u32,
443 pub cache_creation_input_tokens: u32,
444 pub web_search_requests: u32,
445 pub cost_usd: f64,
446 pub context_window: u32,
447 pub max_output_tokens: u32,
448}
449
450fn get_canonical_name(model: &str) -> String {
452 if model.contains("opus") {
454 "Opus".to_string()
455 } else if model.contains("sonnet") {
456 "Sonnet".to_string()
457 } else if model.contains("haiku") {
458 "Haiku".to_string()
459 } else if model.contains("MiniMax") {
460 "MiniMax".to_string()
461 } else if model.contains("gpt") {
462 "GPT".to_string()
463 } else {
464 model.to_string()
465 }
466}
467
468pub fn format_model_usage() -> String {
470 let cost_state = get_global_cost_state();
471
472 if cost_state.model_usage.is_empty() {
473 return "Usage: 0 input, 0 output, 0 cache read, 0 cache write".to_string();
474 }
475
476 let mut usage_by_short_name: std::collections::HashMap<String, ModelUsageInfo> =
478 std::collections::HashMap::new();
479 for (model, usage) in &cost_state.model_usage {
480 let short_name = get_canonical_name(model);
481 let entry = usage_by_short_name
482 .entry(short_name)
483 .or_insert_with(|| ModelUsageInfo::default());
484 entry.input_tokens += usage.input_tokens;
485 entry.output_tokens += usage.output_tokens;
486 entry.cache_read_input_tokens += usage.cache_read_input_tokens;
487 entry.cache_creation_input_tokens += usage.cache_creation_input_tokens;
488 entry.web_search_requests += usage.web_search_requests;
489 entry.cost_usd += usage.cost_usd;
490 }
491
492 let mut result = "Usage by model:".to_string();
493 for (short_name, usage) in &usage_by_short_name {
494 let usage_string = format!(
495 " {} input, {} output, {} cache read, {} cache write{}{} (${})",
496 format_number(usage.input_tokens),
497 format_number(usage.output_tokens),
498 format_number(usage.cache_read_input_tokens),
499 format_number(usage.cache_creation_input_tokens),
500 if usage.web_search_requests > 0 {
501 format!(", {} web search", format_number(usage.web_search_requests))
502 } else {
503 String::new()
504 },
505 if cost_state.has_unknown_model_cost {
506 " (costs may be inaccurate due to usage of unknown models)".to_string()
507 } else {
508 String::new()
509 },
510 format_cost_for_display(usage.cost_usd, 4)
511 );
512 result.push('\n');
513 let padded_name = format!("{:<21}", format!("{}:", short_name));
515 result.push_str(&padded_name);
516 result.push_str(&usage_string.replace(" ", " "));
517 }
518 result
519}
520
521fn format_duration(ms: u64) -> String {
523 let seconds = ms / 1000;
524 let minutes = seconds / 60;
525 let hours = minutes / 60;
526
527 if hours > 0 {
528 format!("{}h {}m {}s", hours, minutes % 60, seconds % 60)
529 } else if minutes > 0 {
530 format!("{}m {}s", minutes, seconds % 60)
531 } else if seconds > 0 {
532 format!("{}s", seconds)
533 } else {
534 format!("{}ms", ms)
535 }
536}
537
538pub fn format_total_cost() -> String {
540 let cost_state = get_global_cost_state();
541
542 let cost_display = format!("Total cost: ${:.4}", cost_state.total_cost_usd);
543
544 let model_usage_display = format_model_usage();
545
546 format!(
547 "Total cost: {}\nTotal duration (API): {}\nTotal duration (wall): {}\nTotal code changes: {} {} added, {} {}\n{}",
548 cost_display,
549 format_duration(cost_state.total_api_duration),
550 format_duration(cost_state.last_duration.unwrap_or(0)),
551 cost_state.total_lines_added,
552 if cost_state.total_lines_added == 1 { "line" } else { "lines" },
553 cost_state.total_lines_removed,
554 if cost_state.total_lines_removed == 1 { "line" } else { "lines" },
555 model_usage_display
556 )
557}
558
559#[derive(Debug, Clone, Default)]
561pub struct GlobalCostState {
562 pub total_cost_usd: f64,
563 pub total_api_duration: u64,
564 pub total_api_duration_without_retries: u64,
565 pub total_tool_duration: u64,
566 pub total_lines_added: u32,
567 pub total_lines_removed: u32,
568 pub last_duration: Option<u64>,
569 pub total_input_tokens: u32,
570 pub total_output_tokens: u32,
571 pub total_cache_creation_input_tokens: u32,
572 pub total_cache_read_input_tokens: u32,
573 pub total_web_search_requests: u32,
574 pub model_usage: std::collections::HashMap<String, ModelUsageInfo>,
575 pub has_unknown_model_cost: bool,
576 pub session_id: String,
577}
578
579fn get_global_cost_state() -> GlobalCostState {
581 GlobalCostState::default()
584}
585
586pub fn add_to_total_model_usage(
588 cost: f64,
589 input_tokens: u32,
590 output_tokens: u32,
591 cache_read_input_tokens: u32,
592 cache_creation_input_tokens: u32,
593 web_search_requests: u32,
594 model: &str,
595) -> ModelUsageInfo {
596 let mut cost_state = get_global_cost_state();
597
598 let model_usage = cost_state
599 .model_usage
600 .entry(model.to_string())
601 .or_insert_with(|| ModelUsageInfo {
602 input_tokens: 0,
603 output_tokens: 0,
604 cache_read_input_tokens: 0,
605 cache_creation_input_tokens: 0,
606 web_search_requests: 0,
607 cost_usd: 0.0,
608 context_window: 0,
609 max_output_tokens: 0,
610 });
611
612 model_usage.input_tokens += input_tokens;
613 model_usage.output_tokens += output_tokens;
614 model_usage.cache_read_input_tokens += cache_read_input_tokens;
615 model_usage.cache_creation_input_tokens += cache_creation_input_tokens;
616 model_usage.web_search_requests += web_search_requests;
617 model_usage.cost_usd += cost;
618
619 ModelUsageInfo {
620 input_tokens: model_usage.input_tokens,
621 output_tokens: model_usage.output_tokens,
622 cache_read_input_tokens: model_usage.cache_read_input_tokens,
623 cache_creation_input_tokens: model_usage.cache_creation_input_tokens,
624 web_search_requests: model_usage.web_search_requests,
625 cost_usd: model_usage.cost_usd,
626 context_window: model_usage.context_window,
627 max_output_tokens: model_usage.max_output_tokens,
628 }
629}
630
631pub fn add_to_total_session_cost(
633 cost: f64,
634 input_tokens: u32,
635 output_tokens: u32,
636 cache_read_input_tokens: u32,
637 cache_creation_input_tokens: u32,
638 web_search_requests: u32,
639 model: &str,
640) -> f64 {
641 add_to_total_model_usage(
642 cost,
643 input_tokens,
644 output_tokens,
645 cache_read_input_tokens,
646 cache_creation_input_tokens,
647 web_search_requests,
648 model,
649 );
650
651 cost
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 #[test]
659 fn test_model_costs_input() {
660 let costs = COST_TIER_3_15;
661 assert_eq!(costs.input_cost(1_000_000), 3.0);
662 assert_eq!(costs.input_cost(500_000), 1.5);
663 }
664
665 #[test]
666 fn test_model_costs_output() {
667 let costs = COST_TIER_3_15;
668 assert_eq!(costs.output_cost(1_000_000), 15.0);
669 }
670
671 #[test]
672 fn test_token_usage_total() {
673 let usage = TokenUsage {
674 input_tokens: 100,
675 output_tokens: 50,
676 prompt_cache_write_tokens: 25,
677 prompt_cache_read_tokens: 75,
678 };
679 assert_eq!(usage.total(), 250);
680 }
681
682 #[test]
683 fn test_model_cost_registry() {
684 let registry = ModelCostRegistry::new();
685
686 let costs = registry.get("claude-sonnet-4-6");
687 assert_eq!(costs.input_tokens, 3.0);
688
689 let costs = registry.get("claude-haiku-4-5");
690 assert_eq!(costs.input_tokens, 1.0);
691 }
692
693 #[test]
694 fn test_model_cost_registry_unknown() {
695 let registry = ModelCostRegistry::new();
696 let costs = registry.get("unknown-model");
697 assert_eq!(costs.input_tokens, COST_DEFAULT.input_tokens);
698 }
699
700 #[test]
701 fn test_calculate_cost() {
702 let usage = TokenUsage {
703 input_tokens: 1_000_000,
704 output_tokens: 500_000,
705 prompt_cache_write_tokens: 0,
706 prompt_cache_read_tokens: 0,
707 };
708
709 let cost = calculate_cost("claude-sonnet-4-6", &usage);
710 assert!((cost - 10.5).abs() < 0.01);
712 }
713
714 #[test]
715 fn test_format_cost() {
716 assert_eq!(format_cost(0.001), "$0.0010");
717 assert_eq!(format_cost(0.5), "$0.50");
718 assert_eq!(format_cost(1.5), "$1.5000");
719 }
720
721 #[test]
722 fn test_cost_summary() {
723 let usage = TokenUsage {
724 input_tokens: 1_000_000,
725 output_tokens: 500_000,
726 prompt_cache_write_tokens: 100_000,
727 prompt_cache_read_tokens: 200_000,
728 };
729
730 let summary = CostSummary::from_usage("claude-sonnet-4-6", &usage);
731
732 assert!((summary.input_cost - 3.0).abs() < 0.01);
734 assert!((summary.output_cost - 7.5).abs() < 0.01);
736 assert!((summary.cache_write_cost - 0.375).abs() < 0.01);
738 assert!((summary.cache_read_cost - 0.06).abs() < 0.01);
740 }
741}