use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use crate::ProviderId;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ModelPrice {
pub input_per_million: f64,
pub output_per_million: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cached_input_per_million: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_write_per_million: Option<f64>,
}
impl ModelPrice {
pub const fn new(input_per_million: f64, output_per_million: f64) -> Self {
Self {
input_per_million,
output_per_million,
cached_input_per_million: None,
cache_write_per_million: None,
}
}
pub const fn with_cached_input(mut self, rate_per_million: f64) -> Self {
self.cached_input_per_million = Some(rate_per_million);
self
}
pub const fn with_cache_write(mut self, rate_per_million: f64) -> Self {
self.cache_write_per_million = Some(rate_per_million);
self
}
pub fn cost_for(
&self,
input_tokens: u64,
output_tokens: u64,
cached_input_tokens: u64,
cache_write_tokens: u64,
) -> f64 {
const PER_MILLION: f64 = 1_000_000.0;
let input = input_tokens as f64 * self.input_per_million;
let output = output_tokens as f64 * self.output_per_million;
let cached_rate = self
.cached_input_per_million
.unwrap_or(self.input_per_million);
let cache_write_rate = self
.cache_write_per_million
.unwrap_or(self.input_per_million);
let cached = cached_input_tokens as f64 * cached_rate;
let cache_write = cache_write_tokens as f64 * cache_write_rate;
(input + output + cached + cache_write) / PER_MILLION
}
#[cfg(feature = "rig-hook")]
pub fn cost_for_usage(&self, usage: &rig_core::completion::Usage) -> f64 {
self.cost_for(
usage.input_tokens,
usage.output_tokens,
usage.cached_input_tokens,
usage.cache_creation_input_tokens,
)
}
}
#[derive(Debug, Clone, Deserialize)]
struct PricingEntry {
provider: ProviderId,
model: String,
input_per_million: f64,
output_per_million: f64,
#[serde(default)]
cached_input_per_million: Option<f64>,
#[serde(default)]
cache_write_per_million: Option<f64>,
}
#[derive(Debug, Clone, Default)]
pub struct PricingTable {
entries: BTreeMap<(ProviderId, String), ModelPrice>,
}
impl PricingTable {
pub fn new() -> Self {
Self::default()
}
pub fn with(
mut self,
provider: impl Into<ProviderId>,
model: impl Into<String>,
price: ModelPrice,
) -> Self {
self.entries.insert((provider.into(), model.into()), price);
self
}
pub fn lookup(&self, provider: impl Into<ProviderId>, model: &str) -> Option<&ModelPrice> {
self.entries.get(&(provider.into(), model.to_string()))
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&ProviderId, &str, &ModelPrice)> {
self.entries
.iter()
.map(|((provider, model), price)| (provider, model.as_str(), price))
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
let entries: Vec<PricingEntry> = serde_json::from_str(json)?;
Ok(Self::from_entries(entries))
}
fn from_entries(entries: Vec<PricingEntry>) -> Self {
let mut table = Self::new();
for entry in entries {
let price = ModelPrice {
input_per_million: entry.input_per_million,
output_per_million: entry.output_per_million,
cached_input_per_million: entry.cached_input_per_million,
cache_write_per_million: entry.cache_write_per_million,
};
table.entries.insert((entry.provider, entry.model), price);
}
table
}
pub fn builtin() -> Self {
match Self::from_json(BUILTIN_JSON) {
Ok(table) => table,
Err(err) => {
tracing::error!(
error = %err,
"rig-model-meta: bundled pricing.json failed to parse; \
returning empty table",
);
Self::new()
}
}
}
}
const BUILTIN_JSON: &str = include_str!("../data/pricing.json");
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing
)]
mod tests {
use super::*;
#[test]
fn cost_for_input_output_only() {
let price = ModelPrice::new(2.50, 10.00);
let cost = price.cost_for(1_000_000, 1_000_000, 0, 0);
assert!((cost - 12.50).abs() < 1e-9);
}
#[test]
fn cost_for_with_cached_input_uses_dedicated_rate() {
let price = ModelPrice::new(2.50, 10.00).with_cached_input(1.25);
let cost = price.cost_for(1_000_000, 0, 1_000_000, 0);
assert!((cost - 3.75).abs() < 1e-9);
}
#[test]
fn cost_for_with_cache_write_uses_dedicated_rate() {
let price = ModelPrice::new(3.00, 15.00)
.with_cached_input(0.30)
.with_cache_write(3.75);
let cost = price.cost_for(0, 0, 0, 1_000_000);
assert!((cost - 3.75).abs() < 1e-9);
}
#[test]
fn cost_falls_back_to_input_rate_when_cache_rates_absent() {
let price = ModelPrice::new(2.50, 10.00);
let cost = price.cost_for(0, 0, 1_000_000, 0);
assert!((cost - 2.50).abs() < 1e-9);
let cost = price.cost_for(0, 0, 0, 1_000_000);
assert!((cost - 2.50).abs() < 1e-9);
}
#[test]
fn pricing_table_with_then_lookup() {
let table = PricingTable::new().with("openai", "gpt-4o", ModelPrice::new(2.50, 10.00));
let price = table.lookup("openai", "gpt-4o").expect("inserted");
assert!((price.input_per_million - 2.50).abs() < 1e-9);
assert!(table.lookup("openai", "missing").is_none());
assert_eq!(table.len(), 1);
assert!(!table.is_empty());
}
#[test]
fn pricing_table_from_json_round_trip() {
let json = r#"[
{
"provider": "openai",
"model": "gpt-4o",
"input_per_million": 2.5,
"output_per_million": 10.0,
"cached_input_per_million": 1.25
}
]"#;
let table = PricingTable::from_json(json).expect("parses");
let price = table.lookup("openai", "gpt-4o").expect("present");
assert_eq!(price.cached_input_per_million, Some(1.25));
}
#[test]
fn builtin_catalog_seeds_known_models() {
let table = PricingTable::builtin();
assert!(
table.lookup("openai", "gpt-4o-mini").is_some(),
"seed must include gpt-4o-mini",
);
assert!(
table
.lookup("anthropic", "claude-3-5-sonnet-20241022")
.is_some(),
"seed must include claude-3-5-sonnet-20241022",
);
for (_, _, price) in table.iter() {
assert!(price.input_per_million > 0.0);
assert!(price.output_per_million > 0.0);
}
}
#[test]
fn pricing_table_iter_is_sorted() {
let table = PricingTable::builtin();
let keys: Vec<(String, String)> = table
.iter()
.map(|(p, m, _)| (p.as_str().to_string(), m.to_string()))
.collect();
let mut sorted = keys.clone();
sorted.sort();
assert_eq!(keys, sorted, "iter() must yield rows in sorted order");
}
}