Skip to main content

ai_agent/services/
model_cost.rs

1//! Model cost calculation.
2//!
3//! Provides cost estimation for different AI models similar to claude code.
4
5use serde::{Deserialize, Serialize};
6
7/// Model cost configuration (per million tokens)
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ModelCosts {
10    /// Input tokens cost per million
11    pub input_tokens: f64,
12    /// Output tokens cost per million
13    pub output_tokens: f64,
14    /// Prompt cache write tokens cost per million
15    pub prompt_cache_write_tokens: f64,
16    /// Prompt cache read tokens cost per million
17    pub prompt_cache_read_tokens: f64,
18    /// Web search requests cost per search
19    pub web_search_requests: f64,
20}
21
22impl ModelCosts {
23    /// Calculate cost for input tokens
24    pub fn input_cost(&self, tokens: u32) -> f64 {
25        (tokens as f64 / 1_000_000.0) * self.input_tokens
26    }
27
28    /// Calculate cost for output tokens
29    pub fn output_cost(&self, tokens: u32) -> f64 {
30        (tokens as f64 / 1_000_000.0) * self.output_tokens
31    }
32
33    /// Calculate cost for cache write tokens
34    pub fn cache_write_cost(&self, tokens: u32) -> f64 {
35        (tokens as f64 / 1_000_000.0) * self.prompt_cache_write_tokens
36    }
37
38    /// Calculate cost for cache read tokens
39    pub fn cache_read_cost(&self, tokens: u32) -> f64 {
40        (tokens as f64 / 1_000_000.0) * self.prompt_cache_read_tokens
41    }
42
43    /// Calculate total cost for a usage record
44    pub fn total_cost(&self, usage: &TokenUsage) -> f64 {
45        self.input_cost(usage.input_tokens)
46            + self.output_cost(usage.output_tokens)
47            + self.cache_write_cost(usage.prompt_cache_write_tokens)
48            + self.cache_read_cost(usage.prompt_cache_read_tokens)
49    }
50}
51
52/// Token usage from API response
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct TokenUsage {
55    pub input_tokens: u32,
56    pub output_tokens: u32,
57    #[serde(rename = "promptCacheWriteTokens")]
58    pub prompt_cache_write_tokens: u32,
59    #[serde(rename = "promptCacheReadTokens")]
60    pub prompt_cache_read_tokens: u32,
61}
62
63impl TokenUsage {
64    /// Total tokens used
65    pub fn total(&self) -> u32 {
66        self.input_tokens
67            + self.output_tokens
68            + self.prompt_cache_write_tokens
69            + self.prompt_cache_read_tokens
70    }
71}
72
73/// Model information for listing available models
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ModelInfo {
76    /// Model identifier
77    pub id: String,
78    /// Display name
79    pub name: String,
80    /// Description
81    pub description: String,
82    /// Context window size in tokens
83    pub context_window: u32,
84}
85
86/// Common cost tiers
87
88/// Standard pricing: $3 input / $15 output per M tokens
89pub const COST_TIER_3_15: ModelCosts = ModelCosts {
90    input_tokens: 3.0,
91    output_tokens: 15.0,
92    prompt_cache_write_tokens: 3.75,
93    prompt_cache_read_tokens: 0.3,
94    web_search_requests: 0.01,
95};
96
97/// Opus pricing: $15 input / $75 output per M tokens
98pub const COST_TIER_15_75: ModelCosts = ModelCosts {
99    input_tokens: 15.0,
100    output_tokens: 75.0,
101    prompt_cache_write_tokens: 18.75,
102    prompt_cache_read_tokens: 1.5,
103    web_search_requests: 0.01,
104};
105
106/// Mid-tier pricing: $5 input / $25 output per M tokens
107pub const COST_TIER_5_25: ModelCosts = ModelCosts {
108    input_tokens: 5.0,
109    output_tokens: 25.0,
110    prompt_cache_write_tokens: 6.25,
111    prompt_cache_read_tokens: 0.5,
112    web_search_requests: 0.01,
113};
114
115/// Fast mode pricing: $30 input / $150 output per M tokens
116pub const COST_TIER_30_150: ModelCosts = ModelCosts {
117    input_tokens: 30.0,
118    output_tokens: 150.0,
119    prompt_cache_write_tokens: 37.5,
120    prompt_cache_read_tokens: 3.0,
121    web_search_requests: 0.01,
122};
123
124/// Haiku 3.5 pricing: $0.80 input / $4 output per M tokens
125pub const COST_HAIKU_35: ModelCosts = ModelCosts {
126    input_tokens: 0.8,
127    output_tokens: 4.0,
128    prompt_cache_write_tokens: 1.0,
129    prompt_cache_read_tokens: 0.08,
130    web_search_requests: 0.01,
131};
132
133/// Haiku 4.5 pricing: $1 input / $5 output per M tokens
134pub const COST_HAIKU_45: ModelCosts = ModelCosts {
135    input_tokens: 1.0,
136    output_tokens: 5.0,
137    prompt_cache_write_tokens: 1.25,
138    prompt_cache_read_tokens: 0.1,
139    web_search_requests: 0.01,
140};
141
142/// Default cost for unknown models
143pub const COST_DEFAULT: ModelCosts = COST_TIER_5_25;
144
145/// Model cost registry
146pub struct ModelCostRegistry {
147    costs: std::collections::HashMap<String, ModelCosts>,
148}
149
150impl ModelCostRegistry {
151    pub fn new() -> Self {
152        let mut costs = std::collections::HashMap::new();
153
154        // Anthropic models
155        costs.insert("claude-opus-4-6".to_string(), COST_TIER_5_25);
156        costs.insert("claude-opus-4-5".to_string(), COST_TIER_5_25);
157        costs.insert("claude-opus-4-1".to_string(), COST_TIER_15_75);
158        costs.insert("claude-opus-4".to_string(), COST_TIER_15_75);
159        costs.insert("claude-sonnet-4-6".to_string(), COST_TIER_3_15);
160        costs.insert("claude-sonnet-4-5".to_string(), COST_TIER_3_15);
161        costs.insert("claude-sonnet-4".to_string(), COST_TIER_3_15);
162        costs.insert("claude-sonnet-3-5".to_string(), COST_TIER_3_15);
163        costs.insert("claude-haiku-4-5".to_string(), COST_HAIKU_45);
164        costs.insert("claude-haiku-3-5".to_string(), COST_HAIKU_35);
165
166        // MiniMax models
167        costs.insert("MiniMaxAI/MiniMax-M2.5".to_string(), COST_TIER_3_15);
168        costs.insert("MiniMaxAI/MiniMax-M2".to_string(), COST_TIER_3_15);
169
170        // OpenAI models (for compatibility)
171        costs.insert("gpt-4o".to_string(), COST_TIER_5_25);
172        costs.insert("gpt-4o-mini".to_string(), COST_HAIKU_35);
173        costs.insert("gpt-4-turbo".to_string(), COST_TIER_10_30);
174        costs.insert("gpt-4".to_string(), COST_TIER_30_60);
175
176        Self { costs }
177    }
178
179    /// Get cost for a model
180    pub fn get(&self, model: &str) -> &ModelCosts {
181        // Try exact match first
182        if let Some(cost) = self.costs.get(model) {
183            return cost;
184        }
185
186        // Try prefix match
187        for (key, cost) in &self.costs {
188            if model.starts_with(key) || key.starts_with(model) {
189                return cost;
190            }
191        }
192
193        &COST_DEFAULT
194    }
195
196    /// Register a custom model cost
197    pub fn register(&mut self, model: &str, costs: ModelCosts) {
198        self.costs.insert(model.to_string(), costs);
199    }
200}
201
202impl Default for ModelCostRegistry {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// Pricing tier for GPT-4: $30 input / $60 output per M tokens
209pub const COST_TIER_30_60: ModelCosts = ModelCosts {
210    input_tokens: 30.0,
211    output_tokens: 60.0,
212    prompt_cache_write_tokens: 30.0,
213    prompt_cache_read_tokens: 10.0,
214    web_search_requests: 0.01,
215};
216
217/// Pricing tier for GPT-4 Turbo: $10 input / $30 output per M tokens
218pub const COST_TIER_10_30: ModelCosts = ModelCosts {
219    input_tokens: 10.0,
220    output_tokens: 30.0,
221    prompt_cache_write_tokens: 10.0,
222    prompt_cache_read_tokens: 3.0,
223    web_search_requests: 0.01,
224};
225
226/// Calculate cost from model name and usage
227pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
228    let registry = ModelCostRegistry::new();
229    let costs = registry.get(model);
230    costs.total_cost(usage)
231}
232
233/// Get list of available models with their display names and descriptions
234pub fn get_available_models() -> Vec<ModelInfo> {
235    vec![
236        ModelInfo {
237            id: "claude-opus-4-6".to_string(),
238            name: "Opus".to_string(),
239            description: "Most capable for complex work".to_string(),
240            context_window: 200_000,
241        },
242        ModelInfo {
243            id: "claude-sonnet-4-6".to_string(),
244            name: "Sonnet".to_string(),
245            description: "Best for everyday tasks".to_string(),
246            context_window: 200_000,
247        },
248        ModelInfo {
249            id: "claude-sonnet-4-6-20250520".to_string(),
250            name: "Sonnet 4.6".to_string(),
251            description: "Latest Sonnet model".to_string(),
252            context_window: 200_000,
253        },
254        ModelInfo {
255            id: "claude-haiku-4-5".to_string(),
256            name: "Haiku".to_string(),
257            description: "Fastest for quick answers".to_string(),
258            context_window: 200_000,
259        },
260        ModelInfo {
261            id: "claude-opus-4-5".to_string(),
262            name: "Opus 4.5".to_string(),
263            description: "Previous Opus version".to_string(),
264            context_window: 200_000,
265        },
266        ModelInfo {
267            id: "claude-sonnet-4-5".to_string(),
268            name: "Sonnet 4.5".to_string(),
269            description: "Previous Sonnet version".to_string(),
270            context_window: 200_000,
271        },
272        ModelInfo {
273            id: "MiniMaxAI/MiniMax-M2.5".to_string(),
274            name: "MiniMax M2.5".to_string(),
275            description: "Fast and capable (default)".to_string(),
276            context_window: 1_000_000,
277        },
278    ]
279}
280
281/// Format cost as dollars
282pub fn format_cost(cost: f64) -> String {
283    if cost < 0.01 {
284        format!("${:.4}", cost)
285    } else if cost < 1.0 {
286        format!("${:.2}", cost)
287    } else {
288        format!("${:.4}", cost)
289    }
290}
291
292/// Cost summary for display
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct CostSummary {
295    pub input_cost: f64,
296    pub output_cost: f64,
297    pub cache_write_cost: f64,
298    pub cache_read_cost: f64,
299    pub total_cost: f64,
300}
301
302impl CostSummary {
303    pub fn from_usage(model: &str, usage: &TokenUsage) -> Self {
304        let registry = ModelCostRegistry::new();
305        let costs = registry.get(model);
306
307        Self {
308            input_cost: costs.input_cost(usage.input_tokens),
309            output_cost: costs.output_cost(usage.output_tokens),
310            cache_write_cost: costs.cache_write_cost(usage.prompt_cache_write_tokens),
311            cache_read_cost: costs.cache_read_cost(usage.prompt_cache_read_tokens),
312            total_cost: costs.total_cost(usage),
313        }
314    }
315}
316
317use crate::utils::config::{
318    get_current_project_config, save_current_project_config, ModelUsage as ConfigModelUsage,
319};
320
321/// Stored cost state from project config
322#[derive(Debug, Clone, Default)]
323pub struct StoredCostState {
324    pub total_cost_usd: f64,
325    pub total_api_duration: u64,
326    pub total_api_duration_without_retries: u64,
327    pub total_tool_duration: u64,
328    pub total_lines_added: u32,
329    pub total_lines_removed: u32,
330    pub last_duration: Option<u64>,
331    pub model_usage: Option<std::collections::HashMap<String, ConfigModelUsage>>,
332}
333
334/// Get stored cost state from project config for a specific session.
335/// Returns the cost data if the session ID matches, or None otherwise.
336/// Use this to read costs BEFORE overwriting the config with save_current_session_costs().
337pub fn get_stored_session_costs(session_id: &str) -> Option<StoredCostState> {
338    let project_config = get_current_project_config();
339
340    // Only return costs if this is the same session that was last saved
341    if project_config.last_session_id.as_deref() != Some(session_id) {
342        return None;
343    }
344
345    Some(StoredCostState {
346        total_cost_usd: project_config.last_cost.unwrap_or(0.0),
347        total_api_duration: project_config.last_api_duration.unwrap_or(0),
348        total_api_duration_without_retries: project_config
349            .last_api_duration_without_retries
350            .unwrap_or(0),
351        total_tool_duration: project_config.last_tool_duration.unwrap_or(0),
352        total_lines_added: project_config.last_lines_added.unwrap_or(0),
353        total_lines_removed: project_config.last_lines_removed.unwrap_or(0),
354        last_duration: project_config.last_duration,
355        model_usage: project_config.last_model_usage,
356    })
357}
358
359/// Restores cost state from project config when resuming a session.
360/// Only restores if the session ID matches the last saved session.
361/// Returns true if cost state was restored, false otherwise.
362pub fn restore_cost_state_for_session(_session_id: &str) -> bool {
363    // This would need to integrate with the session state management
364    // For now, return false - the session state restoration is handled elsewhere
365    false
366}
367
368/// Saves the current session's costs to project config.
369/// Call this before switching sessions to avoid losing accumulated costs.
370pub fn save_current_session_costs() {
371    let cost_state = get_global_cost_state();
372
373    let model_usage_map: Option<std::collections::HashMap<String, ConfigModelUsage>> =
374        if cost_state.model_usage.is_empty() {
375            None
376        } else {
377            let mut map = std::collections::HashMap::new();
378            for (model, usage) in &cost_state.model_usage {
379                map.insert(
380                    model.clone(),
381                    ConfigModelUsage {
382                        input_tokens: usage.input_tokens,
383                        output_tokens: usage.output_tokens,
384                        cache_read_input_tokens: usage.cache_read_input_tokens,
385                        cache_creation_input_tokens: usage.cache_creation_input_tokens,
386                        web_search_requests: usage.web_search_requests,
387                        cost_usd: usage.cost_usd,
388                    },
389                );
390            }
391            Some(map)
392        };
393
394    let mut config = get_current_project_config();
395    config.last_cost = Some(cost_state.total_cost_usd);
396    config.last_api_duration = Some(cost_state.total_api_duration);
397    config.last_api_duration_without_retries = Some(cost_state.total_api_duration_without_retries);
398    config.last_tool_duration = Some(cost_state.total_tool_duration);
399    config.last_duration = cost_state.last_duration;
400    config.last_lines_added = Some(cost_state.total_lines_added);
401    config.last_lines_removed = Some(cost_state.total_lines_removed);
402    config.last_total_input_tokens = Some(cost_state.total_input_tokens);
403    config.last_total_output_tokens = Some(cost_state.total_output_tokens);
404    config.last_total_cache_creation_input_tokens =
405        Some(cost_state.total_cache_creation_input_tokens);
406    config.last_total_cache_read_input_tokens = Some(cost_state.total_cache_read_input_tokens);
407    config.last_total_web_search_requests = Some(cost_state.total_web_search_requests);
408    config.last_model_usage = model_usage_map;
409    config.last_session_id = Some(cost_state.session_id.clone());
410
411    let _ = save_current_project_config(config);
412}
413
414/// Format cost for display with variable decimal places
415fn format_cost_for_display(cost: f64, max_decimal_places: usize) -> String {
416    if cost > 0.5 {
417        format!("${:.2}", (cost * 100.0).round() / 100.0)
418    } else {
419        format!("${:.width$}", cost, width = max_decimal_places + 2)
420    }
421}
422
423/// Format a number with thousands separator
424fn format_number(n: u32) -> String {
425    let s = n.to_string();
426    let mut result = String::new();
427    let len = s.len();
428    for (i, c) in s.chars().enumerate() {
429        if i > 0 && (len - i) % 3 == 0 {
430            result.push(',');
431        }
432        result.push(c);
433    }
434    result
435}
436
437/// Model usage for cost tracking (includes context window info)
438#[derive(Debug, Clone, Default)]
439pub struct ModelUsageInfo {
440    pub input_tokens: u32,
441    pub output_tokens: u32,
442    pub cache_read_input_tokens: u32,
443    pub cache_creation_input_tokens: u32,
444    pub web_search_requests: u32,
445    pub cost_usd: f64,
446    pub context_window: u32,
447    pub max_output_tokens: u32,
448}
449
450/// Get canonical name for a model (short name)
451fn get_canonical_name(model: &str) -> String {
452    // Extract short name from model identifier
453    if model.contains("opus") {
454        "Opus".to_string()
455    } else if model.contains("sonnet") {
456        "Sonnet".to_string()
457    } else if model.contains("haiku") {
458        "Haiku".to_string()
459    } else if model.contains("MiniMax") {
460        "MiniMax".to_string()
461    } else if model.contains("gpt") {
462        "GPT".to_string()
463    } else {
464        model.to_string()
465    }
466}
467
468/// Format model usage for display
469pub fn format_model_usage() -> String {
470    let cost_state = get_global_cost_state();
471
472    if cost_state.model_usage.is_empty() {
473        return "Usage:                 0 input, 0 output, 0 cache read, 0 cache write".to_string();
474    }
475
476    // Accumulate usage by short name
477    let mut usage_by_short_name: std::collections::HashMap<String, ModelUsageInfo> =
478        std::collections::HashMap::new();
479    for (model, usage) in &cost_state.model_usage {
480        let short_name = get_canonical_name(model);
481        let entry = usage_by_short_name
482            .entry(short_name)
483            .or_insert_with(|| ModelUsageInfo::default());
484        entry.input_tokens += usage.input_tokens;
485        entry.output_tokens += usage.output_tokens;
486        entry.cache_read_input_tokens += usage.cache_read_input_tokens;
487        entry.cache_creation_input_tokens += usage.cache_creation_input_tokens;
488        entry.web_search_requests += usage.web_search_requests;
489        entry.cost_usd += usage.cost_usd;
490    }
491
492    let mut result = "Usage by model:".to_string();
493    for (short_name, usage) in &usage_by_short_name {
494        let usage_string = format!(
495            "  {} input, {} output, {} cache read, {} cache write{}{} (${})",
496            format_number(usage.input_tokens),
497            format_number(usage.output_tokens),
498            format_number(usage.cache_read_input_tokens),
499            format_number(usage.cache_creation_input_tokens),
500            if usage.web_search_requests > 0 {
501                format!(", {} web search", format_number(usage.web_search_requests))
502            } else {
503                String::new()
504            },
505            if cost_state.has_unknown_model_cost {
506                " (costs may be inaccurate due to usage of unknown models)".to_string()
507            } else {
508                String::new()
509            },
510            format_cost_for_display(usage.cost_usd, 4)
511        );
512        result.push('\n');
513        // Pad the model name to 21 characters
514        let padded_name = format!("{:<21}", format!("{}:", short_name));
515        result.push_str(&padded_name);
516        result.push_str(&usage_string.replace("  ", " "));
517    }
518    result
519}
520
521/// Format duration in human-readable format
522fn format_duration(ms: u64) -> String {
523    let seconds = ms / 1000;
524    let minutes = seconds / 60;
525    let hours = minutes / 60;
526
527    if hours > 0 {
528        format!("{}h {}m {}s", hours, minutes % 60, seconds % 60)
529    } else if minutes > 0 {
530        format!("{}m {}s", minutes, seconds % 60)
531    } else if seconds > 0 {
532        format!("{}s", seconds)
533    } else {
534        format!("{}ms", ms)
535    }
536}
537
538/// Format total cost for display
539pub fn format_total_cost() -> String {
540    let cost_state = get_global_cost_state();
541
542    let cost_display = format!("Total cost:            ${:.4}", cost_state.total_cost_usd);
543
544    let model_usage_display = format_model_usage();
545
546    format!(
547        "Total cost:            {}\nTotal duration (API):  {}\nTotal duration (wall): {}\nTotal code changes:    {} {} added, {} {}\n{}",
548        cost_display,
549        format_duration(cost_state.total_api_duration),
550        format_duration(cost_state.last_duration.unwrap_or(0)),
551        cost_state.total_lines_added,
552        if cost_state.total_lines_added == 1 { "line" } else { "lines" },
553        cost_state.total_lines_removed,
554        if cost_state.total_lines_removed == 1 { "line" } else { "lines" },
555        model_usage_display
556    )
557}
558
559/// Global cost tracking state
560#[derive(Debug, Clone, Default)]
561pub struct GlobalCostState {
562    pub total_cost_usd: f64,
563    pub total_api_duration: u64,
564    pub total_api_duration_without_retries: u64,
565    pub total_tool_duration: u64,
566    pub total_lines_added: u32,
567    pub total_lines_removed: u32,
568    pub last_duration: Option<u64>,
569    pub total_input_tokens: u32,
570    pub total_output_tokens: u32,
571    pub total_cache_creation_input_tokens: u32,
572    pub total_cache_read_input_tokens: u32,
573    pub total_web_search_requests: u32,
574    pub model_usage: std::collections::HashMap<String, ModelUsageInfo>,
575    pub has_unknown_model_cost: bool,
576    pub session_id: String,
577}
578
579/// Get the global cost state (singleton)
580fn get_global_cost_state() -> GlobalCostState {
581    // In a real implementation, this would be a static or thread-local
582    // For now, return a default state
583    GlobalCostState::default()
584}
585
586/// Add to total model usage
587pub fn add_to_total_model_usage(
588    cost: f64,
589    input_tokens: u32,
590    output_tokens: u32,
591    cache_read_input_tokens: u32,
592    cache_creation_input_tokens: u32,
593    web_search_requests: u32,
594    model: &str,
595) -> ModelUsageInfo {
596    let mut cost_state = get_global_cost_state();
597
598    let model_usage = cost_state
599        .model_usage
600        .entry(model.to_string())
601        .or_insert_with(|| ModelUsageInfo {
602            input_tokens: 0,
603            output_tokens: 0,
604            cache_read_input_tokens: 0,
605            cache_creation_input_tokens: 0,
606            web_search_requests: 0,
607            cost_usd: 0.0,
608            context_window: 0,
609            max_output_tokens: 0,
610        });
611
612    model_usage.input_tokens += input_tokens;
613    model_usage.output_tokens += output_tokens;
614    model_usage.cache_read_input_tokens += cache_read_input_tokens;
615    model_usage.cache_creation_input_tokens += cache_creation_input_tokens;
616    model_usage.web_search_requests += web_search_requests;
617    model_usage.cost_usd += cost;
618
619    ModelUsageInfo {
620        input_tokens: model_usage.input_tokens,
621        output_tokens: model_usage.output_tokens,
622        cache_read_input_tokens: model_usage.cache_read_input_tokens,
623        cache_creation_input_tokens: model_usage.cache_creation_input_tokens,
624        web_search_requests: model_usage.web_search_requests,
625        cost_usd: model_usage.cost_usd,
626        context_window: model_usage.context_window,
627        max_output_tokens: model_usage.max_output_tokens,
628    }
629}
630
631/// Add to total session cost
632pub fn add_to_total_session_cost(
633    cost: f64,
634    input_tokens: u32,
635    output_tokens: u32,
636    cache_read_input_tokens: u32,
637    cache_creation_input_tokens: u32,
638    web_search_requests: u32,
639    model: &str,
640) -> f64 {
641    add_to_total_model_usage(
642        cost,
643        input_tokens,
644        output_tokens,
645        cache_read_input_tokens,
646        cache_creation_input_tokens,
647        web_search_requests,
648        model,
649    );
650
651    cost
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657
658    #[test]
659    fn test_model_costs_input() {
660        let costs = COST_TIER_3_15;
661        assert_eq!(costs.input_cost(1_000_000), 3.0);
662        assert_eq!(costs.input_cost(500_000), 1.5);
663    }
664
665    #[test]
666    fn test_model_costs_output() {
667        let costs = COST_TIER_3_15;
668        assert_eq!(costs.output_cost(1_000_000), 15.0);
669    }
670
671    #[test]
672    fn test_token_usage_total() {
673        let usage = TokenUsage {
674            input_tokens: 100,
675            output_tokens: 50,
676            prompt_cache_write_tokens: 25,
677            prompt_cache_read_tokens: 75,
678        };
679        assert_eq!(usage.total(), 250);
680    }
681
682    #[test]
683    fn test_model_cost_registry() {
684        let registry = ModelCostRegistry::new();
685
686        let costs = registry.get("claude-sonnet-4-6");
687        assert_eq!(costs.input_tokens, 3.0);
688
689        let costs = registry.get("claude-haiku-4-5");
690        assert_eq!(costs.input_tokens, 1.0);
691    }
692
693    #[test]
694    fn test_model_cost_registry_unknown() {
695        let registry = ModelCostRegistry::new();
696        let costs = registry.get("unknown-model");
697        assert_eq!(costs.input_tokens, COST_DEFAULT.input_tokens);
698    }
699
700    #[test]
701    fn test_calculate_cost() {
702        let usage = TokenUsage {
703            input_tokens: 1_000_000,
704            output_tokens: 500_000,
705            prompt_cache_write_tokens: 0,
706            prompt_cache_read_tokens: 0,
707        };
708
709        let cost = calculate_cost("claude-sonnet-4-6", &usage);
710        // $3 * 1 + $15 * 0.5 = $3 + $7.50 = $10.50
711        assert!((cost - 10.5).abs() < 0.01);
712    }
713
714    #[test]
715    fn test_format_cost() {
716        assert_eq!(format_cost(0.001), "$0.0010");
717        assert_eq!(format_cost(0.5), "$0.50");
718        assert_eq!(format_cost(1.5), "$1.5000");
719    }
720
721    #[test]
722    fn test_cost_summary() {
723        let usage = TokenUsage {
724            input_tokens: 1_000_000,
725            output_tokens: 500_000,
726            prompt_cache_write_tokens: 100_000,
727            prompt_cache_read_tokens: 200_000,
728        };
729
730        let summary = CostSummary::from_usage("claude-sonnet-4-6", &usage);
731
732        // Input: 1M * $3/M = $3
733        assert!((summary.input_cost - 3.0).abs() < 0.01);
734        // Output: 500K * $15/M = $7.50
735        assert!((summary.output_cost - 7.5).abs() < 0.01);
736        // Cache write: 100K * $3.75/M = $0.375
737        assert!((summary.cache_write_cost - 0.375).abs() < 0.01);
738        // Cache read: 200K * $0.3/M = $0.06
739        assert!((summary.cache_read_cost - 0.06).abs() < 0.01);
740    }
741}