ai_lib/provider/
generic.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpTransport, DynHttpTransportRef};
4use super::config::ProviderConfig;
5use std::env;
6use futures::stream::{Stream, StreamExt};
7
8/// 配置驱动的通用适配器,支持OpenAI兼容API
9/// 
10/// Configuration-driven generic adapter for OpenAI-compatible APIs
11pub struct GenericAdapter {
12    transport: DynHttpTransportRef,
13    config: ProviderConfig,
14    api_key: Option<String>,
15}
16
17impl GenericAdapter {
18    pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
19        // For generic/config-driven providers we treat the API key as optional.
20        // Some deployments (e.g. local Ollama) don't require a key. If the env var
21        // is missing we continue with None and callers will simply omit auth headers.
22        let api_key = env::var(&config.api_key_env).ok();
23
24        Ok(Self {
25            transport: HttpTransport::new().boxed(),
26            config,
27            api_key,
28        })
29    }
30    
31    /// Create adapter with custom transport layer (for testing)
32    pub fn with_transport(config: ProviderConfig, transport: HttpTransport) -> Result<Self, AiLibError> {
33        let api_key = env::var(&config.api_key_env).ok();
34
35        Ok(Self {
36            transport: transport.boxed(),
37            config,
38            api_key,
39        })
40    }
41
42    /// Accept an object-safe transport reference directly
43    pub fn with_transport_ref(config: ProviderConfig, transport: DynHttpTransportRef) -> Result<Self, AiLibError> {
44        let api_key = env::var(&config.api_key_env).ok();
45        Ok(Self { transport, config, api_key })
46    }
47
48    /// Convert generic request to provider-specific format
49    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
50        let default_role = "user".to_string();
51        
52        // Build messages array
53        let messages: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
54            let role_key = format!("{:?}", msg.role);
55            let mapped_role = self.config.field_mapping.role_mapping
56                .get(&role_key)
57                .unwrap_or(&default_role);
58            serde_json::json!({
59                "role": mapped_role,
60                "content": msg.content
61            })
62        }).collect();
63        
64        // Use string literals as JSON keys
65        let mut provider_request = serde_json::json!({
66            "model": request.model,
67            "messages": messages
68        });
69
70        // Add optional parameters
71        if let Some(temp) = request.temperature {
72            provider_request["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
73        }
74        if let Some(max_tokens) = request.max_tokens {
75            provider_request["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
76        }
77        if let Some(top_p) = request.top_p {
78            provider_request["top_p"] = serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
79        }
80        if let Some(freq_penalty) = request.frequency_penalty {
81            provider_request["frequency_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(freq_penalty.into()).unwrap());
82        }
83        if let Some(presence_penalty) = request.presence_penalty {
84            provider_request["presence_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(presence_penalty.into()).unwrap());
85        }
86
87        provider_request
88    }
89
90    /// Find event boundary
91    fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
92        let mut i = 0;
93        while i < buffer.len().saturating_sub(1) {
94            if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
95                return Some(i + 2);
96            }
97            if i < buffer.len().saturating_sub(3) 
98                && buffer[i] == b'\r' && buffer[i + 1] == b'\n' 
99                && buffer[i + 2] == b'\r' && buffer[i + 3] == b'\n' {
100                return Some(i + 4);
101            }
102            i += 1;
103        }
104        None
105    }
106    
107    /// Parse SSE event
108    fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
109        for line in event_text.lines() {
110            let line = line.trim();
111            if line.starts_with("data: ") {
112                let data = &line[6..];
113                if data == "[DONE]" {
114                    return Some(Ok(None));
115                }
116                return Some(Self::parse_chunk_data(data));
117            }
118        }
119        None
120    }
121    
122    /// Parse chunk data
123    fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
124        match serde_json::from_str::<serde_json::Value>(data) {
125            Ok(json) => {
126                let choices = json["choices"].as_array()
127                    .map(|arr| {
128                        arr.iter()
129                            .enumerate()
130                            .map(|(index, choice)| {
131                                let delta = &choice["delta"];
132                                ChoiceDelta {
133                                    index: index as u32,
134                                    delta: MessageDelta {
135                                        role: delta["role"].as_str().map(|r| match r {
136                                            "assistant" => Role::Assistant,
137                                            "user" => Role::User,
138                                            "system" => Role::System,
139                                            _ => Role::Assistant,
140                                        }),
141                                        content: delta["content"].as_str().map(str::to_string),
142                                    },
143                                    finish_reason: choice["finish_reason"].as_str().map(str::to_string),
144                                }
145                            })
146                            .collect()
147                    })
148                    .unwrap_or_default();
149                
150                Ok(Some(ChatCompletionChunk {
151                    id: json["id"].as_str().unwrap_or_default().to_string(),
152                    object: json["object"].as_str().unwrap_or("chat.completion.chunk").to_string(),
153                    created: json["created"].as_u64().unwrap_or(0),
154                    model: json["model"].as_str().unwrap_or_default().to_string(),
155                    choices,
156                }))
157            }
158            Err(e) => Err(AiLibError::ProviderError(format!("JSON parse error: {}", e)))
159        }
160    }
161    
162    /// Parse response
163    fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
164        let choices = response["choices"]
165            .as_array()
166            .ok_or_else(|| AiLibError::ProviderError("Invalid response format: choices not found".to_string()))?
167            .iter()
168            .enumerate()
169            .map(|(index, choice)| {
170                let message = choice["message"].as_object()
171                    .ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
172                
173                let role = match message["role"].as_str().unwrap_or("user") {
174                    "system" => Role::System,
175                    "assistant" => Role::Assistant,
176                    _ => Role::User,
177                };
178                
179                let content = message["content"].as_str()
180                    .unwrap_or("")
181                    .to_string();
182                
183                Ok(Choice {
184                    index: index as u32,
185                    message: Message { role, content },
186                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
187                })
188            })
189            .collect::<Result<Vec<_>, AiLibError>>()?;
190        
191        let usage = response["usage"].as_object()
192            .ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
193        
194        let usage = Usage {
195            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
196            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
197            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
198        };
199        
200        Ok(ChatCompletionResponse {
201            id: response["id"].as_str().unwrap_or("").to_string(),
202            object: response["object"].as_str().unwrap_or("").to_string(),
203            created: response["created"].as_u64().unwrap_or(0),
204            model: response["model"].as_str().unwrap_or("").to_string(),
205            choices,
206            usage,
207        })
208    }
209}
210
211#[async_trait::async_trait]
212impl ChatApi for GenericAdapter {
213    async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
214        let provider_request = self.convert_request(&request);
215        let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
216        
217        let mut headers = self.config.headers.clone();
218        
219        // Set different authentication methods based on provider when an API key is present
220        if let Some(key) = &self.api_key {
221            if self.config.base_url.contains("anthropic.com") {
222                headers.insert("x-api-key".to_string(), key.clone());
223            } else {
224                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
225            }
226        }
227        
228        let response: serde_json::Value = self.transport
229            .post_json(&url, Some(headers), provider_request)
230            .await?;
231        
232        self.parse_response(response)
233    }
234
235    async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
236        let mut stream_request = self.convert_request(&request);
237        stream_request["stream"] = serde_json::Value::Bool(true);
238        
239        let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
240        
241        // Create HTTP client
242        let mut client_builder = reqwest::Client::builder();
243        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
244            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
245                client_builder = client_builder.proxy(proxy);
246            }
247        }
248        let client = client_builder.build()
249            .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
250        
251        let mut headers = self.config.headers.clone();
252        headers.insert("Accept".to_string(), "text/event-stream".to_string());
253        
254        // Set different authentication methods based on provider when an API key is present
255        if let Some(key) = &self.api_key {
256            if self.config.base_url.contains("anthropic.com") {
257                headers.insert("x-api-key".to_string(), key.clone());
258            } else {
259                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
260            }
261        }
262        
263        let response = client
264            .post(&url)
265            .json(&stream_request);
266        
267        let mut req = response;
268        for (key, value) in headers {
269            req = req.header(key, value);
270        }
271        
272        let response = req.send().await
273            .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
274        
275        if !response.status().is_success() {
276            let error_text = response.text().await.unwrap_or_default();
277            return Err(AiLibError::ProviderError(format!("Stream error: {}", error_text)));
278        }
279        
280        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
281        
282        tokio::spawn(async move {
283            let mut buffer = Vec::new();
284            let mut stream = response.bytes_stream();
285            
286            while let Some(result) = stream.next().await {
287                match result {
288                    Ok(bytes) => {
289                        buffer.extend_from_slice(&bytes);
290                        
291                        while let Some(event_end) = Self::find_event_boundary(&buffer) {
292                            let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
293                            
294                            if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
295                                if let Some(chunk) = Self::parse_sse_event(event_text) {
296                                    match chunk {
297                                        Ok(Some(c)) => {
298                                            if tx.send(Ok(c)).is_err() {
299                                                return;
300                                            }
301                                        }
302                                        Ok(None) => return,
303                                        Err(e) => {
304                                            let _ = tx.send(Err(e));
305                                            return;
306                                        }
307                                    }
308                                }
309                            }
310                        }
311                    }
312                    Err(e) => {
313                        let _ = tx.send(Err(AiLibError::ProviderError(format!("Stream error: {}", e))));
314                        break;
315                    }
316                }
317            }
318        });
319        
320        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
321        Ok(Box::new(Box::pin(stream)))
322    }
323    
324
325
326    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
327        if let Some(models_endpoint) = &self.config.models_endpoint {
328            let url = format!("{}{}", self.config.base_url, models_endpoint);
329            let mut headers = self.config.headers.clone();
330            
331            // Set different authentication methods based on provider when an API key is present
332            if let Some(key) = &self.api_key {
333                if self.config.base_url.contains("anthropic.com") {
334                    headers.insert("x-api-key".to_string(), key.clone());
335                } else {
336                    headers.insert("Authorization".to_string(), format!("Bearer {}", key));
337                }
338            }
339            
340            let response: serde_json::Value = self.transport
341                .get_json(&url, Some(headers))
342                .await?;
343            
344            Ok(response["data"].as_array()
345                .unwrap_or(&vec![])
346                .iter()
347                .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
348                .collect())
349        } else {
350            Err(AiLibError::ProviderError("Models endpoint not configured".to_string()))
351        }
352    }
353
354    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
355        Ok(ModelInfo {
356            id: model_id.to_string(),
357            object: "model".to_string(),
358            created: 0,
359            owned_by: "generic".to_string(),
360            permission: vec![ModelPermission {
361                id: "default".to_string(),
362                object: "model_permission".to_string(),
363                created: 0,
364                allow_create_engine: false,
365                allow_sampling: true,
366                allow_logprobs: false,
367                allow_search_indices: false,
368                allow_view: true,
369                allow_fine_tuning: false,
370                organization: "*".to_string(),
371                group: None,
372                is_blocking: false,
373            }],
374        })
375    }
376}