Skip to main content

codex_convert_proxy/providers/
trait_.rs

1//! Provider trait definition.
2
3use crate::error::ConversionError;
4use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock};
7
8// ============================================================================
9// Provider Factory Registry
10// ============================================================================
11
12/// Factory function type for creating providers (type-erased function pointer).
13type ProviderFactory = fn() -> Arc<dyn Provider>;
14
15/// Static registry of provider factories.
16fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
17    static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
18    REGISTRY.get_or_init(|| {
19        let mut m = HashMap::new();
20        m.insert("glm", glm_factory as ProviderFactory);
21        m.insert("kimi", kimi_factory as ProviderFactory);
22        m.insert("deepseek", deepseek_factory as ProviderFactory);
23        m.insert("minimax", minimax_factory as ProviderFactory);
24        m
25    })
26}
27
28/// Get all registered provider names.
29pub fn registered_provider_names() -> Vec<&'static str> {
30    get_registry().keys().copied().collect()
31}
32
33// Factory functions (must be in separate functions to get unique addresses)
34fn glm_factory() -> Arc<dyn Provider> {
35    Arc::new(super::glm::GLMProvider::new())
36}
37fn kimi_factory() -> Arc<dyn Provider> {
38    Arc::new(super::kimi::KimiProvider::new())
39}
40fn deepseek_factory() -> Arc<dyn Provider> {
41    Arc::new(super::deepseek::DeepSeekProvider::new())
42}
43fn minimax_factory() -> Arc<dyn Provider> {
44    Arc::new(super::minimax::MiniMaxProvider::new())
45}
46
47// ============================================================================
48// Provider Trait
49// ============================================================================
50
51/// Provider trait for LLM provider-specific transformations.
52///
53/// Each Chinese LLM provider may have slightly different API requirements
54/// or model name formats that need to be normalized.
55///
56/// Implementations are expected to be **stateless** so a single instance can
57/// be shared across all requests via `Arc<dyn Provider>`.
58pub trait Provider: Send + Sync + 'static {
59    /// Get provider name.
60    fn name(&self) -> &'static str;
61
62    /// Normalize model name from Responses API to provider's format.
63    fn normalize_model(&self, model: String) -> String {
64        model
65    }
66
67    /// Get the chat completions path for this provider.
68    /// Only returns the endpoint path, e.g., "/chat/completions".
69    /// The version prefix (e.g., "/v1") should come from the backend URL's base_path.
70    fn chat_completions_path(&self) -> String {
71        "/v1/chat/completions".to_string()
72    }
73
74    /// Transform request before sending to provider.
75    ///
76    /// This is called after the standard conversion but before sending
77    /// to the upstream provider. Providers can modify the request to
78    /// handle API differences.
79    fn transform_request(&self, _request: &mut ChatRequest) {}
80
81    /// Transform response after receiving from provider.
82    ///
83    /// This is called after receiving the response but before converting
84    /// to Responses API format. Providers can normalize response format.
85    fn transform_response(&self, _response: &mut ChatResponse) {}
86
87    /// Transform streaming chunk in real-time.
88    ///
89    /// This is called for each SSE chunk received from the provider.
90    /// Providers can modify chunk content before event conversion.
91    fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
92}
93
94// ============================================================================
95// Factory Function
96// ============================================================================
97
98/// Create a provider by name using the static registry.
99///
100/// Supports both exact names and aliases (e.g., "moonshot" -> "kimi").
101pub fn create_provider(name: &str) -> Result<Arc<dyn Provider>, ConversionError> {
102    let name_lower = name.to_lowercase();
103
104    // Handle aliases
105    let normalized_name = match name_lower.as_str() {
106        "moonshot" => "kimi",
107        other => other,
108    };
109
110    // Try to get from registry
111    if let Some(factory) = get_registry().get(normalized_name) {
112        return Ok(factory());
113    }
114
115    // Return error with available provider names
116    let available = registered_provider_names();
117    Err(ConversionError::ProviderError(format!(
118        "Unknown provider: {}. Available: {:?}",
119        name, available
120    )))
121}