use std::collections::HashMap;
use std::sync::LazyLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProviderKind {
Claude,
OpenAI,
Mistral,
Groq,
DeepSeek,
Gemini,
XAi,
Native,
}
impl ProviderKind {
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"claude" | "anthropic" => Some(Self::Claude),
"openai" | "gpt" => Some(Self::OpenAI),
"mistral" => Some(Self::Mistral),
"groq" => Some(Self::Groq),
"deepseek" => Some(Self::DeepSeek),
"gemini" | "google" => Some(Self::Gemini),
"xai" | "grok" | "x-ai" => Some(Self::XAi),
"native" => Some(Self::Native),
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Claude => "Claude",
Self::OpenAI => "OpenAI",
Self::Mistral => "Mistral",
Self::Groq => "Groq",
Self::DeepSeek => "DeepSeek",
Self::Gemini => "Gemini",
Self::XAi => "xAI",
Self::Native => "Native",
}
}
pub fn is_free(&self) -> bool {
matches!(self, Self::Native)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ModelPricing {
pub input_per_million: f64,
pub output_per_million: f64,
}
impl ModelPricing {
pub const fn new(input_per_million: f64, output_per_million: f64) -> Self {
Self {
input_per_million,
output_per_million,
}
}
pub fn calculate(&self, input_tokens: u64, output_tokens: u64) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_per_million;
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_per_million;
input_cost + output_cost
}
}
static CLAUDE_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("claude-opus-4-20250514", ModelPricing::new(15.0, 75.0));
m.insert("claude-opus-4", ModelPricing::new(15.0, 75.0));
m.insert("claude-sonnet-4-20250514", ModelPricing::new(3.0, 15.0));
m.insert("claude-sonnet-4-6", ModelPricing::new(3.0, 15.0));
m.insert("claude-sonnet-4", ModelPricing::new(3.0, 15.0));
m.insert("claude-3-5-sonnet-20241022", ModelPricing::new(3.0, 15.0));
m.insert("claude-3-5-sonnet-latest", ModelPricing::new(3.0, 15.0));
m.insert("claude-3-5-haiku-20241022", ModelPricing::new(0.8, 4.0));
m.insert("claude-3-5-haiku-latest", ModelPricing::new(0.8, 4.0));
m.insert("claude-3-opus-20240229", ModelPricing::new(15.0, 75.0));
m.insert("claude-3-opus-latest", ModelPricing::new(15.0, 75.0));
m.insert("claude-3-sonnet-20240229", ModelPricing::new(3.0, 15.0));
m.insert("claude-3-haiku-20240307", ModelPricing::new(0.25, 1.25));
m
});
static OPENAI_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("gpt-4o", ModelPricing::new(2.5, 10.0));
m.insert("gpt-4o-2024-11-20", ModelPricing::new(2.5, 10.0));
m.insert("gpt-4o-mini", ModelPricing::new(0.15, 0.6));
m.insert("gpt-4o-mini-2024-07-18", ModelPricing::new(0.15, 0.6));
m.insert("gpt-4-turbo", ModelPricing::new(10.0, 30.0));
m.insert("gpt-4-turbo-2024-04-09", ModelPricing::new(10.0, 30.0));
m.insert("gpt-4-turbo-preview", ModelPricing::new(10.0, 30.0));
m.insert("gpt-4", ModelPricing::new(30.0, 60.0));
m.insert("gpt-4-0613", ModelPricing::new(30.0, 60.0));
m.insert("gpt-3.5-turbo", ModelPricing::new(0.5, 1.5));
m.insert("gpt-3.5-turbo-0125", ModelPricing::new(0.5, 1.5));
m.insert("o1", ModelPricing::new(15.0, 60.0));
m.insert("o1-2024-12-17", ModelPricing::new(15.0, 60.0));
m.insert("o1-preview", ModelPricing::new(15.0, 60.0));
m.insert("o1-mini", ModelPricing::new(3.0, 12.0));
m.insert("o1-mini-2024-09-12", ModelPricing::new(3.0, 12.0));
m.insert("o3-mini", ModelPricing::new(1.1, 4.4));
m.insert("o3-mini-2025-01-31", ModelPricing::new(1.1, 4.4));
m
});
static MISTRAL_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("mistral-large-latest", ModelPricing::new(2.0, 6.0));
m.insert("mistral-large-2411", ModelPricing::new(2.0, 6.0));
m.insert("mistral-medium-latest", ModelPricing::new(2.7, 8.1));
m.insert("mistral-small-latest", ModelPricing::new(0.2, 0.6));
m.insert("mistral-small-2409", ModelPricing::new(0.2, 0.6));
m.insert("codestral-latest", ModelPricing::new(0.3, 0.9));
m.insert("codestral-2501", ModelPricing::new(0.3, 0.9));
m.insert("ministral-8b-latest", ModelPricing::new(0.1, 0.1));
m.insert("ministral-3b-latest", ModelPricing::new(0.04, 0.04));
m.insert("pixtral-large-latest", ModelPricing::new(2.0, 6.0));
m.insert("pixtral-12b-2409", ModelPricing::new(0.15, 0.15));
m
});
static GROQ_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("llama-3.3-70b-versatile", ModelPricing::new(0.59, 0.79));
m.insert("llama-3.3-70b-specdec", ModelPricing::new(0.59, 0.99));
m.insert("llama-3.1-70b-versatile", ModelPricing::new(0.59, 0.79));
m.insert("llama-3.1-8b-instant", ModelPricing::new(0.05, 0.08));
m.insert("llama3-70b-8192", ModelPricing::new(0.59, 0.79));
m.insert("llama3-8b-8192", ModelPricing::new(0.05, 0.08));
m.insert("mixtral-8x7b-32768", ModelPricing::new(0.24, 0.24));
m.insert("gemma2-9b-it", ModelPricing::new(0.20, 0.20));
m
});
static DEEPSEEK_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("deepseek-chat", ModelPricing::new(0.14, 0.28));
m.insert("deepseek-reasoner", ModelPricing::new(0.55, 2.19));
m.insert("deepseek-coder", ModelPricing::new(0.14, 0.28));
m
});
static GEMINI_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("gemini-2.0-flash", ModelPricing::new(0.1, 0.4));
m.insert("gemini-2.0-flash-exp", ModelPricing::new(0.0, 0.0)); m.insert("gemini-2.0-flash-thinking", ModelPricing::new(0.0, 0.0)); m.insert("gemini-1.5-pro", ModelPricing::new(1.25, 5.0));
m.insert("gemini-1.5-pro-latest", ModelPricing::new(1.25, 5.0));
m.insert("gemini-1.5-flash", ModelPricing::new(0.075, 0.3));
m.insert("gemini-1.5-flash-latest", ModelPricing::new(0.075, 0.3));
m.insert("gemini-1.5-flash-8b", ModelPricing::new(0.0375, 0.15));
m.insert("gemini-pro", ModelPricing::new(0.5, 1.5));
m
});
static XAI_PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("grok-3", ModelPricing::new(3.0, 15.0));
m.insert("grok-3-fast", ModelPricing::new(0.6, 4.0));
m.insert("grok-3-mini", ModelPricing::new(0.3, 0.5));
m.insert("grok-3-mini-fast", ModelPricing::new(0.1, 0.4));
m.insert("grok-2", ModelPricing::new(2.0, 10.0));
m
});
const DEFAULT_PRICING: ModelPricing = ModelPricing::new(5.0, 15.0);
const FREE_PRICING: ModelPricing = ModelPricing::new(0.0, 0.0);
pub fn get_model_pricing(provider: ProviderKind, model: &str) -> ModelPricing {
if provider.is_free() {
return FREE_PRICING;
}
let pricing = match provider {
ProviderKind::Claude => CLAUDE_PRICING.get(model),
ProviderKind::OpenAI => OPENAI_PRICING.get(model),
ProviderKind::Mistral => MISTRAL_PRICING.get(model),
ProviderKind::Groq => GROQ_PRICING.get(model),
ProviderKind::DeepSeek => DEEPSEEK_PRICING.get(model),
ProviderKind::Gemini => GEMINI_PRICING.get(model),
ProviderKind::XAi => XAI_PRICING.get(model),
ProviderKind::Native => return FREE_PRICING,
};
pricing.copied().unwrap_or(DEFAULT_PRICING)
}
pub fn calculate_cost(
provider: ProviderKind,
model: &str,
input_tokens: u64,
output_tokens: u64,
) -> f64 {
let pricing = get_model_pricing(provider, model);
pricing.calculate(input_tokens, output_tokens)
}
pub fn estimate_cost(
provider: ProviderKind,
model: &str,
estimated_input: u64,
estimated_output: u64,
) -> f64 {
calculate_cost(provider, model, estimated_input, estimated_output)
}
pub fn format_cost(cost: f64) -> String {
if cost < 0.0001 {
format!("${:.6}", cost)
} else if cost < 0.01 {
format!("${:.4}", cost)
} else if cost < 1.0 {
format!("${:.3}", cost)
} else {
format!("${:.2}", cost)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provider_from_str_claude() {
assert_eq!(ProviderKind::parse("claude"), Some(ProviderKind::Claude));
assert_eq!(ProviderKind::parse("anthropic"), Some(ProviderKind::Claude));
assert_eq!(ProviderKind::parse("CLAUDE"), Some(ProviderKind::Claude));
}
#[test]
fn provider_from_str_openai() {
assert_eq!(ProviderKind::parse("openai"), Some(ProviderKind::OpenAI));
assert_eq!(ProviderKind::parse("gpt"), Some(ProviderKind::OpenAI));
}
#[test]
fn provider_from_str_all() {
assert!(ProviderKind::parse("mistral").is_some());
assert!(ProviderKind::parse("groq").is_some());
assert!(ProviderKind::parse("deepseek").is_some());
assert!(ProviderKind::parse("gemini").is_some());
assert!(ProviderKind::parse("native").is_some());
}
#[test]
fn provider_from_str_unknown() {
assert_eq!(ProviderKind::parse("unknown"), None);
assert_eq!(ProviderKind::parse(""), None);
}
#[test]
fn provider_is_free() {
assert!(ProviderKind::Native.is_free());
assert!(!ProviderKind::Claude.is_free());
assert!(!ProviderKind::OpenAI.is_free());
}
#[test]
fn pricing_calculate_simple() {
let pricing = ModelPricing::new(10.0, 30.0); let cost = pricing.calculate(1_000_000, 1_000_000);
assert!((cost - 40.0).abs() < 0.0001);
}
#[test]
fn pricing_calculate_fractional() {
let pricing = ModelPricing::new(3.0, 15.0); let cost = pricing.calculate(1000, 500);
assert!((cost - 0.0105).abs() < 0.0001);
}
#[test]
fn pricing_calculate_zero_tokens() {
let pricing = ModelPricing::new(10.0, 30.0);
let cost = pricing.calculate(0, 0);
assert!((cost - 0.0).abs() < 0.0001);
}
#[test]
fn calculate_cost_claude_sonnet() {
let cost = calculate_cost(ProviderKind::Claude, "claude-sonnet-4-6", 10_000, 5_000);
assert!((cost - 0.105).abs() < 0.0001);
}
#[test]
fn calculate_cost_openai_gpt4o() {
let cost = calculate_cost(ProviderKind::OpenAI, "gpt-4o", 10_000, 5_000);
assert!((cost - 0.075).abs() < 0.0001);
}
#[test]
fn calculate_cost_native_free() {
let cost = calculate_cost(ProviderKind::Native, "llama3.2-q4", 1_000_000, 1_000_000);
assert!((cost - 0.0).abs() < 0.0001);
}
#[test]
fn calculate_cost_unknown_model() {
let cost = calculate_cost(ProviderKind::Claude, "unknown-model", 1_000_000, 1_000_000);
assert!((cost - 20.0).abs() < 0.0001);
}
#[test]
fn calculate_cost_groq() {
let cost = calculate_cost(
ProviderKind::Groq,
"llama-3.3-70b-versatile",
100_000,
50_000,
);
assert!((cost - 0.0985).abs() < 0.0001);
}
#[test]
fn calculate_cost_deepseek() {
let cost = calculate_cost(ProviderKind::DeepSeek, "deepseek-chat", 100_000, 50_000);
assert!((cost - 0.028).abs() < 0.0001);
}
#[test]
fn calculate_cost_gemini() {
let cost = calculate_cost(ProviderKind::Gemini, "gemini-2.0-flash", 100_000, 50_000);
assert!((cost - 0.03).abs() < 0.0001);
}
#[test]
fn format_cost_tiny() {
assert_eq!(format_cost(0.00001), "$0.000010");
}
#[test]
fn format_cost_small() {
assert_eq!(format_cost(0.005), "$0.0050");
}
#[test]
fn format_cost_medium() {
assert_eq!(format_cost(0.105), "$0.105");
}
#[test]
fn format_cost_large() {
assert_eq!(format_cost(5.50), "$5.50");
}
#[test]
fn claude_pricing_coverage() {
assert!(CLAUDE_PRICING.contains_key("claude-sonnet-4-6"));
assert!(CLAUDE_PRICING.contains_key("claude-opus-4"));
assert!(CLAUDE_PRICING.contains_key("claude-3-5-haiku-latest"));
}
#[test]
fn openai_pricing_coverage() {
assert!(OPENAI_PRICING.contains_key("gpt-4o"));
assert!(OPENAI_PRICING.contains_key("gpt-4o-mini"));
assert!(OPENAI_PRICING.contains_key("o1"));
}
#[test]
fn mistral_pricing_coverage() {
assert!(MISTRAL_PRICING.contains_key("mistral-large-latest"));
assert!(MISTRAL_PRICING.contains_key("mistral-small-latest"));
}
#[test]
fn all_providers_have_pricing() {
assert!(!CLAUDE_PRICING.is_empty());
assert!(!OPENAI_PRICING.is_empty());
assert!(!MISTRAL_PRICING.is_empty());
assert!(!GROQ_PRICING.is_empty());
assert!(!DEEPSEEK_PRICING.is_empty());
assert!(!GEMINI_PRICING.is_empty());
}
}