use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingEntry {
pub prefix: String,
pub input_per_m: f64,
pub output_per_m: f64,
#[serde(default)]
pub cache_write_per_m: Option<f64>,
#[serde(default)]
pub cache_read_per_m: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProviderBlock {
#[serde(default)]
pub entries: Vec<PricingEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PricingConfig {
#[serde(default)]
pub providers: HashMap<String, ProviderBlock>,
}
impl PricingConfig {
pub fn calculate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
self.calculate_cost_with_cache(model, input_tokens, output_tokens, 0, 0)
}
pub fn calculate_cost_with_cache(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
cache_creation_tokens: u32,
cache_read_tokens: u32,
) -> f64 {
let m = model.to_lowercase();
for block in self.providers.values() {
for entry in &block.entries {
if m.contains(&entry.prefix.to_lowercase()) {
let input = (input_tokens as f64 / 1_000_000.0) * entry.input_per_m;
let output = (output_tokens as f64 / 1_000_000.0) * entry.output_per_m;
let cache_write_rate =
entry.cache_write_per_m.unwrap_or(entry.input_per_m * 1.25);
let cache_read_rate = entry.cache_read_per_m.unwrap_or(entry.input_per_m * 0.1);
let cache_write =
(cache_creation_tokens as f64 / 1_000_000.0) * cache_write_rate;
let cache_read = (cache_read_tokens as f64 / 1_000_000.0) * cache_read_rate;
return input + output + cache_write + cache_read;
}
}
}
0.0
}
pub fn estimate_cost(&self, model: &str, token_count: i64) -> Option<f64> {
let m = model.to_lowercase();
if let Some(cost) = self.try_match(&m, token_count) {
return Some(cost);
}
let prefixes = [
"claude-",
"gpt-",
"gemini-",
"deepseek-",
"llama-",
"qwen",
"kimi-",
"zhipu-",
"glm-",
];
for p in &prefixes {
let prefixed = format!("{}{}", p, m);
if let Some(cost) = self.try_match(&prefixed, token_count) {
return Some(cost);
}
}
None
}
fn try_match(&self, model_lower: &str, token_count: i64) -> Option<f64> {
for block in self.providers.values() {
for entry in &block.entries {
if model_lower.contains(&entry.prefix.to_lowercase()) {
let input = (token_count as f64 * 0.80 / 1_000_000.0) * entry.input_per_m;
let output = (token_count as f64 * 0.20 / 1_000_000.0) * entry.output_per_m;
return Some(input + output);
}
}
}
None
}
pub fn load() -> Result<Self, String> {
let path = crate::config::opencrabs_home().join("usage_pricing.toml");
let content = match std::fs::read_to_string(&path) {
Ok(c) => c,
Err(_) => {
let example = include_str!("../../usage_pricing.toml.example");
return Self::parse_content(example, "embedded example");
}
};
Self::parse_content(&content, &format!("{:?}", path))
}
fn parse_content(content: &str, source: &str) -> Result<Self, String> {
if let Ok(cfg) = toml::from_str::<PricingConfig>(content)
&& !cfg.providers.is_empty()
{
return Ok(cfg);
}
if let Ok(cfg) = Self::load_legacy(content)
&& !cfg.providers.is_empty()
{
tracing::warn!(
"usage_pricing.toml uses old format — please update it to the new schema. \
See usage_pricing.toml.example in the repo"
);
let new_content = Self::serialize_to_toml(&cfg);
if source != "embedded example" {
let path = crate::config::opencrabs_home().join("usage_pricing.toml");
let _ = std::fs::write(&path, new_content);
}
return Ok(cfg);
}
Err(format!(
"usage_pricing.toml from {} failed to parse with both schemas.\n\
Check the file syntax or re-copy from usage_pricing.toml.example",
source
))
}
fn load_legacy(content: &str) -> Result<Self, toml::de::Error> {
#[derive(serde::Deserialize)]
struct LegacyRoot {
usage: Option<LegacyUsage>,
}
#[derive(serde::Deserialize)]
struct LegacyUsage {
pricing: Option<toml::Value>,
}
let root: LegacyRoot = toml::from_str(content)?;
let pricing_val = root
.usage
.and_then(|u| u.pricing)
.unwrap_or(toml::Value::Table(toml::map::Map::new()));
let mut providers: HashMap<String, ProviderBlock> = HashMap::new();
if let toml::Value::Table(table) = pricing_val {
for (provider_name, entries_val) in table {
if let toml::Value::Array(arr) = entries_val {
let entries: Vec<PricingEntry> =
arr.into_iter().filter_map(|v| v.try_into().ok()).collect();
if !entries.is_empty() {
providers.insert(provider_name, ProviderBlock { entries });
}
}
}
}
Ok(PricingConfig { providers })
}
fn serialize_to_toml(cfg: &PricingConfig) -> String {
let mut out = String::from(
"# OpenCrabs Usage Pricing — auto-migrated to current schema.\n\
# Edit freely. Changes take effect immediately on next /usage open.\n\
# prefix is matched case-insensitively as a substring of the model name.\n\
# Costs are per 1 million tokens (USD).\n\n",
);
let mut providers: Vec<(&String, &ProviderBlock)> = cfg.providers.iter().collect();
providers.sort_by_key(|(k, _)| k.as_str());
for (name, block) in providers {
out.push_str(&format!("[providers.{}]\nentries = [\n", name));
for e in &block.entries {
out.push_str(&format!(
" {{ prefix = {:?}, input_per_m = {}, output_per_m = {} }},\n",
e.prefix, e.input_per_m, e.output_per_m
));
}
out.push_str("]\n\n");
}
out
}
pub fn seed_from_example() {
let path = crate::config::opencrabs_home().join("usage_pricing.toml");
if path.exists() {
return; }
let example_content = include_str!("../../usage_pricing.toml.example");
if let Err(e) = std::fs::write(&path, example_content) {
tracing::error!("Failed to seed usage_pricing.toml from example: {}", e);
} else {
tracing::info!("Seeded usage_pricing.toml from example");
}
}
}