llm_connector/core/
configurable.rs

1//! 可配置的协议适配器 - 配置驱动的抽象
2//!
3//! 这个模块提供了一个通用的协议适配器,通过配置来定制行为,
4//! 避免为每个 Provider 编写重复的样板代码。
5
6use crate::core::Protocol;
7use crate::error::LlmConnectorError;
8use crate::types::{ChatRequest, ChatResponse};
9use async_trait::async_trait;
10use std::sync::Arc;
11
12#[cfg(feature = "streaming")]
13use crate::types::ChatStream;
14
15/// 可配置的协议适配器
16///
17/// 包装一个基础协议,通过配置来修改其行为(端点路径、认证方式等)。
18///
19/// # 示例
20/// ```rust,no_run
21/// use llm_connector::core::{ConfigurableProtocol, ProtocolConfig, EndpointConfig, AuthConfig};
22/// use llm_connector::protocols::OpenAIProtocol;
23///
24/// let config = ProtocolConfig {
25///     name: "custom".to_string(),
26///     endpoints: EndpointConfig {
27///         chat_template: "{base_url}/v1/chat/completions".to_string(),
28///         models_template: Some("{base_url}/v1/models".to_string()),
29///     },
30///     auth: AuthConfig::Bearer,
31///     extra_headers: vec![],
32/// };
33///
34/// let protocol = ConfigurableProtocol::new(
35///     OpenAIProtocol::new("sk-..."),
36///     config
37/// );
38/// ```
39#[derive(Clone)]
40pub struct ConfigurableProtocol<P: Protocol> {
41    inner: P,
42    config: ProtocolConfig,
43}
44
45/// 协议配置
46///
47/// 定义协议的静态配置,包括名称、端点、认证方式等。
48#[derive(Clone, Debug)]
49pub struct ProtocolConfig {
50    /// 协议名称
51    pub name: String,
52
53    /// 端点配置
54    pub endpoints: EndpointConfig,
55
56    /// 认证配置
57    pub auth: AuthConfig,
58
59    /// 额外的静态头部
60    pub extra_headers: Vec<(String, String)>,
61}
62
63/// 端点配置
64///
65/// 定义 API 端点的路径模板,支持 `{base_url}` 变量替换。
66#[derive(Clone, Debug)]
67pub struct EndpointConfig {
68    /// 聊天端点模板
69    ///
70    /// 支持变量: `{base_url}`
71    ///
72    /// 例如: `"{base_url}/v1/chat/completions"`
73    pub chat_template: String,
74
75    /// 模型列表端点模板(可选)
76    ///
77    /// 例如: `"{base_url}/v1/models"`
78    pub models_template: Option<String>,
79}
80
81/// 认证配置
82///
83/// 定义如何处理 API 认证。
84#[derive(Clone)]
85pub enum AuthConfig {
86    /// Bearer token 认证
87    ///
88    /// 生成: `Authorization: Bearer {token}`
89    Bearer,
90
91    /// API Key header 认证
92    ///
93    /// 生成: `{header_name}: {token}`
94    ApiKeyHeader {
95        /// Header 名称
96        header_name: String,
97    },
98
99    /// 无认证
100    None,
101
102    /// 自定义认证(通过闭包)
103    ///
104    /// 闭包接收 token,返回头部列表
105    Custom(Arc<dyn Fn(&str) -> Vec<(String, String)> + Send + Sync>),
106}
107
108impl std::fmt::Debug for AuthConfig {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match self {
111            AuthConfig::Bearer => write!(f, "Bearer"),
112            AuthConfig::ApiKeyHeader { header_name } => {
113                write!(f, "ApiKeyHeader({})", header_name)
114            }
115            AuthConfig::None => write!(f, "None"),
116            AuthConfig::Custom(_) => write!(f, "Custom(...)"),
117        }
118    }
119}
120
121impl<P: Protocol> ConfigurableProtocol<P> {
122    /// 创建新的可配置协议适配器
123    ///
124    /// # 参数
125    /// - `inner`: 基础协议实例
126    /// - `config`: 协议配置
127    pub fn new(inner: P, config: ProtocolConfig) -> Self {
128        Self { inner, config }
129    }
130
131    /// 便捷构造器 - OpenAI 兼容协议
132    ///
133    /// 创建一个使用标准 OpenAI 端点和 Bearer 认证的配置。
134    ///
135    /// # 参数
136    /// - `inner`: 基础协议实例
137    /// - `name`: 协议名称
138    ///
139    /// # 示例
140    /// ```rust,no_run
141    /// use llm_connector::core::ConfigurableProtocol;
142    /// use llm_connector::protocols::OpenAIProtocol;
143    ///
144    /// let protocol = ConfigurableProtocol::openai_compatible(
145    ///     OpenAIProtocol::new("sk-..."),
146    ///     "custom-openai"
147    /// );
148    /// ```
149    pub fn openai_compatible(inner: P, name: &str) -> Self {
150        Self::new(
151            inner,
152            ProtocolConfig {
153                name: name.to_string(),
154                endpoints: EndpointConfig {
155                    chat_template: "{base_url}/v1/chat/completions".to_string(),
156                    models_template: Some("{base_url}/v1/models".to_string()),
157                },
158                auth: AuthConfig::Bearer,
159                extra_headers: vec![],
160            },
161        )
162    }
163
164    /// 从内部协议提取 token
165    ///
166    /// 这是一个辅助方法,用于从内部协议的认证头中提取 token。
167    fn extract_token_from_inner(&self) -> String {
168        let headers = self.inner.auth_headers();
169        for (key, value) in headers {
170            if key.to_lowercase() == "authorization" {
171                // 提取 "Bearer xxx" 或 "xxx"
172                if let Some(token) = value.strip_prefix("Bearer ") {
173                    return token.to_string();
174                }
175                return value;
176            } else if key.to_lowercase() == "x-api-key" {
177                return value;
178            }
179        }
180        // 如果找不到,返回空字符串
181        String::new()
182    }
183}
184
185#[async_trait]
186impl<P: Protocol> Protocol for ConfigurableProtocol<P> {
187    type Request = P::Request;
188    type Response = P::Response;
189
190    fn name(&self) -> &str {
191        &self.config.name
192    }
193
194    fn chat_endpoint(&self, base_url: &str) -> String {
195        self.config
196            .endpoints
197            .chat_template
198            .replace("{base_url}", base_url.trim_end_matches('/'))
199    }
200
201    fn models_endpoint(&self, base_url: &str) -> Option<String> {
202        self.config
203            .endpoints
204            .models_template
205            .as_ref()
206            .map(|template| template.replace("{base_url}", base_url.trim_end_matches('/')))
207    }
208
209    fn build_request(
210        &self,
211        request: &ChatRequest,
212    ) -> Result<Self::Request, LlmConnectorError> {
213        self.inner.build_request(request)
214    }
215
216    fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
217        self.inner.parse_response(response)
218    }
219
220    fn parse_models(&self, response: &str) -> Result<Vec<String>, LlmConnectorError> {
221        self.inner.parse_models(response)
222    }
223
224    fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
225        self.inner.map_error(status, body)
226    }
227
228    fn auth_headers(&self) -> Vec<(String, String)> {
229        let mut headers = match &self.config.auth {
230            AuthConfig::Bearer => {
231                // 从 inner protocol 获取 token 并转换为 Bearer 格式
232                let token = self.extract_token_from_inner();
233                if token.is_empty() {
234                    vec![]
235                } else {
236                    vec![("Authorization".to_string(), format!("Bearer {}", token))]
237                }
238            }
239            AuthConfig::ApiKeyHeader { header_name } => {
240                // 从 inner protocol 获取 token,使用自定义 header 名称
241                let token = self.extract_token_from_inner();
242                if token.is_empty() {
243                    vec![]
244                } else {
245                    vec![(header_name.clone(), token)]
246                }
247            }
248            AuthConfig::None => vec![],
249            AuthConfig::Custom(f) => {
250                let token = self.extract_token_from_inner();
251                f(&token)
252            }
253        };
254
255        // 添加额外的静态头部
256        headers.extend(self.config.extra_headers.clone());
257        headers
258    }
259
260    #[cfg(feature = "streaming")]
261    async fn parse_stream_response(
262        &self,
263        response: reqwest::Response,
264    ) -> Result<ChatStream, LlmConnectorError> {
265        self.inner.parse_stream_response(response).await
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::protocols::OpenAIProtocol;
273
274    #[test]
275    fn test_configurable_protocol_basic() {
276        let config = ProtocolConfig {
277            name: "test".to_string(),
278            endpoints: EndpointConfig {
279                chat_template: "{base_url}/v1/chat/completions".to_string(),
280                models_template: Some("{base_url}/v1/models".to_string()),
281            },
282            auth: AuthConfig::Bearer,
283            extra_headers: vec![],
284        };
285
286        let protocol = ConfigurableProtocol::new(OpenAIProtocol::new("sk-test"), config);
287
288        assert_eq!(protocol.name(), "test");
289        assert_eq!(
290            protocol.chat_endpoint("https://api.example.com"),
291            "https://api.example.com/v1/chat/completions"
292        );
293        assert_eq!(
294            protocol.models_endpoint("https://api.example.com"),
295            Some("https://api.example.com/v1/models".to_string())
296        );
297    }
298
299    #[test]
300    fn test_openai_compatible() {
301        let protocol =
302            ConfigurableProtocol::openai_compatible(OpenAIProtocol::new("sk-test"), "custom");
303
304        assert_eq!(protocol.name(), "custom");
305        assert_eq!(
306            protocol.chat_endpoint("https://api.example.com"),
307            "https://api.example.com/v1/chat/completions"
308        );
309    }
310
311    #[test]
312    fn test_custom_endpoint() {
313        let config = ProtocolConfig {
314            name: "volcengine".to_string(),
315            endpoints: EndpointConfig {
316                chat_template: "{base_url}/api/v3/chat/completions".to_string(),
317                models_template: Some("{base_url}/api/v3/models".to_string()),
318            },
319            auth: AuthConfig::Bearer,
320            extra_headers: vec![],
321        };
322
323        let protocol = ConfigurableProtocol::new(OpenAIProtocol::new("sk-test"), config);
324
325        assert_eq!(
326            protocol.chat_endpoint("https://api.example.com"),
327            "https://api.example.com/api/v3/chat/completions"
328        );
329    }
330
331    #[test]
332    fn test_extra_headers() {
333        let config = ProtocolConfig {
334            name: "test".to_string(),
335            endpoints: EndpointConfig {
336                chat_template: "{base_url}/v1/chat/completions".to_string(),
337                models_template: None,
338            },
339            auth: AuthConfig::Bearer,
340            extra_headers: vec![
341                ("X-Custom-Header".to_string(), "value".to_string()),
342                ("X-Another-Header".to_string(), "value2".to_string()),
343            ],
344        };
345
346        let protocol = ConfigurableProtocol::new(OpenAIProtocol::new("sk-test"), config);
347        let headers = protocol.auth_headers();
348
349        assert!(headers
350            .iter()
351            .any(|(k, v)| k == "X-Custom-Header" && v == "value"));
352        assert!(headers
353            .iter()
354            .any(|(k, v)| k == "X-Another-Header" && v == "value2"));
355    }
356}
357