Skip to main content

codex_convert_proxy/providers/
trait_.rs

1//! Provider trait definition.
2
3use crate::error::ConversionError;
4use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
5use std::collections::HashMap;
6use std::sync::OnceLock;
7
8// ============================================================================
9// Provider Factory Registry
10// ============================================================================
11
12/// Factory function type for creating providers (type-erased function pointer).
13type ProviderFactory = fn() -> Box<dyn Provider + Send + Sync>;
14
15/// Static registry of provider factories.
16fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
17    static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
18    REGISTRY.get_or_init(|| {
19        let mut m = HashMap::new();
20        m.insert("glm", glm_factory as ProviderFactory);
21        m.insert("kimi", kimi_factory as ProviderFactory);
22        m.insert("deepseek", deepseek_factory as ProviderFactory);
23        m.insert("minimax", minimax_factory as ProviderFactory);
24        m
25    })
26}
27
28/// Get all registered provider names.
29pub fn registered_provider_names() -> Vec<&'static str> {
30    get_registry().keys().copied().collect()
31}
32
33// Factory functions (must be in separate functions to get unique addresses)
34fn glm_factory() -> Box<dyn Provider + Send + Sync> {
35    Box::new(super::glm::GLMProvider::new())
36}
37fn kimi_factory() -> Box<dyn Provider + Send + Sync> {
38    Box::new(super::kimi::KimiProvider::new())
39}
40fn deepseek_factory() -> Box<dyn Provider + Send + Sync> {
41    Box::new(super::deepseek::DeepSeekProvider::new())
42}
43fn minimax_factory() -> Box<dyn Provider + Send + Sync> {
44    Box::new(super::minimax::MiniMaxProvider::new())
45}
46
47// ============================================================================
48// Provider Trait
49// ============================================================================
50
51/// Provider trait for LLM provider-specific transformations.
52///
53/// Each Chinese LLM provider may have slightly different API requirements
54/// or model name formats that need to be normalized.
55pub trait Provider: Send + Sync + 'static {
56    /// Get provider name.
57    fn name(&self) -> &'static str;
58
59    /// Normalize model name from Responses API to provider's format.
60    fn normalize_model(&self, model: String) -> String {
61        model
62    }
63
64    /// Get the chat completions path for this provider.
65    /// Only returns the endpoint path, e.g., "/chat/completions".
66    /// The version prefix (e.g., "/v1") should come from the backend URL's base_path.
67    fn chat_completions_path(&self) -> String {
68        "/v1/chat/completions".to_string()
69    }
70
71    /// Transform request before sending to provider.
72    ///
73    /// This is called after the standard conversion but before sending
74    /// to the upstream provider. Providers can modify the request to
75    /// handle API differences.
76    fn transform_request(&self, _request: &mut ChatRequest) {}
77
78    /// Transform response after receiving from provider.
79    ///
80    /// This is called after receiving the response but before converting
81    /// to Responses API format. Providers can normalize response format.
82    fn transform_response(&self, _response: &mut ChatResponse) {}
83
84    /// Transform streaming chunk in real-time.
85    ///
86    /// This is called for each SSE chunk received from the provider.
87    /// Providers can modify chunk content before event conversion.
88    fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
89
90    /// Clone the provider as a boxed trait object.
91    fn clone_box(&self) -> Box<dyn Provider + Send + Sync>;
92
93    /// Convert self to Any for downcasting.
94    fn as_any(&self) -> &dyn std::any::Any;
95}
96
97// ============================================================================
98// Clone Implementation
99// ============================================================================
100
101/// Clone for Box<dyn Provider> uses downcasting (Rust object safety limitation).
102impl Clone for Box<dyn Provider + Send + Sync> {
103    fn clone(&self) -> Self {
104        self.as_ref().clone_box()
105    }
106}
107
108// ============================================================================
109// Factory Function
110// ============================================================================
111
112/// Create a provider by name using the static registry.
113///
114/// Supports both exact names and aliases (e.g., "moonshot" -> "kimi").
115pub fn create_provider(name: &str) -> Result<Box<dyn Provider + Send + Sync>, ConversionError> {
116    let name_lower = name.to_lowercase();
117
118    // Handle aliases
119    let normalized_name = match name_lower.as_str() {
120        "moonshot" => "kimi",
121        other => other,
122    };
123
124    // Try to get from registry
125    if let Some(factory) = get_registry().get(normalized_name) {
126        return Ok(factory());
127    }
128
129    // Return error with available provider names
130    let available = registered_provider_names();
131    Err(ConversionError::ProviderError(format!(
132        "Unknown provider: {}. Available: {:?}",
133        name, available
134    )))
135}