codex_convert_proxy/providers/trait_.rs
1//! Provider trait definition.
2
3use tracing::info;
4
5use crate::error::ConversionError;
6use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
7use std::collections::HashMap;
8use std::sync::{Arc, OnceLock};
9
10// ============================================================================
11// Provider Factory Registry
12// ============================================================================
13
14/// Factory function type for creating providers (type-erased function pointer).
15type ProviderFactory = fn() -> Arc<dyn Provider>;
16
17/// Static registry of provider factories.
18fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
19 static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
20 REGISTRY.get_or_init(|| {
21 let mut m = HashMap::new();
22 m.insert("glm", glm_factory as ProviderFactory);
23 m.insert("kimi", kimi_factory as ProviderFactory);
24 m.insert("deepseek", deepseek_factory as ProviderFactory);
25 m.insert("minimax", minimax_factory as ProviderFactory);
26 m
27 })
28}
29
30/// Get all registered provider names.
31pub fn registered_provider_names() -> Vec<&'static str> {
32 get_registry().keys().copied().collect()
33}
34
35// Factory functions (must be in separate functions to get unique addresses)
36fn glm_factory() -> Arc<dyn Provider> {
37 Arc::new(super::glm::GLMProvider::new())
38}
39fn kimi_factory() -> Arc<dyn Provider> {
40 Arc::new(super::kimi::KimiProvider::new())
41}
42fn deepseek_factory() -> Arc<dyn Provider> {
43 Arc::new(super::deepseek::DeepSeekProvider::new())
44}
45fn minimax_factory() -> Arc<dyn Provider> {
46 Arc::new(super::minimax::MiniMaxProvider::new())
47}
48// ============================================================================
49// Provider Trait
50// ============================================================================
51
52/// Provider trait for LLM provider-specific transformations.
53///
54/// Each Chinese LLM provider may have slightly different API requirements
55/// or model name formats that need to be normalized.
56///
57/// Implementations are expected to be **stateless** so a single instance can
58/// be shared across all requests via `Arc<dyn Provider>`.
59pub trait Provider: Send + Sync + 'static {
60 /// Get provider type identifier.
61 ///
62 /// Returns a static string identifying the provider type, e.g. `"glm"`,
63 /// `"kimi"`, `"default"`. This is used for programmatic dispatch
64 /// (e.g. `provider.name() == "minimax"`) and should **not** be used
65 /// for user-facing logs.
66 fn name(&self) -> &'static str;
67
68 /// Get the display name for logging and diagnostics.
69 ///
70 /// For named providers (GLM, Kimi, etc.) this returns the same value as
71 /// [`name()`](Self::name). For [`DefaultProvider`] this returns the
72 /// original backend name from config (e.g. `"qwen"`, `"yi-lightning"`),
73 /// making it easy to identify which backend a request was routed to.
74 ///
75 /// Use this in log messages, metrics labels, and diagnostics.
76 fn display_name(&self) -> &str {
77 self.name()
78 }
79
80 /// Normalize model name from Responses API to provider's format.
81 fn normalize_model(&self, model: String) -> String {
82 model
83 }
84
85 /// Get the chat completions path for this provider.
86 ///
87 /// Returns the endpoint path **without** the version prefix, e.g.
88 /// `"/chat/completions"`. The version prefix (e.g., `"/v1"`) comes
89 /// from the backend URL's `base_path` configured in `config.json` and
90 /// is prepended automatically during path rewriting in
91 /// `upstream_request_filter`.
92 ///
93 /// # Example
94 ///
95 /// Config URL `https://api.moonshot.cn/v1` → `base_path = "/v1"`
96 /// `chat_completions_path() = "/chat/completions"`
97 /// → final path: `/v1/chat/completions`
98 fn chat_completions_path(&self) -> String {
99 "/chat/completions".to_string()
100 }
101
102 /// Transform request before sending to provider.
103 ///
104 /// This is called after the standard conversion but before sending
105 /// to the upstream provider. Providers can modify the request to
106 /// handle API differences.
107 fn transform_request(&self, _request: &mut ChatRequest) {}
108
109 /// Transform response after receiving from provider.
110 ///
111 /// This is called after receiving the response but before converting
112 /// to Responses API format. Providers can normalize response format.
113 fn transform_response(&self, _response: &mut ChatResponse) {}
114
115 /// Transform streaming chunk in real-time.
116 ///
117 /// This is called for each SSE chunk received from the provider.
118 /// Providers can modify chunk content before event conversion.
119 fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
120}
121
122// ============================================================================
123// Factory Function
124// ============================================================================
125
126/// Create a provider by name using the static registry.
127///
128/// Supports both exact names and aliases (e.g., "moonshot" -> "kimi").
129///
130/// When the provider name is not found in the registry, falls back to
131/// [`DefaultProvider`] which makes minimal assumptions about the provider
132/// (standard OpenAI-compatible Chat API format). This allows any Chinese LLM
133/// provider not explicitly registered to work out of the box.
134pub fn create_provider(name: &str) -> Result<Arc<dyn Provider>, ConversionError> {
135 let name_lower = name.to_lowercase();
136
137 // Handle aliases
138 let normalized_name = match name_lower.as_str() {
139 "moonshot" => "kimi",
140 other => other,
141 };
142
143 // Try to get from registry
144 if let Some(factory) = get_registry().get(normalized_name) {
145 return Ok(factory());
146 }
147
148 // Fall back to DefaultProvider (OpenAI compatible) for unknown providers.
149 // This is not in the registry to keep "default" as a reserved fallback
150 // concept, separate from user-facing provider names.
151 info!(
152 "[PROVIDER] Unknown provider '{}', falling back to DefaultProvider (OpenAI compatible)",
153 name
154 );
155 Ok(Arc::new(super::default::DefaultProvider::new(name)))
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_create_provider_known() {
164 // Known provider should return the correct provider
165 let provider = create_provider("glm").unwrap();
166 assert_eq!(provider.name(), "glm");
167 assert_eq!(provider.display_name(), "glm");
168
169 let provider = create_provider("kimi").unwrap();
170 assert_eq!(provider.name(), "kimi");
171 assert_eq!(provider.display_name(), "kimi");
172
173 let provider = create_provider("deepseek").unwrap();
174 assert_eq!(provider.name(), "deepseek");
175 assert_eq!(provider.display_name(), "deepseek");
176
177 let provider = create_provider("minimax").unwrap();
178 assert_eq!(provider.name(), "minimax");
179 assert_eq!(provider.display_name(), "minimax");
180 }
181
182 #[test]
183 fn test_create_provider_unknown_fallback_to_default() {
184 // Unknown provider should fall back to DefaultProvider
185 let provider = create_provider("qwen").unwrap();
186 assert_eq!(provider.name(), "default");
187
188 let provider = create_provider("some-unknown-provider").unwrap();
189 assert_eq!(provider.name(), "default");
190
191 let provider = create_provider("abc").unwrap();
192 assert_eq!(provider.name(), "default");
193 }
194
195 #[test]
196 fn test_default_provider_display_name_preserves_backend_name() {
197 // DefaultProvider.display_name() should return the original backend name
198 let provider = create_provider("qwen").unwrap();
199 assert_eq!(provider.name(), "default");
200 assert_eq!(provider.display_name(), "qwen");
201
202 // Case insensitive
203 let provider = create_provider("Yi-Lightning").unwrap();
204 assert_eq!(provider.name(), "default");
205 assert_eq!(provider.display_name(), "yi-lightning");
206
207 let provider = create_provider("some-unknown-provider").unwrap();
208 assert_eq!(provider.name(), "default");
209 assert_eq!(provider.display_name(), "some-unknown-provider");
210 }
211
212 #[test]
213 fn test_create_provider_alias() {
214 // Aliases should work
215 let provider = create_provider("moonshot").unwrap();
216 assert_eq!(provider.name(), "kimi");
217 assert_eq!(provider.display_name(), "kimi");
218 }
219
220 #[test]
221 fn test_registered_provider_names_excludes_default() {
222 // "default" should NOT appear in registered names — it's a fallback, not a named provider
223 let names = registered_provider_names();
224 assert!(!names.contains(&"default"), "default should not be in registered_provider_names");
225 assert!(names.contains(&"glm"));
226 assert!(names.contains(&"kimi"));
227 assert!(names.contains(&"deepseek"));
228 assert!(names.contains(&"minimax"));
229 }
230}