oli_server/apis/
openai.rs

1use crate::apis::api_client::{
2    ApiClient, CompletionOptions, Message, ToolCall, ToolDefinition, ToolResult,
3};
4use crate::app::logger::{format_log_with_color, LogLevel};
5use crate::errors::AppError;
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
9use reqwest::Client as ReqwestClient;
10use serde::{Deserialize, Serialize};
11use serde_json::{self, json, Value};
12use std::env;
13
14// OpenAI API Types
15#[derive(Debug, Clone, Serialize, Deserialize)]
16struct OpenAIFunction {
17    name: String,
18    description: String,
19    parameters: Value,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23struct OpenAITool {
24    #[serde(rename = "type")]
25    tool_type: String,
26    function: OpenAIFunction,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct OpenAIFunctionCall {
31    name: String,
32    arguments: String,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36struct OpenAIToolCall {
37    id: String,
38    #[serde(rename = "type")]
39    tool_type: String,
40    function: OpenAIFunctionCall,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44struct OpenAIMessage {
45    role: String,
46    content: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    tool_calls: Option<Vec<OpenAIToolCall>>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    tool_call_id: Option<String>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54struct OpenAIRequest {
55    model: String,
56    messages: Vec<OpenAIMessage>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    max_tokens: Option<u32>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    temperature: Option<f32>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    top_p: Option<f32>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    tools: Option<Vec<OpenAITool>>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    tool_choice: Option<String>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    response_format: Option<Value>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72struct OpenAIResponseChoice {
73    index: usize,
74    message: OpenAIMessage,
75    finish_reason: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct OpenAIResponse {
80    id: String,
81    object: String,
82    created: u64,
83    model: String,
84    choices: Vec<OpenAIResponseChoice>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    usage: Option<Value>,
87}
88
89pub struct OpenAIClient {
90    client: ReqwestClient,
91    model: String,
92    api_base: String,
93}
94
95impl OpenAIClient {
96    pub fn new(model: Option<String>) -> Result<Self> {
97        // Try to get API key from environment
98        let api_key =
99            env::var("OPENAI_API_KEY").context("OPENAI_API_KEY environment variable not set")?;
100
101        Self::with_api_key(api_key, model)
102    }
103
104    pub fn with_api_key(api_key: String, model: Option<String>) -> Result<Self> {
105        // Create new client with appropriate headers
106        let mut headers = HeaderMap::new();
107        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
108        headers.insert(
109            AUTHORIZATION,
110            HeaderValue::from_str(&format!("Bearer {}", api_key))?,
111        );
112
113        let client = ReqwestClient::builder().default_headers(headers).build()?;
114
115        // Default to GPT-4o as the latest model with tooling capabilities
116        let model = model.unwrap_or_else(|| "gpt-4o".to_string());
117
118        Ok(Self {
119            client,
120            model,
121            api_base: "https://api.openai.com/v1/chat/completions".to_string(),
122        })
123    }
124
125    fn convert_messages(&self, messages: Vec<Message>) -> Vec<OpenAIMessage> {
126        messages
127            .into_iter()
128            .map(|msg| {
129                // Convert standard messages
130                OpenAIMessage {
131                    role: msg.role,
132                    content: Some(msg.content),
133                    tool_calls: None,
134                    tool_call_id: None,
135                }
136            })
137            .collect()
138    }
139
140    fn convert_tool_definitions(&self, tools: Vec<ToolDefinition>) -> Vec<OpenAITool> {
141        tools
142            .into_iter()
143            .map(|tool| OpenAITool {
144                tool_type: "function".to_string(),
145                function: OpenAIFunction {
146                    name: tool.name,
147                    description: tool.description,
148                    parameters: tool.parameters,
149                },
150            })
151            .collect()
152    }
153}
154
155#[async_trait]
156impl ApiClient for OpenAIClient {
157    async fn complete(&self, messages: Vec<Message>, options: CompletionOptions) -> Result<String> {
158        let openai_messages = self.convert_messages(messages);
159
160        let mut request = OpenAIRequest {
161            model: self.model.clone(),
162            messages: openai_messages,
163            max_tokens: options.max_tokens,
164            temperature: options.temperature,
165            top_p: options.top_p,
166            tools: None,
167            tool_choice: None,
168            response_format: None,
169        };
170
171        // Add structured output format if specified in options
172        if let Some(_json_schema) = &options.json_schema {
173            request.response_format = Some(json!({
174                "type": "json_object"
175            }));
176        }
177
178        eprintln!(
179            "{}",
180            format_log_with_color(
181                LogLevel::Debug,
182                &format!("Sending request to OpenAI API with model: {}", self.model)
183            )
184        );
185
186        let response = self
187            .client
188            .post(&self.api_base)
189            .json(&request)
190            .send()
191            .await
192            .map_err(|e| {
193                let error_msg = format!("Failed to send request to OpenAI: {}", e);
194                eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
195                AppError::NetworkError(error_msg)
196            })?;
197
198        if !response.status().is_success() {
199            let status = response.status();
200            let error_text = response
201                .text()
202                .await
203                .unwrap_or_else(|_| "Unknown error".to_string());
204            return Err(AppError::NetworkError(format!(
205                "OpenAI API error: {} - {}",
206                status, error_text
207            ))
208            .into());
209        }
210
211        // Parse response
212        let response_text = response.text().await.map_err(|e| {
213            let error_msg = format!("Failed to get response text: {}", e);
214            eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
215            AppError::NetworkError(error_msg)
216        })?;
217
218        eprintln!(
219            "{}",
220            format_log_with_color(
221                LogLevel::Debug,
222                &format!(
223                    "OpenAI API response received: {} bytes",
224                    response_text.len()
225                )
226            )
227        );
228
229        let openai_response: OpenAIResponse =
230            serde_json::from_str(&response_text).map_err(|e| {
231                let error_msg = format!("Failed to parse OpenAI response: {}", e);
232                eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
233                AppError::Other(error_msg)
234            })?;
235
236        // Extract content from the first choice
237        if let Some(first_choice) = openai_response.choices.first() {
238            if let Some(content) = &first_choice.message.content {
239                return Ok(content.clone());
240            }
241        }
242
243        let error_msg = "No content in OpenAI response".to_string();
244        eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
245        Err(AppError::LLMError(error_msg).into())
246    }
247
248    async fn complete_with_tools(
249        &self,
250        messages: Vec<Message>,
251        options: CompletionOptions,
252        tool_results: Option<Vec<ToolResult>>,
253    ) -> Result<(String, Option<Vec<ToolCall>>)> {
254        // Convert messages to OpenAI format
255        let mut openai_messages = self.convert_messages(messages);
256
257        // Track tool calls that need responses
258        let mut pending_tool_calls = Vec::new();
259
260        // First pass: identify all tool calls that need responses
261        for msg in &openai_messages {
262            if msg.role == "assistant" && msg.tool_calls.is_some() {
263                if let Some(tool_calls) = &msg.tool_calls {
264                    for call in tool_calls {
265                        pending_tool_calls.push(call.id.clone());
266                    }
267                }
268            }
269        }
270
271        // Second pass: remove tool call IDs that already have responses
272        for msg in &openai_messages {
273            if msg.role == "tool" && msg.tool_call_id.is_some() {
274                if let Some(tool_call_id) = &msg.tool_call_id {
275                    pending_tool_calls.retain(|id| id != tool_call_id);
276                }
277            }
278        }
279
280        // Add tool results for any pending tool calls
281        if let Some(results) = &tool_results {
282            let result_map: std::collections::HashMap<String, String> = results
283                .iter()
284                .map(|r| (r.tool_call_id.clone(), r.output.clone()))
285                .collect();
286
287            // Add responses for any pending tool calls
288            for tool_id in &pending_tool_calls {
289                if let Some(output) = result_map.get(tool_id) {
290                    openai_messages.push(OpenAIMessage {
291                        role: "tool".to_string(),
292                        content: Some(output.clone()),
293                        tool_calls: None,
294                        tool_call_id: Some(tool_id.clone()),
295                    });
296                } else {
297                    // For any tool call without a provided result, add a default response
298                    // This is crucial for OpenAI - every tool call must have a response
299                    openai_messages.push(OpenAIMessage {
300                        role: "tool".to_string(),
301                        content: Some(
302                            "Tool execution completed without detailed results.".to_string(),
303                        ),
304                        tool_calls: None,
305                        tool_call_id: Some(tool_id.clone()),
306                    });
307                }
308            }
309        } else if !pending_tool_calls.is_empty() {
310            // If we have pending tool calls but no results were provided,
311            // we need to add default responses for all pending tool calls
312            for tool_id in &pending_tool_calls {
313                openai_messages.push(OpenAIMessage {
314                    role: "tool".to_string(),
315                    content: Some("Tool execution completed without detailed results.".to_string()),
316                    tool_calls: None,
317                    tool_call_id: Some(tool_id.clone()),
318                });
319            }
320        }
321
322        let mut request = OpenAIRequest {
323            model: self.model.clone(),
324            messages: openai_messages,
325            max_tokens: options.max_tokens,
326            temperature: options.temperature,
327            top_p: options.top_p,
328            tools: None,
329            tool_choice: None,
330            response_format: None,
331        };
332
333        // Add structured output format if specified in options
334        if let Some(_json_schema) = &options.json_schema {
335            request.response_format = Some(json!({
336                "type": "json_object"
337            }));
338
339            // Ensure at least one message contains the word "json" when using json_object response format
340            let has_json_keyword = request.messages.iter().any(|msg| {
341                msg.content
342                    .as_ref()
343                    .is_some_and(|content| content.to_lowercase().contains("json"))
344            });
345
346            if !has_json_keyword && !request.messages.is_empty() {
347                // Add "json" to the user's last message if it doesn't already contain it
348                if let Some(last_user_msg) = request
349                    .messages
350                    .iter_mut()
351                    .rev()
352                    .find(|msg| msg.role == "user")
353                {
354                    if let Some(content) = &mut last_user_msg.content {
355                        *content = format!("{} (Please provide the response as JSON)", content);
356                    }
357                }
358            }
359        }
360
361        // Add tools if they exist
362        if let Some(tools) = options.tools {
363            let converted_tools = self.convert_tool_definitions(tools);
364            request.tools = Some(converted_tools);
365
366            // Set tool_choice based on option
367            request.tool_choice = if options.require_tool_use {
368                Some("required".to_string())
369            } else {
370                Some("auto".to_string())
371            };
372        }
373
374        eprintln!(
375            "{}",
376            format_log_with_color(
377                LogLevel::Debug,
378                &format!("Sending request to OpenAI API with model: {}", self.model)
379            )
380        );
381
382        let response = self
383            .client
384            .post(&self.api_base)
385            .json(&request)
386            .send()
387            .await
388            .map_err(|e| {
389                let error_msg = format!("Failed to send request to OpenAI: {}", e);
390                eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
391                AppError::NetworkError(error_msg)
392            })?;
393
394        if !response.status().is_success() {
395            let status = response.status();
396            let error_text = response
397                .text()
398                .await
399                .unwrap_or_else(|_| "Unknown error".to_string());
400            return Err(AppError::NetworkError(format!(
401                "OpenAI API error: {} - {}",
402                status, error_text
403            ))
404            .into());
405        }
406
407        // Parse response
408        let response_text = response.text().await.map_err(|e| {
409            let error_msg = format!("Failed to get response text: {}", e);
410            eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
411            AppError::NetworkError(error_msg)
412        })?;
413
414        eprintln!(
415            "{}",
416            format_log_with_color(
417                LogLevel::Debug,
418                &format!(
419                    "OpenAI API response received: {} bytes",
420                    response_text.len()
421                )
422            )
423        );
424
425        let openai_response: OpenAIResponse =
426            serde_json::from_str(&response_text).map_err(|e| {
427                let error_msg = format!("Failed to parse OpenAI response: {}", e);
428                eprintln!("{}", format_log_with_color(LogLevel::Error, &error_msg));
429                AppError::Other(error_msg)
430            })?;
431
432        // Extract content and tool calls from the first choice
433        if let Some(first_choice) = openai_response.choices.first() {
434            let content = first_choice.message.content.clone().unwrap_or_default();
435
436            // Extract tool calls if present
437            let tool_calls = if let Some(openai_tool_calls) = &first_choice.message.tool_calls {
438                if openai_tool_calls.is_empty() {
439                    None
440                } else {
441                    let calls = openai_tool_calls
442                        .iter()
443                        .map(|call| {
444                            // Parse arguments as JSON
445                            let arguments_result =
446                                serde_json::from_str::<Value>(&call.function.arguments);
447                            let arguments = match arguments_result {
448                                Ok(args) => args,
449                                Err(_) => json!({}),
450                            };
451
452                            // Create a tool call with OpenAI's required format
453                            ToolCall {
454                                id: Some(call.id.clone()), // Important for tool results later
455                                name: call.function.name.clone(),
456                                arguments,
457                            }
458                        })
459                        .collect::<Vec<_>>();
460
461                    if calls.is_empty() {
462                        None
463                    } else {
464                        Some(calls)
465                    }
466                }
467            } else {
468                None
469            };
470
471            return Ok((content, tool_calls));
472        }
473
474        Ok((String::new(), None))
475    }
476}