use serde::Deserialize;
use std::collections::HashMap;
use std::fs;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::sync::mpsc;
use std::time::{Duration, SystemTime};
use tokio::task;
const PRICING_URL: &str =
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
const CACHE_MAX_AGE_SECS: u64 = 86400;
#[derive(Debug, Deserialize, Clone)]
pub struct ModelPricing {
#[serde(default)]
pub input_cost_per_token: Option<f64>,
#[serde(default)]
pub output_cost_per_token: Option<f64>,
}
type RawPricingData = HashMap<String, serde_json::Value>;
fn parse_pricing_data(raw: RawPricingData) -> HashMap<String, ModelPricing> {
raw.into_iter()
.filter_map(|(key, value)| {
if key.starts_with("sample_") {
return None;
}
let input_cost = value.get("input_cost_per_token").and_then(|v| v.as_f64());
let output_cost = value.get("output_cost_per_token").and_then(|v| v.as_f64());
if input_cost.is_some() || output_cost.is_some() {
Some((
key,
ModelPricing {
input_cost_per_token: input_cost,
output_cost_per_token: output_cost,
},
))
} else {
None
}
})
.collect()
}
pub struct PricingCache {
receiver: Option<mpsc::Receiver<HashMap<String, ModelPricing>>>,
data: Option<HashMap<String, ModelPricing>>,
}
impl PricingCache {
pub fn new() -> Self {
let (tx, rx) = mpsc::channel();
task::spawn(async move {
if let Some(pricing) = load_or_fetch_pricing().await {
let _ = tx.send(pricing);
}
});
Self {
receiver: Some(rx),
data: None,
}
}
pub fn try_get(&mut self) -> Option<&HashMap<String, ModelPricing>> {
if self.data.is_some() {
return self.data.as_ref();
}
if let Some(ref rx) = self.receiver {
if let Ok(data) = rx.try_recv() {
self.data = Some(data);
self.receiver = None; return self.data.as_ref();
}
}
None
}
pub fn wait_get(&mut self, timeout: Duration) -> Option<&HashMap<String, ModelPricing>> {
if self.data.is_some() {
return self.data.as_ref();
}
if let Some(ref rx) = self.receiver {
if let Ok(data) = rx.recv_timeout(timeout) {
self.data = Some(data);
self.receiver = None;
return self.data.as_ref();
}
}
None
}
pub fn get_model_pricing(&mut self, provider: &str, model: &str) -> Option<ModelPricing> {
let data = self.try_get()?;
let keys_to_try = generate_model_keys(provider, model);
for key in keys_to_try {
if let Some(pricing) = data.get(&key) {
if pricing.input_cost_per_token.is_some() {
return Some(pricing.clone());
}
}
}
None
}
}
impl Default for PricingCache {
fn default() -> Self {
Self::new()
}
}
fn generate_model_keys(provider: &str, model: &str) -> Vec<String> {
let mut keys = Vec::new();
match provider {
"gemini" => {
keys.push(format!("gemini/{}", model));
keys.push(format!("google/{}", model));
keys.push(format!("vertex_ai/{}", model));
if model.ends_with("-preview") {
let base = model.trim_end_matches("-preview");
keys.push(format!("gemini/{}", base));
}
}
"claude" => {
keys.push(format!("claude/{}", model));
keys.push(format!("anthropic/{}", model));
keys.push(model.to_string());
}
"openai" => {
keys.push(model.to_string());
keys.push(format!("openai/{}", model));
}
_ => {
keys.push(model.to_string());
keys.push(format!("{}/{}", provider, model));
}
}
keys.push(model.to_string());
keys
}
fn cache_dir() -> Option<PathBuf> {
dirs::home_dir().map(|p| p.join(".cache").join("cmt"))
}
fn cache_file() -> Option<PathBuf> {
cache_dir().map(|p| p.join("model_pricing.json"))
}
fn is_cache_valid(path: &PathBuf) -> bool {
if let Ok(metadata) = fs::metadata(path) {
if let Ok(modified) = metadata.modified() {
if let Ok(age) = SystemTime::now().duration_since(modified) {
return age.as_secs() < CACHE_MAX_AGE_SECS;
}
}
}
false
}
async fn load_or_fetch_pricing() -> Option<HashMap<String, ModelPricing>> {
let cache_path = cache_file()?;
if is_cache_valid(&cache_path) {
if let Ok(mut file) = fs::File::open(&cache_path) {
let mut contents = String::new();
if file.read_to_string(&mut contents).is_ok() {
if let Ok(raw) = serde_json::from_str::<RawPricingData>(&contents) {
return Some(parse_pricing_data(raw));
}
}
}
}
fetch_and_cache_pricing(&cache_path).await
}
async fn fetch_and_cache_pricing(cache_path: &PathBuf) -> Option<HashMap<String, ModelPricing>> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.ok()?;
let response = match client.get(PRICING_URL).send().await {
Ok(r) => r,
Err(_e) => {
#[cfg(test)]
eprintln!("Fetch failed: {:?}", _e);
return None;
}
};
if !response.status().is_success() {
#[cfg(test)]
eprintln!("Bad status: {}", response.status());
return None;
}
let text = match response.text().await {
Ok(t) => t,
Err(_e) => {
#[cfg(test)]
eprintln!("Text read failed: {:?}", _e);
return None;
}
};
let raw: RawPricingData = match serde_json::from_str(&text) {
Ok(d) => d,
Err(_e) => {
#[cfg(test)]
eprintln!("JSON parse failed: {:?}", _e);
return None;
}
};
let data = parse_pricing_data(raw);
if let Some(dir) = cache_dir() {
let _ = fs::create_dir_all(&dir);
if let Ok(mut file) = fs::File::create(cache_path) {
let _ = file.write_all(text.as_bytes());
}
}
Some(data)
}
pub fn calculate_cost(
pricing: &ModelPricing,
input_tokens: u64,
output_tokens: u64,
) -> Option<f64> {
let input_cost = pricing.input_cost_per_token? * input_tokens as f64;
let output_cost = pricing.output_cost_per_token? * output_tokens as f64;
Some(input_cost + output_cost)
}
pub fn format_cost(cost: f64) -> String {
if cost < 0.0001 {
format!("${:.6}", cost)
} else if cost < 0.01 {
format!("${:.4}", cost)
} else {
format!("${:.2}", cost)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_model_keys() {
let keys = generate_model_keys("gemini", "gemini-3-flash-preview");
assert!(keys.contains(&"gemini/gemini-3-flash-preview".to_string()));
assert!(keys.contains(&"gemini/gemini-3-flash".to_string()));
let keys = generate_model_keys("openai", "gpt-5.2");
assert!(keys.contains(&"gpt-5.2".to_string()));
assert!(keys.contains(&"openai/gpt-5.2".to_string()));
let keys = generate_model_keys("claude", "claude-sonnet-4-5-20250929");
assert!(keys.contains(&"claude-sonnet-4-5-20250929".to_string()));
assert!(keys.contains(&"anthropic/claude-sonnet-4-5-20250929".to_string()));
}
#[test]
fn test_format_cost() {
assert_eq!(format_cost(0.000001), "$0.000001");
assert_eq!(format_cost(0.001), "$0.0010");
assert_eq!(format_cost(0.15), "$0.15");
}
#[test]
fn test_calculate_cost() {
let pricing = ModelPricing {
input_cost_per_token: Some(0.000001),
output_cost_per_token: Some(0.000002),
};
let cost = calculate_cost(&pricing, 1000, 500);
assert_eq!(cost, Some(0.002)); }
#[test]
fn test_cache_dir() {
let dir = super::cache_dir();
assert!(dir.is_some(), "Cache dir should be available");
println!("Cache dir: {:?}", dir);
}
#[test]
#[ignore] fn test_fetch_pricing() {
use std::time::Duration;
let mut cache = PricingCache::new();
let data = cache.wait_get(Duration::from_secs(15));
assert!(data.is_some(), "Should fetch pricing data");
let gemini_pricing = cache.get_model_pricing("gemini", "gemini-2.0-flash");
println!("Gemini pricing: {:?}", gemini_pricing);
}
}