use crate::error::ConversionError;
use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
use std::collections::HashMap;
use std::sync::OnceLock;
type ProviderFactory = fn() -> Box<dyn Provider + Send + Sync>;
fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
REGISTRY.get_or_init(|| {
let mut m = HashMap::new();
m.insert("glm", glm_factory as ProviderFactory);
m.insert("kimi", kimi_factory as ProviderFactory);
m.insert("deepseek", deepseek_factory as ProviderFactory);
m.insert("minimax", minimax_factory as ProviderFactory);
m
})
}
pub fn registered_provider_names() -> Vec<&'static str> {
get_registry().keys().copied().collect()
}
fn glm_factory() -> Box<dyn Provider + Send + Sync> {
Box::new(super::glm::GLMProvider::new())
}
fn kimi_factory() -> Box<dyn Provider + Send + Sync> {
Box::new(super::kimi::KimiProvider::new())
}
fn deepseek_factory() -> Box<dyn Provider + Send + Sync> {
Box::new(super::deepseek::DeepSeekProvider::new())
}
fn minimax_factory() -> Box<dyn Provider + Send + Sync> {
Box::new(super::minimax::MiniMaxProvider::new())
}
pub trait Provider: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn normalize_model(&self, model: String) -> String {
model
}
fn chat_completions_path(&self) -> String {
"/v1/chat/completions".to_string()
}
fn transform_request(&self, _request: &mut ChatRequest) {}
fn transform_response(&self, _response: &mut ChatResponse) {}
fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
fn clone_box(&self) -> Box<dyn Provider + Send + Sync>;
fn as_any(&self) -> &dyn std::any::Any;
}
impl Clone for Box<dyn Provider + Send + Sync> {
fn clone(&self) -> Self {
self.as_ref().clone_box()
}
}
pub fn create_provider(name: &str) -> Result<Box<dyn Provider + Send + Sync>, ConversionError> {
let name_lower = name.to_lowercase();
let normalized_name = match name_lower.as_str() {
"moonshot" => "kimi",
other => other,
};
if let Some(factory) = get_registry().get(normalized_name) {
return Ok(factory());
}
let available = registered_provider_names();
Err(ConversionError::ProviderError(format!(
"Unknown provider: {}. Available: {:?}",
name, available
)))
}