ai_lib/provider/
generic.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpClient, HttpTransport};
4use super::config::ProviderConfig;
5use std::env;
6use futures::stream::{Stream, StreamExt};
7
8/// 配置驱动的通用适配器,支持OpenAI兼容API
9/// 
10/// Configuration-driven generic adapter for OpenAI-compatible APIs
11pub struct GenericAdapter {
12    transport: HttpTransport,
13    config: ProviderConfig,
14    api_key: String,
15}
16
17impl GenericAdapter {
18    pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
19        let api_key = env::var(&config.api_key_env)
20            .map_err(|_| AiLibError::AuthenticationError(
21                format!("{} environment variable not set", config.api_key_env)
22            ))?;
23        
24        Ok(Self {
25            transport: HttpTransport::new(),
26            config,
27            api_key,
28        })
29    }
30    
31    /// Create adapter with custom transport layer (for testing)
32    pub fn with_transport(config: ProviderConfig, transport: HttpTransport) -> Result<Self, AiLibError> {
33        let api_key = env::var(&config.api_key_env)
34            .map_err(|_| AiLibError::AuthenticationError(
35                format!("{} environment variable not set", config.api_key_env)
36            ))?;
37        
38        Ok(Self {
39            transport,
40            config,
41            api_key,
42        })
43    }
44
45    /// Convert generic request to provider-specific format
46    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
47        let default_role = "user".to_string();
48        
49        // Build messages array
50        let messages: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
51            let role_key = format!("{:?}", msg.role);
52            let mapped_role = self.config.field_mapping.role_mapping
53                .get(&role_key)
54                .unwrap_or(&default_role);
55            serde_json::json!({
56                "role": mapped_role,
57                "content": msg.content
58            })
59        }).collect();
60        
61        // Use string literals as JSON keys
62        let mut provider_request = serde_json::json!({
63            "model": request.model,
64            "messages": messages
65        });
66
67        // Add optional parameters
68        if let Some(temp) = request.temperature {
69            provider_request["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
70        }
71        if let Some(max_tokens) = request.max_tokens {
72            provider_request["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
73        }
74        if let Some(top_p) = request.top_p {
75            provider_request["top_p"] = serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
76        }
77        if let Some(freq_penalty) = request.frequency_penalty {
78            provider_request["frequency_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(freq_penalty.into()).unwrap());
79        }
80        if let Some(presence_penalty) = request.presence_penalty {
81            provider_request["presence_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(presence_penalty.into()).unwrap());
82        }
83
84        provider_request
85    }
86
87    /// Find event boundary
88    fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
89        let mut i = 0;
90        while i < buffer.len().saturating_sub(1) {
91            if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
92                return Some(i + 2);
93            }
94            if i < buffer.len().saturating_sub(3) 
95                && buffer[i] == b'\r' && buffer[i + 1] == b'\n' 
96                && buffer[i + 2] == b'\r' && buffer[i + 3] == b'\n' {
97                return Some(i + 4);
98            }
99            i += 1;
100        }
101        None
102    }
103    
104    /// Parse SSE event
105    fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
106        for line in event_text.lines() {
107            let line = line.trim();
108            if line.starts_with("data: ") {
109                let data = &line[6..];
110                if data == "[DONE]" {
111                    return Some(Ok(None));
112                }
113                return Some(Self::parse_chunk_data(data));
114            }
115        }
116        None
117    }
118    
119    /// Parse chunk data
120    fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
121        match serde_json::from_str::<serde_json::Value>(data) {
122            Ok(json) => {
123                let choices = json["choices"].as_array()
124                    .map(|arr| {
125                        arr.iter()
126                            .enumerate()
127                            .map(|(index, choice)| {
128                                let delta = &choice["delta"];
129                                ChoiceDelta {
130                                    index: index as u32,
131                                    delta: MessageDelta {
132                                        role: delta["role"].as_str().map(|r| match r {
133                                            "assistant" => Role::Assistant,
134                                            "user" => Role::User,
135                                            "system" => Role::System,
136                                            _ => Role::Assistant,
137                                        }),
138                                        content: delta["content"].as_str().map(str::to_string),
139                                    },
140                                    finish_reason: choice["finish_reason"].as_str().map(str::to_string),
141                                }
142                            })
143                            .collect()
144                    })
145                    .unwrap_or_default();
146                
147                Ok(Some(ChatCompletionChunk {
148                    id: json["id"].as_str().unwrap_or_default().to_string(),
149                    object: json["object"].as_str().unwrap_or("chat.completion.chunk").to_string(),
150                    created: json["created"].as_u64().unwrap_or(0),
151                    model: json["model"].as_str().unwrap_or_default().to_string(),
152                    choices,
153                }))
154            }
155            Err(e) => Err(AiLibError::ProviderError(format!("JSON parse error: {}", e)))
156        }
157    }
158    
159    /// Parse response
160    fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
161        let choices = response["choices"]
162            .as_array()
163            .ok_or_else(|| AiLibError::ProviderError("Invalid response format: choices not found".to_string()))?
164            .iter()
165            .enumerate()
166            .map(|(index, choice)| {
167                let message = choice["message"].as_object()
168                    .ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
169                
170                let role = match message["role"].as_str().unwrap_or("user") {
171                    "system" => Role::System,
172                    "assistant" => Role::Assistant,
173                    _ => Role::User,
174                };
175                
176                let content = message["content"].as_str()
177                    .unwrap_or("")
178                    .to_string();
179                
180                Ok(Choice {
181                    index: index as u32,
182                    message: Message { role, content },
183                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
184                })
185            })
186            .collect::<Result<Vec<_>, AiLibError>>()?;
187        
188        let usage = response["usage"].as_object()
189            .ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
190        
191        let usage = Usage {
192            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
193            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
194            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
195        };
196        
197        Ok(ChatCompletionResponse {
198            id: response["id"].as_str().unwrap_or("").to_string(),
199            object: response["object"].as_str().unwrap_or("").to_string(),
200            created: response["created"].as_u64().unwrap_or(0),
201            model: response["model"].as_str().unwrap_or("").to_string(),
202            choices,
203            usage,
204        })
205    }
206}
207
208#[async_trait::async_trait]
209impl ChatApi for GenericAdapter {
210    async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
211        let provider_request = self.convert_request(&request);
212        let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
213        
214        let mut headers = self.config.headers.clone();
215        
216        // Set different authentication methods based on provider
217        if self.config.base_url.contains("anthropic.com") {
218            headers.insert("x-api-key".to_string(), self.api_key.clone());
219        } else {
220            headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
221        }
222        
223        let response: serde_json::Value = self.transport
224            .post(&url, Some(headers), &provider_request)
225            .await?;
226        
227        self.parse_response(response)
228    }
229
230    async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
231        let mut stream_request = self.convert_request(&request);
232        stream_request["stream"] = serde_json::Value::Bool(true);
233        
234        let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
235        
236        // Create HTTP client
237        let mut client_builder = reqwest::Client::builder();
238        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
239            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
240                client_builder = client_builder.proxy(proxy);
241            }
242        }
243        let client = client_builder.build()
244            .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
245        
246        let mut headers = self.config.headers.clone();
247        headers.insert("Accept".to_string(), "text/event-stream".to_string());
248        
249        // Set different authentication methods based on provider
250        if self.config.base_url.contains("anthropic.com") {
251            headers.insert("x-api-key".to_string(), self.api_key.clone());
252        } else {
253            headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
254        }
255        
256        let response = client
257            .post(&url)
258            .json(&stream_request);
259        
260        let mut req = response;
261        for (key, value) in headers {
262            req = req.header(key, value);
263        }
264        
265        let response = req.send().await
266            .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
267        
268        if !response.status().is_success() {
269            let error_text = response.text().await.unwrap_or_default();
270            return Err(AiLibError::ProviderError(format!("Stream error: {}", error_text)));
271        }
272        
273        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
274        
275        tokio::spawn(async move {
276            let mut buffer = Vec::new();
277            let mut stream = response.bytes_stream();
278            
279            while let Some(result) = stream.next().await {
280                match result {
281                    Ok(bytes) => {
282                        buffer.extend_from_slice(&bytes);
283                        
284                        while let Some(event_end) = Self::find_event_boundary(&buffer) {
285                            let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
286                            
287                            if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
288                                if let Some(chunk) = Self::parse_sse_event(event_text) {
289                                    match chunk {
290                                        Ok(Some(c)) => {
291                                            if tx.send(Ok(c)).is_err() {
292                                                return;
293                                            }
294                                        }
295                                        Ok(None) => return,
296                                        Err(e) => {
297                                            let _ = tx.send(Err(e));
298                                            return;
299                                        }
300                                    }
301                                }
302                            }
303                        }
304                    }
305                    Err(e) => {
306                        let _ = tx.send(Err(AiLibError::ProviderError(format!("Stream error: {}", e))));
307                        break;
308                    }
309                }
310            }
311        });
312        
313        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
314        Ok(Box::new(Box::pin(stream)))
315    }
316    
317
318
319    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
320        if let Some(models_endpoint) = &self.config.models_endpoint {
321            let url = format!("{}{}", self.config.base_url, models_endpoint);
322            let mut headers = self.config.headers.clone();
323            
324            // Set different authentication methods based on provider
325            if self.config.base_url.contains("anthropic.com") {
326                headers.insert("x-api-key".to_string(), self.api_key.clone());
327            } else {
328                headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
329            }
330            
331            let response: serde_json::Value = self.transport
332                .get(&url, Some(headers))
333                .await?;
334            
335            Ok(response["data"].as_array()
336                .unwrap_or(&vec![])
337                .iter()
338                .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
339                .collect())
340        } else {
341            Err(AiLibError::ProviderError("Models endpoint not configured".to_string()))
342        }
343    }
344
345    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
346        Ok(ModelInfo {
347            id: model_id.to_string(),
348            object: "model".to_string(),
349            created: 0,
350            owned_by: "generic".to_string(),
351            permission: vec![ModelPermission {
352                id: "default".to_string(),
353                object: "model_permission".to_string(),
354                created: 0,
355                allow_create_engine: false,
356                allow_sampling: true,
357                allow_logprobs: false,
358                allow_search_indices: false,
359                allow_view: true,
360                allow_fine_tuning: false,
361                organization: "*".to_string(),
362                group: None,
363                is_blocking: false,
364            }],
365        })
366    }
367}