pub mod cache;
pub mod embedded;
pub mod litellm;
use anyhow::Result;
pub use embedded::{ModelPricing, MODEL_PRICING};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::RwLock;
static DYNAMIC_PRICING: Lazy<RwLock<HashMap<String, ModelPricing>>> = Lazy::new(|| {
let mut pricing = embedded::MODEL_PRICING.clone();
if let Ok(Some(cached)) = cache::load_cached_pricing() {
tracing::info!(
"Merging {} cached prices with {} embedded prices",
cached.len(),
pricing.len()
);
pricing.extend(cached);
} else {
tracing::debug!("Using embedded pricing only (no cache available)");
}
RwLock::new(pricing)
});
pub fn get_model_pricing(model_id: &str) -> ModelPricing {
if let Ok(guard) = DYNAMIC_PRICING.read() {
if let Some(pricing) = guard.get(model_id) {
return pricing.clone();
}
}
embedded::get_model_pricing(model_id)
}
pub async fn update_pricing_from_litellm() -> Result<usize> {
tracing::info!("Updating pricing from LiteLLM");
let fetched = litellm::fetch_litellm_pricing().await?;
let count = fetched.len();
cache::save_pricing_cache(fetched.clone())?;
if let Ok(mut guard) = DYNAMIC_PRICING.write() {
let mut merged = embedded::MODEL_PRICING.clone();
merged.extend(fetched);
*guard = merged;
}
Ok(count)
}
pub fn clear_cache() -> Result<()> {
cache::clear_pricing_cache()
}
pub fn calculate_cost(
model: &str,
input: u64,
output: u64,
cache_create: u64,
cache_read: u64,
) -> f64 {
let pricing = get_model_pricing(model);
let input_cost = (input as f64 / 1_000_000.0) * pricing.input_price_per_million;
let output_cost = (output as f64 / 1_000_000.0) * pricing.output_price_per_million;
let cache_create_cost = (cache_create as f64 / 1_000_000.0)
* pricing.input_price_per_million
* pricing.cache_write_multiplier;
let cache_read_cost = (cache_read as f64 / 1_000_000.0)
* pricing.input_price_per_million
* pricing.cache_read_multiplier;
input_cost + output_cost + cache_create_cost + cache_read_cost
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_model_pricing_embedded() {
let pricing = get_model_pricing("claude-opus-4-5");
assert_eq!(pricing.input_price_per_million, 5.0);
assert_eq!(pricing.output_price_per_million, 25.0);
}
#[test]
fn test_get_model_pricing_unknown() {
let pricing = get_model_pricing("unknown-model");
assert!(pricing.input_price_per_million > 0.0);
}
#[test]
fn test_calculate_cost_opus_basic() {
let cost = calculate_cost("opus-4", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, 30.0);
}
#[test]
fn test_calculate_cost_sonnet_basic() {
let cost = calculate_cost("sonnet-4", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, 18.0);
}
#[test]
fn test_calculate_cost_with_cache() {
let cost = calculate_cost("opus-4", 1_000_000, 0, 1_000_000, 10_000_000);
assert_eq!(cost, 16.25);
}
}