use std::collections::HashMap;
use serde::Deserialize;
use thiserror::Error;
use solvela_protocol::{CostBreakdown, ModelInfo, PLATFORM_FEE_MULTIPLIER, PLATFORM_FEE_PERCENT};
#[derive(Debug, Error)]
pub enum ModelRegistryError {
#[error("model not found: {0}")]
NotFound(String),
#[error("failed to parse model config: {0}")]
ParseError(String),
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ModelEntry {
pub provider: String,
pub model_id: String,
pub display_name: String,
pub input_cost_per_million: f64,
pub output_cost_per_million: f64,
pub context_window: u32,
#[serde(default)]
pub supports_streaming: bool,
#[serde(default)]
pub supports_tools: bool,
#[serde(default)]
pub supports_vision: bool,
#[serde(default)]
pub reasoning: bool,
#[serde(default)]
pub supports_structured_output: bool,
#[serde(default)]
pub supports_batch: bool,
#[serde(default)]
pub max_output_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ModelsConfig {
pub models: HashMap<String, ModelEntry>,
}
#[derive(Debug)]
pub struct ModelRegistry {
models: HashMap<String, ModelInfo>,
}
impl ModelRegistry {
pub fn from_toml(toml_str: &str) -> Result<Self, ModelRegistryError> {
let config: ModelsConfig =
toml::from_str(toml_str).map_err(|e| ModelRegistryError::ParseError(e.to_string()))?;
for (key, entry) in &config.models {
if !entry.input_cost_per_million.is_finite() || entry.input_cost_per_million < 0.0 {
return Err(ModelRegistryError::ParseError(format!(
"model {key:?}: input_cost_per_million must be finite and non-negative, got {}",
entry.input_cost_per_million
)));
}
if !entry.output_cost_per_million.is_finite() || entry.output_cost_per_million < 0.0 {
return Err(ModelRegistryError::ParseError(format!(
"model {key:?}: output_cost_per_million must be finite and non-negative, got {}",
entry.output_cost_per_million
)));
}
}
let mut models: HashMap<String, ModelInfo> = HashMap::new();
for (key, entry) in config.models {
let id = format!("{}/{}", entry.provider, entry.model_id);
let info = ModelInfo {
id: id.clone(),
provider: entry.provider,
model_id: entry.model_id,
display_name: entry.display_name,
input_cost_per_million: entry.input_cost_per_million,
output_cost_per_million: entry.output_cost_per_million,
context_window: entry.context_window,
supports_streaming: entry.supports_streaming,
supports_tools: entry.supports_tools,
supports_vision: entry.supports_vision,
reasoning: entry.reasoning,
supports_structured_output: entry.supports_structured_output,
supports_batch: entry.supports_batch,
max_output_tokens: entry.max_output_tokens,
};
if let Some(existing) = models.get(&id) {
let pricing_matches =
(existing.input_cost_per_million - info.input_cost_per_million).abs()
< f64::EPSILON
&& (existing.output_cost_per_million - info.output_cost_per_million).abs()
< f64::EPSILON;
if !pricing_matches {
return Err(ModelRegistryError::ParseError(format!(
"duplicate canonical key {id:?} with conflicting pricing: \
entry {key:?} has input={}/output={} but a previous entry \
registered input={}/output={}",
info.input_cost_per_million,
info.output_cost_per_million,
existing.input_cost_per_million,
existing.output_cost_per_million
)));
}
}
models.insert(key, info.clone());
models.insert(id, info);
}
Ok(Self { models })
}
pub fn get(&self, model_id: &str) -> Option<&ModelInfo> {
self.models.get(model_id)
}
pub fn all(&self) -> Vec<&ModelInfo> {
let mut seen = std::collections::HashSet::new();
self.models
.values()
.filter(|m| seen.insert(&m.id))
.collect()
}
pub fn estimate_cost(
&self,
model_id: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<CostBreakdown, ModelRegistryError> {
let model = self
.get(model_id)
.ok_or_else(|| ModelRegistryError::NotFound(model_id.to_string()))?;
let input_cost = (input_tokens as f64 / 1_000_000.0) * model.input_cost_per_million;
let output_cost = (output_tokens as f64 / 1_000_000.0) * model.output_cost_per_million;
let provider_cost = input_cost + output_cost;
let total_with_fee = provider_cost * PLATFORM_FEE_MULTIPLIER;
let platform_fee = total_with_fee - provider_cost;
Ok(CostBreakdown {
provider_cost: format!("{provider_cost:.6}"),
platform_fee: format!("{platform_fee:.6}"),
total: format!("{total_with_fee:.6}"),
currency: "USDC".to_string(),
fee_percent: PLATFORM_FEE_PERCENT,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_TOML: &str = r#"
[models.openai-gpt-4o]
provider = "openai"
model_id = "gpt-4o"
display_name = "GPT-4o"
input_cost_per_million = 2.50
output_cost_per_million = 10.00
context_window = 128000
supports_streaming = true
supports_tools = true
supports_vision = true
[models.deepseek-chat]
provider = "deepseek"
model_id = "deepseek-chat"
display_name = "DeepSeek V3.2 Chat"
input_cost_per_million = 0.28
output_cost_per_million = 0.42
context_window = 128000
supports_streaming = true
"#;
#[test]
fn test_load_from_toml() {
let registry = ModelRegistry::from_toml(TEST_TOML).unwrap();
assert!(registry.get("openai/gpt-4o").is_some());
assert!(registry.get("openai-gpt-4o").is_some());
assert!(registry.get("deepseek/deepseek-chat").is_some());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_registry_stores_raw_provider_rates() {
let registry = ModelRegistry::from_toml(TEST_TOML).unwrap();
let model = registry.get("openai/gpt-4o").unwrap();
assert!(
(model.input_cost_per_million - 2.50).abs() < 0.001,
"got {}",
model.input_cost_per_million
);
assert!(
(model.output_cost_per_million - 10.00).abs() < 0.001,
"got {}",
model.output_cost_per_million
);
}
#[test]
fn test_estimate_cost_applies_exactly_one_5_percent_fee() {
let registry = ModelRegistry::from_toml(TEST_TOML).unwrap();
let cost = registry
.estimate_cost("openai/gpt-4o", 1_000_000, 0)
.unwrap();
let provider: f64 = cost.provider_cost.parse().unwrap();
let fee: f64 = cost.platform_fee.parse().unwrap();
let total: f64 = cost.total.parse().unwrap();
assert!(
(provider - 2.50).abs() < 1e-6,
"provider_cost should be 2.50, got {provider}"
);
assert!(
(fee - 0.125).abs() < 1e-6,
"platform_fee should be 0.125 (5% of 2.50), got {fee}"
);
assert!(
(total - 2.625).abs() < 1e-6,
"total should be 2.625 (2.50 * 1.05), got {total}"
);
}
#[test]
fn test_cost_estimate() {
let registry = ModelRegistry::from_toml(TEST_TOML).unwrap();
let cost = registry.estimate_cost("openai/gpt-4o", 1000, 500).unwrap();
assert_eq!(cost.currency, "USDC");
assert_eq!(cost.fee_percent, 5);
let total: f64 = cost.total.parse().unwrap();
assert!(total > 0.0);
}
#[test]
fn test_all_models() {
let registry = ModelRegistry::from_toml(TEST_TOML).unwrap();
let all = registry.all();
assert_eq!(all.len(), 2);
}
#[test]
fn from_toml_rejects_nan_negative_and_infinite_pricing() {
for (label, body) in [
(
"NaN input cost",
r#"
[models.bad]
provider = "test"
model_id = "bad"
display_name = "Bad"
input_cost_per_million = nan
output_cost_per_million = 1.0
context_window = 4096
"#,
),
(
"Infinity output cost",
r#"
[models.bad]
provider = "test"
model_id = "bad"
display_name = "Bad"
input_cost_per_million = 1.0
output_cost_per_million = inf
context_window = 4096
"#,
),
(
"negative input cost",
r#"
[models.bad]
provider = "test"
model_id = "bad"
display_name = "Bad"
input_cost_per_million = -0.50
output_cost_per_million = 1.0
context_window = 4096
"#,
),
] {
let err =
ModelRegistry::from_toml(body).expect_err(&format!("{label} must be rejected"));
match err {
ModelRegistryError::ParseError(msg) => {
assert!(
msg.contains("input_cost_per_million")
|| msg.contains("output_cost_per_million"),
"{label}: error must name the offending field, got: {msg}"
);
}
other => panic!("{label}: expected ParseError, got {other:?}"),
}
}
}
#[test]
fn from_toml_rejects_unknown_field_typo() {
let body = r#"
[models.typo]
provider = "test"
model_id = "typo"
display_name = "Typo"
input_cost_per_milion = 2.50
output_cost_per_million = 5.00
context_window = 4096
"#;
let err = ModelRegistry::from_toml(body)
.expect_err("unknown field must be rejected by deny_unknown_fields");
let msg = match err {
ModelRegistryError::ParseError(m) => m,
other => panic!("expected ParseError, got {other:?}"),
};
assert!(
msg.contains("unknown") || msg.contains("input_cost_per_milion"),
"error must surface the unknown field, got: {msg}"
);
}
#[test]
fn from_toml_rejects_canonical_collision_with_conflicting_pricing() {
let body = r#"
[models.first]
provider = "test"
model_id = "shared"
display_name = "First"
input_cost_per_million = 1.00
output_cost_per_million = 2.00
context_window = 4096
[models.second]
provider = "test"
model_id = "shared"
display_name = "Second"
input_cost_per_million = 9.99
output_cost_per_million = 2.00
context_window = 4096
"#;
let err = ModelRegistry::from_toml(body)
.expect_err("canonical-key collision with conflicting pricing must error");
let msg = match err {
ModelRegistryError::ParseError(m) => m,
other => panic!("expected ParseError, got {other:?}"),
};
assert!(
msg.contains("duplicate canonical key") && msg.contains("test/shared"),
"error must mention the canonical key, got: {msg}"
);
}
#[test]
fn from_toml_allows_canonical_collision_with_identical_pricing() {
let body = r#"
[models.first]
provider = "test"
model_id = "shared"
display_name = "First"
input_cost_per_million = 3.00
output_cost_per_million = 15.00
context_window = 4096
[models.second]
provider = "test"
model_id = "shared"
display_name = "Second"
input_cost_per_million = 3.00
output_cost_per_million = 15.00
context_window = 4096
"#;
let registry = ModelRegistry::from_toml(body)
.expect("equal-pricing duplicates should load successfully");
assert!(registry.get("first").is_some());
assert!(registry.get("second").is_some());
assert!(registry.get("test/shared").is_some());
}
}