codex_convert_proxy/providers/trait_.rs
1//! Provider trait definition.
2
3use crate::error::ConversionError;
4use crate::types::chat_api::{ChatRequest, ChatResponse, ChatStreamChunk};
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock};
7
8// ============================================================================
9// Provider Factory Registry
10// ============================================================================
11
12/// Factory function type for creating providers (type-erased function pointer).
13type ProviderFactory = fn() -> Arc<dyn Provider>;
14
15/// Static registry of provider factories.
16fn get_registry() -> &'static HashMap<&'static str, ProviderFactory> {
17 static REGISTRY: OnceLock<HashMap<&'static str, ProviderFactory>> = OnceLock::new();
18 REGISTRY.get_or_init(|| {
19 let mut m = HashMap::new();
20 m.insert("glm", glm_factory as ProviderFactory);
21 m.insert("kimi", kimi_factory as ProviderFactory);
22 m.insert("deepseek", deepseek_factory as ProviderFactory);
23 m.insert("minimax", minimax_factory as ProviderFactory);
24 m
25 })
26}
27
28/// Get all registered provider names.
29pub fn registered_provider_names() -> Vec<&'static str> {
30 get_registry().keys().copied().collect()
31}
32
33// Factory functions (must be in separate functions to get unique addresses)
34fn glm_factory() -> Arc<dyn Provider> {
35 Arc::new(super::glm::GLMProvider::new())
36}
37fn kimi_factory() -> Arc<dyn Provider> {
38 Arc::new(super::kimi::KimiProvider::new())
39}
40fn deepseek_factory() -> Arc<dyn Provider> {
41 Arc::new(super::deepseek::DeepSeekProvider::new())
42}
43fn minimax_factory() -> Arc<dyn Provider> {
44 Arc::new(super::minimax::MiniMaxProvider::new())
45}
46
47// ============================================================================
48// Provider Trait
49// ============================================================================
50
51/// Provider trait for LLM provider-specific transformations.
52///
53/// Each Chinese LLM provider may have slightly different API requirements
54/// or model name formats that need to be normalized.
55///
56/// Implementations are expected to be **stateless** so a single instance can
57/// be shared across all requests via `Arc<dyn Provider>`.
58pub trait Provider: Send + Sync + 'static {
59 /// Get provider name.
60 fn name(&self) -> &'static str;
61
62 /// Normalize model name from Responses API to provider's format.
63 fn normalize_model(&self, model: String) -> String {
64 model
65 }
66
67 /// Get the chat completions path for this provider.
68 /// Only returns the endpoint path, e.g., "/chat/completions".
69 /// The version prefix (e.g., "/v1") should come from the backend URL's base_path.
70 fn chat_completions_path(&self) -> String {
71 "/v1/chat/completions".to_string()
72 }
73
74 /// Transform request before sending to provider.
75 ///
76 /// This is called after the standard conversion but before sending
77 /// to the upstream provider. Providers can modify the request to
78 /// handle API differences.
79 fn transform_request(&self, _request: &mut ChatRequest) {}
80
81 /// Transform response after receiving from provider.
82 ///
83 /// This is called after receiving the response but before converting
84 /// to Responses API format. Providers can normalize response format.
85 fn transform_response(&self, _response: &mut ChatResponse) {}
86
87 /// Transform streaming chunk in real-time.
88 ///
89 /// This is called for each SSE chunk received from the provider.
90 /// Providers can modify chunk content before event conversion.
91 fn transform_stream_chunk(&self, _chunk: &mut ChatStreamChunk) {}
92}
93
94// ============================================================================
95// Factory Function
96// ============================================================================
97
98/// Create a provider by name using the static registry.
99///
100/// Supports both exact names and aliases (e.g., "moonshot" -> "kimi").
101pub fn create_provider(name: &str) -> Result<Arc<dyn Provider>, ConversionError> {
102 let name_lower = name.to_lowercase();
103
104 // Handle aliases
105 let normalized_name = match name_lower.as_str() {
106 "moonshot" => "kimi",
107 other => other,
108 };
109
110 // Try to get from registry
111 if let Some(factory) = get_registry().get(normalized_name) {
112 return Ok(factory());
113 }
114
115 // Return error with available provider names
116 let available = registered_provider_names();
117 Err(ConversionError::ProviderError(format!(
118 "Unknown provider: {}. Available: {:?}",
119 name, available
120 )))
121}