use tracing::info;
use crate::error::ConversionError;
use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
type ProviderFactory = fn() -> Arc<dyn Provider>;
fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
REGISTRY.get_or_init(|| {
let mut m = HashMap::new();
m.insert("glm", glm_factory as ProviderFactory);
m.insert("kimi", kimi_factory as ProviderFactory);
m.insert("deepseek", deepseek_factory as ProviderFactory);
m.insert("minimax", minimax_factory as ProviderFactory);
m
})
}
pub fn registered_provider_names() -> Vec<&'static str> {
get_registry().keys().copied().collect()
}
fn glm_factory() -> Arc<dyn Provider> {
Arc::new(super::glm::GLMProvider::new())
}
fn kimi_factory() -> Arc<dyn Provider> {
Arc::new(super::kimi::KimiProvider::new())
}
fn deepseek_factory() -> Arc<dyn Provider> {
Arc::new(super::deepseek::DeepSeekProvider::new())
}
fn minimax_factory() -> Arc<dyn Provider> {
Arc::new(super::minimax::MiniMaxProvider::new())
}
pub trait Provider: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn display_name(&self) -> &str {
self.name()
}
fn normalize_model(&self, model: String) -> String {
model
}
fn chat_completions_path(&self) -> String {
"/chat/completions".to_string()
}
fn transform_request(&self, _request: &mut ChatRequest) {}
fn transform_response(&self, _response: &mut ChatResponse) {}
fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
}
pub fn create_provider(name: &str) -> Result<Arc<dyn Provider>, ConversionError> {
let name_lower = name.to_lowercase();
let normalized_name = match name_lower.as_str() {
"moonshot" => "kimi",
other => other,
};
if let Some(factory) = get_registry().get(normalized_name) {
return Ok(factory());
}
info!(
"[PROVIDER] Unknown provider '{}', falling back to DefaultProvider (OpenAI compatible)",
name
);
Ok(Arc::new(super::default::DefaultProvider::new(name)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_provider_known() {
let provider = create_provider("glm").unwrap();
assert_eq!(provider.name(), "glm");
assert_eq!(provider.display_name(), "glm");
let provider = create_provider("kimi").unwrap();
assert_eq!(provider.name(), "kimi");
assert_eq!(provider.display_name(), "kimi");
let provider = create_provider("deepseek").unwrap();
assert_eq!(provider.name(), "deepseek");
assert_eq!(provider.display_name(), "deepseek");
let provider = create_provider("minimax").unwrap();
assert_eq!(provider.name(), "minimax");
assert_eq!(provider.display_name(), "minimax");
}
#[test]
fn test_create_provider_unknown_fallback_to_default() {
let provider = create_provider("qwen").unwrap();
assert_eq!(provider.name(), "default");
let provider = create_provider("some-unknown-provider").unwrap();
assert_eq!(provider.name(), "default");
let provider = create_provider("abc").unwrap();
assert_eq!(provider.name(), "default");
}
#[test]
fn test_default_provider_display_name_preserves_backend_name() {
let provider = create_provider("qwen").unwrap();
assert_eq!(provider.name(), "default");
assert_eq!(provider.display_name(), "qwen");
let provider = create_provider("Yi-Lightning").unwrap();
assert_eq!(provider.name(), "default");
assert_eq!(provider.display_name(), "yi-lightning");
let provider = create_provider("some-unknown-provider").unwrap();
assert_eq!(provider.name(), "default");
assert_eq!(provider.display_name(), "some-unknown-provider");
}
#[test]
fn test_create_provider_alias() {
let provider = create_provider("moonshot").unwrap();
assert_eq!(provider.name(), "kimi");
assert_eq!(provider.display_name(), "kimi");
}
#[test]
fn test_registered_provider_names_excludes_default() {
let names = registered_provider_names();
assert!(!names.contains(&"default"), "default should not be in registered_provider_names");
assert!(names.contains(&"glm"));
assert!(names.contains(&"kimi"));
assert!(names.contains(&"deepseek"));
assert!(names.contains(&"minimax"));
}
}