pub mod optimizer;
use std::collections::HashMap;
use std::sync::LazyLock;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use crate::inference::TokenUsage;
use crate::provider::ProviderType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub input_per_million: f64,
pub output_per_million: f64,
#[serde(default)]
pub cached_input_per_million: Option<f64>,
}
impl ModelPricing {
const fn new(input: f64, output: f64) -> Self {
Self {
input_per_million: input,
output_per_million: output,
cached_input_per_million: None,
}
}
const fn zero() -> Self {
Self::new(0.0, 0.0)
}
}
static PRICING: LazyLock<HashMap<&'static str, ModelPricing>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("claude-opus-4", ModelPricing::new(15.0, 75.0));
m.insert("claude-sonnet-4", ModelPricing::new(3.0, 15.0));
m.insert("claude-3.5-haiku", ModelPricing::new(0.8, 4.0));
m.insert("claude-3-5-haiku", ModelPricing::new(0.8, 4.0));
m.insert("gpt-4o", ModelPricing::new(2.5, 10.0));
m.insert("gpt-4o-mini", ModelPricing::new(0.15, 0.6));
m.insert("o1", ModelPricing::new(15.0, 60.0));
m.insert("o3-mini", ModelPricing::new(1.1, 4.4));
m.insert("deepseek-chat", ModelPricing::new(0.27, 1.1));
m.insert("deepseek-coder", ModelPricing::new(0.14, 0.28));
m.insert("deepseek-reasoner", ModelPricing::new(0.55, 2.19));
m.insert("grok-3", ModelPricing::new(3.0, 15.0));
m.insert("grok-3-mini", ModelPricing::new(0.3, 0.5));
m.insert("llama-3.3-70b", ModelPricing::new(0.59, 0.79));
m.insert("llama-3.1-8b", ModelPricing::new(0.05, 0.08));
m.insert("gemini-2.0-flash", ModelPricing::new(0.1, 0.4));
m
});
fn fallback_pricing(provider: ProviderType) -> ModelPricing {
if provider.is_local() {
return ModelPricing::zero();
}
match provider {
ProviderType::Anthropic => ModelPricing::new(3.0, 15.0),
ProviderType::OpenAi => ModelPricing::new(2.5, 10.0),
ProviderType::DeepSeek => ModelPricing::new(0.27, 1.1),
ProviderType::Groq => ModelPricing::new(0.59, 0.79),
ProviderType::Grok => ModelPricing::new(3.0, 15.0),
ProviderType::Google => ModelPricing::new(0.1, 0.4),
ProviderType::Mistral => ModelPricing::new(2.0, 6.0),
ProviderType::OpenRouter => ModelPricing::new(2.5, 10.0),
_ => ModelPricing::zero(),
}
}
pub(crate) fn lookup_pricing(model: &str, provider: ProviderType) -> ModelPricing {
if let Some(p) = PRICING.get(model) {
return p.clone();
}
let mut best: Option<(&str, &ModelPricing)> = None;
for (key, pricing) in PRICING.iter() {
if model.starts_with(key) && best.is_none_or(|(k, _)| key.len() > k.len()) {
best = Some((key, pricing));
}
}
if let Some((_, p)) = best {
return p.clone();
}
fallback_pricing(provider)
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ProviderCostRecord {
pub provider: String,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_cost_usd: f64,
pub request_count: u64,
}
pub struct CostTracker {
records: DashMap<String, ProviderCostRecord>,
}
impl CostTracker {
pub fn new() -> Self {
Self {
records: DashMap::new(),
}
}
pub fn record(
&self,
provider: ProviderType,
base_url: &str,
model: &str,
usage: &TokenUsage,
) -> f64 {
let pricing = lookup_pricing(model, provider);
let cost = (usage.prompt_tokens as f64 * pricing.input_per_million / 1_000_000.0)
+ (usage.completion_tokens as f64 * pricing.output_per_million / 1_000_000.0);
let key = format!("{provider}:{base_url}");
let mut entry = self
.records
.entry(key)
.or_insert_with(|| ProviderCostRecord {
provider: provider.to_string(),
..Default::default()
});
entry.total_input_tokens += usage.prompt_tokens as u64;
entry.total_output_tokens += usage.completion_tokens as u64;
entry.total_cost_usd += cost;
entry.request_count += 1;
cost
}
pub fn all(&self) -> Vec<ProviderCostRecord> {
self.records.iter().map(|r| r.value().clone()).collect()
}
pub fn all_with_total(&self) -> (Vec<ProviderCostRecord>, f64) {
let mut records = Vec::new();
let mut total = 0.0;
for entry in self.records.iter() {
total += entry.value().total_cost_usd;
records.push(entry.value().clone());
}
(records, total)
}
pub fn total_cost(&self) -> f64 {
self.records.iter().map(|r| r.value().total_cost_usd).sum()
}
pub fn reset(&self) {
self.records.clear();
}
}
impl Default for CostTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pricing_lookup_exact_match() {
let p = lookup_pricing("gpt-4o", ProviderType::OpenAi);
assert!((p.input_per_million - 2.5).abs() < f64::EPSILON);
assert!((p.output_per_million - 10.0).abs() < f64::EPSILON);
}
#[test]
fn pricing_lookup_prefix_match() {
let p = lookup_pricing("claude-sonnet-4-20250514", ProviderType::Anthropic);
assert!((p.input_per_million - 3.0).abs() < f64::EPSILON);
assert!((p.output_per_million - 15.0).abs() < f64::EPSILON);
}
#[test]
fn pricing_fallback_to_provider() {
let p = lookup_pricing("some-unknown-model", ProviderType::Anthropic);
assert!((p.input_per_million - 3.0).abs() < f64::EPSILON);
}
#[test]
fn pricing_local_providers_are_free() {
for provider in [
ProviderType::Ollama,
ProviderType::LlamaCpp,
ProviderType::LmStudio,
ProviderType::LocalAi,
ProviderType::Synapse,
] {
let p = lookup_pricing("anything", provider);
assert!(
p.input_per_million == 0.0 && p.output_per_million == 0.0,
"local provider {provider} should be free"
);
}
}
#[test]
fn cost_calculation_math() {
let tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 1_000,
completion_tokens: 500,
total_tokens: 1_500,
};
let cost = tracker.record(
ProviderType::OpenAi,
"https://api.openai.com",
"gpt-4o",
&usage,
);
let expected = (1_000.0 * 2.5 / 1_000_000.0) + (500.0 * 10.0 / 1_000_000.0);
assert!(
(cost - expected).abs() < 1e-12,
"cost={cost}, expected={expected}"
);
}
#[test]
fn tracker_accumulates() {
let tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
};
tracker.record(
ProviderType::OpenAi,
"https://api.openai.com",
"gpt-4o",
&usage,
);
tracker.record(
ProviderType::OpenAi,
"https://api.openai.com",
"gpt-4o",
&usage,
);
let records = tracker.all();
assert_eq!(records.len(), 1);
let rec = &records[0];
assert_eq!(rec.request_count, 2);
assert_eq!(rec.total_input_tokens, 200);
assert_eq!(rec.total_output_tokens, 100);
assert!(rec.total_cost_usd > 0.0);
}
#[test]
fn tracker_separates_providers() {
let tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
};
tracker.record(
ProviderType::OpenAi,
"https://api.openai.com",
"gpt-4o",
&usage,
);
tracker.record(
ProviderType::Anthropic,
"https://api.anthropic.com",
"claude-sonnet-4",
&usage,
);
assert_eq!(tracker.all().len(), 2);
}
#[test]
fn tracker_total_cost() {
let tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 1_000_000,
completion_tokens: 0,
total_tokens: 1_000_000,
};
tracker.record(
ProviderType::OpenAi,
"https://api.openai.com",
"gpt-4o",
&usage,
);
assert!((tracker.total_cost() - 2.5).abs() < 1e-9);
}
#[test]
fn tracker_reset() {
let tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
};
tracker.record(
ProviderType::OpenAi,
"https://api.openai.com",
"gpt-4o",
&usage,
);
assert!(!tracker.all().is_empty());
tracker.reset();
assert!(tracker.all().is_empty());
assert!(tracker.total_cost() == 0.0);
}
}