1use crate::apis::api_client::{
2 ApiClient, CompletionOptions, Message, ToolCall, ToolDefinition, ToolResult,
3};
4use crate::app::logger::{format_log_with_color, LogLevel};
5use crate::errors::AppError;
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
9use reqwest::Client as ReqwestClient;
10use serde::{Deserialize, Serialize};
11use serde_json::{self, json, Value};
12use std::env;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16struct OpenAIFunction {
17 name: String,
18 description: String,
19 parameters: Value,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23struct OpenAITool {
24 #[serde(rename = "type")]
25 tool_type: String,
26 function: OpenAIFunction,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct OpenAIFunctionCall {
31 name: String,
32 arguments: String,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36struct OpenAIToolCall {
37 id: String,
38 #[serde(rename = "type")]
39 tool_type: String,
40 function: OpenAIFunctionCall,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44struct OpenAIMessage {
45 role: String,
46 content: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 tool_calls: Option<Vec<OpenAIToolCall>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 tool_call_id: Option<String>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54struct OpenAIRequest {
55 model: String,
56 messages: Vec<OpenAIMessage>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 max_tokens: Option<u32>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 temperature: Option<f32>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 top_p: Option<f32>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 tools: Option<Vec<OpenAITool>>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 tool_choice: Option<String>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 response_format: Option<Value>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72struct OpenAIResponseChoice {
73 index: usize,
74 message: OpenAIMessage,
75 finish_reason: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct OpenAIResponse {
80 id: String,
81 object: String,
82 created: u64,
83 model: String,
84 choices: Vec<OpenAIResponseChoice>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 usage: Option<Value>,
87}
88
89pub struct OpenAIClient {
90 client: ReqwestClient,
91 model: String,
92 api_base: String,
93}
94
95impl OpenAIClient {
96 pub fn new(model: Option<String>) -> Result<Self> {
97 let api_key =
99 env::var("OPENAI_API_KEY").context("OPENAI_API_KEY environment variable not set")?;
100
101 Self::with_api_key(api_key, model)
102 }
103
104 pub fn with_api_key(api_key: String, model: Option<String>) -> Result<Self> {
105 let mut headers = HeaderMap::new();
107 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
108 headers.insert(
109 AUTHORIZATION,
110 HeaderValue::from_str(&format!("Bearer {}", api_key))?,
111 );
112
113 let client = ReqwestClient::builder().default_headers(headers).build()?;
114
115 let model = model.unwrap_or_else(|| "gpt-4o".to_string());
117
118 Ok(Self {
119 client,
120 model,
121 api_base: "https://api.openai.com/v1/chat/completions".to_string(),
122 })
123 }
124
125 fn convert_messages(&self, messages: Vec<Message>) -> Vec<OpenAIMessage> {
126 messages
127 .into_iter()
128 .map(|msg| {
129 OpenAIMessage {
131 role: msg.role,
132 content: Some(msg.content),
133 tool_calls: None,
134 tool_call_id: None,
135 }
136 })
137 .collect()
138 }
139
140 fn convert_tool_definitions(&self, tools: Vec<ToolDefinition>) -> Vec<OpenAITool> {
141 tools
142 .into_iter()
143 .map(|tool| OpenAITool {
144 tool_type: "function".to_string(),
145 function: OpenAIFunction {
146 name: tool.name,
147 description: tool.description,
148 parameters: tool.parameters,
149 },
150 })
151 .collect()
152 }
153}
154
155#[async_trait]
156impl ApiClient for OpenAIClient {
157 async fn complete(&self, messages: Vec<Message>, options: CompletionOptions) -> Result<String> {
158 let openai_messages = self.convert_messages(messages);
159
160 let mut request = OpenAIRequest {
161 model: self.model.clone(),
162 messages: openai_messages,
163 max_tokens: options.max_tokens,
164 temperature: options.temperature,
165 top_p: options.top_p,
166 tools: None,
167 tool_choice: None,
168 response_format: None,
169 };
170
171 if let Some(_json_schema) = &options.json_schema {
173 request.response_format = Some(json!({
174 "type": "json_object"
175 }));
176 }
177
178 eprintln!(
179 "{}",
180 format_log_with_color(
181 LogLevel::Debug,
182 &format!("Sending request to OpenAI API with model: {}", self.model)
183 )
184 );
185
186 let response = self
187 .client
188 .post(&self.api_base)
189 .json(&request)
190 .send()
191 .await
192 .map_err(|e| {
193 let error_msg = format!("Failed to send request to OpenAI: {}", e);
194 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
195 AppError::NetworkError(error_msg)
196 })?;
197
198 if !response.status().is_success() {
199 let status = response.status();
200 let error_text = response
201 .text()
202 .await
203 .unwrap_or_else(|_| "Unknown error".to_string());
204 return Err(AppError::NetworkError(format!(
205 "OpenAI API error: {} - {}",
206 status, error_text
207 ))
208 .into());
209 }
210
211 let response_text = response.text().await.map_err(|e| {
213 let error_msg = format!("Failed to get response text: {}", e);
214 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
215 AppError::NetworkError(error_msg)
216 })?;
217
218 eprintln!(
219 "{}",
220 format_log_with_color(
221 LogLevel::Debug,
222 &format!(
223 "OpenAI API response received: {} bytes",
224 response_text.len()
225 )
226 )
227 );
228
229 let openai_response: OpenAIResponse =
230 serde_json::from_str(&response_text).map_err(|e| {
231 let error_msg = format!("Failed to parse OpenAI response: {}", e);
232 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
233 AppError::Other(error_msg)
234 })?;
235
236 if let Some(first_choice) = openai_response.choices.first() {
238 if let Some(content) = &first_choice.message.content {
239 return Ok(content.clone());
240 }
241 }
242
243 let error_msg = "No content in OpenAI response".to_string();
244 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
245 Err(AppError::LLMError(error_msg).into())
246 }
247
248 async fn complete_with_tools(
249 &self,
250 messages: Vec<Message>,
251 options: CompletionOptions,
252 tool_results: Option<Vec<ToolResult>>,
253 ) -> Result<(String, Option<Vec<ToolCall>>)> {
254 let mut openai_messages = self.convert_messages(messages);
256
257 let mut pending_tool_calls = Vec::new();
259
260 for msg in &openai_messages {
262 if msg.role == "assistant" && msg.tool_calls.is_some() {
263 if let Some(tool_calls) = &msg.tool_calls {
264 for call in tool_calls {
265 pending_tool_calls.push(call.id.clone());
266 }
267 }
268 }
269 }
270
271 for msg in &openai_messages {
273 if msg.role == "tool" && msg.tool_call_id.is_some() {
274 if let Some(tool_call_id) = &msg.tool_call_id {
275 pending_tool_calls.retain(|id| id != tool_call_id);
276 }
277 }
278 }
279
280 if let Some(results) = &tool_results {
282 let result_map: std::collections::HashMap<String, String> = results
283 .iter()
284 .map(|r| (r.tool_call_id.clone(), r.output.clone()))
285 .collect();
286
287 for tool_id in &pending_tool_calls {
289 if let Some(output) = result_map.get(tool_id) {
290 openai_messages.push(OpenAIMessage {
291 role: "tool".to_string(),
292 content: Some(output.clone()),
293 tool_calls: None,
294 tool_call_id: Some(tool_id.clone()),
295 });
296 } else {
297 openai_messages.push(OpenAIMessage {
300 role: "tool".to_string(),
301 content: Some(
302 "Tool execution completed without detailed results.".to_string(),
303 ),
304 tool_calls: None,
305 tool_call_id: Some(tool_id.clone()),
306 });
307 }
308 }
309 } else if !pending_tool_calls.is_empty() {
310 for tool_id in &pending_tool_calls {
313 openai_messages.push(OpenAIMessage {
314 role: "tool".to_string(),
315 content: Some("Tool execution completed without detailed results.".to_string()),
316 tool_calls: None,
317 tool_call_id: Some(tool_id.clone()),
318 });
319 }
320 }
321
322 let mut request = OpenAIRequest {
323 model: self.model.clone(),
324 messages: openai_messages,
325 max_tokens: options.max_tokens,
326 temperature: options.temperature,
327 top_p: options.top_p,
328 tools: None,
329 tool_choice: None,
330 response_format: None,
331 };
332
333 if let Some(_json_schema) = &options.json_schema {
335 request.response_format = Some(json!({
336 "type": "json_object"
337 }));
338
339 let has_json_keyword = request.messages.iter().any(|msg| {
341 msg.content
342 .as_ref()
343 .is_some_and(|content| content.to_lowercase().contains("json"))
344 });
345
346 if !has_json_keyword && !request.messages.is_empty() {
347 if let Some(last_user_msg) = request
349 .messages
350 .iter_mut()
351 .rev()
352 .find(|msg| msg.role == "user")
353 {
354 if let Some(content) = &mut last_user_msg.content {
355 *content = format!("{} (Please provide the response as JSON)", content);
356 }
357 }
358 }
359 }
360
361 if let Some(tools) = options.tools {
363 let converted_tools = self.convert_tool_definitions(tools);
364 request.tools = Some(converted_tools);
365
366 request.tool_choice = if options.require_tool_use {
368 Some("required".to_string())
369 } else {
370 Some("auto".to_string())
371 };
372 }
373
374 eprintln!(
375 "{}",
376 format_log_with_color(
377 LogLevel::Debug,
378 &format!("Sending request to OpenAI API with model: {}", self.model)
379 )
380 );
381
382 let response = self
383 .client
384 .post(&self.api_base)
385 .json(&request)
386 .send()
387 .await
388 .map_err(|e| {
389 let error_msg = format!("Failed to send request to OpenAI: {}", e);
390 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
391 AppError::NetworkError(error_msg)
392 })?;
393
394 if !response.status().is_success() {
395 let status = response.status();
396 let error_text = response
397 .text()
398 .await
399 .unwrap_or_else(|_| "Unknown error".to_string());
400 return Err(AppError::NetworkError(format!(
401 "OpenAI API error: {} - {}",
402 status, error_text
403 ))
404 .into());
405 }
406
407 let response_text = response.text().await.map_err(|e| {
409 let error_msg = format!("Failed to get response text: {}", e);
410 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
411 AppError::NetworkError(error_msg)
412 })?;
413
414 eprintln!(
415 "{}",
416 format_log_with_color(
417 LogLevel::Debug,
418 &format!(
419 "OpenAI API response received: {} bytes",
420 response_text.len()
421 )
422 )
423 );
424
425 let openai_response: OpenAIResponse =
426 serde_json::from_str(&response_text).map_err(|e| {
427 let error_msg = format!("Failed to parse OpenAI response: {}", e);
428 eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
429 AppError::Other(error_msg)
430 })?;
431
432 if let Some(first_choice) = openai_response.choices.first() {
434 let content = first_choice.message.content.clone().unwrap_or_default();
435
436 let tool_calls = if let Some(openai_tool_calls) = &first_choice.message.tool_calls {
438 if openai_tool_calls.is_empty() {
439 None
440 } else {
441 let calls = openai_tool_calls
442 .iter()
443 .map(|call| {
444 let arguments_result =
446 serde_json::from_str::<Value>(&call.function.arguments);
447 let arguments = match arguments_result {
448 Ok(args) => args,
449 Err(_) => json!({}),
450 };
451
452 ToolCall {
454 id: Some(call.id.clone()), name: call.function.name.clone(),
456 arguments,
457 }
458 })
459 .collect::<Vec<_>>();
460
461 if calls.is_empty() {
462 None
463 } else {
464 Some(calls)
465 }
466 }
467 } else {
468 None
469 };
470
471 return Ok((content, tool_calls));
472 }
473
474 Ok((String::new(), None))
475 }
476}