Skip to main content

ai_agent/services/
model_cost.rs

1//! Model cost calculation.
2//!
3//! Provides cost estimation for different AI models similar to claude code.
4
5use serde::{Deserialize, Serialize};
6
7/// Model cost configuration (per million tokens)
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ModelCosts {
10    /// Input tokens cost per million
11    pub input_tokens: f64,
12    /// Output tokens cost per million
13    pub output_tokens: f64,
14    /// Prompt cache write tokens cost per million
15    pub prompt_cache_write_tokens: f64,
16    /// Prompt cache read tokens cost per million
17    pub prompt_cache_read_tokens: f64,
18    /// Web search requests cost per search
19    pub web_search_requests: f64,
20}
21
22impl ModelCosts {
23    /// Calculate cost for input tokens
24    pub fn input_cost(&self, tokens: u32) -> f64 {
25        (tokens as f64 / 1_000_000.0) * self.input_tokens
26    }
27
28    /// Calculate cost for output tokens
29    pub fn output_cost(&self, tokens: u32) -> f64 {
30        (tokens as f64 / 1_000_000.0) * self.output_tokens
31    }
32
33    /// Calculate cost for cache write tokens
34    pub fn cache_write_cost(&self, tokens: u32) -> f64 {
35        (tokens as f64 / 1_000_000.0) * self.prompt_cache_write_tokens
36    }
37
38    /// Calculate cost for cache read tokens
39    pub fn cache_read_cost(&self, tokens: u32) -> f64 {
40        (tokens as f64 / 1_000_000.0) * self.prompt_cache_read_tokens
41    }
42
43    /// Calculate total cost for a usage record
44    pub fn total_cost(&self, usage: &TokenUsage) -> f64 {
45        self.input_cost(usage.input_tokens)
46            + self.output_cost(usage.output_tokens)
47            + self.cache_write_cost(usage.prompt_cache_write_tokens)
48            + self.cache_read_cost(usage.prompt_cache_read_tokens)
49    }
50}
51
52/// Token usage from API response
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct TokenUsage {
55    pub input_tokens: u32,
56    pub output_tokens: u32,
57    #[serde(rename = "promptCacheWriteTokens")]
58    pub prompt_cache_write_tokens: u32,
59    #[serde(rename = "promptCacheReadTokens")]
60    pub prompt_cache_read_tokens: u32,
61}
62
63impl TokenUsage {
64    /// Total tokens used
65    pub fn total(&self) -> u32 {
66        self.input_tokens
67            + self.output_tokens
68            + self.prompt_cache_write_tokens
69            + self.prompt_cache_read_tokens
70    }
71}
72
73/// Model information for listing available models
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ModelInfo {
76    /// Model identifier
77    pub id: String,
78    /// Display name
79    pub name: String,
80    /// Description
81    pub description: String,
82    /// Context window size in tokens
83    pub context_window: u32,
84}
85
86/// Common cost tiers
87
88/// Standard pricing: $3 input / $15 output per M tokens
89pub const COST_TIER_3_15: ModelCosts = ModelCosts {
90    input_tokens: 3.0,
91    output_tokens: 15.0,
92    prompt_cache_write_tokens: 3.75,
93    prompt_cache_read_tokens: 0.3,
94    web_search_requests: 0.01,
95};
96
97/// Opus pricing: $15 input / $75 output per M tokens
98pub const COST_TIER_15_75: ModelCosts = ModelCosts {
99    input_tokens: 15.0,
100    output_tokens: 75.0,
101    prompt_cache_write_tokens: 18.75,
102    prompt_cache_read_tokens: 1.5,
103    web_search_requests: 0.01,
104};
105
106/// Mid-tier pricing: $5 input / $25 output per M tokens
107pub const COST_TIER_5_25: ModelCosts = ModelCosts {
108    input_tokens: 5.0,
109    output_tokens: 25.0,
110    prompt_cache_write_tokens: 6.25,
111    prompt_cache_read_tokens: 0.5,
112    web_search_requests: 0.01,
113};
114
115/// Fast mode pricing: $30 input / $150 output per M tokens
116pub const COST_TIER_30_150: ModelCosts = ModelCosts {
117    input_tokens: 30.0,
118    output_tokens: 150.0,
119    prompt_cache_write_tokens: 37.5,
120    prompt_cache_read_tokens: 3.0,
121    web_search_requests: 0.01,
122};
123
124/// Haiku 3.5 pricing: $0.80 input / $4 output per M tokens
125pub const COST_HAIKU_35: ModelCosts = ModelCosts {
126    input_tokens: 0.8,
127    output_tokens: 4.0,
128    prompt_cache_write_tokens: 1.0,
129    prompt_cache_read_tokens: 0.08,
130    web_search_requests: 0.01,
131};
132
133/// Haiku 4.5 pricing: $1 input / $5 output per M tokens
134pub const COST_HAIKU_45: ModelCosts = ModelCosts {
135    input_tokens: 1.0,
136    output_tokens: 5.0,
137    prompt_cache_write_tokens: 1.25,
138    prompt_cache_read_tokens: 0.1,
139    web_search_requests: 0.01,
140};
141
142/// Default cost for unknown models
143pub const COST_DEFAULT: ModelCosts = COST_TIER_5_25;
144
145/// Model cost registry
146pub struct ModelCostRegistry {
147    costs: std::collections::HashMap<String, ModelCosts>,
148}
149
150impl ModelCostRegistry {
151    pub fn new() -> Self {
152        let mut costs = std::collections::HashMap::new();
153
154        // Anthropic models
155        costs.insert("claude-opus-4-6".to_string(), COST_TIER_5_25);
156        costs.insert("claude-opus-4-5".to_string(), COST_TIER_5_25);
157        costs.insert("claude-opus-4-1".to_string(), COST_TIER_15_75);
158        costs.insert("claude-opus-4".to_string(), COST_TIER_15_75);
159        costs.insert("claude-sonnet-4-6".to_string(), COST_TIER_3_15);
160        costs.insert("claude-sonnet-4-5".to_string(), COST_TIER_3_15);
161        costs.insert("claude-sonnet-4".to_string(), COST_TIER_3_15);
162        costs.insert("claude-sonnet-3-5".to_string(), COST_TIER_3_15);
163        costs.insert("claude-haiku-4-5".to_string(), COST_HAIKU_45);
164        costs.insert("claude-haiku-3-5".to_string(), COST_HAIKU_35);
165
166        // MiniMax models
167        costs.insert("MiniMaxAI/MiniMax-M2.5".to_string(), COST_TIER_3_15);
168        costs.insert("MiniMaxAI/MiniMax-M2".to_string(), COST_TIER_3_15);
169
170        // OpenAI models (for compatibility)
171        costs.insert("gpt-4o".to_string(), COST_TIER_5_25);
172        costs.insert("gpt-4o-mini".to_string(), COST_HAIKU_35);
173        costs.insert("gpt-4-turbo".to_string(), COST_TIER_10_30);
174        costs.insert("gpt-4".to_string(), COST_TIER_30_60);
175
176        Self { costs }
177    }
178
179    /// Get cost for a model
180    pub fn get(&self, model: &str) -> &ModelCosts {
181        // Try exact match first
182        if let Some(cost) = self.costs.get(model) {
183            return cost;
184        }
185
186        // Try prefix match
187        for (key, cost) in &self.costs {
188            if model.starts_with(key) || key.starts_with(model) {
189                return cost;
190            }
191        }
192
193        &COST_DEFAULT
194    }
195
196    /// Register a custom model cost
197    pub fn register(&mut self, model: &str, costs: ModelCosts) {
198        self.costs.insert(model.to_string(), costs);
199    }
200}
201
202impl Default for ModelCostRegistry {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// Pricing tier for GPT-4: $30 input / $60 output per M tokens
209pub const COST_TIER_30_60: ModelCosts = ModelCosts {
210    input_tokens: 30.0,
211    output_tokens: 60.0,
212    prompt_cache_write_tokens: 30.0,
213    prompt_cache_read_tokens: 10.0,
214    web_search_requests: 0.01,
215};
216
217/// Pricing tier for GPT-4 Turbo: $10 input / $30 output per M tokens
218pub const COST_TIER_10_30: ModelCosts = ModelCosts {
219    input_tokens: 10.0,
220    output_tokens: 30.0,
221    prompt_cache_write_tokens: 10.0,
222    prompt_cache_read_tokens: 3.0,
223    web_search_requests: 0.01,
224};
225
226/// Calculate cost from model name and usage
227pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
228    let registry = ModelCostRegistry::new();
229    let costs = registry.get(model);
230    costs.total_cost(usage)
231}
232
233/// Calculate cost from raw token counts (avoids TokenUsage struct conversion)
234pub fn calculate_cost_for_tokens(
235    model: &str,
236    input_tokens: u32,
237    output_tokens: u32,
238    cache_read_input_tokens: u32,
239    cache_creation_input_tokens: u32,
240) -> f64 {
241    let registry = ModelCostRegistry::new();
242    let costs = registry.get(model);
243    costs.input_cost(input_tokens)
244        + costs.output_cost(output_tokens)
245        + costs.cache_read_cost(cache_read_input_tokens)
246        + costs.cache_write_cost(cache_creation_input_tokens)
247}
248
249/// Get list of available models with their display names and descriptions
250pub fn get_available_models() -> Vec<ModelInfo> {
251    vec![
252        ModelInfo {
253            id: "claude-opus-4-6".to_string(),
254            name: "Opus".to_string(),
255            description: "Most capable for complex work".to_string(),
256            context_window: 200_000,
257        },
258        ModelInfo {
259            id: "claude-sonnet-4-6".to_string(),
260            name: "Sonnet".to_string(),
261            description: "Best for everyday tasks".to_string(),
262            context_window: 200_000,
263        },
264        ModelInfo {
265            id: "claude-sonnet-4-6-20250520".to_string(),
266            name: "Sonnet 4.6".to_string(),
267            description: "Latest Sonnet model".to_string(),
268            context_window: 200_000,
269        },
270        ModelInfo {
271            id: "claude-haiku-4-5".to_string(),
272            name: "Haiku".to_string(),
273            description: "Fastest for quick answers".to_string(),
274            context_window: 200_000,
275        },
276        ModelInfo {
277            id: "claude-opus-4-5".to_string(),
278            name: "Opus 4.5".to_string(),
279            description: "Previous Opus version".to_string(),
280            context_window: 200_000,
281        },
282        ModelInfo {
283            id: "claude-sonnet-4-5".to_string(),
284            name: "Sonnet 4.5".to_string(),
285            description: "Previous Sonnet version".to_string(),
286            context_window: 200_000,
287        },
288        ModelInfo {
289            id: "MiniMaxAI/MiniMax-M2.5".to_string(),
290            name: "MiniMax M2.5".to_string(),
291            description: "Fast and capable (default)".to_string(),
292            context_window: 1_000_000,
293        },
294    ]
295}
296
297/// Format cost as dollars
298pub fn format_cost(cost: f64) -> String {
299    if cost < 0.01 {
300        format!("${:.4}", cost)
301    } else if cost < 1.0 {
302        format!("${:.2}", cost)
303    } else {
304        format!("${:.4}", cost)
305    }
306}
307
308/// Cost summary for display
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct CostSummary {
311    pub input_cost: f64,
312    pub output_cost: f64,
313    pub cache_write_cost: f64,
314    pub cache_read_cost: f64,
315    pub total_cost: f64,
316}
317
318impl CostSummary {
319    pub fn from_usage(model: &str, usage: &TokenUsage) -> Self {
320        let registry = ModelCostRegistry::new();
321        let costs = registry.get(model);
322
323        Self {
324            input_cost: costs.input_cost(usage.input_tokens),
325            output_cost: costs.output_cost(usage.output_tokens),
326            cache_write_cost: costs.cache_write_cost(usage.prompt_cache_write_tokens),
327            cache_read_cost: costs.cache_read_cost(usage.prompt_cache_read_tokens),
328            total_cost: costs.total_cost(usage),
329        }
330    }
331}
332
333use crate::utils::config::{
334    ModelUsage as ConfigModelUsage, get_current_project_config, save_current_project_config,
335};
336
337/// Stored cost state from project config
338#[derive(Debug, Clone, Default)]
339pub struct StoredCostState {
340    pub total_cost_usd: f64,
341    pub total_api_duration: u64,
342    pub total_api_duration_without_retries: u64,
343    pub total_tool_duration: u64,
344    pub total_lines_added: u32,
345    pub total_lines_removed: u32,
346    pub last_duration: Option<u64>,
347    pub model_usage: Option<std::collections::HashMap<String, ConfigModelUsage>>,
348}
349
350/// Get stored cost state from project config for a specific session.
351/// Returns the cost data if the session ID matches, or None otherwise.
352/// Use this to read costs BEFORE overwriting the config with save_current_session_costs().
353pub fn get_stored_session_costs(session_id: &str) -> Option<StoredCostState> {
354    let project_config = get_current_project_config();
355
356    // Only return costs if this is the same session that was last saved
357    if project_config.last_session_id.as_deref() != Some(session_id) {
358        return None;
359    }
360
361    Some(StoredCostState {
362        total_cost_usd: project_config.last_cost.unwrap_or(0.0),
363        total_api_duration: project_config.last_api_duration.unwrap_or(0),
364        total_api_duration_without_retries: project_config
365            .last_api_duration_without_retries
366            .unwrap_or(0),
367        total_tool_duration: project_config.last_tool_duration.unwrap_or(0),
368        total_lines_added: project_config.last_lines_added.unwrap_or(0),
369        total_lines_removed: project_config.last_lines_removed.unwrap_or(0),
370        last_duration: project_config.last_duration,
371        model_usage: project_config.last_model_usage,
372    })
373}
374
375/// Restores cost state from project config when resuming a session.
376/// Only restores if the session ID matches the last saved session.
377/// Returns true if cost state was restored, false otherwise.
378pub fn restore_cost_state_for_session(session_id: &str) -> bool {
379    let stored = get_stored_session_costs(session_id);
380    let Some(stored) = stored else {
381        return false;
382    };
383
384    update_global_cost_state(|state| {
385        state.total_cost_usd = stored.total_cost_usd;
386        state.total_api_duration = stored.total_api_duration;
387        state.total_api_duration_without_retries = stored.total_api_duration_without_retries;
388        state.total_tool_duration = stored.total_tool_duration;
389        state.total_lines_added = stored.total_lines_added;
390        state.total_lines_removed = stored.total_lines_removed;
391        state.last_duration = stored.last_duration;
392        state.model_usage = stored
393            .model_usage
394            .map(|mu| {
395                mu.into_iter()
396                    .map(|(k, v)| {
397                        (
398                            k,
399                            ModelUsageInfo {
400                                input_tokens: v.input_tokens,
401                                output_tokens: v.output_tokens,
402                                cache_read_input_tokens: v.cache_read_input_tokens,
403                                cache_creation_input_tokens: v.cache_creation_input_tokens,
404                                web_search_requests: v.web_search_requests,
405                                cost_usd: v.cost_usd,
406                                context_window: 0,
407                                max_output_tokens: 0,
408                            },
409                        )
410                    })
411                    .collect()
412            })
413            .unwrap_or_default();
414        state.session_id = session_id.to_string();
415    });
416
417    true
418}
419
420/// Saves the current session's costs to project config.
421/// Call this before switching sessions to avoid losing accumulated costs.
422pub fn save_current_session_costs() {
423    let cost_state = get_global_cost_state();
424
425    let model_usage_map: Option<std::collections::HashMap<String, ConfigModelUsage>> =
426        if cost_state.model_usage.is_empty() {
427            None
428        } else {
429            let mut map = std::collections::HashMap::new();
430            for (model, usage) in &cost_state.model_usage {
431                map.insert(
432                    model.clone(),
433                    ConfigModelUsage {
434                        input_tokens: usage.input_tokens,
435                        output_tokens: usage.output_tokens,
436                        cache_read_input_tokens: usage.cache_read_input_tokens,
437                        cache_creation_input_tokens: usage.cache_creation_input_tokens,
438                        web_search_requests: usage.web_search_requests,
439                        cost_usd: usage.cost_usd,
440                    },
441                );
442            }
443            Some(map)
444        };
445
446    let mut config = get_current_project_config();
447    config.last_cost = Some(cost_state.total_cost_usd);
448    config.last_api_duration = Some(cost_state.total_api_duration);
449    config.last_api_duration_without_retries = Some(cost_state.total_api_duration_without_retries);
450    config.last_tool_duration = Some(cost_state.total_tool_duration);
451    config.last_duration = cost_state.last_duration;
452    config.last_lines_added = Some(cost_state.total_lines_added);
453    config.last_lines_removed = Some(cost_state.total_lines_removed);
454    config.last_total_input_tokens = Some(cost_state.total_input_tokens);
455    config.last_total_output_tokens = Some(cost_state.total_output_tokens);
456    config.last_total_cache_creation_input_tokens =
457        Some(cost_state.total_cache_creation_input_tokens);
458    config.last_total_cache_read_input_tokens = Some(cost_state.total_cache_read_input_tokens);
459    config.last_total_web_search_requests = Some(cost_state.total_web_search_requests);
460    config.last_model_usage = model_usage_map;
461    config.last_session_id = Some(cost_state.session_id.clone());
462
463    let _ = save_current_project_config(config);
464}
465
466/// Format cost for display with variable decimal places
467fn format_cost_for_display(cost: f64, max_decimal_places: usize) -> String {
468    if cost > 0.5 {
469        format!("${:.2}", (cost * 100.0).round() / 100.0)
470    } else {
471        format!("${:.width$}", cost, width = max_decimal_places + 2)
472    }
473}
474
475/// Format a number with thousands separator
476fn format_number(n: u32) -> String {
477    let s = n.to_string();
478    let mut result = String::new();
479    let len = s.len();
480    for (i, c) in s.chars().enumerate() {
481        if i > 0 && (len - i) % 3 == 0 {
482            result.push(',');
483        }
484        result.push(c);
485    }
486    result
487}
488
489/// Model usage for cost tracking (includes context window info)
490#[derive(Debug, Clone, Default)]
491pub struct ModelUsageInfo {
492    pub input_tokens: u32,
493    pub output_tokens: u32,
494    pub cache_read_input_tokens: u32,
495    pub cache_creation_input_tokens: u32,
496    pub web_search_requests: u32,
497    pub cost_usd: f64,
498    pub context_window: u32,
499    pub max_output_tokens: u32,
500}
501
502/// Get canonical name for a model (short name)
503fn get_canonical_name(model: &str) -> String {
504    // Extract short name from model identifier
505    if model.contains("opus") {
506        "Opus".to_string()
507    } else if model.contains("sonnet") {
508        "Sonnet".to_string()
509    } else if model.contains("haiku") {
510        "Haiku".to_string()
511    } else if model.contains("MiniMax") {
512        "MiniMax".to_string()
513    } else if model.contains("gpt") {
514        "GPT".to_string()
515    } else {
516        model.to_string()
517    }
518}
519
520/// Format model usage for display
521pub fn format_model_usage() -> String {
522    let cost_state = get_global_cost_state();
523
524    if cost_state.model_usage.is_empty() {
525        return "Usage:                 0 input, 0 output, 0 cache read, 0 cache write".to_string();
526    }
527
528    // Accumulate usage by short name
529    let mut usage_by_short_name: std::collections::HashMap<String, ModelUsageInfo> =
530        std::collections::HashMap::new();
531    for (model, usage) in &cost_state.model_usage {
532        let short_name = get_canonical_name(model);
533        let entry = usage_by_short_name
534            .entry(short_name)
535            .or_insert_with(|| ModelUsageInfo::default());
536        entry.input_tokens += usage.input_tokens;
537        entry.output_tokens += usage.output_tokens;
538        entry.cache_read_input_tokens += usage.cache_read_input_tokens;
539        entry.cache_creation_input_tokens += usage.cache_creation_input_tokens;
540        entry.web_search_requests += usage.web_search_requests;
541        entry.cost_usd += usage.cost_usd;
542    }
543
544    let mut result = "Usage by model:".to_string();
545    for (short_name, usage) in &usage_by_short_name {
546        let usage_string = format!(
547            "  {} input, {} output, {} cache read, {} cache write{}{} (${})",
548            format_number(usage.input_tokens),
549            format_number(usage.output_tokens),
550            format_number(usage.cache_read_input_tokens),
551            format_number(usage.cache_creation_input_tokens),
552            if usage.web_search_requests > 0 {
553                format!(", {} web search", format_number(usage.web_search_requests))
554            } else {
555                String::new()
556            },
557            if cost_state.has_unknown_model_cost {
558                " (costs may be inaccurate due to usage of unknown models)".to_string()
559            } else {
560                String::new()
561            },
562            format_cost_for_display(usage.cost_usd, 4)
563        );
564        result.push('\n');
565        // Pad the model name to 21 characters
566        let padded_name = format!("{:<21}", format!("{}:", short_name));
567        result.push_str(&padded_name);
568        result.push_str(&usage_string.replace("  ", " "));
569    }
570    result
571}
572
573/// Format duration in human-readable format
574fn format_duration(ms: u64) -> String {
575    let seconds = ms / 1000;
576    let minutes = seconds / 60;
577    let hours = minutes / 60;
578
579    if hours > 0 {
580        format!("{}h {}m {}s", hours, minutes % 60, seconds % 60)
581    } else if minutes > 0 {
582        format!("{}m {}s", minutes, seconds % 60)
583    } else if seconds > 0 {
584        format!("{}s", seconds)
585    } else {
586        format!("{}ms", ms)
587    }
588}
589
590/// Format total cost for display
591pub fn format_total_cost() -> String {
592    let cost_state = get_global_cost_state();
593
594    let cost_display = format!("Total cost:            ${:.4}", cost_state.total_cost_usd);
595
596    let model_usage_display = format_model_usage();
597
598    format!(
599        "Total cost:            {}\nTotal duration (API):  {}\nTotal duration (wall): {}\nTotal code changes:    {} {} added, {} {}\n{}",
600        cost_display,
601        format_duration(cost_state.total_api_duration),
602        format_duration(cost_state.last_duration.unwrap_or(0)),
603        cost_state.total_lines_added,
604        if cost_state.total_lines_added == 1 {
605            "line"
606        } else {
607            "lines"
608        },
609        cost_state.total_lines_removed,
610        if cost_state.total_lines_removed == 1 {
611            "line"
612        } else {
613            "lines"
614        },
615        model_usage_display
616    )
617}
618
619/// Global cost tracking state
620#[derive(Debug, Clone, Default)]
621pub struct GlobalCostState {
622    pub total_cost_usd: f64,
623    pub total_api_duration: u64,
624    pub total_api_duration_without_retries: u64,
625    pub total_tool_duration: u64,
626    pub total_lines_added: u32,
627    pub total_lines_removed: u32,
628    pub last_duration: Option<u64>,
629    pub total_input_tokens: u32,
630    pub total_output_tokens: u32,
631    pub total_cache_creation_input_tokens: u32,
632    pub total_cache_read_input_tokens: u32,
633    pub total_web_search_requests: u32,
634    pub model_usage: std::collections::HashMap<String, ModelUsageInfo>,
635    pub has_unknown_model_cost: bool,
636    pub session_id: String,
637    /// Per-turn tool metrics (TS: turnToolDurationMs, turnToolCount)
638    pub turn_tool_duration_ms: u64,
639    pub turn_tool_count: u32,
640    /// Turn-level token budget tracking (TS: outputTokensAtTurnStart, currentTurnTokenBudget)
641    pub output_tokens_at_turn_start: u64,
642    pub current_turn_token_budget: Option<u64>,
643    pub budget_continuation_count: u32,
644}
645
646/// Global cost state singleton - thread-safe, persisted across calls
647static GLOBAL_COST_STATE: once_cell::sync::Lazy<std::sync::Mutex<GlobalCostState>> =
648    once_cell::sync::Lazy::new(|| std::sync::Mutex::new(GlobalCostState::default()));
649
650/// Initialize cost tracking for a new session
651pub fn init_cost_state(session_id: &str) {
652    let mut state = GLOBAL_COST_STATE.lock().unwrap();
653    *state = GlobalCostState {
654        session_id: session_id.to_string(),
655        ..Default::default()
656    };
657}
658
659/// Get the global cost state (singleton)
660fn get_global_cost_state() -> GlobalCostState {
661    GLOBAL_COST_STATE.lock().unwrap().clone()
662}
663
664/// Update the global cost state with a mutation closure
665pub fn update_global_cost_state<F: FnOnce(&mut GlobalCostState)>(f: F) {
666    let mut state = GLOBAL_COST_STATE.lock().unwrap();
667    f(&mut state);
668}
669
670/// Add to total model usage
671pub fn add_to_total_model_usage(
672    cost: f64,
673    input_tokens: u32,
674    output_tokens: u32,
675    cache_read_input_tokens: u32,
676    cache_creation_input_tokens: u32,
677    web_search_requests: u32,
678    model: &str,
679) -> ModelUsageInfo {
680    update_global_cost_state(|cost_state| {
681        let model_usage = cost_state
682            .model_usage
683            .entry(model.to_string())
684            .or_insert_with(|| ModelUsageInfo {
685                input_tokens: 0,
686                output_tokens: 0,
687                cache_read_input_tokens: 0,
688                cache_creation_input_tokens: 0,
689                web_search_requests: 0,
690                cost_usd: 0.0,
691                context_window: 0,
692                max_output_tokens: 0,
693            });
694
695        model_usage.input_tokens += input_tokens;
696        model_usage.output_tokens += output_tokens;
697        model_usage.cache_read_input_tokens += cache_read_input_tokens;
698        model_usage.cache_creation_input_tokens += cache_creation_input_tokens;
699        model_usage.web_search_requests += web_search_requests;
700        model_usage.cost_usd += cost;
701
702        cost_state.total_cost_usd += cost;
703        cost_state.total_input_tokens += input_tokens;
704        cost_state.total_output_tokens += output_tokens;
705        cost_state.total_cache_creation_input_tokens += cache_creation_input_tokens;
706        cost_state.total_cache_read_input_tokens += cache_read_input_tokens;
707        cost_state.total_web_search_requests += web_search_requests;
708    });
709
710    get_global_cost_state()
711        .model_usage
712        .get(model)
713        .cloned()
714        .unwrap_or_default()
715}
716
717/// Add to total session cost
718pub fn add_to_total_session_cost(
719    cost: f64,
720    input_tokens: u32,
721    output_tokens: u32,
722    cache_read_input_tokens: u32,
723    cache_creation_input_tokens: u32,
724    web_search_requests: u32,
725    model: &str,
726) -> f64 {
727    add_to_total_model_usage(
728        cost,
729        input_tokens,
730        output_tokens,
731        cache_read_input_tokens,
732        cache_creation_input_tokens,
733        web_search_requests,
734        model,
735    );
736
737    cost
738}
739
740/// Reset per-turn metrics at the start of a new turn
741pub fn reset_turn_metrics() {
742    update_global_cost_state(|state| {
743        state.turn_tool_duration_ms = 0;
744        state.turn_tool_count = 0;
745        state.output_tokens_at_turn_start = state.total_output_tokens as u64;
746    });
747}
748
749/// Record tool execution duration for the current turn
750pub fn record_turn_tool_duration(duration_ms: u64) {
751    update_global_cost_state(|state| {
752        state.turn_tool_duration_ms += duration_ms;
753        state.turn_tool_count += 1;
754    });
755}
756
757/// Get current turn metrics
758pub fn get_turn_metrics() -> (u64, u32) {
759    let state = get_global_cost_state();
760    (state.turn_tool_duration_ms, state.turn_tool_count)
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766
767    #[test]
768    fn test_model_costs_input() {
769        let costs = COST_TIER_3_15;
770        assert_eq!(costs.input_cost(1_000_000), 3.0);
771        assert_eq!(costs.input_cost(500_000), 1.5);
772    }
773
774    #[test]
775    fn test_model_costs_output() {
776        let costs = COST_TIER_3_15;
777        assert_eq!(costs.output_cost(1_000_000), 15.0);
778    }
779
780    #[test]
781    fn test_token_usage_total() {
782        let usage = TokenUsage {
783            input_tokens: 100,
784            output_tokens: 50,
785            prompt_cache_write_tokens: 25,
786            prompt_cache_read_tokens: 75,
787        };
788        assert_eq!(usage.total(), 250);
789    }
790
791    #[test]
792    fn test_model_cost_registry() {
793        let registry = ModelCostRegistry::new();
794
795        let costs = registry.get("claude-sonnet-4-6");
796        assert_eq!(costs.input_tokens, 3.0);
797
798        let costs = registry.get("claude-haiku-4-5");
799        assert_eq!(costs.input_tokens, 1.0);
800    }
801
802    #[test]
803    fn test_model_cost_registry_unknown() {
804        let registry = ModelCostRegistry::new();
805        let costs = registry.get("unknown-model");
806        assert_eq!(costs.input_tokens, COST_DEFAULT.input_tokens);
807    }
808
809    #[test]
810    fn test_calculate_cost() {
811        let usage = TokenUsage {
812            input_tokens: 1_000_000,
813            output_tokens: 500_000,
814            prompt_cache_write_tokens: 0,
815            prompt_cache_read_tokens: 0,
816        };
817
818        let cost = calculate_cost("claude-sonnet-4-6", &usage);
819        // $3 * 1 + $15 * 0.5 = $3 + $7.50 = $10.50
820        assert!((cost - 10.5).abs() < 0.01);
821    }
822
823    #[test]
824    fn test_format_cost() {
825        assert_eq!(format_cost(0.001), "$0.0010");
826        assert_eq!(format_cost(0.5), "$0.50");
827        assert_eq!(format_cost(1.5), "$1.5000");
828    }
829
830    #[test]
831    fn test_cost_summary() {
832        let usage = TokenUsage {
833            input_tokens: 1_000_000,
834            output_tokens: 500_000,
835            prompt_cache_write_tokens: 100_000,
836            prompt_cache_read_tokens: 200_000,
837        };
838
839        let summary = CostSummary::from_usage("claude-sonnet-4-6", &usage);
840
841        // Input: 1M * $3/M = $3
842        assert!((summary.input_cost - 3.0).abs() < 0.01);
843        // Output: 500K * $15/M = $7.50
844        assert!((summary.output_cost - 7.5).abs() < 0.01);
845        // Cache write: 100K * $3.75/M = $0.375
846        assert!((summary.cache_write_cost - 0.375).abs() < 0.01);
847        // Cache read: 200K * $0.3/M = $0.06
848        assert!((summary.cache_read_cost - 0.06).abs() < 0.01);
849    }
850}