autoagents_llm/backends/
openai.rs

1//! OpenAI API client implementation for chat and completion functionality.
2//!
3//! This module provides integration with OpenAI's GPT models through their API.
4
5use crate::{
6    builder::LLMBackend,
7    chat::Tool,
8    chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
9    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10    embedding::EmbeddingProvider,
11    error::LLMError,
12    models::{ModelListRawEntry, ModelListRequest, ModelListResponse, ModelsProvider},
13    LLMProvider,
14};
15use crate::{
16    builder::LLMBuilder,
17    chat::{ChatResponse, ToolChoice},
18    FunctionCall, ToolCall,
19};
20use async_trait::async_trait;
21use chrono::{DateTime, Utc};
22use either::*;
23use futures::stream::Stream;
24use reqwest::{Client, Url};
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::sync::Arc;
28
29/// Client for interacting with OpenAI's API.
30///
31/// Provides methods for chat and completion requests using OpenAI's models.
32pub struct OpenAI {
33    pub api_key: String,
34    pub base_url: Url,
35    pub model: String,
36    pub max_tokens: Option<u32>,
37    pub temperature: Option<f32>,
38    pub system: Option<String>,
39    pub timeout_seconds: Option<u64>,
40    pub stream: Option<bool>,
41    pub top_p: Option<f32>,
42    pub top_k: Option<u32>,
43    pub tools: Option<Vec<Tool>>,
44    pub tool_choice: Option<ToolChoice>,
45    /// Embedding parameters
46    pub embedding_encoding_format: Option<String>,
47    pub embedding_dimensions: Option<u32>,
48    pub reasoning_effort: Option<String>,
49    /// JSON schema for structured output
50    pub json_schema: Option<StructuredOutputFormat>,
51    pub voice: Option<String>,
52    pub enable_web_search: Option<bool>,
53    pub web_search_context_size: Option<String>,
54    pub web_search_user_location_type: Option<String>,
55    pub web_search_user_location_approximate_country: Option<String>,
56    pub web_search_user_location_approximate_city: Option<String>,
57    pub web_search_user_location_approximate_region: Option<String>,
58    client: Client,
59}
60
61/// Individual message in an OpenAI chat conversation.
62#[derive(Serialize, Debug)]
63struct OpenAIChatMessage<'a> {
64    #[allow(dead_code)]
65    role: &'a str,
66    #[serde(
67        skip_serializing_if = "Option::is_none",
68        with = "either::serde_untagged_optional"
69    )]
70    content: Option<Either<Vec<MessageContent<'a>>, String>>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    tool_calls: Option<Vec<OpenAIFunctionCall<'a>>>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    tool_call_id: Option<String>,
75}
76
77#[derive(Serialize, Debug)]
78struct OpenAIFunctionPayload<'a> {
79    name: &'a str,
80    arguments: &'a str,
81}
82
83#[derive(Serialize, Debug)]
84struct OpenAIFunctionCall<'a> {
85    id: &'a str,
86    #[serde(rename = "type")]
87    content_type: &'a str,
88    function: OpenAIFunctionPayload<'a>,
89}
90
91#[derive(Serialize, Debug)]
92struct MessageContent<'a> {
93    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
94    message_type: Option<&'a str>,
95    #[serde(skip_serializing_if = "Option::is_none")]
96    text: Option<&'a str>,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    image_url: Option<ImageUrlContent<'a>>,
99    #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
100    tool_call_id: Option<&'a str>,
101    #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
102    tool_output: Option<&'a str>,
103}
104
105/// Individual image message in an OpenAI chat conversation.
106#[derive(Serialize, Debug)]
107struct ImageUrlContent<'a> {
108    url: &'a str,
109}
110
111#[derive(Serialize)]
112struct OpenAIEmbeddingRequest {
113    model: String,
114    input: Vec<String>,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    encoding_format: Option<String>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    dimensions: Option<u32>,
119}
120
121/// Request payload for OpenAI's chat API endpoint.
122#[derive(Serialize, Debug)]
123struct OpenAIChatRequest<'a> {
124    model: &'a str,
125    messages: Vec<OpenAIChatMessage<'a>>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    max_tokens: Option<u32>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    temperature: Option<f32>,
130    stream: bool,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    top_p: Option<f32>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    top_k: Option<u32>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    tools: Option<Vec<Tool>>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    tool_choice: Option<ToolChoice>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    reasoning_effort: Option<String>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    response_format: Option<OpenAIResponseFormat>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    web_search_options: Option<OpenAIWebSearchOptions>,
145}
146
147impl std::fmt::Display for ToolCall {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        write!(
150            f,
151            "{{\n  \"id\": \"{}\",\n  \"type\": \"{}\",\n  \"function\": {}\n}}",
152            self.id, self.call_type, self.function
153        )
154    }
155}
156
157impl std::fmt::Display for FunctionCall {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        write!(
160            f,
161            "{{\n  \"name\": \"{}\",\n  \"arguments\": {}\n}}",
162            self.name, self.arguments
163        )
164    }
165}
166
167/// Response from OpenAI's chat API endpoint.
168#[derive(Deserialize, Debug)]
169struct OpenAIChatResponse {
170    choices: Vec<OpenAIChatChoice>,
171}
172
173/// Individual choice within an OpenAI chat API response.
174#[derive(Deserialize, Debug)]
175struct OpenAIChatChoice {
176    message: OpenAIChatMsg,
177}
178
179/// Message content within an OpenAI chat API response.
180#[derive(Deserialize, Debug)]
181struct OpenAIChatMsg {
182    #[allow(dead_code)]
183    role: String,
184    content: Option<String>,
185    tool_calls: Option<Vec<ToolCall>>,
186}
187
188#[derive(Deserialize, Debug)]
189struct OpenAIEmbeddingData {
190    embedding: Vec<f32>,
191}
192#[derive(Deserialize, Debug)]
193struct OpenAIEmbeddingResponse {
194    data: Vec<OpenAIEmbeddingData>,
195}
196
197/// Response from OpenAI's streaming chat API endpoint.
198#[derive(Deserialize, Debug)]
199struct OpenAIChatStreamResponse {
200    choices: Vec<OpenAIChatStreamChoice>,
201}
202
203/// Individual choice within an OpenAI streaming chat API response.
204#[derive(Deserialize, Debug)]
205struct OpenAIChatStreamChoice {
206    delta: OpenAIChatStreamDelta,
207}
208
209/// Delta content within an OpenAI streaming chat API response.
210#[derive(Deserialize, Debug)]
211struct OpenAIChatStreamDelta {
212    content: Option<String>,
213}
214
215/// An object specifying the format that the model must output.
216///Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which ensures the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
217/// Setting to `{ "type": "json_object" }` enables the older JSON mode, which ensures the message the model generates is valid JSON. Using `json_schema` is preferred for models that support it.
218#[derive(Deserialize, Debug, Serialize)]
219enum OpenAIResponseType {
220    #[serde(rename = "text")]
221    Text,
222    #[serde(rename = "json_schema")]
223    JsonSchema,
224    #[serde(rename = "json_object")]
225    JsonObject,
226}
227
228#[derive(Deserialize, Debug, Serialize)]
229struct OpenAIResponseFormat {
230    #[serde(rename = "type")]
231    response_type: OpenAIResponseType,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    json_schema: Option<StructuredOutputFormat>,
234}
235
236#[derive(Deserialize, Debug, Serialize)]
237struct OpenAIWebSearchOptions {
238    #[serde(skip_serializing_if = "Option::is_none")]
239    user_location: Option<UserLocation>,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    search_context_size: Option<String>,
242}
243
244#[derive(Deserialize, Debug, Serialize)]
245struct UserLocation {
246    #[serde(rename = "type")]
247    location_type: String,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    approximate: Option<ApproximateLocation>,
250}
251
252#[derive(Deserialize, Debug, Serialize)]
253struct ApproximateLocation {
254    country: String,
255    city: String,
256    region: String,
257}
258
259impl From<StructuredOutputFormat> for OpenAIResponseFormat {
260    /// Modify the schema to ensure that it meets OpenAI's requirements.
261    fn from(structured_response_format: StructuredOutputFormat) -> Self {
262        // It's possible to pass a StructuredOutputJsonSchema without an actual schema.
263        // In this case, just pass the StructuredOutputJsonSchema object without modifying it.
264        match structured_response_format.schema {
265            None => OpenAIResponseFormat {
266                response_type: OpenAIResponseType::JsonSchema,
267                json_schema: Some(structured_response_format),
268            },
269            Some(mut schema) => {
270                // Although [OpenAI's specifications](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat#additionalproperties-false-must-always-be-set-in-objects) say that the "additionalProperties" field is required, my testing shows that it is not.
271                // Just to be safe, add it to the schema if it is missing.
272                schema = if schema.get("additionalProperties").is_none() {
273                    schema["additionalProperties"] = serde_json::json!(false);
274                    schema
275                } else {
276                    schema
277                };
278
279                OpenAIResponseFormat {
280                    response_type: OpenAIResponseType::JsonSchema,
281                    json_schema: Some(StructuredOutputFormat {
282                        name: structured_response_format.name,
283                        description: structured_response_format.description,
284                        schema: Some(schema),
285                        strict: structured_response_format.strict,
286                    }),
287                }
288            }
289        }
290    }
291}
292
293impl ChatResponse for OpenAIChatResponse {
294    fn text(&self) -> Option<String> {
295        self.choices.first().and_then(|c| c.message.content.clone())
296    }
297
298    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
299        self.choices
300            .first()
301            .and_then(|c| c.message.tool_calls.clone())
302    }
303}
304
305impl std::fmt::Display for OpenAIChatResponse {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        match (
308            &self.choices.first().unwrap().message.content,
309            &self.choices.first().unwrap().message.tool_calls,
310        ) {
311            (Some(content), Some(tool_calls)) => {
312                for tool_call in tool_calls {
313                    write!(f, "{}", tool_call)?;
314                }
315                write!(f, "{}", content)
316            }
317            (Some(content), None) => write!(f, "{}", content),
318            (None, Some(tool_calls)) => {
319                for tool_call in tool_calls {
320                    write!(f, "{}", tool_call)?;
321                }
322                Ok(())
323            }
324            (None, None) => write!(f, ""),
325        }
326    }
327}
328
329impl OpenAI {
330    /// Creates a new OpenAI client with the specified configuration.
331    ///
332    /// # Arguments
333    ///
334    /// * `api_key` - OpenAI API key
335    /// * `model` - Model to use (defaults to "gpt-3.5-turbo")
336    /// * `max_tokens` - Maximum tokens to generate
337    /// * `temperature` - Sampling temperature
338    /// * `timeout_seconds` - Request timeout in seconds
339    /// * `system` - System prompt
340    /// * `stream` - Whether to stream responses
341    /// * `top_p` - Top-p sampling parameter
342    /// * `top_k` - Top-k sampling parameter
343    /// * `embedding_encoding_format` - Format for embedding outputs
344    /// * `embedding_dimensions` - Dimensions for embedding vectors
345    /// * `tools` - Function tools that the model can use
346    /// * `tool_choice` - Determines how the model uses tools
347    /// * `reasoning_effort` - Reasoning effort level
348    /// * `json_schema` - JSON schema for structured output
349    #[allow(clippy::too_many_arguments)]
350    pub fn new(
351        api_key: impl Into<String>,
352        base_url: Option<String>,
353        model: Option<String>,
354        max_tokens: Option<u32>,
355        temperature: Option<f32>,
356        timeout_seconds: Option<u64>,
357        system: Option<String>,
358        stream: Option<bool>,
359        top_p: Option<f32>,
360        top_k: Option<u32>,
361        embedding_encoding_format: Option<String>,
362        embedding_dimensions: Option<u32>,
363        tools: Option<Vec<Tool>>,
364        tool_choice: Option<ToolChoice>,
365        reasoning_effort: Option<String>,
366        json_schema: Option<StructuredOutputFormat>,
367        voice: Option<String>,
368        enable_web_search: Option<bool>,
369        web_search_context_size: Option<String>,
370        web_search_user_location_type: Option<String>,
371        web_search_user_location_approximate_country: Option<String>,
372        web_search_user_location_approximate_city: Option<String>,
373        web_search_user_location_approximate_region: Option<String>,
374    ) -> Self {
375        let mut builder = Client::builder();
376        if let Some(sec) = timeout_seconds {
377            builder = builder.timeout(std::time::Duration::from_secs(sec));
378        }
379        Self {
380            api_key: api_key.into(),
381            base_url: Url::parse(
382                &base_url.unwrap_or_else(|| "https://api.openai.com/v1/".to_owned()),
383            )
384            .expect("Failed to prase base Url"),
385            model: model.unwrap_or("gpt-3.5-turbo".to_string()),
386            max_tokens,
387            temperature,
388            system,
389            timeout_seconds,
390            stream,
391            top_p,
392            top_k,
393            tools,
394            tool_choice,
395            embedding_encoding_format,
396            embedding_dimensions,
397            client: builder.build().expect("Failed to build reqwest Client"),
398            reasoning_effort,
399            json_schema,
400            voice,
401            enable_web_search,
402            web_search_context_size,
403            web_search_user_location_type,
404            web_search_user_location_approximate_country,
405            web_search_user_location_approximate_city,
406            web_search_user_location_approximate_region,
407        }
408    }
409}
410
411#[async_trait]
412impl ChatProvider for OpenAI {
413    /// Sends a chat request to OpenAI's API.
414    ///
415    /// # Arguments
416    ///
417    /// * `messages` - Slice of chat messages representing the conversation
418    /// * `tools` - Optional slice of tools to use in the chat
419    /// # Returns
420    ///
421    /// The model's response text or an error
422    async fn chat_with_tools(
423        &self,
424        messages: &[ChatMessage],
425        tools: Option<&[Tool]>,
426    ) -> Result<Box<dyn ChatResponse>, LLMError> {
427        if self.api_key.is_empty() {
428            return Err(LLMError::AuthError("Missing OpenAI API key".to_string()));
429        }
430
431        // Clone the messages to have an owned mutable vector.
432        let messages = messages.to_vec();
433
434        let mut openai_msgs: Vec<OpenAIChatMessage> = vec![];
435
436        for msg in messages {
437            if let MessageType::ToolResult(ref results) = msg.message_type {
438                for result in results {
439                    openai_msgs.push(
440                        // Clone strings to own them
441                        OpenAIChatMessage {
442                            role: "tool",
443                            tool_call_id: Some(result.id.clone()),
444                            tool_calls: None,
445                            content: Some(Right(result.function.arguments.clone())),
446                        },
447                    );
448                }
449            } else {
450                openai_msgs.push(chat_message_to_api_message(msg))
451            }
452        }
453
454        if let Some(system) = &self.system {
455            openai_msgs.insert(
456                0,
457                OpenAIChatMessage {
458                    role: "system",
459                    content: Some(Left(vec![MessageContent {
460                        message_type: Some("text"),
461                        text: Some(system),
462                        image_url: None,
463                        tool_call_id: None,
464                        tool_output: None,
465                    }])),
466                    tool_calls: None,
467                    tool_call_id: None,
468                },
469            );
470        }
471
472        let response_format: Option<OpenAIResponseFormat> =
473            self.json_schema.clone().map(|s| s.into());
474
475        let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
476
477        let request_tool_choice = if request_tools.is_some() {
478            self.tool_choice.clone()
479        } else {
480            None
481        };
482
483        let web_search_options = if self.enable_web_search.unwrap_or(false) {
484            let loc_type_opt = self
485                .web_search_user_location_type
486                .as_ref()
487                .filter(|t| matches!(t.as_str(), "exact" | "approximate"));
488
489            let country = self.web_search_user_location_approximate_country.as_ref();
490            let city = self.web_search_user_location_approximate_city.as_ref();
491            let region = self.web_search_user_location_approximate_region.as_ref();
492
493            let approximate = if [country, city, region].iter().any(|v| v.is_some()) {
494                Some(ApproximateLocation {
495                    country: country.cloned().unwrap_or_default(),
496                    city: city.cloned().unwrap_or_default(),
497                    region: region.cloned().unwrap_or_default(),
498                })
499            } else {
500                None
501            };
502
503            let user_location = loc_type_opt.map(|loc_type| UserLocation {
504                location_type: loc_type.clone(),
505                approximate,
506            });
507
508            Some(OpenAIWebSearchOptions {
509                search_context_size: self.web_search_context_size.clone(),
510                user_location,
511            })
512        } else {
513            None
514        };
515
516        let body = OpenAIChatRequest {
517            model: &self.model,
518            messages: openai_msgs,
519            max_tokens: self.max_tokens,
520            temperature: self.temperature,
521            stream: self.stream.unwrap_or(false),
522            top_p: self.top_p,
523            top_k: self.top_k,
524            tools: request_tools,
525            tool_choice: request_tool_choice,
526            reasoning_effort: self.reasoning_effort.clone(),
527            response_format,
528            web_search_options,
529        };
530
531        let url = self
532            .base_url
533            .join("chat/completions")
534            .map_err(|e| LLMError::HttpError(e.to_string()))?;
535
536        let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
537
538        if log::log_enabled!(log::Level::Trace) {
539            if let Ok(json) = serde_json::to_string(&body) {
540                log::trace!("OpenAI request payload: {}", json);
541            }
542        }
543
544        if let Some(timeout) = self.timeout_seconds {
545            request = request.timeout(std::time::Duration::from_secs(timeout));
546        }
547
548        let response = request.send().await?;
549
550        log::debug!("OpenAI HTTP status: {}", response.status());
551
552        if !response.status().is_success() {
553            let status = response.status();
554            let error_text = response.text().await?;
555            return Err(LLMError::ResponseFormatError {
556                message: format!("OpenAI API returned error status: {}", status),
557                raw_response: error_text,
558            });
559        }
560
561        // Parse the successful response
562        let resp_text = response.text().await?;
563        let json_resp: Result<OpenAIChatResponse, serde_json::Error> =
564            serde_json::from_str(&resp_text);
565
566        match json_resp {
567            Ok(response) => Ok(Box::new(response)),
568            Err(e) => Err(LLMError::ResponseFormatError {
569                message: format!("Failed to decode OpenAI API response: {}", e),
570                raw_response: resp_text,
571            }),
572        }
573    }
574
575    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
576        self.chat_with_tools(messages, None).await
577    }
578
579    /// Sends a streaming chat request to OpenAI's API.
580    ///
581    /// # Arguments
582    ///
583    /// * `messages` - Slice of chat messages representing the conversation
584    ///
585    /// # Returns
586    ///
587    /// A stream of text tokens or an error
588    async fn chat_stream(
589        &self,
590        messages: &[ChatMessage],
591    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
592    {
593        if self.api_key.is_empty() {
594            return Err(LLMError::AuthError("Missing OpenAI API key".to_string()));
595        }
596
597        let messages = messages.to_vec();
598        let mut openai_msgs: Vec<OpenAIChatMessage> = vec![];
599
600        for msg in messages {
601            if let MessageType::ToolResult(ref results) = msg.message_type {
602                for result in results {
603                    openai_msgs.push(OpenAIChatMessage {
604                        role: "tool",
605                        tool_call_id: Some(result.id.clone()),
606                        tool_calls: None,
607                        content: Some(Right(result.function.arguments.clone())),
608                    });
609                }
610            } else {
611                openai_msgs.push(chat_message_to_api_message(msg))
612            }
613        }
614
615        if let Some(system) = &self.system {
616            openai_msgs.insert(
617                0,
618                OpenAIChatMessage {
619                    role: "system",
620                    content: Some(Left(vec![MessageContent {
621                        message_type: Some("text"),
622                        text: Some(system),
623                        image_url: None,
624                        tool_call_id: None,
625                        tool_output: None,
626                    }])),
627                    tool_calls: None,
628                    tool_call_id: None,
629                },
630            );
631        }
632
633        let body = OpenAIChatRequest {
634            model: &self.model,
635            messages: openai_msgs,
636            max_tokens: self.max_tokens,
637            temperature: self.temperature,
638            stream: true,
639            top_p: self.top_p,
640            top_k: self.top_k,
641            tools: self.tools.clone(),
642            tool_choice: self.tool_choice.clone(),
643            reasoning_effort: self.reasoning_effort.clone(),
644            response_format: None,
645            web_search_options: None,
646        };
647
648        let url = self
649            .base_url
650            .join("chat/completions")
651            .map_err(|e| LLMError::HttpError(e.to_string()))?;
652
653        let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
654
655        if let Some(timeout) = self.timeout_seconds {
656            request = request.timeout(std::time::Duration::from_secs(timeout));
657        }
658
659        let response = request.send().await?;
660
661        if !response.status().is_success() {
662            let status = response.status();
663            let error_text = response.text().await?;
664            return Err(LLMError::ResponseFormatError {
665                message: format!("OpenAI API returned error status: {}", status),
666                raw_response: error_text,
667            });
668        }
669
670        Ok(crate::chat::create_sse_stream(response, parse_sse_chunk))
671    }
672}
673
674// Create an owned OpenAIChatMessage that doesn't borrow from any temporary variables
675fn chat_message_to_api_message(chat_msg: ChatMessage) -> OpenAIChatMessage<'static> {
676    // For other message types, create an owned OpenAIChatMessage
677    OpenAIChatMessage {
678        role: match chat_msg.role {
679            ChatRole::User => "user",
680            ChatRole::Assistant => "assistant",
681        },
682        tool_call_id: None,
683        content: match &chat_msg.message_type {
684            MessageType::Text => Some(Right(chat_msg.content.clone())),
685            // Image case is handled separately above
686            MessageType::Image(_) => unreachable!(),
687            MessageType::Pdf(_) => unimplemented!(),
688            MessageType::ImageURL(url) => {
689                // Clone the URL to create an owned version
690                let owned_url = url.clone();
691                // Leak the string to get a 'static reference
692                let url_str = Box::leak(owned_url.into_boxed_str());
693                Some(Left(vec![MessageContent {
694                    message_type: Some("image_url"),
695                    text: None,
696                    image_url: Some(ImageUrlContent { url: url_str }),
697                    tool_output: None,
698                    tool_call_id: None,
699                }]))
700            }
701            MessageType::ToolUse(_) => None,
702            MessageType::ToolResult(_) => None,
703        },
704        tool_calls: match &chat_msg.message_type {
705            MessageType::ToolUse(calls) => {
706                let owned_calls: Vec<OpenAIFunctionCall<'static>> = calls
707                    .iter()
708                    .map(|c| {
709                        let owned_id = c.id.clone();
710                        let owned_name = c.function.name.clone();
711                        let owned_args = c.function.arguments.clone();
712
713                        // Need to leak these strings to create 'static references
714                        // This is a deliberate choice to solve the lifetime issue
715                        // The small memory leak is acceptable in this context
716                        let id_str = Box::leak(owned_id.into_boxed_str());
717                        let name_str = Box::leak(owned_name.into_boxed_str());
718                        let args_str = Box::leak(owned_args.into_boxed_str());
719
720                        OpenAIFunctionCall {
721                            id: id_str,
722                            content_type: "function",
723                            function: OpenAIFunctionPayload {
724                                name: name_str,
725                                arguments: args_str,
726                            },
727                        }
728                    })
729                    .collect();
730                Some(owned_calls)
731            }
732            _ => None,
733        },
734    }
735}
736
737#[async_trait]
738impl CompletionProvider for OpenAI {
739    /// Sends a completion request to OpenAI's API.
740    ///
741    /// Currently not implemented.
742    async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
743        Ok(CompletionResponse {
744            text: "OpenAI completion not implemented.".into(),
745        })
746    }
747}
748
749#[cfg(feature = "openai")]
750#[async_trait]
751impl EmbeddingProvider for OpenAI {
752    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
753        if self.api_key.is_empty() {
754            return Err(LLMError::AuthError("Missing OpenAI API key".into()));
755        }
756
757        let emb_format = self
758            .embedding_encoding_format
759            .clone()
760            .unwrap_or_else(|| "float".to_string());
761
762        let body = OpenAIEmbeddingRequest {
763            model: self.model.clone(),
764            input,
765            encoding_format: Some(emb_format),
766            dimensions: self.embedding_dimensions,
767        };
768
769        let url = self
770            .base_url
771            .join("embeddings")
772            .map_err(|e| LLMError::HttpError(e.to_string()))?;
773
774        let resp = self
775            .client
776            .post(url)
777            .bearer_auth(&self.api_key)
778            .json(&body)
779            .send()
780            .await?
781            .error_for_status()?;
782
783        let json_resp: OpenAIEmbeddingResponse = resp.json().await?;
784
785        let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
786        Ok(embeddings)
787    }
788}
789
790#[derive(Clone, Debug, Deserialize)]
791pub struct OpenAIModelEntry {
792    pub id: String,
793    pub created: Option<u64>,
794    #[serde(flatten)]
795    pub extra: Value,
796}
797
798impl ModelListRawEntry for OpenAIModelEntry {
799    fn get_id(&self) -> String {
800        self.id.clone()
801    }
802
803    fn get_created_at(&self) -> DateTime<Utc> {
804        self.created
805            .map(|t| chrono::DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
806            .unwrap_or_default()
807    }
808
809    fn get_raw(&self) -> Value {
810        self.extra.clone()
811    }
812}
813
814#[derive(Clone, Debug, Deserialize)]
815pub struct OpenAIModelListResponse {
816    pub data: Vec<OpenAIModelEntry>,
817}
818
819impl ModelListResponse for OpenAIModelListResponse {
820    fn get_models(&self) -> Vec<String> {
821        self.data.iter().map(|e| e.id.clone()).collect()
822    }
823
824    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
825        self.data
826            .iter()
827            .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
828            .collect()
829    }
830
831    fn get_backend(&self) -> LLMBackend {
832        LLMBackend::OpenAI
833    }
834}
835
836#[async_trait]
837impl ModelsProvider for OpenAI {
838    async fn list_models(
839        &self,
840        _request: Option<&ModelListRequest>,
841    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
842        let url = self
843            .base_url
844            .join("models")
845            .map_err(|e| LLMError::HttpError(e.to_string()))?;
846
847        let resp = self
848            .client
849            .get(url)
850            .bearer_auth(&self.api_key)
851            .send()
852            .await?
853            .error_for_status()?;
854
855        let result = resp.json::<OpenAIModelListResponse>().await?;
856
857        Ok(Box::new(result))
858    }
859}
860
861impl LLMProvider for OpenAI {
862    fn tools(&self) -> Option<&[Tool]> {
863        self.tools.as_deref()
864    }
865}
866
867/// Parses a Server-Sent Events (SSE) chunk from OpenAI's streaming API.
868///
869/// # Arguments
870///
871/// * `chunk` - The raw SSE chunk text
872///
873/// # Returns
874///
875/// * `Ok(Some(String))` - Content token if found
876/// * `Ok(None)` - If chunk should be skipped (e.g., ping, done signal)
877/// * `Err(LLMError)` - If parsing fails
878fn parse_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
879    let mut collected_content = String::new();
880
881    for line in chunk.lines() {
882        let line = line.trim();
883
884        if let Some(data) = line.strip_prefix("data: ") {
885            if data == "[DONE]" {
886                if collected_content.is_empty() {
887                    return Ok(None);
888                } else {
889                    return Ok(Some(collected_content));
890                }
891            }
892
893            match serde_json::from_str::<OpenAIChatStreamResponse>(data) {
894                Ok(response) => {
895                    if let Some(choice) = response.choices.first() {
896                        if let Some(content) = &choice.delta.content {
897                            collected_content.push_str(content);
898                        }
899                    }
900                }
901                Err(_) => continue,
902            }
903        }
904    }
905
906    if collected_content.is_empty() {
907        Ok(None)
908    } else {
909        Ok(Some(collected_content))
910    }
911}
912
913impl LLMBuilder<OpenAI> {
914    /// Set the voice.
915    pub fn voice(mut self, voice: impl Into<String>) -> Self {
916        self.voice = Some(voice.into());
917        self
918    }
919
920    /// Enable web search
921    pub fn openai_enable_web_search(mut self, enable: bool) -> Self {
922        self.openai_enable_web_search = Some(enable);
923        self
924    }
925
926    /// Set the web search context
927    pub fn openai_web_search_context_size(mut self, context_size: impl Into<String>) -> Self {
928        self.openai_web_search_context_size = Some(context_size.into());
929        self
930    }
931
932    /// Set the web search user location type
933    pub fn openai_web_search_user_location_type(
934        mut self,
935        location_type: impl Into<String>,
936    ) -> Self {
937        self.openai_web_search_user_location_type = Some(location_type.into());
938        self
939    }
940
941    /// Set the web search user location approximate country
942    pub fn openai_web_search_user_location_approximate_country(
943        mut self,
944        country: impl Into<String>,
945    ) -> Self {
946        self.openai_web_search_user_location_approximate_country = Some(country.into());
947        self
948    }
949
950    /// Set the web search user location approximate city
951    pub fn openai_web_search_user_location_approximate_city(
952        mut self,
953        city: impl Into<String>,
954    ) -> Self {
955        self.openai_web_search_user_location_approximate_city = Some(city.into());
956        self
957    }
958
959    /// Set the web search user location approximate region
960    pub fn openai_web_search_user_location_approximate_region(
961        mut self,
962        region: impl Into<String>,
963    ) -> Self {
964        self.openai_web_search_user_location_approximate_region = Some(region.into());
965        self
966    }
967
968    pub fn build(self) -> Result<Arc<dyn LLMProvider>, LLMError> {
969        let (tools, tool_choice) = self.validate_tool_config()?;
970        let key = self.api_key.ok_or_else(|| {
971            LLMError::InvalidRequest("No API key provided for OpenAI".to_string())
972        })?;
973        let openai = OpenAI::new(
974            key,
975            self.base_url,
976            self.model,
977            self.max_tokens,
978            self.temperature,
979            self.timeout_seconds,
980            self.system,
981            self.stream,
982            self.top_p,
983            self.top_k,
984            self.embedding_encoding_format,
985            self.embedding_dimensions,
986            tools,
987            tool_choice,
988            self.reasoning_effort,
989            self.json_schema,
990            self.voice,
991            self.openai_enable_web_search,
992            self.openai_web_search_context_size,
993            self.openai_web_search_user_location_type,
994            self.openai_web_search_user_location_approximate_country,
995            self.openai_web_search_user_location_approximate_city,
996            self.openai_web_search_user_location_approximate_region,
997        );
998
999        Ok(Arc::new(openai))
1000    }
1001}