use llmtrace_core::LLMProvider;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use tracing::{info, warn};
pub use llmtrace_core::{CostEstimationConfig, ModelPricingConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FilePricingEntry {
pub input_per_million: f64,
pub output_per_million: f64,
}
pub type PricingFile = HashMap<String, FilePricingEntry>;
pub fn load_pricing_file(path: &str) -> Result<PricingFile, String> {
let contents = std::fs::read_to_string(path)
.map_err(|e| format!("Failed to read pricing file '{}': {}", path, e))?;
serde_yaml::from_str::<PricingFile>(&contents)
.map_err(|e| format!("Failed to parse pricing file '{}': {}", path, e))
}
#[derive(Debug, Clone, Copy)]
struct Pricing {
input_per_million: f64,
output_per_million: f64,
}
fn builtin_pricing() -> HashMap<&'static str, Pricing> {
let mut m = HashMap::new();
m.insert(
"gpt-4o-mini",
Pricing {
input_per_million: 0.15,
output_per_million: 0.60,
},
);
m.insert(
"gpt-4o",
Pricing {
input_per_million: 2.50,
output_per_million: 10.0,
},
);
m.insert(
"gpt-4",
Pricing {
input_per_million: 30.0,
output_per_million: 60.0,
},
);
m.insert(
"gpt-3.5-turbo",
Pricing {
input_per_million: 0.50,
output_per_million: 1.50,
},
);
m.insert(
"claude-3-5-sonnet",
Pricing {
input_per_million: 3.0,
output_per_million: 15.0,
},
);
m.insert(
"claude-3.5-sonnet",
Pricing {
input_per_million: 3.0,
output_per_million: 15.0,
},
);
m.insert(
"claude-3-5-haiku",
Pricing {
input_per_million: 0.80,
output_per_million: 4.0,
},
);
m.insert(
"claude-3.5-haiku",
Pricing {
input_per_million: 0.80,
output_per_million: 4.0,
},
);
m.insert(
"claude-3-opus",
Pricing {
input_per_million: 15.0,
output_per_million: 75.0,
},
);
m
}
pub struct CostEstimator {
enabled: bool,
builtin: HashMap<&'static str, Pricing>,
file_pricing: HashMap<String, Pricing>,
custom: HashMap<String, Pricing>,
pricing_file_path: Option<String>,
}
impl CostEstimator {
pub fn new(config: &CostEstimationConfig) -> Self {
let custom = config
.custom_models
.iter()
.map(|(name, mc)| {
(
name.to_lowercase(),
Pricing {
input_per_million: mc.input_per_million,
output_per_million: mc.output_per_million,
},
)
})
.collect();
let (file_pricing, _) = Self::load_file_pricing(config.pricing_file.as_deref());
Self {
enabled: config.enabled,
builtin: builtin_pricing(),
file_pricing,
custom,
pricing_file_path: config.pricing_file.clone(),
}
}
fn load_file_pricing(path: Option<&str>) -> (HashMap<String, Pricing>, Option<String>) {
let Some(path) = path else {
return (HashMap::new(), None);
};
if !Path::new(path).exists() {
warn!(
path = path,
"Pricing file not found — using built-in defaults"
);
return (HashMap::new(), Some(path.to_string()));
}
match load_pricing_file(path) {
Ok(entries) => {
let count = entries.len();
let map = entries
.into_iter()
.map(|(name, entry)| {
(
name.to_lowercase(),
Pricing {
input_per_million: entry.input_per_million,
output_per_million: entry.output_per_million,
},
)
})
.collect();
info!(
path = path,
models = count,
"Loaded pricing from external file"
);
(map, Some(path.to_string()))
}
Err(e) => {
warn!(
path = path,
error = %e,
"Failed to load pricing file — using built-in defaults"
);
(HashMap::new(), Some(path.to_string()))
}
}
}
pub fn reload_pricing_file(&mut self) -> bool {
let path = match &self.pricing_file_path {
Some(p) => p.clone(),
None => return false,
};
match load_pricing_file(&path) {
Ok(entries) => {
let count = entries.len();
self.file_pricing = entries
.into_iter()
.map(|(name, entry)| {
(
name.to_lowercase(),
Pricing {
input_per_million: entry.input_per_million,
output_per_million: entry.output_per_million,
},
)
})
.collect();
info!(
path = path,
models = count,
"Reloaded pricing from external file"
);
true
}
Err(e) => {
warn!(
path = path,
error = %e,
"Failed to reload pricing file — keeping existing pricing"
);
false
}
}
}
#[must_use]
pub fn estimate_cost(
&self,
provider: &LLMProvider,
model: &str,
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
) -> Option<f64> {
if !self.enabled {
return None;
}
if is_self_hosted(provider) {
return None;
}
let pricing = self.lookup_pricing(model)?;
let input_cost =
prompt_tokens.unwrap_or(0) as f64 * pricing.input_per_million / 1_000_000.0;
let output_cost =
completion_tokens.unwrap_or(0) as f64 * pricing.output_per_million / 1_000_000.0;
Some(input_cost + output_cost)
}
fn lookup_pricing(&self, model: &str) -> Option<Pricing> {
let lower = model.to_lowercase();
if let Some(p) = self.custom.get(&lower) {
return Some(*p);
}
if let Some(p) = self.file_pricing.get(&lower) {
return Some(*p);
}
if let Some(p) = Self::prefix_match_owned(&self.file_pricing, &lower) {
return Some(p);
}
if let Some(p) = self.builtin.get(lower.as_str()) {
return Some(*p);
}
let mut best: Option<(&str, Pricing)> = None;
for (&prefix, &pricing) in &self.builtin {
if lower.starts_with(prefix) {
match best {
Some((bp, _)) if prefix.len() <= bp.len() => {}
_ => best = Some((prefix, pricing)),
}
}
}
best.map(|(_, p)| p)
}
fn prefix_match_owned(map: &HashMap<String, Pricing>, lower: &str) -> Option<Pricing> {
let mut best: Option<(&str, Pricing)> = None;
for (prefix, &pricing) in map {
if lower.starts_with(prefix.as_str()) {
match best {
Some((bp, _)) if prefix.len() <= bp.len() => {}
_ => best = Some((prefix.as_str(), pricing)),
}
}
}
best.map(|(_, p)| p)
}
}
fn is_self_hosted(provider: &LLMProvider) -> bool {
matches!(
provider,
LLMProvider::VLLm | LLMProvider::SGLang | LLMProvider::TGI | LLMProvider::Ollama
)
}
impl From<&CostEstimationConfig> for CostEstimator {
fn from(config: &CostEstimationConfig) -> Self {
Self::new(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_estimator() -> CostEstimator {
CostEstimator::new(&CostEstimationConfig::default())
}
#[test]
fn test_gpt4_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 90.0).abs() < 1e-6);
}
#[test]
fn test_gpt4o_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 12.50).abs() < 1e-6);
}
#[test]
fn test_gpt4o_mini_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o-mini",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 0.75).abs() < 1e-6);
}
#[test]
fn test_gpt35_turbo_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-3.5-turbo",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 2.0).abs() < 1e-6);
}
#[test]
fn test_claude_35_sonnet_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::Anthropic,
"claude-3-5-sonnet-20241022",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 18.0).abs() < 1e-6);
}
#[test]
fn test_claude_35_sonnet_dot_notation() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::Anthropic,
"claude-3.5-sonnet-20241022",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 18.0).abs() < 1e-6);
}
#[test]
fn test_claude_35_haiku_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::Anthropic,
"claude-3-5-haiku-20241022",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 4.80).abs() < 1e-6);
}
#[test]
fn test_claude_3_opus_pricing() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::Anthropic,
"claude-3-opus-20240229",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 90.0).abs() < 1e-6);
}
#[test]
fn test_prefix_match_gpt4o_dated() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o-2024-08-06",
Some(1000),
Some(500),
)
.unwrap();
let expected = (1000.0 * 2.50 + 500.0 * 10.0) / 1_000_000.0;
assert!((cost - expected).abs() < 1e-10);
}
#[test]
fn test_prefix_match_prefers_longest() {
let est = default_estimator();
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o-mini-2024-07-18",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 0.75).abs() < 1e-6);
}
#[test]
fn test_unknown_model_returns_none() {
let est = default_estimator();
assert!(est
.estimate_cost(
&LLMProvider::OpenAI,
"some-unknown-model-v42",
Some(100),
Some(50)
)
.is_none());
}
#[test]
fn test_vllm_returns_none() {
let est = default_estimator();
assert!(est
.estimate_cost(&LLMProvider::VLLm, "gpt-4o", Some(100), Some(50))
.is_none());
}
#[test]
fn test_sglang_returns_none() {
let est = default_estimator();
assert!(est
.estimate_cost(&LLMProvider::SGLang, "gpt-4o", Some(100), Some(50))
.is_none());
}
#[test]
fn test_tgi_returns_none() {
let est = default_estimator();
assert!(est
.estimate_cost(&LLMProvider::TGI, "gpt-4o", Some(100), Some(50))
.is_none());
}
#[test]
fn test_ollama_returns_none() {
let est = default_estimator();
assert!(est
.estimate_cost(&LLMProvider::Ollama, "llama3", Some(100), Some(50))
.is_none());
}
#[test]
fn test_zero_tokens_returns_zero_cost() {
let est = default_estimator();
let cost = est
.estimate_cost(&LLMProvider::OpenAI, "gpt-4o", Some(0), Some(0))
.unwrap();
assert!((cost - 0.0).abs() < 1e-10);
}
#[test]
fn test_none_tokens_returns_zero_cost() {
let est = default_estimator();
let cost = est
.estimate_cost(&LLMProvider::OpenAI, "gpt-4o", None, None)
.unwrap();
assert!((cost - 0.0).abs() < 1e-10);
}
#[test]
fn test_partial_tokens_only_input() {
let est = default_estimator();
let cost = est
.estimate_cost(&LLMProvider::OpenAI, "gpt-4o", Some(1_000_000), None)
.unwrap();
assert!((cost - 2.50).abs() < 1e-6);
}
#[test]
fn test_partial_tokens_only_output() {
let est = default_estimator();
let cost = est
.estimate_cost(&LLMProvider::OpenAI, "gpt-4o", None, Some(1_000_000))
.unwrap();
assert!((cost - 10.0).abs() < 1e-6);
}
#[test]
fn test_custom_model_pricing() {
let config = CostEstimationConfig {
enabled: true,
pricing_file: None,
custom_models: {
let mut m = HashMap::new();
m.insert(
"my-custom-model".to_string(),
ModelPricingConfig {
input_per_million: 1.0,
output_per_million: 2.0,
},
);
m
},
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"my-custom-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 3.0).abs() < 1e-6);
}
#[test]
fn test_custom_pricing_overrides_builtin() {
let config = CostEstimationConfig {
enabled: true,
pricing_file: None,
custom_models: {
let mut m = HashMap::new();
m.insert(
"gpt-4o".to_string(),
ModelPricingConfig {
input_per_million: 99.0,
output_per_million: 99.0,
},
);
m
},
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 198.0).abs() < 1e-6);
}
#[test]
fn test_custom_pricing_case_insensitive() {
let config = CostEstimationConfig {
enabled: true,
pricing_file: None,
custom_models: {
let mut m = HashMap::new();
m.insert(
"My-Model".to_string(),
ModelPricingConfig {
input_per_million: 5.0,
output_per_million: 10.0,
},
);
m
},
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"my-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 15.0).abs() < 1e-6);
}
#[test]
fn test_disabled_returns_none() {
let config = CostEstimationConfig {
enabled: false,
pricing_file: None,
custom_models: HashMap::new(),
};
let est = CostEstimator::new(&config);
assert!(est
.estimate_cost(&LLMProvider::OpenAI, "gpt-4o", Some(100), Some(50))
.is_none());
}
#[test]
fn test_azure_openai_uses_pricing() {
let est = default_estimator();
let cost = est.estimate_cost(
&LLMProvider::AzureOpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
);
assert!(cost.is_some());
assert!((cost.unwrap() - 12.50).abs() < 1e-6);
}
#[test]
fn test_bedrock_with_claude_uses_pricing() {
let est = default_estimator();
let cost = est.estimate_cost(
&LLMProvider::Bedrock,
"claude-3-opus-20240229",
Some(1_000_000),
Some(1_000_000),
);
assert!(cost.is_some());
assert!((cost.unwrap() - 90.0).abs() < 1e-6);
}
#[test]
fn test_realistic_small_request() {
let est = default_estimator();
let cost = est
.estimate_cost(&LLMProvider::OpenAI, "gpt-4o-mini", Some(500), Some(200))
.unwrap();
let expected = (500.0 * 0.15 + 200.0 * 0.60) / 1_000_000.0;
assert!((cost - expected).abs() < 1e-10);
}
#[test]
fn test_load_pricing_file_valid_yaml() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
my-custom-llm:
input_per_million: 5.0
output_per_million: 10.0
another-model:
input_per_million: 1.0
output_per_million: 2.0
"#,
)
.unwrap();
let result = load_pricing_file(path.to_str().unwrap());
assert!(result.is_ok());
let map = result.unwrap();
assert_eq!(map.len(), 2);
assert!((map["my-custom-llm"].input_per_million - 5.0).abs() < 1e-6);
assert!((map["another-model"].output_per_million - 2.0).abs() < 1e-6);
}
#[test]
fn test_load_pricing_file_missing_file() {
let result = load_pricing_file("/nonexistent/pricing.yaml");
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to read"));
}
#[test]
fn test_load_pricing_file_invalid_yaml() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.yaml");
std::fs::write(&path, "this is not: [valid yaml: {").unwrap();
let result = load_pricing_file(path.to_str().unwrap());
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to parse"));
}
#[test]
fn test_estimator_with_pricing_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
file-model:
input_per_million: 7.0
output_per_million: 14.0
"#,
)
.unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: HashMap::new(),
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"file-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 21.0).abs() < 1e-6);
}
#[test]
fn test_file_pricing_prefix_match() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
file-model:
input_per_million: 7.0
output_per_million: 14.0
"#,
)
.unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: HashMap::new(),
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"file-model-v2",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 21.0).abs() < 1e-6);
}
#[test]
fn test_custom_overrides_file_pricing() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
gpt-4o:
input_per_million: 50.0
output_per_million: 50.0
"#,
)
.unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: {
let mut m = HashMap::new();
m.insert(
"gpt-4o".to_string(),
ModelPricingConfig {
input_per_million: 99.0,
output_per_million: 99.0,
},
);
m
},
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 198.0).abs() < 1e-6);
}
#[test]
fn test_file_pricing_overrides_builtin() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
gpt-4o:
input_per_million: 50.0
output_per_million: 50.0
"#,
)
.unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: HashMap::new(),
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 100.0).abs() < 1e-6);
}
#[test]
fn test_fallback_to_builtin_when_file_missing() {
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some("/nonexistent/pricing.yaml".to_string()),
custom_models: HashMap::new(),
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 12.50).abs() < 1e-6);
}
#[test]
fn test_fallback_to_builtin_when_file_invalid() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.yaml");
std::fs::write(&path, "not valid yaml: [[[").unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: HashMap::new(),
};
let est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"gpt-4o",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 12.50).abs() < 1e-6);
}
#[test]
fn test_reload_pricing_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
reload-model:
input_per_million: 1.0
output_per_million: 2.0
"#,
)
.unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: HashMap::new(),
};
let mut est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"reload-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 3.0).abs() < 1e-6);
std::fs::write(
&path,
r#"
reload-model:
input_per_million: 10.0
output_per_million: 20.0
"#,
)
.unwrap();
assert!(est.reload_pricing_file());
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"reload-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 30.0).abs() < 1e-6);
}
#[test]
fn test_reload_returns_false_when_no_file_configured() {
let mut est = default_estimator();
assert!(!est.reload_pricing_file());
}
#[test]
fn test_reload_keeps_existing_on_bad_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("pricing.yaml");
std::fs::write(
&path,
r#"
keep-model:
input_per_million: 5.0
output_per_million: 10.0
"#,
)
.unwrap();
let config = CostEstimationConfig {
enabled: true,
pricing_file: Some(path.to_str().unwrap().to_string()),
custom_models: HashMap::new(),
};
let mut est = CostEstimator::new(&config);
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"keep-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 15.0).abs() < 1e-6);
std::fs::write(&path, "not valid yaml: [[[").unwrap();
assert!(!est.reload_pricing_file());
let cost = est
.estimate_cost(
&LLMProvider::OpenAI,
"keep-model",
Some(1_000_000),
Some(1_000_000),
)
.unwrap();
assert!((cost - 15.0).abs() < 1e-6);
}
#[test]
fn test_load_real_config_pricing_yaml() {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/pricing.yaml");
if std::path::Path::new(path).exists() {
let result = load_pricing_file(path);
assert!(result.is_ok(), "config/pricing.yaml should be valid YAML");
let map = result.unwrap();
assert!(!map.is_empty(), "config/pricing.yaml should not be empty");
assert!(map.contains_key("gpt-4o"), "Should contain gpt-4o");
}
}
}