1use crate::core::Protocol;
6use crate::types::{ChatRequest, ChatResponse, Message, Role, Choice, Usage};
7use crate::error::LlmConnectorError;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10
11#[derive(Clone, Debug)]
13pub struct OpenAIProtocol {
14 api_key: String,
15}
16
17impl OpenAIProtocol {
18 pub fn new(api_key: &str) -> Self {
20 Self {
21 api_key: api_key.to_string(),
22 }
23 }
24
25 pub fn api_key(&self) -> &str {
27 &self.api_key
28 }
29}
30
31#[async_trait]
32impl Protocol for OpenAIProtocol {
33 type Request = OpenAIRequest;
34 type Response = OpenAIResponse;
35
36 fn name(&self) -> &str {
37 "openai"
38 }
39
40 fn chat_endpoint(&self, base_url: &str) -> String {
41 format!("{}/v1/chat/completions", base_url.trim_end_matches('/'))
42 }
43
44 fn models_endpoint(&self, base_url: &str) -> Option<String> {
45 Some(format!("{}/v1/models", base_url.trim_end_matches('/')))
46 }
47
48 fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError> {
49 let messages = request.messages.iter()
50 .map(|msg| {
51 let content = if msg.content.len() == 1 && msg.content[0].is_text() {
53 serde_json::json!(msg.content[0].as_text().unwrap())
55 } else {
56 serde_json::to_value(&msg.content).unwrap()
58 };
59
60 OpenAIMessage {
61 role: match msg.role {
62 Role::User => "user".to_string(),
63 Role::Assistant => "assistant".to_string(),
64 Role::System => "system".to_string(),
65 Role::Tool => "tool".to_string(),
66 },
67 content,
68 tool_calls: msg.tool_calls.as_ref().map(|calls| {
69 calls.iter().map(|call| {
70 serde_json::json!({
71 "id": call.id,
72 "type": call.call_type,
73 "function": {
74 "name": call.function.name,
75 "arguments": call.function.arguments,
76 }
77 })
78 }).collect()
79 }),
80 tool_call_id: msg.tool_call_id.clone(),
81 name: msg.name.clone(),
82 }
83 })
84 .collect();
85
86 let tools = request.tools.as_ref().map(|tools| {
88 tools.iter().map(|tool| {
89 serde_json::json!({
90 "type": tool.tool_type,
91 "function": {
92 "name": tool.function.name,
93 "description": tool.function.description,
94 "parameters": tool.function.parameters,
95 }
96 })
97 }).collect()
98 });
99
100 let tool_choice = request.tool_choice.as_ref().map(|choice| {
102 serde_json::to_value(choice).unwrap_or(serde_json::json!("auto"))
103 });
104
105 Ok(OpenAIRequest {
106 model: request.model.clone(),
107 messages,
108 temperature: request.temperature,
109 max_tokens: request.max_tokens,
110 top_p: request.top_p,
111 frequency_penalty: request.frequency_penalty,
112 presence_penalty: request.presence_penalty,
113 stream: request.stream,
114 tools,
115 tool_choice,
116 })
117 }
118
119 fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
120 let openai_response: OpenAIResponse = serde_json::from_str(response)
121 .map_err(|e| LlmConnectorError::ParseError(format!("Failed to parse OpenAI response: {}", e)))?;
122
123 if openai_response.choices.is_empty() {
124 return Err(LlmConnectorError::ParseError("No choices in response".to_string()));
125 }
126
127 let choices: Vec<Choice> = openai_response.choices.into_iter()
128 .map(|choice| {
129 let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
131 calls.iter().filter_map(|call| {
132 Some(crate::types::ToolCall {
133 id: call.get("id")?.as_str()?.to_string(),
134 call_type: call.get("type")?.as_str()?.to_string(),
135 function: crate::types::FunctionCall {
136 name: call.get("function")?.get("name")?.as_str()?.to_string(),
137 arguments: call.get("function")?.get("arguments")?.as_str()?.to_string(),
138 },
139 })
140 }).collect()
141 });
142
143 let content = if let Some(content_value) = &choice.message.content {
145 if let Some(text) = content_value.as_str() {
146 vec![crate::types::MessageBlock::text(text)]
148 } else if let Some(array) = content_value.as_array() {
149 serde_json::from_value(serde_json::Value::Array(array.clone()))
151 .unwrap_or_else(|_| vec![])
152 } else {
153 vec![]
154 }
155 } else {
156 vec![]
157 };
158
159 Choice {
160 index: choice.index,
161 message: Message {
162 role: Role::Assistant,
163 content,
164 name: None,
165 tool_calls,
166 tool_call_id: None,
167 reasoning_content: choice.message.reasoning_content.clone(),
168 reasoning: None,
169 thought: None,
170 thinking: None,
171 },
172 finish_reason: choice.finish_reason,
173 logprobs: None,
174 }
175 })
176 .collect();
177
178 let usage = openai_response.usage.map(|u| Usage {
179 prompt_tokens: u.prompt_tokens,
180 completion_tokens: u.completion_tokens,
181 total_tokens: u.total_tokens,
182 completion_tokens_details: None,
183 prompt_cache_hit_tokens: None,
184 prompt_cache_miss_tokens: None,
185 prompt_tokens_details: None,
186 });
187
188 let content = choices.first()
190 .map(|choice| choice.message.content_as_text())
191 .unwrap_or_default();
192
193 let reasoning_content = choices.first()
195 .and_then(|c| c.message.reasoning_content.clone());
196
197 Ok(ChatResponse {
198 id: openai_response.id,
199 object: openai_response.object,
200 created: openai_response.created,
201 model: openai_response.model,
202 choices,
203 content,
204 reasoning_content,
205 usage,
206 system_fingerprint: openai_response.system_fingerprint,
207 })
208 }
209
210 fn parse_models(&self, response: &str) -> Result<Vec<String>, LlmConnectorError> {
211 let models_response: OpenAIModelsResponse = serde_json::from_str(response)
212 .map_err(|e| LlmConnectorError::ParseError(format!("Failed to parse models response: {}", e)))?;
213
214 Ok(models_response.data.into_iter().map(|model| model.id).collect())
215 }
216
217 fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
218 let error_info = serde_json::from_str::<serde_json::Value>(body)
219 .ok()
220 .and_then(|v| v.get("error").cloned())
221 .unwrap_or_else(|| serde_json::json!({"message": body}));
222
223 let message = error_info.get("message")
224 .and_then(|m| m.as_str())
225 .unwrap_or("Unknown OpenAI error");
226
227 match status {
228 400 => LlmConnectorError::InvalidRequest(format!("OpenAI: {}", message)),
229 401 => LlmConnectorError::AuthenticationError(format!("OpenAI: {}", message)),
230 403 => LlmConnectorError::PermissionError(format!("OpenAI: {}", message)),
231 429 => LlmConnectorError::RateLimitError(format!("OpenAI: {}", message)),
232 500..=599 => LlmConnectorError::ServerError(format!("OpenAI: {}", message)),
233 _ => LlmConnectorError::ApiError(format!("OpenAI HTTP {}: {}", status, message)),
234 }
235 }
236
237 fn auth_headers(&self) -> Vec<(String, String)> {
238 vec![
239 ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
240 ("Content-Type".to_string(), "application/json".to_string()),
241 ]
242 }
243}
244
245#[derive(Serialize, Debug)]
247pub struct OpenAIRequest {
248 pub model: String,
249 pub messages: Vec<OpenAIMessage>,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub temperature: Option<f32>,
252 #[serde(skip_serializing_if = "Option::is_none")]
253 pub max_tokens: Option<u32>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 pub top_p: Option<f32>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub frequency_penalty: Option<f32>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 pub presence_penalty: Option<f32>,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub stream: Option<bool>,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 pub tools: Option<Vec<serde_json::Value>>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 pub tool_choice: Option<serde_json::Value>,
266}
267
268#[derive(Serialize, Debug)]
269pub struct OpenAIMessage {
270 pub role: String,
271 pub content: serde_json::Value, #[serde(skip_serializing_if = "Option::is_none")]
273 pub tool_calls: Option<Vec<serde_json::Value>>,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 pub tool_call_id: Option<String>,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 pub name: Option<String>,
278}
279
280#[derive(Deserialize, Debug)]
282pub struct OpenAIResponse {
283 pub id: String,
284 pub object: String,
285 pub created: u64,
286 pub model: String,
287 pub choices: Vec<OpenAIChoice>,
288 pub usage: Option<OpenAIUsage>,
289 pub system_fingerprint: Option<String>,
290}
291
292#[derive(Deserialize, Debug)]
293pub struct OpenAIChoice {
294 pub index: u32,
295 pub message: OpenAIResponseMessage,
296 pub finish_reason: Option<String>,
297}
298
299#[derive(Deserialize, Debug)]
300pub struct OpenAIResponseMessage {
301 pub content: Option<serde_json::Value>, #[serde(skip_serializing_if = "Option::is_none")]
303 pub tool_calls: Option<Vec<serde_json::Value>>,
304 #[serde(skip_serializing_if = "Option::is_none")]
305 pub reasoning_content: Option<String>,
306}
307
308#[derive(Deserialize, Debug)]
309pub struct OpenAIUsage {
310 pub prompt_tokens: u32,
311 pub completion_tokens: u32,
312 pub total_tokens: u32,
313}
314
315#[derive(Deserialize, Debug)]
317pub struct OpenAIModelsResponse {
318 pub data: Vec<OpenAIModel>,
319}
320
321#[derive(Deserialize, Debug)]
322pub struct OpenAIModel {
323 pub id: String,
324}