use crate::protocol::ArcProtocol;
use crate::types::Provider;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub struct ProtocolRegistry {
providers: RwLock<HashMap<String, ArcProtocol>>,
}
impl ProtocolRegistry {
pub fn new() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, provider: ArcProtocol) {
let provider_type = provider.provider_type();
let mut providers = self.providers.write();
providers.insert(provider_type.as_str().to_string(), provider);
}
pub fn get(&self, provider: &Provider) -> Option<ArcProtocol> {
let providers = self.providers.read();
providers.get(provider.as_str()).cloned()
}
pub fn get_by_name(&self, provider_name: &str) -> Option<ArcProtocol> {
let providers = self.providers.read();
providers.get(provider_name).cloned()
}
pub fn unregister(&self, provider: &Provider) {
let mut providers = self.providers.write();
providers.remove(provider.as_str());
}
pub fn clear(&self) {
let mut providers = self.providers.write();
providers.clear();
}
pub fn provider_types(&self) -> Vec<String> {
let providers = self.providers.read();
providers.keys().cloned().collect()
}
pub fn contains(&self, provider: &Provider) -> bool {
let providers = self.providers.read();
providers.contains_key(provider.as_str())
}
}
impl Default for ProtocolRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_REGISTRY: once_cell::sync::Lazy<ProtocolRegistry> =
once_cell::sync::Lazy::new(ProtocolRegistry::new);
pub fn global_registry() -> &'static ProtocolRegistry {
&GLOBAL_REGISTRY
}
pub fn register_provider(provider: ArcProtocol) {
GLOBAL_REGISTRY.register(provider);
}
pub fn get_provider(provider: &Provider) -> Option<ArcProtocol> {
if let Some(p) = GLOBAL_REGISTRY.get(provider) {
return Some(p);
}
if let Some(p) = create_default_provider(provider) {
GLOBAL_REGISTRY.register(p.clone());
Some(p)
} else {
None
}
}
fn create_default_provider(provider: &Provider) -> Option<ArcProtocol> {
match provider {
Provider::OpenAI => Some(Arc::new(super::openai::OpenAIProvider::new())),
Provider::OpenAICompatible => Some(Arc::new(
super::openai_compatible::OpenAICompatibleProvider::new(),
)),
Provider::OpenAIResponses => Some(Arc::new(
super::openai_responses::OpenAIResponsesProvider::new(),
)),
Provider::Anthropic => Some(Arc::new(super::anthropic::AnthropicProvider::new())),
Provider::Google => Some(Arc::new(super::google::GoogleProvider::new())),
Provider::Ollama => Some(Arc::new(super::ollama::OllamaProvider::new())),
Provider::XAI => Some(Arc::new(super::xai::XAIProvider::new())),
Provider::Groq => Some(Arc::new(super::groq::GroqProvider::new())),
Provider::OpenRouter => Some(Arc::new(super::openrouter::OpenRouterProvider::new())),
Provider::MiniMax | Provider::MiniMaxCN => {
Some(Arc::new(super::minimax::MiniMaxProvider::new()))
}
Provider::KimiCoding => Some(Arc::new(super::kimi_coding::KimiCodingProvider::new())),
Provider::ZAI => Some(Arc::new(super::zai::ZAIProvider::new())),
Provider::DeepSeek => Some(Arc::new(super::deepseek::DeepSeekProvider::new())),
Provider::Zenmux => Some(Arc::new(super::zenmux::ZenmuxProvider::new())),
Provider::OpenCodeGo => Some(Arc::new(super::opencode_go::OpenCodeGoProvider::new())),
_ => None,
}
}
pub fn get_registered_providers() -> Vec<String> {
GLOBAL_REGISTRY.provider_types()
}
pub fn clear_providers() {
GLOBAL_REGISTRY.clear();
}