use std::collections::HashMap;
use std::sync::OnceLock;
use crate::catalog::BuiltinModelEntry;
use crate::{Api, InputModality};
fn parse_api(s: &str) -> Api {
match s {
"anthropic-messages" => Api::AnthropicMessages,
"openai-completions" => Api::OpenAiCompletions,
"openai-responses" => Api::OpenAiResponses,
"google-generative-ai" => Api::GoogleGenerativeAi,
"google-vertex" => Api::GoogleVertex,
"mistral-conversations" => Api::MistralConversations,
"azure-openai-responses" => Api::AzureOpenAiResponses,
"bedrock-converse-stream" => Api::BedrockConverseStream,
_ => Api::OpenAiCompletions,
}
}
fn parse_input_modality(s: &str) -> InputModality {
match s {
"text" | "Text" => InputModality::Text,
"image" | "Image" => InputModality::Image,
_ => InputModality::Text,
}
}
impl From<&BuiltinModelEntry> for ModelEntry {
fn from(e: &BuiltinModelEntry) -> Self {
let id: &'static str = Box::leak(e.id.clone().into_boxed_str());
let name: &'static str = Box::leak(e.name.clone().into_boxed_str());
let provider: &'static str = Box::leak(e.provider.clone().into_boxed_str());
let input: &'static [InputModality] = Box::leak(
e.input
.iter()
.map(|s| parse_input_modality(s))
.collect::<Vec<_>>()
.into_boxed_slice(),
);
let (ci, co) = if is_openclaw_sourced(&e.provider) {
(
if e.cost_input == 0.0 {
UNVERIFIED_PRICE
} else {
e.cost_input
},
if e.cost_output == 0.0 {
UNVERIFIED_PRICE
} else {
e.cost_output
},
)
} else {
(e.cost_input, e.cost_output)
};
ModelEntry {
id,
name,
api: parse_api(&e.api),
provider,
reasoning: e.reasoning,
input,
cost_input: ci,
cost_output: co,
cost_cache_read: e.cost_cache_read,
cost_cache_write: e.cost_cache_write,
context_window: e.context_window,
max_tokens: e.max_tokens,
}
}
}
pub const UNVERIFIED_PRICE: f64 = -1.0;
fn is_openclaw_sourced(provider: &str) -> bool {
matches!(
provider,
"gmi"
| "kilocode"
| "moonshot"
| "nvidia"
| "ollama-cloud"
| "qianfan"
| "qwen-oauth"
| "stepfun"
| "byteplus"
| "chutes"
| "deepinfra"
| "stepfun-plan"
| "byteplus-plan"
)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelEntry {
pub id: &'static str,
pub name: &'static str,
pub api: Api,
pub provider: &'static str,
pub reasoning: bool,
pub input: &'static [InputModality],
pub cost_input: f64,
pub cost_output: f64,
pub cost_cache_read: f64,
pub cost_cache_write: f64,
pub context_window: u32,
pub max_tokens: u32,
}
impl ModelEntry {
pub fn supports_vision(&self) -> bool {
self.input.contains(&InputModality::Image)
}
pub fn supports_reasoning(&self) -> bool {
self.reasoning
}
pub fn calculate_cost(
&self,
input_tokens: u64,
output_tokens: u64,
cache_read: u64,
cache_write: u64,
) -> f64 {
let in_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_input.max(0.0);
let out_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_output.max(0.0);
let cr_cost = (cache_read as f64 / 1_000_000.0) * self.cost_cache_read.max(0.0);
let cw_cost = (cache_write as f64 / 1_000_000.0) * self.cost_cache_write.max(0.0);
in_cost + out_cost + cr_cost + cw_cost
}
pub fn pricing_verified(&self) -> bool {
self.cost_input >= 0.0 && self.cost_output >= 0.0
}
pub fn pricing_unverified(&self) -> bool {
self.cost_input < 0.0 || self.cost_output < 0.0
}
}
static ALL_PROVIDER_MODELS: OnceLock<Vec<(&'static str, &'static [ModelEntry])>> = OnceLock::new();
fn all_provider_models() -> &'static [(&'static str, &'static [ModelEntry])] {
ALL_PROVIDER_MODELS
.get_or_init(|| {
let catalog = crate::catalog::CatalogRoot::get();
use std::collections::BTreeMap;
let mut all_builtins: Vec<crate::catalog::BuiltinModelEntry> = Vec::new();
for (_file_pid, builtin_models) in catalog.models.iter() {
for bm in builtin_models.iter() {
all_builtins.push(bm.clone());
}
}
if let Some(overrides) = crate::catalog::load_overrides() {
let mut all_map: BTreeMap<String, Vec<crate::catalog::BuiltinModelEntry>> =
BTreeMap::new();
for bm in all_builtins.into_iter() {
all_map.entry(bm.provider.clone()).or_default().push(bm);
}
crate::catalog::apply_model_overrides(&mut all_map, &overrides.model);
all_builtins = all_map.into_values().flatten().collect();
}
let mut by_pid: BTreeMap<String, Vec<ModelEntry>> = BTreeMap::new();
for bm in all_builtins.iter() {
let entry = ModelEntry::from(bm);
by_pid
.entry(entry.provider.to_string())
.or_default()
.push(entry);
}
let mut out: Vec<(&'static str, &'static [ModelEntry])> =
Vec::with_capacity(by_pid.len());
for (pid, mut entries) in by_pid {
let pid_static: &'static str = Box::leak(pid.into_boxed_str());
entries.sort_by(|a, b| a.id.cmp(b.id));
let slice: &'static [ModelEntry] = Box::leak(entries.into_boxed_slice());
out.push((pid_static, slice));
}
out
})
.as_slice()
}
static MODEL_INDEX: OnceLock<HashMap<&'static str, &'static ModelEntry>> = OnceLock::new();
fn model_index() -> &'static HashMap<&'static str, &'static ModelEntry> {
MODEL_INDEX.get_or_init(|| {
let mut map = HashMap::with_capacity(model_count());
for (provider, models) in all_provider_models().iter() {
for model in models.iter() {
let key = format!("{}/{}", provider, model.id);
let key_static: &'static str = Box::leak(key.into_boxed_str());
map.insert(key_static, model);
}
}
map
})
}
static PROVIDER_INDEX: OnceLock<HashMap<&'static str, &'static [ModelEntry]>> = OnceLock::new();
fn provider_index() -> &'static HashMap<&'static str, &'static [ModelEntry]> {
PROVIDER_INDEX.get_or_init(|| {
let mut map = HashMap::with_capacity(all_provider_models().len());
for (provider, models) in all_provider_models().iter() {
map.insert(*provider, *models);
}
map
})
}
pub fn get_model_entry(provider: &str, id: &str) -> Option<&'static ModelEntry> {
let key = format!("{}/{}", provider, id);
model_index().get(key.as_str()).copied()
}
pub fn get_provider_models(provider: &str) -> &'static [ModelEntry] {
provider_index().get(provider).copied().unwrap_or(&[])
}
pub fn get_all_models() -> impl Iterator<Item = &'static ModelEntry> {
all_provider_models()
.iter()
.flat_map(|(_, models)| models.iter())
}
pub fn model_count() -> usize {
all_provider_models().iter().map(|(_, m)| m.len()).sum()
}
pub fn builtin_model_count_sentinel() -> usize {
get_all_models().filter(|m| m.pricing_unverified()).count()
}
pub fn get_providers() -> Vec<&'static str> {
all_provider_models()
.iter()
.map(|(name, _)| *name)
.collect()
}
pub fn search_models(pattern: &str) -> Vec<&'static ModelEntry> {
let lower = pattern.to_lowercase();
get_all_models()
.filter(|m| m.id.to_lowercase().contains(&lower) || m.name.to_lowercase().contains(&lower))
.collect()
}
pub fn get_reasoning_models() -> Vec<&'static ModelEntry> {
get_all_models().filter(|m| m.reasoning).collect()
}
pub fn get_vision_models() -> Vec<&'static ModelEntry> {
get_all_models().filter(|m| m.supports_vision()).collect()
}
pub fn get_cheapest_models(limit: usize) -> Vec<&'static ModelEntry> {
let mut all: Vec<_> = get_all_models().collect();
all.sort_by(|a, b| {
a.cost_input
.partial_cmp(&b.cost_input)
.unwrap_or(std::cmp::Ordering::Equal)
});
all.truncate(limit);
all
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_total_model_count() {
let count = model_count();
assert!(count >= 934, "Expected at least 934 models, got {}", count);
}
#[test]
fn test_get_anthropic_model() {
let m = get_model_entry("anthropic", "claude-3-5-sonnet-20240620");
assert!(m.is_some(), "Claude Sonnet 3.5 should exist");
let m = m.unwrap();
assert_eq!(m.provider, "anthropic");
assert!(m.context_window >= 200_000);
}
#[test]
fn test_get_openai_model() {
let m = get_model_entry("openai", "gpt-4o");
assert!(m.is_some(), "GPT-4o should exist");
let m = m.unwrap();
assert_eq!(m.provider, "openai");
}
#[test]
fn test_provider_models() {
let anthropic = get_provider_models("anthropic");
assert!(!anthropic.is_empty(), "Anthropic should have models");
assert!(anthropic.iter().all(|m| m.provider == "anthropic"));
let unknown = get_provider_models("nonexistent-provider");
assert!(unknown.is_empty());
}
#[test]
fn test_search_models() {
let results = search_models("claude");
assert!(!results.is_empty(), "Should find Claude models");
assert!(results
.iter()
.all(|m| m.name.to_lowercase().contains("claude")
|| m.id.to_lowercase().contains("claude")));
}
#[test]
fn test_all_providers() {
let providers = get_providers();
assert!(providers.contains(&"openai"), "Should have openai");
assert!(providers.contains(&"anthropic"), "Should have anthropic");
}
#[test]
fn test_reasoning_models() {
let reasoning = get_reasoning_models();
assert!(!reasoning.is_empty(), "Should have reasoning models");
assert!(reasoning.iter().all(|m| m.reasoning));
}
#[test]
fn test_vision_models() {
let vision = get_vision_models();
assert!(!vision.is_empty(), "Should have vision models");
assert!(vision.iter().all(|m| m.supports_vision()));
}
#[test]
fn test_cheapest_models() {
let cheapest = get_cheapest_models(5);
assert_eq!(cheapest.len(), 5.min(model_count()));
for i in 1..cheapest.len() {
assert!(cheapest[i].cost_input >= cheapest[i - 1].cost_input);
}
}
}