use rust_decimal::Decimal;
use rust_decimal_macros::dec;
pub fn model_cost(model_id: &str) -> Option<(Decimal, Decimal)> {
let id = model_id
.rsplit_once('/')
.map(|(_, name)| name)
.unwrap_or(model_id);
match id {
"gpt-4o" | "gpt-4o-2024-11-20" | "gpt-4o-2024-08-06" => {
Some((dec!(0.0000025), dec!(0.00001)))
}
"gpt-4o-mini" | "gpt-4o-mini-2024-07-18" => Some((dec!(0.00000015), dec!(0.0000006))),
"gpt-4-turbo" | "gpt-4-turbo-2024-04-09" => Some((dec!(0.00001), dec!(0.00003))),
"gpt-4" | "gpt-4-0613" => Some((dec!(0.00003), dec!(0.00006))),
"gpt-3.5-turbo" | "gpt-3.5-turbo-0125" => Some((dec!(0.0000005), dec!(0.0000015))),
"o1" | "o1-2024-12-17" => Some((dec!(0.000015), dec!(0.00006))),
"o1-mini" | "o1-mini-2024-09-12" => Some((dec!(0.000003), dec!(0.000012))),
"o3-mini" | "o3-mini-2025-01-31" => Some((dec!(0.0000011), dec!(0.0000044))),
"claude-3-5-sonnet-20241022" | "claude-3-5-sonnet-latest" | "claude-sonnet-4-20250514" => {
Some((dec!(0.000003), dec!(0.000015)))
}
"claude-3-5-haiku-20241022" | "claude-3-5-haiku-latest" => {
Some((dec!(0.0000008), dec!(0.000004)))
}
"claude-3-opus-20240229" | "claude-3-opus-latest" | "claude-opus-4-20250514" => {
Some((dec!(0.000015), dec!(0.000075)))
}
"claude-3-haiku-20240307" => Some((dec!(0.00000025), dec!(0.00000125))),
_ if is_local_model(id) => Some((Decimal::ZERO, Decimal::ZERO)),
_ => None,
}
}
pub fn default_cost() -> (Decimal, Decimal) {
(dec!(0.0000025), dec!(0.00001))
}
fn is_local_model(model_id: &str) -> bool {
let lower = model_id.to_lowercase();
lower.starts_with("llama")
|| lower.starts_with("mistral")
|| lower.starts_with("mixtral")
|| lower.starts_with("phi")
|| lower.starts_with("gemma")
|| lower.starts_with("qwen")
|| lower.starts_with("codellama")
|| lower.starts_with("deepseek")
|| lower.starts_with("starcoder")
|| lower.starts_with("vicuna")
|| lower.starts_with("yi")
|| lower.contains(":latest")
|| lower.contains(":instruct")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_known_model_costs() {
let (input, output) = model_cost("gpt-4o").unwrap();
assert!(input > Decimal::ZERO);
assert!(output > input);
}
#[test]
fn test_claude_costs() {
let (input, output) = model_cost("claude-3-5-sonnet-20241022").unwrap();
assert!(input > Decimal::ZERO);
assert!(output > input);
}
#[test]
fn test_local_model_free() {
let (input, output) = model_cost("llama3").unwrap();
assert_eq!(input, Decimal::ZERO);
assert_eq!(output, Decimal::ZERO);
}
#[test]
fn test_ollama_tagged_model_free() {
let (input, output) = model_cost("mistral:latest").unwrap();
assert_eq!(input, Decimal::ZERO);
assert_eq!(output, Decimal::ZERO);
}
#[test]
fn test_unknown_model_returns_none() {
assert!(model_cost("some-totally-unknown-model-xyz").is_none());
}
#[test]
fn test_default_cost_nonzero() {
let (input, output) = default_cost();
assert!(input > Decimal::ZERO);
assert!(output > Decimal::ZERO);
}
#[test]
fn test_provider_prefix_stripped() {
assert_eq!(model_cost("openai/gpt-4o"), model_cost("gpt-4o"));
}
}