llm_connector/protocols/
openai.rs

1//! OpenAI协议实现 - V2架构
2//!
3//! 这个模块实现了标准的OpenAI API协议规范。
4
5use crate::core::Protocol;
6use crate::types::{ChatRequest, ChatResponse, Message, Role, Choice, Usage};
7use crate::error::LlmConnectorError;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10
11/// OpenAI协议实现
12#[derive(Clone, Debug)]
13pub struct OpenAIProtocol {
14    api_key: String,
15}
16
17impl OpenAIProtocol {
18    /// 创建新的OpenAI协议实例
19    pub fn new(api_key: &str) -> Self {
20        Self {
21            api_key: api_key.to_string(),
22        }
23    }
24    
25    /// 获取API密钥
26    pub fn api_key(&self) -> &str {
27        &self.api_key
28    }
29}
30
31#[async_trait]
32impl Protocol for OpenAIProtocol {
33    type Request = OpenAIRequest;
34    type Response = OpenAIResponse;
35    
36    fn name(&self) -> &str {
37        "openai"
38    }
39    
40    fn chat_endpoint(&self, base_url: &str) -> String {
41        format!("{}/v1/chat/completions", base_url.trim_end_matches('/'))
42    }
43    
44    fn models_endpoint(&self, base_url: &str) -> Option<String> {
45        Some(format!("{}/v1/models", base_url.trim_end_matches('/')))
46    }
47    
48    fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError> {
49        let messages = request.messages.iter()
50            .map(|msg| {
51                // 转换 MessageBlock 到 OpenAI 格式
52                let content = if msg.content.len() == 1 && msg.content[0].is_text() {
53                    // 纯文本:使用字符串格式
54                    serde_json::json!(msg.content[0].as_text().unwrap())
55                } else {
56                    // 多模态:使用数组格式
57                    serde_json::to_value(&msg.content).unwrap()
58                };
59
60                OpenAIMessage {
61                    role: match msg.role {
62                        Role::User => "user".to_string(),
63                        Role::Assistant => "assistant".to_string(),
64                        Role::System => "system".to_string(),
65                        Role::Tool => "tool".to_string(),
66                    },
67                    content,
68                    tool_calls: msg.tool_calls.as_ref().map(|calls| {
69                        calls.iter().map(|call| {
70                            serde_json::json!({
71                                "id": call.id,
72                                "type": call.call_type,
73                                "function": {
74                                    "name": call.function.name,
75                                    "arguments": call.function.arguments,
76                                }
77                            })
78                        }).collect()
79                    }),
80                    tool_call_id: msg.tool_call_id.clone(),
81                    name: msg.name.clone(),
82                }
83            })
84            .collect();
85
86        // 转换 tools
87        let tools = request.tools.as_ref().map(|tools| {
88            tools.iter().map(|tool| {
89                serde_json::json!({
90                    "type": tool.tool_type,
91                    "function": {
92                        "name": tool.function.name,
93                        "description": tool.function.description,
94                        "parameters": tool.function.parameters,
95                    }
96                })
97            }).collect()
98        });
99
100        // 转换 tool_choice
101        let tool_choice = request.tool_choice.as_ref().map(|choice| {
102            serde_json::to_value(choice).unwrap_or(serde_json::json!("auto"))
103        });
104
105        Ok(OpenAIRequest {
106            model: request.model.clone(),
107            messages,
108            temperature: request.temperature,
109            max_tokens: request.max_tokens,
110            top_p: request.top_p,
111            frequency_penalty: request.frequency_penalty,
112            presence_penalty: request.presence_penalty,
113            stream: request.stream,
114            tools,
115            tool_choice,
116        })
117    }
118    
119    fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
120        let openai_response: OpenAIResponse = serde_json::from_str(response)
121            .map_err(|e| LlmConnectorError::ParseError(format!("Failed to parse OpenAI response: {}", e)))?;
122
123        if openai_response.choices.is_empty() {
124            return Err(LlmConnectorError::ParseError("No choices in response".to_string()));
125        }
126
127        let choices: Vec<Choice> = openai_response.choices.into_iter()
128            .map(|choice| {
129                // 转换 tool_calls
130                let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
131                    calls.iter().filter_map(|call| {
132                        Some(crate::types::ToolCall {
133                            id: call.get("id")?.as_str()?.to_string(),
134                            call_type: call.get("type")?.as_str()?.to_string(),
135                            function: crate::types::FunctionCall {
136                                name: call.get("function")?.get("name")?.as_str()?.to_string(),
137                                arguments: call.get("function")?.get("arguments")?.as_str()?.to_string(),
138                            },
139                        })
140                    }).collect()
141                });
142
143                // 转换 content 到 MessageBlock
144                let content = if let Some(content_value) = &choice.message.content {
145                    if let Some(text) = content_value.as_str() {
146                        // 纯文本
147                        vec![crate::types::MessageBlock::text(text)]
148                    } else if let Some(array) = content_value.as_array() {
149                        // 多模态数组
150                        serde_json::from_value(serde_json::Value::Array(array.clone()))
151                            .unwrap_or_else(|_| vec![])
152                    } else {
153                        vec![]
154                    }
155                } else {
156                    vec![]
157                };
158
159                Choice {
160                    index: choice.index,
161                    message: Message {
162                        role: Role::Assistant,
163                        content,
164                        name: None,
165                        tool_calls,
166                        tool_call_id: None,
167                        reasoning_content: choice.message.reasoning_content.clone(),
168                        reasoning: None,
169                        thought: None,
170                        thinking: None,
171                    },
172                    finish_reason: choice.finish_reason,
173                    logprobs: None,
174                }
175            })
176            .collect();
177
178        let usage = openai_response.usage.map(|u| Usage {
179            prompt_tokens: u.prompt_tokens,
180            completion_tokens: u.completion_tokens,
181            total_tokens: u.total_tokens,
182            completion_tokens_details: None,
183            prompt_cache_hit_tokens: None,
184            prompt_cache_miss_tokens: None,
185            prompt_tokens_details: None,
186        });
187
188        // 提取第一个选择的内容作为便利字段(纯文本)
189        let content = choices.first()
190            .map(|choice| choice.message.content_as_text())
191            .unwrap_or_default();
192
193        // 提取第一个choice的reasoning_content
194        let reasoning_content = choices.first()
195            .and_then(|c| c.message.reasoning_content.clone());
196
197        Ok(ChatResponse {
198            id: openai_response.id,
199            object: openai_response.object,
200            created: openai_response.created,
201            model: openai_response.model,
202            choices,
203            content,
204            reasoning_content,
205            usage,
206            system_fingerprint: openai_response.system_fingerprint,
207        })
208    }
209    
210    fn parse_models(&self, response: &str) -> Result<Vec<String>, LlmConnectorError> {
211        let models_response: OpenAIModelsResponse = serde_json::from_str(response)
212            .map_err(|e| LlmConnectorError::ParseError(format!("Failed to parse models response: {}", e)))?;
213
214        Ok(models_response.data.into_iter().map(|model| model.id).collect())
215    }
216    
217    fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
218        let error_info = serde_json::from_str::<serde_json::Value>(body)
219            .ok()
220            .and_then(|v| v.get("error").cloned())
221            .unwrap_or_else(|| serde_json::json!({"message": body}));
222            
223        let message = error_info.get("message")
224            .and_then(|m| m.as_str())
225            .unwrap_or("Unknown OpenAI error");
226
227        match status {
228            400 => LlmConnectorError::InvalidRequest(format!("OpenAI: {}", message)),
229            401 => LlmConnectorError::AuthenticationError(format!("OpenAI: {}", message)),
230            403 => LlmConnectorError::PermissionError(format!("OpenAI: {}", message)),
231            429 => LlmConnectorError::RateLimitError(format!("OpenAI: {}", message)),
232            500..=599 => LlmConnectorError::ServerError(format!("OpenAI: {}", message)),
233            _ => LlmConnectorError::ApiError(format!("OpenAI HTTP {}: {}", status, message)),
234        }
235    }
236    
237    fn auth_headers(&self) -> Vec<(String, String)> {
238        vec![
239            ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
240            ("Content-Type".to_string(), "application/json".to_string()),
241        ]
242    }
243}
244
245// OpenAI请求类型
246#[derive(Serialize, Debug)]
247pub struct OpenAIRequest {
248    pub model: String,
249    pub messages: Vec<OpenAIMessage>,
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub temperature: Option<f32>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub max_tokens: Option<u32>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub top_p: Option<f32>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub frequency_penalty: Option<f32>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub presence_penalty: Option<f32>,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub stream: Option<bool>,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub tools: Option<Vec<serde_json::Value>>,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub tool_choice: Option<serde_json::Value>,
266}
267
268#[derive(Serialize, Debug)]
269pub struct OpenAIMessage {
270    pub role: String,
271    pub content: serde_json::Value,  // 支持 String 或 Array
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub tool_calls: Option<Vec<serde_json::Value>>,
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub tool_call_id: Option<String>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub name: Option<String>,
278}
279
280// OpenAI响应类型
281#[derive(Deserialize, Debug)]
282pub struct OpenAIResponse {
283    pub id: String,
284    pub object: String,
285    pub created: u64,
286    pub model: String,
287    pub choices: Vec<OpenAIChoice>,
288    pub usage: Option<OpenAIUsage>,
289    pub system_fingerprint: Option<String>,
290}
291
292#[derive(Deserialize, Debug)]
293pub struct OpenAIChoice {
294    pub index: u32,
295    pub message: OpenAIResponseMessage,
296    pub finish_reason: Option<String>,
297}
298
299#[derive(Deserialize, Debug)]
300pub struct OpenAIResponseMessage {
301    pub content: Option<serde_json::Value>,  // 支持 String 或 Array
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub tool_calls: Option<Vec<serde_json::Value>>,
304    #[serde(skip_serializing_if = "Option::is_none")]
305    pub reasoning_content: Option<String>,
306}
307
308#[derive(Deserialize, Debug)]
309pub struct OpenAIUsage {
310    pub prompt_tokens: u32,
311    pub completion_tokens: u32,
312    pub total_tokens: u32,
313}
314
315// 模型列表响应
316#[derive(Deserialize, Debug)]
317pub struct OpenAIModelsResponse {
318    pub data: Vec<OpenAIModel>,
319}
320
321#[derive(Deserialize, Debug)]
322pub struct OpenAIModel {
323    pub id: String,
324}