Skip to main content

codex_convert_proxy/providers/
trait_.rs

1//! Provider trait definition.
2
3use tracing::info;
4
5use crate::error::ConversionError;
6use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
7use std::collections::HashMap;
8use std::sync::{Arc, OnceLock};
9
10// ============================================================================
11// Provider Factory Registry
12// ============================================================================
13
14/// Factory function type for creating providers (type-erased function pointer).
15type ProviderFactory = fn() -> Arc<dyn Provider>;
16
17/// Static registry of provider factories.
18fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
19    static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
20    REGISTRY.get_or_init(|| {
21        let mut m = HashMap::new();
22        m.insert("glm", glm_factory as ProviderFactory);
23        m.insert("kimi", kimi_factory as ProviderFactory);
24        m.insert("deepseek", deepseek_factory as ProviderFactory);
25        m.insert("minimax", minimax_factory as ProviderFactory);
26        m
27    })
28}
29
30/// Get all registered provider names.
31pub fn registered_provider_names() -> Vec<&'static str> {
32    get_registry().keys().copied().collect()
33}
34
35// Factory functions (must be in separate functions to get unique addresses)
36fn glm_factory() -> Arc<dyn Provider> {
37    Arc::new(super::glm::GLMProvider::new())
38}
39fn kimi_factory() -> Arc<dyn Provider> {
40    Arc::new(super::kimi::KimiProvider::new())
41}
42fn deepseek_factory() -> Arc<dyn Provider> {
43    Arc::new(super::deepseek::DeepSeekProvider::new())
44}
45fn minimax_factory() -> Arc<dyn Provider> {
46    Arc::new(super::minimax::MiniMaxProvider::new())
47}
48// ============================================================================
49// Provider Trait
50// ============================================================================
51
52/// Provider trait for LLM provider-specific transformations.
53///
54/// Each Chinese LLM provider may have slightly different API requirements
55/// or model name formats that need to be normalized.
56///
57/// Implementations are expected to be **stateless** so a single instance can
58/// be shared across all requests via `Arc<dyn Provider>`.
59pub trait Provider: Send + Sync + 'static {
60    /// Get provider type identifier.
61    ///
62    /// Returns a static string identifying the provider type, e.g. `"glm"`,
63    /// `"kimi"`, `"default"`. This is used for programmatic dispatch
64    /// (e.g. `provider.name() == "minimax"`) and should **not** be used
65    /// for user-facing logs.
66    fn name(&self) -> &'static str;
67
68    /// Get the display name for logging and diagnostics.
69    ///
70    /// For named providers (GLM, Kimi, etc.) this returns the same value as
71    /// [`name()`](Self::name). For [`DefaultProvider`] this returns the
72    /// original backend name from config (e.g. `"qwen"`, `"yi-lightning"`),
73    /// making it easy to identify which backend a request was routed to.
74    ///
75    /// Use this in log messages, metrics labels, and diagnostics.
76    fn display_name(&self) -> &str {
77        self.name()
78    }
79
80    /// Normalize model name from Responses API to provider's format.
81    fn normalize_model(&self, model: String) -> String {
82        model
83    }
84
85    /// Get the chat completions path for this provider.
86    ///
87    /// Returns the endpoint path **without** the version prefix, e.g.
88    /// `"/chat/completions"`. The version prefix (e.g., `"/v1"`) comes
89    /// from the backend URL's `base_path` configured in `config.json` and
90    /// is prepended automatically during path rewriting in
91    /// `upstream_request_filter`.
92    ///
93    /// # Example
94    ///
95    /// Config URL `https://api.moonshot.cn/v1` → `base_path = "/v1"`
96    /// `chat_completions_path() = "/chat/completions"`
97    /// → final path: `/v1/chat/completions`
98    fn chat_completions_path(&self) -> String {
99        "/chat/completions".to_string()
100    }
101
102    /// Transform request before sending to provider.
103    ///
104    /// This is called after the standard conversion but before sending
105    /// to the upstream provider. Providers can modify the request to
106    /// handle API differences.
107    fn transform_request(&self, _request: &mut ChatRequest) {}
108
109    /// Transform response after receiving from provider.
110    ///
111    /// This is called after receiving the response but before converting
112    /// to Responses API format. Providers can normalize response format.
113    fn transform_response(&self, _response: &mut ChatResponse) {}
114
115    /// Transform streaming chunk in real-time.
116    ///
117    /// This is called for each SSE chunk received from the provider.
118    /// Providers can modify chunk content before event conversion.
119    fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
120}
121
122// ============================================================================
123// Factory Function
124// ============================================================================
125
126/// Create a provider by name using the static registry.
127///
128/// Supports both exact names and aliases (e.g., "moonshot" -> "kimi").
129///
130/// When the provider name is not found in the registry, falls back to
131/// [`DefaultProvider`] which makes minimal assumptions about the provider
132/// (standard OpenAI-compatible Chat API format). This allows any Chinese LLM
133/// provider not explicitly registered to work out of the box.
134pub fn create_provider(name: &str) -> Result<Arc<dyn Provider>, ConversionError> {
135    let name_lower = name.to_lowercase();
136
137    // Handle aliases
138    let normalized_name = match name_lower.as_str() {
139        "moonshot" => "kimi",
140        other => other,
141    };
142
143    // Try to get from registry
144    if let Some(factory) = get_registry().get(normalized_name) {
145        return Ok(factory());
146    }
147
148    // Fall back to DefaultProvider (OpenAI compatible) for unknown providers.
149    // This is not in the registry to keep "default" as a reserved fallback
150    // concept, separate from user-facing provider names.
151    info!(
152        "[PROVIDER] Unknown provider '{}', falling back to DefaultProvider (OpenAI compatible)",
153        name
154    );
155    Ok(Arc::new(super::default::DefaultProvider::new(name)))
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_create_provider_known() {
164        // Known provider should return the correct provider
165        let provider = create_provider("glm").unwrap();
166        assert_eq!(provider.name(), "glm");
167        assert_eq!(provider.display_name(), "glm");
168
169        let provider = create_provider("kimi").unwrap();
170        assert_eq!(provider.name(), "kimi");
171        assert_eq!(provider.display_name(), "kimi");
172
173        let provider = create_provider("deepseek").unwrap();
174        assert_eq!(provider.name(), "deepseek");
175        assert_eq!(provider.display_name(), "deepseek");
176
177        let provider = create_provider("minimax").unwrap();
178        assert_eq!(provider.name(), "minimax");
179        assert_eq!(provider.display_name(), "minimax");
180    }
181
182    #[test]
183    fn test_create_provider_unknown_fallback_to_default() {
184        // Unknown provider should fall back to DefaultProvider
185        let provider = create_provider("qwen").unwrap();
186        assert_eq!(provider.name(), "default");
187
188        let provider = create_provider("some-unknown-provider").unwrap();
189        assert_eq!(provider.name(), "default");
190
191        let provider = create_provider("abc").unwrap();
192        assert_eq!(provider.name(), "default");
193    }
194
195    #[test]
196    fn test_default_provider_display_name_preserves_backend_name() {
197        // DefaultProvider.display_name() should return the original backend name
198        let provider = create_provider("qwen").unwrap();
199        assert_eq!(provider.name(), "default");
200        assert_eq!(provider.display_name(), "qwen");
201
202        // Case insensitive
203        let provider = create_provider("Yi-Lightning").unwrap();
204        assert_eq!(provider.name(), "default");
205        assert_eq!(provider.display_name(), "yi-lightning");
206
207        let provider = create_provider("some-unknown-provider").unwrap();
208        assert_eq!(provider.name(), "default");
209        assert_eq!(provider.display_name(), "some-unknown-provider");
210    }
211
212    #[test]
213    fn test_create_provider_alias() {
214        // Aliases should work
215        let provider = create_provider("moonshot").unwrap();
216        assert_eq!(provider.name(), "kimi");
217        assert_eq!(provider.display_name(), "kimi");
218    }
219
220    #[test]
221    fn test_registered_provider_names_excludes_default() {
222        // "default" should NOT appear in registered names — it's a fallback, not a named provider
223        let names = registered_provider_names();
224        assert!(!names.contains(&"default"), "default should not be in registered_provider_names");
225        assert!(names.contains(&"glm"));
226        assert!(names.contains(&"kimi"));
227        assert!(names.contains(&"deepseek"));
228        assert!(names.contains(&"minimax"));
229    }
230}