use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DefaultModelConfig {
pub chat_model: &'static str,
pub embedding_model: Option<&'static str>,
pub image_model: Option<&'static str>,
pub rerank_model: Option<&'static str>,
}
impl DefaultModelConfig {
pub const fn new(chat_model: &'static str) -> Self {
Self {
chat_model,
embedding_model: None,
image_model: None,
rerank_model: None,
}
}
pub const fn with_embedding(mut self, model: &'static str) -> Self {
self.embedding_model = Some(model);
self
}
pub const fn with_image(mut self, model: &'static str) -> Self {
self.image_model = Some(model);
self
}
pub const fn with_rerank(mut self, model: &'static str) -> Self {
self.rerank_model = Some(model);
self
}
}
pub struct DefaultModelRegistry {
configs: HashMap<&'static str, DefaultModelConfig>,
}
impl DefaultModelRegistry {
pub fn new() -> Self {
let mut configs = HashMap::new();
configs.insert(
"siliconflow",
DefaultModelConfig::new(super::providers::models::siliconflow::DEEPSEEK_V3_1)
.with_embedding(super::providers::models::siliconflow::BGE_LARGE_ZH_V1_5)
.with_image(super::providers::models::siliconflow::STABLE_DIFFUSION_3_5_LARGE)
.with_rerank(super::providers::models::siliconflow::BGE_RERANKER_V2_M3),
);
configs.insert(
"deepseek",
DefaultModelConfig::new(super::providers::models::deepseek::CHAT)
.with_embedding("deepseek-embedding"), );
configs.insert(
"openrouter",
DefaultModelConfig::new(super::providers::models::openrouter::openai::GPT_4O)
.with_embedding("text-embedding-3-small"), );
configs.insert("groq", DefaultModelConfig::new("llama-3.3-70b-versatile"));
configs.insert("xai", DefaultModelConfig::new("grok-2-1212"));
Self { configs }
}
pub fn get_config(&self, provider_id: &str) -> Option<&DefaultModelConfig> {
self.configs.get(provider_id)
}
pub fn get_default_chat_model(&self, provider_id: &str) -> Option<&'static str> {
self.get_config(provider_id).map(|config| config.chat_model)
}
pub fn get_default_embedding_model(&self, provider_id: &str) -> Option<&'static str> {
self.get_config(provider_id)
.and_then(|config| config.embedding_model)
}
pub fn get_default_image_model(&self, provider_id: &str) -> Option<&'static str> {
self.get_config(provider_id)
.and_then(|config| config.image_model)
}
pub fn get_default_rerank_model(&self, provider_id: &str) -> Option<&'static str> {
self.get_config(provider_id)
.and_then(|config| config.rerank_model)
}
pub fn get_supported_providers(&self) -> Vec<&'static str> {
self.configs.keys().copied().collect()
}
}
impl Default for DefaultModelRegistry {
fn default() -> Self {
Self::new()
}
}
static DEFAULT_MODEL_REGISTRY: std::sync::LazyLock<DefaultModelRegistry> =
std::sync::LazyLock::new(DefaultModelRegistry::new);
pub fn get_registry() -> &'static DefaultModelRegistry {
&DEFAULT_MODEL_REGISTRY
}
pub fn get_default_chat_model(provider_id: &str) -> Option<&'static str> {
get_registry().get_default_chat_model(provider_id)
}
pub fn get_default_embedding_model(provider_id: &str) -> Option<&'static str> {
get_registry().get_default_embedding_model(provider_id)
}
pub fn get_default_image_model(provider_id: &str) -> Option<&'static str> {
get_registry().get_default_image_model(provider_id)
}
pub fn get_default_rerank_model(provider_id: &str) -> Option<&'static str> {
get_registry().get_default_rerank_model(provider_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_model_registry() {
let registry = DefaultModelRegistry::new();
assert_eq!(
registry.get_default_chat_model("siliconflow"),
Some("deepseek-ai/DeepSeek-V3.1")
);
assert_eq!(
registry.get_default_embedding_model("siliconflow"),
Some("BAAI/bge-large-zh-v1.5")
);
assert_eq!(
registry.get_default_image_model("siliconflow"),
Some("stabilityai/stable-diffusion-3-5-large")
);
assert_eq!(
registry.get_default_chat_model("deepseek"),
Some("deepseek-chat")
);
assert_eq!(
registry.get_default_chat_model("openrouter"),
Some("openai/gpt-4o")
);
assert_eq!(registry.get_default_chat_model("unknown"), None);
}
#[test]
fn test_global_registry() {
assert_eq!(
get_default_chat_model("siliconflow"),
Some("deepseek-ai/DeepSeek-V3.1")
);
assert_eq!(get_default_chat_model("deepseek"), Some("deepseek-chat"));
assert_eq!(get_default_chat_model("openrouter"), Some("openai/gpt-4o"));
assert_eq!(get_default_chat_model("unknown"), None);
}
#[test]
fn test_supported_providers() {
let registry = DefaultModelRegistry::new();
let providers = registry.get_supported_providers();
assert!(providers.contains(&"siliconflow"));
assert!(providers.contains(&"deepseek"));
assert!(providers.contains(&"openrouter"));
assert!(providers.contains(&"groq"));
assert!(providers.contains(&"xai"));
}
#[test]
fn test_model_config_builder() {
let config = DefaultModelConfig::new("test-chat")
.with_embedding("test-embedding")
.with_image("test-image")
.with_rerank("test-rerank");
assert_eq!(config.chat_model, "test-chat");
assert_eq!(config.embedding_model, Some("test-embedding"));
assert_eq!(config.image_model, Some("test-image"));
assert_eq!(config.rerank_model, Some("test-rerank"));
}
}