1use crate::core::Protocol;
7use crate::error::LlmConnectorError;
8use crate::types::{ChatRequest, ChatResponse};
9use async_trait::async_trait;
10use std::sync::Arc;
11
12#[cfg(feature = "streaming")]
13use crate::types::ChatStream;
14
15#[derive(Clone)]
40pub struct ConfigurableProtocol<P: Protocol> {
41 inner: P,
42 config: ProtocolConfig,
43}
44
45#[derive(Clone, Debug)]
49pub struct ProtocolConfig {
50 pub name: String,
52
53 pub endpoints: EndpointConfig,
55
56 pub auth: AuthConfig,
58
59 pub extra_headers: Vec<(String, String)>,
61}
62
63#[derive(Clone, Debug)]
67pub struct EndpointConfig {
68 pub chat_template: String,
74
75 pub models_template: Option<String>,
79}
80
81#[derive(Clone)]
85pub enum AuthConfig {
86 Bearer,
90
91 ApiKeyHeader {
95 header_name: String,
97 },
98
99 None,
101
102 Custom(Arc<dyn Fn(&str) -> Vec<(String, String)> + Send + Sync>),
106}
107
108impl std::fmt::Debug for AuthConfig {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 AuthConfig::Bearer => write!(f, "Bearer"),
112 AuthConfig::ApiKeyHeader { header_name } => {
113 write!(f, "ApiKeyHeader({})", header_name)
114 }
115 AuthConfig::None => write!(f, "None"),
116 AuthConfig::Custom(_) => write!(f, "Custom(...)"),
117 }
118 }
119}
120
121impl<P: Protocol> ConfigurableProtocol<P> {
122 pub fn new(inner: P, config: ProtocolConfig) -> Self {
128 Self { inner, config }
129 }
130
131 pub fn openai_compatible(inner: P, name: &str) -> Self {
150 Self::new(
151 inner,
152 ProtocolConfig {
153 name: name.to_string(),
154 endpoints: EndpointConfig {
155 chat_template: "{base_url}/v1/chat/completions".to_string(),
156 models_template: Some("{base_url}/v1/models".to_string()),
157 },
158 auth: AuthConfig::Bearer,
159 extra_headers: vec![],
160 },
161 )
162 }
163
164 fn extract_token_from_inner(&self) -> String {
168 let headers = self.inner.auth_headers();
169 for (key, value) in headers {
170 if key.to_lowercase() == "authorization" {
171 if let Some(token) = value.strip_prefix("Bearer ") {
173 return token.to_string();
174 }
175 return value;
176 } else if key.to_lowercase() == "x-api-key" {
177 return value;
178 }
179 }
180 String::new()
182 }
183}
184
185#[async_trait]
186impl<P: Protocol> Protocol for ConfigurableProtocol<P> {
187 type Request = P::Request;
188 type Response = P::Response;
189
190 fn name(&self) -> &str {
191 &self.config.name
192 }
193
194 fn chat_endpoint(&self, base_url: &str) -> String {
195 self.config
196 .endpoints
197 .chat_template
198 .replace("{base_url}", base_url.trim_end_matches('/'))
199 }
200
201 fn models_endpoint(&self, base_url: &str) -> Option<String> {
202 self.config
203 .endpoints
204 .models_template
205 .as_ref()
206 .map(|template| template.replace("{base_url}", base_url.trim_end_matches('/')))
207 }
208
209 fn build_request(
210 &self,
211 request: &ChatRequest,
212 ) -> Result<Self::Request, LlmConnectorError> {
213 self.inner.build_request(request)
214 }
215
216 fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
217 self.inner.parse_response(response)
218 }
219
220 fn parse_models(&self, response: &str) -> Result<Vec<String>, LlmConnectorError> {
221 self.inner.parse_models(response)
222 }
223
224 fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
225 self.inner.map_error(status, body)
226 }
227
228 fn auth_headers(&self) -> Vec<(String, String)> {
229 let mut headers = match &self.config.auth {
230 AuthConfig::Bearer => {
231 let token = self.extract_token_from_inner();
233 if token.is_empty() {
234 vec![]
235 } else {
236 vec![("Authorization".to_string(), format!("Bearer {}", token))]
237 }
238 }
239 AuthConfig::ApiKeyHeader { header_name } => {
240 let token = self.extract_token_from_inner();
242 if token.is_empty() {
243 vec![]
244 } else {
245 vec![(header_name.clone(), token)]
246 }
247 }
248 AuthConfig::None => vec![],
249 AuthConfig::Custom(f) => {
250 let token = self.extract_token_from_inner();
251 f(&token)
252 }
253 };
254
255 headers.extend(self.config.extra_headers.clone());
257 headers
258 }
259
260 #[cfg(feature = "streaming")]
261 async fn parse_stream_response(
262 &self,
263 response: reqwest::Response,
264 ) -> Result<ChatStream, LlmConnectorError> {
265 self.inner.parse_stream_response(response).await
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::protocols::OpenAIProtocol;
273
274 #[test]
275 fn test_configurable_protocol_basic() {
276 let config = ProtocolConfig {
277 name: "test".to_string(),
278 endpoints: EndpointConfig {
279 chat_template: "{base_url}/v1/chat/completions".to_string(),
280 models_template: Some("{base_url}/v1/models".to_string()),
281 },
282 auth: AuthConfig::Bearer,
283 extra_headers: vec![],
284 };
285
286 let protocol = ConfigurableProtocol::new(OpenAIProtocol::new("sk-test"), config);
287
288 assert_eq!(protocol.name(), "test");
289 assert_eq!(
290 protocol.chat_endpoint("https://api.example.com"),
291 "https://api.example.com/v1/chat/completions"
292 );
293 assert_eq!(
294 protocol.models_endpoint("https://api.example.com"),
295 Some("https://api.example.com/v1/models".to_string())
296 );
297 }
298
299 #[test]
300 fn test_openai_compatible() {
301 let protocol =
302 ConfigurableProtocol::openai_compatible(OpenAIProtocol::new("sk-test"), "custom");
303
304 assert_eq!(protocol.name(), "custom");
305 assert_eq!(
306 protocol.chat_endpoint("https://api.example.com"),
307 "https://api.example.com/v1/chat/completions"
308 );
309 }
310
311 #[test]
312 fn test_custom_endpoint() {
313 let config = ProtocolConfig {
314 name: "volcengine".to_string(),
315 endpoints: EndpointConfig {
316 chat_template: "{base_url}/api/v3/chat/completions".to_string(),
317 models_template: Some("{base_url}/api/v3/models".to_string()),
318 },
319 auth: AuthConfig::Bearer,
320 extra_headers: vec![],
321 };
322
323 let protocol = ConfigurableProtocol::new(OpenAIProtocol::new("sk-test"), config);
324
325 assert_eq!(
326 protocol.chat_endpoint("https://api.example.com"),
327 "https://api.example.com/api/v3/chat/completions"
328 );
329 }
330
331 #[test]
332 fn test_extra_headers() {
333 let config = ProtocolConfig {
334 name: "test".to_string(),
335 endpoints: EndpointConfig {
336 chat_template: "{base_url}/v1/chat/completions".to_string(),
337 models_template: None,
338 },
339 auth: AuthConfig::Bearer,
340 extra_headers: vec![
341 ("X-Custom-Header".to_string(), "value".to_string()),
342 ("X-Another-Header".to_string(), "value2".to_string()),
343 ],
344 };
345
346 let protocol = ConfigurableProtocol::new(OpenAIProtocol::new("sk-test"), config);
347 let headers = protocol.auth_headers();
348
349 assert!(headers
350 .iter()
351 .any(|(k, v)| k == "X-Custom-Header" && v == "value"));
352 assert!(headers
353 .iter()
354 .any(|(k, v)| k == "X-Another-Header" && v == "value2"));
355 }
356}
357