codex_convert_proxy/providers/trait_.rs
1//! Provider trait definition.
2
3use crate::error::ConversionError;
4use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
5use std::collections::HashMap;
6use std::sync::OnceLock;
7
8// ============================================================================
9// Provider Factory Registry
10// ============================================================================
11
12/// Factory function type for creating providers (type-erased function pointer).
13type ProviderFactory = fn() -> Box<dyn Provider + Send + Sync>;
14
15/// Static registry of provider factories.
16fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
17 static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
18 REGISTRY.get_or_init(|| {
19 let mut m = HashMap::new();
20 m.insert("glm", glm_factory as ProviderFactory);
21 m.insert("kimi", kimi_factory as ProviderFactory);
22 m.insert("deepseek", deepseek_factory as ProviderFactory);
23 m.insert("minimax", minimax_factory as ProviderFactory);
24 m
25 })
26}
27
28/// Get all registered provider names.
29pub fn registered_provider_names() -> Vec<&'static str> {
30 get_registry().keys().copied().collect()
31}
32
33// Factory functions (must be in separate functions to get unique addresses)
34fn glm_factory() -> Box<dyn Provider + Send + Sync> {
35 Box::new(super::glm::GLMProvider::new())
36}
37fn kimi_factory() -> Box<dyn Provider + Send + Sync> {
38 Box::new(super::kimi::KimiProvider::new())
39}
40fn deepseek_factory() -> Box<dyn Provider + Send + Sync> {
41 Box::new(super::deepseek::DeepSeekProvider::new())
42}
43fn minimax_factory() -> Box<dyn Provider + Send + Sync> {
44 Box::new(super::minimax::MiniMaxProvider::new())
45}
46
47// ============================================================================
48// Provider Trait
49// ============================================================================
50
51/// Provider trait for LLM provider-specific transformations.
52///
53/// Each Chinese LLM provider may have slightly different API requirements
54/// or model name formats that need to be normalized.
55pub trait Provider: Send + Sync + 'static {
56 /// Get provider name.
57 fn name(&self) -> &'static str;
58
59 /// Normalize model name from Responses API to provider's format.
60 fn normalize_model(&self, model: String) -> String {
61 model
62 }
63
64 /// Get the chat completions path for this provider.
65 /// Only returns the endpoint path, e.g., "/chat/completions".
66 /// The version prefix (e.g., "/v1") should come from the backend URL's base_path.
67 fn chat_completions_path(&self) -> String {
68 "/v1/chat/completions".to_string()
69 }
70
71 /// Transform request before sending to provider.
72 ///
73 /// This is called after the standard conversion but before sending
74 /// to the upstream provider. Providers can modify the request to
75 /// handle API differences.
76 fn transform_request(&self, _request: &mut ChatRequest) {}
77
78 /// Transform response after receiving from provider.
79 ///
80 /// This is called after receiving the response but before converting
81 /// to Responses API format. Providers can normalize response format.
82 fn transform_response(&self, _response: &mut ChatResponse) {}
83
84 /// Transform streaming chunk in real-time.
85 ///
86 /// This is called for each SSE chunk received from the provider.
87 /// Providers can modify chunk content before event conversion.
88 fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
89
90 /// Clone the provider as a boxed trait object.
91 fn clone_box(&self) -> Box<dyn Provider + Send + Sync>;
92
93 /// Convert self to Any for downcasting.
94 fn as_any(&self) -> &dyn std::any::Any;
95}
96
97// ============================================================================
98// Clone Implementation
99// ============================================================================
100
101/// Clone for Box<dyn Provider> uses downcasting (Rust object safety limitation).
102impl Clone for Box<dyn Provider + Send + Sync> {
103 fn clone(&self) -> Self {
104 self.as_ref().clone_box()
105 }
106}
107
108// ============================================================================
109// Factory Function
110// ============================================================================
111
112/// Create a provider by name using the static registry.
113///
114/// Supports both exact names and aliases (e.g., "moonshot" -> "kimi").
115pub fn create_provider(name: &str) -> Result<Box<dyn Provider + Send + Sync>, ConversionError> {
116 let name_lower = name.to_lowercase();
117
118 // Handle aliases
119 let normalized_name = match name_lower.as_str() {
120 "moonshot" => "kimi",
121 other => other,
122 };
123
124 // Try to get from registry
125 if let Some(factory) = get_registry().get(normalized_name) {
126 return Ok(factory());
127 }
128
129 // Return error with available provider names
130 let available = registered_provider_names();
131 Err(ConversionError::ProviderError(format!(
132 "Unknown provider: {}. Available: {:?}",
133 name, available
134 )))
135}