use crate::{Api, Cost, InputModality, Model};
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
fn extract_model_name(id: &str) -> &str {
id.rsplit_once('/').map(|(_, name)| name).unwrap_or(id)
}
static MODELS: Lazy<HashMap<String, Model>> = Lazy::new(|| {
let mut map = HashMap::new();
add_openai_models(&mut map);
add_anthropic_models(&mut map);
add_google_models(&mut map);
add_deepseek_models(&mut map);
add_mistral_models(&mut map);
add_groq_models(&mut map);
add_cerebras_models(&mut map);
add_xai_models(&mut map);
add_openrouter_models(&mut map);
add_azure_models(&mut map);
add_zai_models(&mut map);
map
});
fn add_openai_models(map: &mut HashMap<String, Model>) {
let models = [
("openai/gpt-4o", "GPT-4o", true, 2.5, 10.0),
("openai/gpt-4o-mini", "GPT-4o Mini", true, 0.15, 0.60),
("openai/gpt-4-turbo", "GPT-4 Turbo", true, 10.0, 30.0),
("openai/gpt-4", "GPT-4", false, 30.0, 60.0),
("openai/gpt-3.5-turbo", "GPT-3.5 Turbo", false, 0.5, 1.5),
("openai/o1-preview", "OpenAI o1 Preview", true, 15.0, 60.0),
("openai/o1-mini", "OpenAI o1 Mini", true, 15.0, 60.0),
("openai/o1", "OpenAI o1", true, 15.0, 60.0),
("openai/o3", "OpenAI o3", true, 15.0, 60.0),
("openai/o3-mini", "OpenAI o3 Mini", true, 15.0, 60.0),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "openai".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
reasoning,
input: if reasoning {
vec![InputModality::Text]
} else {
vec![InputModality::Text, InputModality::Image]
},
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: input_cost * 0.5,
cache_write: input_cost * 7.5,
},
context_window: 128_000,
max_tokens: 32_000,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_anthropic_models(map: &mut HashMap<String, Model>) {
let models = [
(
"anthropic/claude-sonnet-4-20250514",
"Claude Sonnet 4",
true,
3.0,
15.0,
),
(
"anthropic/claude-opus-4-20250514",
"Claude Opus 4",
true,
15.0,
75.0,
),
(
"anthropic/claude-3-5-sonnet-20241022",
"Claude 3.5 Sonnet",
true,
3.0,
15.0,
),
(
"anthropic/claude-3-5-haiku-20241022",
"Claude 3.5 Haiku",
false,
0.8,
4.0,
),
(
"anthropic/claude-3-opus",
"Claude 3 Opus",
false,
15.0,
75.0,
),
(
"anthropic/claude-3-sonnet",
"Claude 3 Sonnet",
false,
3.0,
15.0,
),
(
"anthropic/claude-3-haiku",
"Claude 3 Haiku",
false,
0.25,
1.25,
),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::AnthropicMessages,
provider: "anthropic".to_string(),
base_url: "https://api.anthropic.com".to_string(),
reasoning,
input: vec![InputModality::Text, InputModality::Image],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: input_cost * 0.1,
cache_write: input_cost * 1.25,
},
context_window: 200_000,
max_tokens: 8192,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_google_models(map: &mut HashMap<String, Model>) {
let models = [
(
"google/gemini-2.0-flash",
"Gemini 2.0 Flash",
0.0,
0.0,
1_000_000,
),
(
"google/gemini-2.5-flash",
"Gemini 2.5 Flash",
0.0,
0.0,
1_000_000,
),
(
"google/gemini-2.5-pro",
"Gemini 2.5 Pro",
1.25,
5.0,
2_000_000,
),
(
"google/gemini-1.5-flash",
"Gemini 1.5 Flash",
0.0,
0.0,
1_000_000,
),
(
"google/gemini-1.5-pro",
"Gemini 1.5 Pro",
1.25,
5.0,
2_000_000,
),
("google/gemini-pro", "Gemini Pro", 0.125, 0.5, 32_000),
];
for (id, name, input_cost, output_cost, ctx) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::GoogleGenerativeAi,
provider: "google".to_string(),
base_url: "https://generativelanguage.googleapis.com".to_string(),
reasoning: false,
input: vec![InputModality::Text, InputModality::Image],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: ctx,
max_tokens: 8192,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_deepseek_models(map: &mut HashMap<String, Model>) {
let models = [
("deepseek/deepseek-chat", "DeepSeek Chat", false, 0.27, 1.1),
(
"deepseek/deepseek-chat-v3",
"DeepSeek Chat V3",
false,
0.27,
1.1,
),
(
"deepseek/deepseek-reasoner",
"DeepSeek Reasoner",
true,
0.55,
2.19,
),
(
"deepseek/deepseek-coder",
"DeepSeek Coder",
false,
0.27,
1.1,
),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "deepseek".to_string(),
base_url: "https://api.deepseek.com".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.1,
cache_write: 1.0,
},
context_window: 64_000,
max_tokens: 8192,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_mistral_models(map: &mut HashMap<String, Model>) {
let models = [
(
"mistral/mistral-large-latest",
"Mistral Large",
false,
2.0,
6.0,
),
(
"mistral/mistral-medium-latest",
"Mistral Medium",
false,
0.5,
1.5,
),
(
"mistral/mistral-small-latest",
"Mistral Small",
false,
0.2,
0.6,
),
("mistral/mistral-nemo", "Mistral Nemo", false, 0.15, 0.15),
("mistral/codestral", "Codestral", false, 0.3, 0.9),
(
"mistral/codestral-mamba",
"Codestral Mamba",
false,
0.25,
0.25,
),
(
"mistral/open-mixtral-8x22b",
"Mixtral 8x22B",
false,
0.45,
1.4,
),
(
"mistral/open-mixtral-8x7b",
"Mixtral 8x7B",
false,
0.24,
0.24,
),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "mistral".to_string(),
base_url: "https://api.mistral.ai".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 32_000,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_groq_models(map: &mut HashMap<String, Model>) {
let models = [
(
"groq/llama-3.3-70b-versatile",
"Llama 3.3 70B Versatile",
false,
0.0,
0.0,
),
(
"groq/llama-3.1-70b-versatile",
"Llama 3.1 70B Versatile",
false,
0.0,
0.0,
),
(
"groq/llama-3.1-8b-instant",
"Llama 3.1 8B Instant",
false,
0.0,
0.0,
),
(
"groq/llama-3-70b-versatile",
"Llama 3 70B Versatile",
false,
0.0,
0.0,
),
(
"groq/llama-3-8b-versatile",
"Llama 3 8B Versatile",
false,
0.0,
0.0,
),
("groq/mixtral-8x7b-32768", "Mixtral 8x7B", false, 0.0, 0.0),
("groq/gemma2-9b-it", "Gemma 2 9B", false, 0.0, 0.0),
("groq/gemma-7b-it", "Gemma 7B", false, 0.0, 0.0),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "groq".to_string(),
base_url: "https://api.groq.com/openai/v1".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 8192,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_cerebras_models(map: &mut HashMap<String, Model>) {
let models = [
("cerebras/llama-3.3-70b", "Llama 3.3 70B", false, 0.0, 0.0),
("cerebras/llama-3.1-8b", "Llama 3.1 8B", false, 0.0, 0.0),
("cerebras/qwen-2.5-32b", "Qwen 2.5 32B", false, 0.0, 0.0),
("cerebras/qwen-2.5-7b", "Qwen 2.5 7B", false, 0.0, 0.0),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "cerebras".to_string(),
base_url: "https://api.cerebras.ai".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 8192,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_xai_models(map: &mut HashMap<String, Model>) {
let models = [
("xai/grok-2", "Grok 2", false, 5.0, 15.0),
("xai/grok-2-mini", "Grok 2 Mini", false, 0.3, 0.5),
("xai/grok-1", "Grok 1", false, 5.0, 15.0),
("xai/grok-1.5", "Grok 1.5", false, 5.0, 15.0),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "xai".to_string(),
base_url: "https://api.x.ai/v1".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 131_072,
max_tokens: 8192,
headers: Default::default(),
compat: None,
},
);
}
}
fn add_openrouter_models(map: &mut HashMap<String, Model>) {
let models = [
(
"openrouter/anthropic/claude-3.5-sonnet",
"Claude 3.5 Sonnet",
false,
3.0,
15.0,
),
(
"openrouter/anthropic/claude-3-opus",
"Claude 3 Opus",
false,
15.0,
75.0,
),
(
"openrouter/google/gemini-pro-1.5",
"Gemini Pro 1.5",
false,
1.25,
5.0,
),
(
"openrouter/meta-llama/llama-3-70b",
"Llama 3 70B",
false,
0.65,
2.75,
),
(
"openrouter/meta-llama/llama-3-8b",
"Llama 3 8B",
false,
0.2,
0.2,
),
(
"openrouter/mistralai/mistral-large",
"Mistral Large",
false,
2.0,
6.0,
),
(
"openrouter/deepseek/deepseek-chat",
"DeepSeek Chat",
false,
0.27,
1.1,
),
("openrouter/qwen/qwen-2-72b", "Qwen 2 72B", false, 0.9, 0.9),
(
"openrouter/nousresearch/hermes-3-llama-3-70b",
"Hermes 3 70B",
false,
0.5,
1.5,
),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "openrouter".to_string(),
base_url: "https://openrouter.ai/api/v1".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 32_000,
headers: [
("HTTP-Referer".to_string(), "https://oxi-ai".to_string()),
("X-Title".to_string(), "oxi-ai".to_string()),
]
.into_iter()
.collect(),
compat: None,
},
);
}
}
fn add_azure_models(map: &mut HashMap<String, Model>) {
let models = [
("azure-openai/gpt-4o", "GPT-4o", false, 2.5, 10.0),
("azure-openai/gpt-4o-mini", "GPT-4o Mini", false, 0.15, 0.60),
("azure-openai/gpt-4-turbo", "GPT-4 Turbo", false, 10.0, 30.0),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::AzureOpenAiResponses,
provider: "azure-openai".to_string(),
base_url: "https://{your-resource-name}.openai.azure.com".to_string(),
reasoning,
input: vec![InputModality::Text, InputModality::Image],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 32_000,
headers: Default::default(),
compat: Some(crate::CompatSettings {
supports_store: false,
supports_developer_role: false,
supports_reasoning_effort: false,
supports_usage_in_streaming: false,
max_tokens_field: Some(crate::MaxTokensField::MaxCompletionTokens),
requires_tool_result_name: true,
requires_assistant_after_tool_result: false,
requires_thinking_as_text: false,
thinking_format: None,
}),
},
);
}
}
pub struct ModelRegistry;
impl ModelRegistry {
pub fn get(provider: &str, model_id: &str) -> Option<&'static Model> {
let key = format!("{}/{}", provider, model_id);
MODELS.get(&key)
}
pub fn get_by_provider(provider: &str) -> Vec<&'static Model> {
MODELS.values().filter(|m| m.provider == provider).collect()
}
pub fn all() -> Vec<&'static Model> {
MODELS.values().collect()
}
pub fn search(pattern: &str) -> Vec<&'static Model> {
let pattern_lower = pattern.to_lowercase();
MODELS
.values()
.filter(|m| {
m.id.to_lowercase().contains(&pattern_lower)
|| m.name.to_lowercase().contains(&pattern_lower)
})
.collect()
}
}
static DYNAMIC_MODELS: Lazy<RwLock<HashMap<String, Model>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
pub fn register_model(model: Model) {
let key = format!("{}/{}", model.provider, model.id);
DYNAMIC_MODELS.write().insert(key, model);
}
pub fn unregister_model(provider: &str, model_id: &str) {
let key = format!("{}/{}", provider, model_id);
DYNAMIC_MODELS.write().remove(&key);
}
pub fn lookup_model(provider: &str, model_id: &str) -> Option<Model> {
let key = format!("{}/{}", provider, model_id);
if let Some(m) = DYNAMIC_MODELS.read().get(&key) {
return Some(m.clone());
}
MODELS.get(&key).cloned()
}
pub fn get_model(provider: &str, model_id: &str) -> Option<&'static Model> {
ModelRegistry::get(provider, model_id)
}
pub fn get_providers() -> Vec<&'static str> {
let mut providers: Vec<&'static str> = MODELS.values().map(|m| m.provider.as_str()).collect();
providers.sort();
providers.dedup();
providers
}
pub fn get_models(provider: &str) -> Vec<&'static Model> {
ModelRegistry::get_by_provider(provider)
}
fn add_zai_models(map: &mut HashMap<String, Model>) {
let models = [
("zai/glm-4.7", "GLM-4.7", true, 0.0, 0.0),
("zai/glm-5-turbo", "GLM-5-Turbo", true, 0.0, 0.0),
("zai/glm-5.1", "GLM-5.1", true, 0.0, 0.0),
("zai/glm-5v-turbo", "GLM-5V-Turbo", true, 0.0, 0.0),
("zai/glm-4.5-air", "GLM-4.5-Air", true, 0.0, 0.0),
];
for (id, name, reasoning, input_cost, output_cost) in models {
map.insert(
id.to_string(),
Model {
id: extract_model_name(id).to_string(),
name: name.to_string(),
api: Api::OpenAiCompletions,
provider: "zai".to_string(),
base_url: "https://api.z.ai/api/coding/paas/v4".to_string(),
reasoning,
input: vec![InputModality::Text],
cost: Cost {
input: input_cost,
output: output_cost,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 200_000,
max_tokens: 131_072,
headers: Default::default(),
compat: None,
},
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_model() {
let model = get_model("openai", "gpt-4o");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "openai");
}
#[test]
fn test_get_providers() {
let providers = get_providers();
assert!(providers.contains(&"openai"));
assert!(providers.contains(&"anthropic"));
assert!(providers.contains(&"google"));
assert!(providers.contains(&"deepseek"));
assert!(providers.contains(&"mistral"));
assert!(providers.contains(&"groq"));
}
#[test]
fn test_deepseek_model() {
let model = get_model("deepseek", "deepseek-chat");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "deepseek");
assert_eq!(model.base_url, "https://api.deepseek.com");
}
#[test]
fn test_search_models() {
let results = ModelRegistry::search("gpt");
assert!(!results.is_empty());
assert!(results
.iter()
.all(|m| m.name.to_lowercase().contains("gpt")));
}
}