use std::collections::HashMap;
use std::sync::LazyLock;
use serde::Deserialize;
const PRICING_JSON: &str = include_str!("../schemas/pricing.json");
static PRICING: LazyLock<std::result::Result<PricingRegistry, String>> =
LazyLock::new(|| serde_json::from_str(PRICING_JSON).map_err(|e| e.to_string()));
fn pricing() -> Option<&'static PricingRegistry> {
PRICING.as_ref().ok()
}
#[derive(Debug, Deserialize)]
struct PricingRegistry {
models: HashMap<String, ModelPricing>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelPricing {
pub input_cost_per_token: f64,
pub output_cost_per_token: f64,
}
#[must_use]
pub fn completion_cost(model: &str, prompt_tokens: u64, completion_tokens: u64) -> Option<f64> {
let pricing = model_pricing(model)?;
Some(
(prompt_tokens as f64) * pricing.input_cost_per_token
+ (completion_tokens as f64) * pricing.output_cost_per_token,
)
}
#[must_use]
pub fn model_pricing(model: &str) -> Option<&'static ModelPricing> {
let models = &pricing()?.models;
if let Some(p) = models.get(model) {
return Some(p);
}
let mut candidate = model;
while let Some(pos) = candidate.rfind(['-', '.']) {
candidate = &candidate[..pos];
if let Some(p) = models.get(candidate) {
return Some(p);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn completion_cost_known_model_returns_expected_value() {
let cost = completion_cost("gpt-4", 100, 50).expect("gpt-4 must be in registry");
let expected = 100.0 * 0.00003 + 50.0 * 0.00006;
assert!((cost - expected).abs() < 1e-12, "expected {expected}, got {cost}");
}
#[test]
fn completion_cost_unknown_model_returns_none() {
assert!(
completion_cost("unknown-model-xyz", 100, 50).is_none(),
"unknown model should return None"
);
}
#[test]
fn completion_cost_gpt4o_matches_published_pricing() {
let cost = completion_cost("gpt-4o", 1_000, 500).expect("gpt-4o must be in registry");
let expected = 1_000.0 * 0.0000025 + 500.0 * 0.00001;
assert!((cost - expected).abs() < 1e-12, "expected {expected}, got {cost}");
}
#[test]
fn completion_cost_embedding_model_has_zero_output_cost() {
let cost =
completion_cost("text-embedding-3-small", 100, 0).expect("text-embedding-3-small must be in registry");
assert!(cost > 0.0, "input tokens must have a positive cost");
let pricing = model_pricing("text-embedding-3-small").unwrap();
assert_eq!(pricing.output_cost_per_token, 0.0, "embedding output cost must be zero");
}
#[test]
fn model_pricing_returns_none_for_unknown_model() {
assert!(model_pricing("does-not-exist").is_none());
}
#[test]
fn model_pricing_prefix_fallback_matches_shorter_name() {
let exact = model_pricing("gpt-4").expect("gpt-4 must be in registry");
let prefix = model_pricing("gpt-4-0613").expect("gpt-4-0613 should match gpt-4 via prefix");
assert!(
(exact.input_cost_per_token - prefix.input_cost_per_token).abs() < 1e-15,
"prefix match should return the same pricing as exact match"
);
}
#[test]
fn completion_cost_prefix_fallback() {
let cost = completion_cost("gpt-4-0613", 100, 50);
assert!(cost.is_some(), "gpt-4-0613 should resolve via prefix fallback to gpt-4");
}
#[test]
fn model_pricing_returns_correct_fields_for_known_model() {
let p = model_pricing("gpt-4o-mini").expect("gpt-4o-mini must be in registry");
assert!(
(p.input_cost_per_token - 0.00000015).abs() < 1e-12,
"unexpected input_cost_per_token: {}",
p.input_cost_per_token
);
assert!(
(p.output_cost_per_token - 0.0000006).abs() < 1e-12,
"unexpected output_cost_per_token: {}",
p.output_cost_per_token
);
}
#[test]
fn pricing_registry_embedded_json_is_valid() {
assert!(
PRICING.as_ref().is_ok(),
"embedded schemas/pricing.json failed to parse: {:?}",
PRICING.as_ref().err()
);
}
}