use rust_decimal::Decimal;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PricingEntry {
#[serde(with = "rust_decimal::serde::str")]
pub input_per_1k_usd: Decimal,
#[serde(with = "rust_decimal::serde::str")]
pub output_per_1k_usd: Decimal,
}
#[derive(serde::Deserialize)]
struct PricingJsonRow {
provider: crate::budget::types::Provider,
model: crate::budget::types::Model,
#[serde(with = "rust_decimal::serde::str")]
input_per_1k_usd: Decimal,
#[serde(with = "rust_decimal::serde::str")]
output_per_1k_usd: Decimal,
}
#[derive(Debug, Clone)]
pub struct PricingTable {
entries: std::collections::HashMap<(crate::budget::types::Provider, crate::budget::types::Model), PricingEntry>,
}
impl PricingTable {
pub fn default_table() -> Self {
use crate::budget::types::{Model, Provider};
fn d(s: &str) -> Decimal {
s.parse().expect("embedded literal")
}
let rows: &[(Provider, Model, &str, &str)] = &[
(Provider::OpenAi, Model::Gpt4o, "0.005", "0.015"),
(Provider::OpenAi, Model::Gpt4, "0.03", "0.06"),
(Provider::OpenAi, Model::Gpt35Turbo, "0.0005", "0.0015"),
(Provider::Anthropic, Model::Claude3Opus, "0.015", "0.075"),
(Provider::Anthropic, Model::Claude3Sonnet, "0.003", "0.015"),
(Provider::Anthropic, Model::Claude3Haiku, "0.00025", "0.00125"),
(Provider::Cohere, Model::CommandRPlus, "0.003", "0.015"),
(Provider::Cohere, Model::CommandR, "0.0005", "0.0015"),
];
let entries = rows
.iter()
.map(|(prov, model, inp, out)| {
(
(*prov, *model),
PricingEntry {
input_per_1k_usd: d(inp),
output_per_1k_usd: d(out),
},
)
})
.collect();
Self { entries }
}
pub fn load_from_json_str(json: &str) -> Result<Self, PricingLoadError> {
let rows: Vec<PricingJsonRow> = serde_json::from_str(json).map_err(PricingLoadError::Json)?;
let mut table = Self::default_table();
for row in rows {
table.entries.insert(
(row.provider, row.model),
PricingEntry {
input_per_1k_usd: row.input_per_1k_usd,
output_per_1k_usd: row.output_per_1k_usd,
},
);
}
Ok(table)
}
pub fn load_from_file(path: &std::path::Path) -> Self {
match std::fs::read_to_string(path) {
Ok(json) => Self::load_from_json_str(&json).unwrap_or_else(|e| {
eprintln!("aa-gateway: pricing.json parse error ({e}); using defaults");
Self::default_table()
}),
Err(_) => Self::default_table(),
}
}
pub fn cost_usd(
&self,
provider: crate::budget::types::Provider,
model: crate::budget::types::Model,
input_tokens: u64,
output_tokens: u64,
) -> Decimal {
match self.entries.get(&(provider, model)) {
Some(entry) => {
let input_cost = entry.input_per_1k_usd * Decimal::from(input_tokens) / Decimal::from(1_000u64);
let output_cost = entry.output_per_1k_usd * Decimal::from(output_tokens) / Decimal::from(1_000u64);
input_cost + output_cost
}
None => Decimal::ZERO,
}
}
pub fn entry(
&self,
provider: crate::budget::types::Provider,
model: crate::budget::types::Model,
) -> Option<&PricingEntry> {
self.entries.get(&(provider, model))
}
}
#[derive(Debug)]
pub enum PricingLoadError {
Json(serde_json::Error),
}
impl std::fmt::Display for PricingLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PricingLoadError::Json(e) => write!(f, "pricing JSON error: {e}"),
}
}
}
impl std::error::Error for PricingLoadError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cost_usd_gpt4o_input_only() {
use crate::budget::types::{Model, Provider};
fn d(s: &str) -> rust_decimal::Decimal {
s.parse().unwrap()
}
let table = PricingTable::default_table();
assert_eq!(table.cost_usd(Provider::OpenAi, Model::Gpt4o, 1_000, 0), d("0.005"));
}
#[test]
fn cost_usd_gpt4o_mixed_tokens() {
use crate::budget::types::{Model, Provider};
fn d(s: &str) -> rust_decimal::Decimal {
s.parse().unwrap()
}
let table = PricingTable::default_table();
assert_eq!(
table.cost_usd(Provider::OpenAi, Model::Gpt4o, 100_000, 20_000),
d("0.80")
);
}
#[test]
fn cost_usd_unknown_pair_returns_zero() {
use crate::budget::types::{Model, Provider};
let table = PricingTable::default_table();
assert_eq!(
table.cost_usd(Provider::Anthropic, Model::CommandR, 1_000, 1_000),
rust_decimal::Decimal::ZERO,
);
}
#[test]
fn load_from_file_falls_back_to_defaults_on_missing_file() {
let path = std::path::Path::new("/nonexistent/path/pricing.json");
let table = PricingTable::load_from_file(path);
use crate::budget::types::{Model, Provider};
assert!(table.entry(Provider::OpenAi, Model::Gpt4o).is_some());
}
#[test]
fn load_from_json_str_overrides_gpt4o_input_price() {
use crate::budget::types::{Model, Provider};
fn d(s: &str) -> rust_decimal::Decimal {
s.parse().unwrap()
}
let json = r#"[
{ "provider": "open_ai", "model": "gpt4o",
"input_per_1k_usd": "0.999", "output_per_1k_usd": "0.015" }
]"#;
let table = PricingTable::load_from_json_str(json).unwrap();
let entry = table.entry(Provider::OpenAi, Model::Gpt4o).unwrap();
assert_eq!(entry.input_per_1k_usd, d("0.999"));
assert!(table.entry(Provider::Anthropic, Model::Claude3Opus).is_some());
}
#[test]
fn default_table_contains_all_eight_models() {
use crate::budget::types::{Model, Provider};
let table = PricingTable::default_table();
for (prov, model) in [
(Provider::OpenAi, Model::Gpt4o),
(Provider::OpenAi, Model::Gpt4),
(Provider::OpenAi, Model::Gpt35Turbo),
(Provider::Anthropic, Model::Claude3Opus),
(Provider::Anthropic, Model::Claude3Sonnet),
(Provider::Anthropic, Model::Claude3Haiku),
(Provider::Cohere, Model::CommandRPlus),
(Provider::Cohere, Model::CommandR),
] {
assert!(table.entry(prov, model).is_some(), "{prov:?}/{model:?} missing");
}
}
#[test]
fn default_table_gpt4o_has_correct_rates() {
use crate::budget::types::{Model, Provider};
fn d(s: &str) -> rust_decimal::Decimal {
s.parse().unwrap()
}
let table = PricingTable::default_table();
let entry = table.entry(Provider::OpenAi, Model::Gpt4o).unwrap();
assert_eq!(entry.input_per_1k_usd, d("0.005"));
assert_eq!(entry.output_per_1k_usd, d("0.015"));
}
#[test]
fn pricing_load_error_displays_message() {
let raw = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
let err = PricingLoadError::Json(raw);
assert!(err.to_string().contains("pricing JSON error"));
}
#[test]
fn pricing_entry_stores_rates() {
fn d(s: &str) -> rust_decimal::Decimal {
s.parse().unwrap()
}
let entry = PricingEntry {
input_per_1k_usd: d("0.005"),
output_per_1k_usd: d("0.015"),
};
assert_eq!(entry.input_per_1k_usd, d("0.005"));
assert_eq!(entry.output_per_1k_usd, d("0.015"));
}
}