language_barrier_core/provider/
openai.rs

1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, OpenAi};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10/// Configuration for the OpenAI provider
11#[derive(Debug, Clone)]
12pub struct OpenAIConfig {
13    /// API key for authentication
14    pub api_key: String,
15    /// Base URL for the API
16    pub base_url: String,
17    /// Organization ID (optional)
18    pub organization: Option<String>,
19}
20
21impl Default for OpenAIConfig {
22    fn default() -> Self {
23        Self {
24            api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
25            base_url: "https://api.openai.com/v1".to_string(),
26            organization: env::var("OPENAI_ORGANIZATION").ok(),
27        }
28    }
29}
30
31/// Implementation of the OpenAI provider
32#[derive(Debug, Clone)]
33pub struct OpenAIProvider {
34    /// Configuration for the provider
35    config: OpenAIConfig,
36}
37
38impl OpenAIProvider {
39    /// Creates a new OpenAIProvider with default configuration
40    ///
41    /// This method will use the OPENAI_API_KEY environment variable for authentication.
42    ///
43    /// # Examples
44    ///
45    /// ```
46    /// use language_barrier_core::provider::openai::OpenAIProvider;
47    ///
48    /// let provider = OpenAIProvider::new();
49    /// ```
50    #[instrument(level = "debug")]
51    pub fn new() -> Self {
52        info!("Creating new OpenAIProvider with default configuration");
53        let config = OpenAIConfig::default();
54        debug!("API key set: {}", !config.api_key.is_empty());
55        debug!("Base URL: {}", config.base_url);
56        debug!("Organization set: {}", config.organization.is_some());
57
58        Self { config }
59    }
60
61    /// Creates a new OpenAIProvider with custom configuration
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use language_barrier_core::provider::openai::{OpenAIProvider, OpenAIConfig};
67    ///
68    /// let config = OpenAIConfig {
69    ///     api_key: "your-api-key".to_string(),
70    ///     base_url: "https://api.openai.com/v1".to_string(),
71    ///     organization: None,
72    /// };
73    ///
74    /// let provider = OpenAIProvider::with_config(config);
75    /// ```
76    #[instrument(skip(config), level = "debug")]
77    pub fn with_config(config: OpenAIConfig) -> Self {
78        info!("Creating new OpenAIProvider with custom configuration");
79        debug!("API key set: {}", !config.api_key.is_empty());
80        debug!("Base URL: {}", config.base_url);
81        debug!("Organization set: {}", config.organization.is_some());
82
83        Self { config }
84    }
85}
86
87impl Default for OpenAIProvider {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl HTTPProvider<OpenAi> for OpenAIProvider {
94    fn accept(&self, model: OpenAi, chat: &Chat) -> Result<Request> {
95        info!("Creating request for OpenAI model: {:?}", model);
96        debug!("Messages in chat history: {}", chat.history.len());
97
98        let url_str = format!("{}/chat/completions", self.config.base_url);
99        debug!("Parsing URL: {}", url_str);
100        let url = match Url::parse(&url_str) {
101            Ok(url) => {
102                debug!("URL parsed successfully: {}", url);
103                url
104            }
105            Err(e) => {
106                error!("Failed to parse URL '{}': {}", url_str, e);
107                return Err(e.into());
108            }
109        };
110
111        let mut request = Request::new(Method::POST, url);
112        debug!("Created request: {} {}", request.method(), request.url());
113
114        // Set headers
115        debug!("Setting request headers");
116
117        // API key as bearer token
118        let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
119            Ok(header) => header,
120            Err(e) => {
121                error!("Invalid API key format: {}", e);
122                return Err(Error::Authentication("Invalid API key format".into()));
123            }
124        };
125
126        let content_type_header = match "application/json".parse() {
127            Ok(header) => header,
128            Err(e) => {
129                error!("Failed to set content type: {}", e);
130                return Err(Error::Other("Failed to set content type".into()));
131            }
132        };
133
134        request.headers_mut().insert("Authorization", auth_header);
135        request
136            .headers_mut()
137            .insert("Content-Type", content_type_header);
138
139        // Add organization header if present
140        if let Some(org) = &self.config.organization {
141            match org.parse() {
142                Ok(header) => {
143                    request.headers_mut().insert("OpenAI-Organization", header);
144                    debug!("Added organization header");
145                }
146                Err(e) => {
147                    warn!("Failed to set organization header: {}", e);
148                    // Continue without organization header
149                }
150            }
151        }
152
153        trace!("Request headers set: {:#?}", request.headers());
154
155        // Create the request payload
156        debug!("Creating request payload");
157        let payload = match self.create_request_payload(model, chat) {
158            Ok(payload) => {
159                debug!("Request payload created successfully");
160                trace!("Model: {}", payload.model);
161                trace!("Max tokens: {:?}", payload.max_tokens);
162                trace!("Number of messages: {}", payload.messages.len());
163                payload
164            }
165            Err(e) => {
166                error!("Failed to create request payload: {}", e);
167                return Err(e);
168            }
169        };
170
171        // Set the request body
172        debug!("Serializing request payload");
173        let body_bytes = match serde_json::to_vec(&payload) {
174            Ok(bytes) => {
175                debug!("Payload serialized successfully ({} bytes)", bytes.len());
176                bytes
177            }
178            Err(e) => {
179                error!("Failed to serialize payload: {}", e);
180                return Err(Error::Serialization(e));
181            }
182        };
183
184        *request.body_mut() = Some(body_bytes.into());
185        info!("Request created successfully");
186
187        Ok(request)
188    }
189
190    fn parse(&self, raw_response_text: String) -> Result<Message> {
191        info!("Parsing response from OpenAI API");
192        trace!("Raw response: {}", raw_response_text);
193
194        // First try to parse as an error response
195        if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&raw_response_text)
196        {
197            if let Some(error) = error_response.error {
198                error!("OpenAI API returned an error: {}", error.message);
199                return Err(Error::ProviderUnavailable(error.message));
200            }
201        }
202
203        // If not an error, parse as a successful response
204        debug!("Deserializing response JSON");
205        let openai_response = match serde_json::from_str::<OpenAIResponse>(&raw_response_text) {
206            Ok(response) => {
207                debug!("Response deserialized successfully");
208                debug!("Response model: {}", response.model);
209                if !response.choices.is_empty() {
210                    debug!("Number of choices: {}", response.choices.len());
211                    debug!(
212                        "First choice finish reason: {:?}",
213                        response.choices[0].finish_reason
214                    );
215                }
216                if let Some(usage) = &response.usage {
217                    debug!(
218                        "Token usage - prompt: {}, completion: {}, total: {}",
219                        usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
220                    );
221                }
222                response
223            }
224            Err(e) => {
225                error!("Failed to deserialize response: {}", e);
226                error!("Raw response: {}", raw_response_text);
227                return Err(Error::Serialization(e));
228            }
229        };
230
231        // Convert to our message format using the From implementation
232        debug!("Converting OpenAI response to Message");
233        let message = Message::from(&openai_response);
234
235        info!("Response parsed successfully");
236        trace!("Response message processed");
237
238        Ok(message)
239    }
240}
241
242// Trait to get OpenAI-specific model IDs
243pub trait OpenAIModelInfo {
244    fn openai_model_id(&self) -> String;
245}
246
247impl OpenAIProvider {
248    /// Creates a request payload from a Chat object
249    ///
250    /// This method converts the Chat's messages and settings into an OpenAI-specific
251    /// format for the API request.
252    #[instrument(skip(self, chat), level = "debug")]
253    fn create_request_payload(&self, model: OpenAi, chat: &Chat) -> Result<OpenAIRequest> {
254        info!("Creating request payload for chat with OpenAI model");
255        debug!("System prompt length: {}", chat.system_prompt.len());
256        debug!("Messages in history: {}", chat.history.len());
257        debug!("Max output tokens: {}", chat.max_output_tokens);
258
259        let model_id = model.openai_model_id();
260        debug!("Using model ID: {}", model_id);
261
262        // Convert all messages including system prompt
263        debug!("Converting messages to OpenAI format");
264        let mut messages: Vec<OpenAIMessage> = Vec::new();
265
266        // Add system prompt if present
267        if !chat.system_prompt.is_empty() {
268            debug!("Adding system prompt");
269            messages.push(OpenAIMessage {
270                role: "system".to_string(),
271                content: Some(chat.system_prompt.clone()),
272                function_call: None,
273                name: None,
274                tool_calls: None,
275                tool_call_id: None,
276            });
277        }
278
279        // Add conversation history
280        for msg in &chat.history {
281            debug!("Converting message with role: {}", msg.role_str());
282            messages.push(OpenAIMessage::from(msg));
283        }
284
285        // OpenAI requires that every message with role "tool" directly follows the
286        // corresponding assistant message that contains a matching `tool_call`.
287        // When callers accidentally provide the messages in a different order
288        // (for example: user → *tool* → assistant) the API rejects the request
289        // with a 400:
290        //   "messages with role 'tool' must be a response to a preceeding message \
291        //    with 'tool_calls'".
292        //
293        // To make the client more robust we perform a best-effort re-ordering pass
294        // so that each tool response is immediately preceded by its initiating
295        // assistant message.  If we *cannot* find the corresponding assistant
296        // message we treat this as a programming error and bail out early with
297        // `Error::Other`.
298
299        // Note: The ordering of assistant and tool messages is left untouched.
300        // Responsibility for providing messages in a valid order lies with the
301        // caller.  We simply forward the history as‐is.
302
303        debug!("Converted {} messages for the request", messages.len());
304
305        // Add tools if present
306        let tools = chat
307            .tools
308            .as_ref()
309            .map(|tools| tools.iter().map(OpenAITool::from).collect());
310
311        // Create the tool choice setting
312        let tool_choice = if let Some(choice) = &chat.tool_choice {
313            // Use the explicitly configured choice
314            match choice {
315                crate::tool::ToolChoice::Auto => Some(serde_json::json!("auto")),
316                // OpenAI uses "required" for what we call "Any"
317                crate::tool::ToolChoice::Any => Some(serde_json::json!("required")),
318                crate::tool::ToolChoice::None => Some(serde_json::json!("none")),
319                crate::tool::ToolChoice::Specific(name) => {
320                    // For specific tool, we need to create an object with type and function properties
321                    Some(serde_json::json!({
322                        "type": "function",
323                        "function": {
324                            "name": name
325                        }
326                    }))
327                }
328            }
329        } else if tools.is_some() {
330            // Default to auto if tools are present but no choice specified
331            Some(serde_json::json!("auto"))
332        } else {
333            None
334        };
335
336        // Create the request
337        debug!("Creating OpenAIRequest");
338
339        // Check if this is an O-series model (starts with "o-")
340        let is_o_series = model_id.starts_with("o");
341
342        let request = OpenAIRequest {
343            model: model_id,
344            messages,
345            temperature: None,
346            top_p: None,
347            n: None,
348            // For O-series models, use max_completion_tokens instead of max_tokens
349            max_tokens: if is_o_series {
350                None
351            } else {
352                Some(chat.max_output_tokens)
353            },
354            max_completion_tokens: if is_o_series {
355                Some(chat.max_output_tokens)
356            } else {
357                None
358            },
359            presence_penalty: None,
360            frequency_penalty: None,
361            stream: None,
362            tools,
363            tool_choice,
364        };
365
366        info!("Request payload created successfully");
367        Ok(request)
368    }
369}
370
371/// Represents a message in the OpenAI API format
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub(crate) struct OpenAIMessage {
374    /// The role of the message sender (system, user, assistant, etc.)
375    pub role: String,
376    /// The content of the message
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub content: Option<String>,
379    /// The function call (deprecated in favor of tool_calls)
380    #[serde(skip_serializing_if = "Option::is_none")]
381    pub function_call: Option<OpenAIFunctionCall>,
382    /// The name of the function
383    #[serde(skip_serializing_if = "Option::is_none")]
384    pub name: Option<String>,
385    /// Tool calls
386    #[serde(skip_serializing_if = "Option::is_none")]
387    pub tool_calls: Option<Vec<OpenAIToolCall>>,
388    /// Tool call ID
389    #[serde(skip_serializing_if = "Option::is_none")]
390    pub tool_call_id: Option<String>,
391}
392
393/// Represents a tool function in the OpenAI API format
394#[derive(Debug, Serialize, Deserialize)]
395pub(crate) struct OpenAIFunction {
396    /// The name of the function
397    pub name: String,
398    /// The description of the function
399    pub description: String,
400    /// The parameters schema as a JSON object
401    pub parameters: serde_json::Value,
402}
403
404/// Represents a tool in the OpenAI API format
405#[derive(Debug, Serialize, Deserialize)]
406pub(crate) struct OpenAITool {
407    /// The type of the tool (currently always "function")
408    pub r#type: String,
409    /// The function definition
410    pub function: OpenAIFunction,
411}
412
413impl From<&LlmToolInfo> for OpenAITool {
414    fn from(value: &LlmToolInfo) -> Self {
415        OpenAITool {
416            r#type: "function".to_string(),
417            function: OpenAIFunction {
418                name: value.name.clone(),
419                description: value.description.clone(),
420                parameters: value.parameters.clone(),
421            },
422        }
423    }
424}
425
426/// Represents a function call in the OpenAI API format
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub(crate) struct OpenAIFunctionCall {
429    /// The name of the function
430    pub name: String,
431    /// The arguments as a JSON string
432    pub arguments: String,
433}
434
435/// Represents a tool call in the OpenAI API format
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub(crate) struct OpenAIToolCall {
438    /// The ID of the tool call
439    pub id: String,
440    /// The type of the tool (currently always "function")
441    pub r#type: String,
442    /// The function call
443    pub function: OpenAIFunctionCall,
444}
445
446/// Represents a request to the OpenAI API
447#[derive(Debug, Serialize, Deserialize)]
448pub(crate) struct OpenAIRequest {
449    /// The model to use
450    pub model: String,
451    /// The messages to send
452    pub messages: Vec<OpenAIMessage>,
453    /// Temperature (randomness)
454    #[serde(skip_serializing_if = "Option::is_none")]
455    pub temperature: Option<f32>,
456    /// Top-p sampling
457    #[serde(skip_serializing_if = "Option::is_none")]
458    pub top_p: Option<f32>,
459    /// Number of completions to generate
460    #[serde(skip_serializing_if = "Option::is_none")]
461    pub n: Option<usize>,
462    /// Maximum number of tokens to generate (for GPT models)
463    #[serde(skip_serializing_if = "Option::is_none")]
464    pub max_tokens: Option<usize>,
465    /// Maximum number of tokens to generate (for O-series models)
466    #[serde(skip_serializing_if = "Option::is_none")]
467    pub max_completion_tokens: Option<usize>,
468    /// Presence penalty
469    #[serde(skip_serializing_if = "Option::is_none")]
470    pub presence_penalty: Option<f32>,
471    /// Frequency penalty
472    #[serde(skip_serializing_if = "Option::is_none")]
473    pub frequency_penalty: Option<f32>,
474    /// Stream mode
475    #[serde(skip_serializing_if = "Option::is_none")]
476    pub stream: Option<bool>,
477    /// Tools available to the model
478    #[serde(skip_serializing_if = "Option::is_none")]
479    pub tools: Option<Vec<OpenAITool>>,
480    /// Tool choice strategy (auto, none, or a specific tool)
481    #[serde(skip_serializing_if = "Option::is_none")]
482    pub tool_choice: Option<serde_json::Value>,
483}
484
485/// Represents a response from the OpenAI API
486#[derive(Debug, Serialize, Deserialize)]
487pub(crate) struct OpenAIResponse {
488    /// Response ID
489    pub id: String,
490    /// Object type (always "chat.completion")
491    pub object: String,
492    /// Creation timestamp
493    pub created: u64,
494    /// Model used
495    pub model: String,
496    /// Choices generated
497    pub choices: Vec<OpenAIChoice>,
498    /// Usage statistics
499    pub usage: Option<OpenAIUsage>,
500}
501
502/// Represents a choice in an OpenAI response
503#[derive(Debug, Serialize, Deserialize)]
504pub(crate) struct OpenAIChoice {
505    /// The index of the choice
506    pub index: usize,
507    /// The message generated
508    pub message: OpenAIMessage,
509    /// The reason generation stopped
510    pub finish_reason: Option<String>,
511}
512
513/// Represents usage statistics in an OpenAI response
514#[derive(Debug, Serialize, Deserialize)]
515pub(crate) struct OpenAIUsage {
516    /// Number of tokens in the prompt
517    pub prompt_tokens: u32,
518    /// Number of tokens in the completion
519    pub completion_tokens: u32,
520    /// Total number of tokens
521    pub total_tokens: u32,
522}
523
524/// Represents an error response from the OpenAI API
525#[derive(Debug, Serialize, Deserialize)]
526pub(crate) struct OpenAIErrorResponse {
527    /// The error details
528    pub error: Option<OpenAIError>,
529}
530
531/// Represents an error from the OpenAI API
532#[derive(Debug, Serialize, Deserialize)]
533pub(crate) struct OpenAIError {
534    /// The error message
535    pub message: String,
536    /// The error type
537    #[serde(rename = "type")]
538    pub error_type: String,
539    /// The error code
540    #[serde(skip_serializing_if = "Option::is_none")]
541    pub code: Option<String>,
542}
543
544/// Convert from our Message to OpenAI's message format
545impl From<&Message> for OpenAIMessage {
546    fn from(msg: &Message) -> Self {
547        let role = match msg {
548            Message::System { .. } => "system",
549            Message::User { .. } => "user",
550            Message::Assistant { .. } => "assistant",
551            Message::Tool { .. } => "tool",
552        }
553        .to_string();
554
555        let (content, name, function_call, tool_calls, tool_call_id) = match msg {
556            Message::System { content, .. } => (Some(content.clone()), None, None, None, None),
557            Message::User { content, name, .. } => {
558                let content_str = match content {
559                    Content::Text(text) => Some(text.clone()),
560                    Content::Parts(parts) => {
561                        // For text parts, concatenate them
562                        let combined_text = parts
563                            .iter()
564                            .filter_map(|part| match part {
565                                ContentPart::Text { text } => Some(text.clone()),
566                                _ => None,
567                            })
568                            .collect::<Vec<String>>()
569                            .join("\n");
570
571                        if combined_text.is_empty() {
572                            None
573                        } else {
574                            Some(combined_text)
575                        }
576                    }
577                };
578                (content_str, name.clone(), None, None, None)
579            }
580            Message::Assistant {
581                content,
582                tool_calls,
583                ..
584            } => {
585                let content_str = match content {
586                    Some(Content::Text(text)) => Some(text.clone()),
587                    Some(Content::Parts(parts)) => {
588                        // For text parts, concatenate them
589                        let combined_text = parts
590                            .iter()
591                            .filter_map(|part| match part {
592                                ContentPart::Text { text } => Some(text.clone()),
593                                _ => None,
594                            })
595                            .collect::<Vec<String>>()
596                            .join("\n");
597
598                        if combined_text.is_empty() {
599                            None
600                        } else {
601                            Some(combined_text)
602                        }
603                    }
604                    None => None,
605                };
606
607                // Convert tool calls if present
608                let openai_tool_calls = if !tool_calls.is_empty() {
609                    let mut calls = Vec::with_capacity(tool_calls.len());
610
611                    for tc in tool_calls {
612                        calls.push(OpenAIToolCall {
613                            id: tc.id.clone(),
614                            r#type: tc.tool_type.clone(),
615                            function: OpenAIFunctionCall {
616                                name: tc.function.name.clone(),
617                                arguments: tc.function.arguments.clone(),
618                            },
619                        });
620                    }
621
622                    Some(calls)
623                } else {
624                    None
625                };
626
627                (content_str, None, None, openai_tool_calls, None)
628            }
629            Message::Tool {
630                tool_call_id,
631                content,
632                ..
633            } => (
634                Some(content.clone()),
635                None,
636                None,
637                None,
638                Some(tool_call_id.clone()),
639            ),
640        };
641
642        OpenAIMessage {
643            role,
644            content,
645            function_call,
646            name,
647            tool_calls,
648            tool_call_id,
649        }
650    }
651}
652
653/// Convert from OpenAI's response to our message format
654impl From<&OpenAIResponse> for Message {
655    fn from(response: &OpenAIResponse) -> Self {
656        // Get the first choice (there should be at least one)
657        if response.choices.is_empty() {
658            return Message::assistant("No response generated");
659        }
660
661        let choice = &response.choices[0];
662        let message = &choice.message;
663
664        // Create appropriate Message variant based on role
665        let mut msg = match message.role.as_str() {
666            "assistant" => {
667                let content = message
668                    .content
669                    .as_ref()
670                    .map(|text| Content::Text(text.clone()));
671
672                // Handle tool calls if present
673                if let Some(openai_tool_calls) = &message.tool_calls {
674                    if !openai_tool_calls.is_empty() {
675                        let mut tool_calls = Vec::with_capacity(openai_tool_calls.len());
676
677                        for call in openai_tool_calls {
678                            let tool_call = crate::message::ToolCall {
679                                id: call.id.clone(),
680                                tool_type: call.r#type.clone(),
681                                function: crate::message::Function {
682                                    name: call.function.name.clone(),
683                                    arguments: call.function.arguments.clone(),
684                                },
685                            };
686                            tool_calls.push(tool_call);
687                        }
688
689                        Message::Assistant {
690                            content,
691                            tool_calls,
692                            metadata: Default::default(),
693                        }
694                    } else {
695                        // No tool calls
696                        if let Some(Content::Text(text)) = content {
697                            Message::assistant(text)
698                        } else {
699                            Message::Assistant {
700                                content,
701                                tool_calls: Vec::new(),
702                                metadata: Default::default(),
703                            }
704                        }
705                    }
706                } else if let Some(fc) = &message.function_call {
707                    // Handle legacy function_call (older OpenAI API)
708                    let tool_call = crate::message::ToolCall {
709                        id: format!("legacy_function_{}", fc.name),
710                        tool_type: "function".to_string(),
711                        function: crate::message::Function {
712                            name: fc.name.clone(),
713                            arguments: fc.arguments.clone(),
714                        },
715                    };
716
717                    Message::Assistant {
718                        content,
719                        tool_calls: vec![tool_call],
720                        metadata: Default::default(),
721                    }
722                } else {
723                    // Simple content only
724                    if let Some(Content::Text(text)) = content {
725                        Message::assistant(text)
726                    } else {
727                        Message::Assistant {
728                            content,
729                            tool_calls: Vec::new(),
730                            metadata: Default::default(),
731                        }
732                    }
733                }
734            }
735            "user" => {
736                if let Some(name) = &message.name {
737                    if let Some(content) = &message.content {
738                        Message::user_with_name(name, content)
739                    } else {
740                        Message::user_with_name(name, "")
741                    }
742                } else if let Some(content) = &message.content {
743                    Message::user(content)
744                } else {
745                    Message::user("")
746                }
747            }
748            "system" => {
749                if let Some(content) = &message.content {
750                    Message::system(content)
751                } else {
752                    Message::system("")
753                }
754            }
755            "tool" => {
756                if let Some(tool_call_id) = &message.tool_call_id {
757                    if let Some(content) = &message.content {
758                        Message::tool(tool_call_id, content)
759                    } else {
760                        Message::tool(tool_call_id, "")
761                    }
762                } else {
763                    // This shouldn't happen, but fall back to user message
764                    if let Some(content) = &message.content {
765                        Message::user(content)
766                    } else {
767                        Message::user("")
768                    }
769                }
770            }
771            _ => {
772                // Default to user for unknown roles
773                if let Some(content) = &message.content {
774                    Message::user(content)
775                } else {
776                    Message::user("")
777                }
778            }
779        };
780
781        // Add token usage information to metadata if available
782        if let Some(usage) = &response.usage {
783            msg = msg.with_metadata(
784                "prompt_tokens",
785                serde_json::Value::Number(usage.prompt_tokens.into()),
786            );
787            msg = msg.with_metadata(
788                "completion_tokens",
789                serde_json::Value::Number(usage.completion_tokens.into()),
790            );
791            msg = msg.with_metadata(
792                "total_tokens",
793                serde_json::Value::Number(usage.total_tokens.into()),
794            );
795        }
796
797        msg
798    }
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn test_message_conversion() {
807        // Test simple text message
808        let msg = Message::user("Hello, world!");
809        let openai_msg = OpenAIMessage::from(&msg);
810
811        assert_eq!(openai_msg.role, "user");
812        assert_eq!(openai_msg.content, Some("Hello, world!".to_string()));
813
814        // Test system message
815        let msg = Message::system("You are a helpful assistant.");
816        let openai_msg = OpenAIMessage::from(&msg);
817
818        assert_eq!(openai_msg.role, "system");
819        assert_eq!(
820            openai_msg.content,
821            Some("You are a helpful assistant.".to_string())
822        );
823
824        // Test assistant message
825        let msg = Message::assistant("I can help with that.");
826        let openai_msg = OpenAIMessage::from(&msg);
827
828        assert_eq!(openai_msg.role, "assistant");
829        assert_eq!(
830            openai_msg.content,
831            Some("I can help with that.".to_string())
832        );
833
834        // Test assistant message with tool calls
835        let tool_call = crate::message::ToolCall {
836            id: "tool_123".to_string(),
837            tool_type: "function".to_string(),
838            function: crate::message::Function {
839                name: "get_weather".to_string(),
840                arguments: "{\"location\":\"San Francisco\"}".to_string(),
841            },
842        };
843
844        let msg = Message::Assistant {
845            content: Some(Content::Text("I'll check the weather".to_string())),
846            tool_calls: vec![tool_call],
847            metadata: Default::default(),
848        };
849
850        let openai_msg = OpenAIMessage::from(&msg);
851
852        assert_eq!(openai_msg.role, "assistant");
853        assert_eq!(
854            openai_msg.content,
855            Some("I'll check the weather".to_string())
856        );
857        assert!(openai_msg.tool_calls.is_some());
858        let tool_calls = openai_msg.tool_calls.unwrap();
859        assert_eq!(tool_calls.len(), 1);
860        assert_eq!(tool_calls[0].id, "tool_123");
861        assert_eq!(tool_calls[0].function.name, "get_weather");
862    }
863
864    #[test]
865    fn test_error_response_parsing() {
866        let error_json = r#"{
867            "error": {
868                "message": "The model does not exist",
869                "type": "invalid_request_error",
870                "code": "model_not_found"
871            }
872        }"#;
873
874        let error_response: OpenAIErrorResponse = serde_json::from_str(error_json).unwrap();
875        assert!(error_response.error.is_some());
876        let error = error_response.error.unwrap();
877        assert_eq!(error.error_type, "invalid_request_error");
878        assert_eq!(error.code, Some("model_not_found".to_string()));
879    }
880
881    // Note: Reordering and orphan-tool detection tests were removed because this
882    // logic has been moved out of the provider.
883
884    // ---------------------------------------------------------------------
885    // Multi-turn tool-calling serialization tests
886    // ---------------------------------------------------------------------
887
888    /// Returns a minimal `LlmToolInfo` for the `get_weather` tool that the
889    /// fixture conversation relies on.
890    fn get_weather_tool_info() -> crate::tool::LlmToolInfo {
891        use serde_json::json;
892
893        crate::tool::LlmToolInfo {
894            name: "get_weather".to_string(),
895            description: "Get current temperature for a given location.".to_string(),
896            parameters: json!({
897                "type": "object",
898                "properties": {
899                    "location": {
900                        "type": "string",
901                        "description": "City and country e.g. Bogotá, Colombia"
902                    }
903                },
904                "required": ["location"],
905                "additionalProperties": false
906            }),
907        }
908    }
909
910    /// Helper to build a fresh `Chat` with the weather tool already registered.
911    fn base_chat_with_tool() -> crate::Chat {
912        crate::Chat::default().with_tools(vec![get_weather_tool_info()])
913    }
914
915    /// Stage-1: only the initial user message exists.  The request payload
916    /// should contain exactly that user message plus the registered tool
917    /// definition.
918    #[test]
919    fn test_stage1_user_only_serialization() {
920        use crate::model::OpenAi;
921        use crate::message::Message;
922
923        let chat = base_chat_with_tool()
924            .add_message(Message::user("What is the weather like in Paris today?"));
925
926        let provider = OpenAIProvider::new();
927        let request = provider
928            .create_request_payload(OpenAi::GPT35Turbo, &chat)
929            .expect("payload generation failed");
930
931        // 1. Messages
932        assert_eq!(request.messages.len(), 1);
933        let msg = &request.messages[0];
934        assert_eq!(msg.role, "user");
935        assert_eq!(msg.content.as_deref(), Some("What is the weather like in Paris today?"));
936        assert!(msg.tool_calls.is_none());
937        assert!(msg.tool_call_id.is_none());
938
939        // 2. Tools – the weather tool should be present
940        let tools = request.tools.expect("tools should be present");
941        assert!(tools.iter().any(|t| t.function.name == "get_weather"));
942
943        // 3. Default tool_choice should be "auto" when tools are provided
944        assert_eq!(request.tool_choice, Some(serde_json::json!("auto")));
945    }
946
947    /// Stage-2: assistant responds with a tool call (no content).  We expect
948    /// the serialized payload to include the assistant message with the
949    /// correct `tool_calls` structure.
950    #[test]
951    fn test_stage2_assistant_tool_call_serialization() {
952        use crate::model::OpenAi;
953        use crate::message::{Function, Message, ToolCall};
954
955        const CALL_ID: &str = "call_19InQqbLUTQIuc6MlV5QSogY";
956
957        // Build assistant message containing the tool call
958        let assistant_msg = Message::assistant_with_tool_calls(vec![ToolCall {
959            id: CALL_ID.to_string(),
960            tool_type: "function".to_string(),
961            function: Function {
962                name: "get_weather".to_string(),
963                arguments: "{\"location\":\"Paris, France\"}".to_string(),
964            },
965        }]);
966
967        let chat = base_chat_with_tool()
968            .add_message(Message::user("What is the weather like in Paris today?"))
969            .add_message(assistant_msg);
970
971        let provider = OpenAIProvider::new();
972        let request = provider
973            .create_request_payload(OpenAi::GPT35Turbo, &chat)
974            .expect("payload generation failed");
975
976        assert_eq!(request.messages.len(), 2);
977
978        // Check ordering: user first, assistant second
979        assert_eq!(request.messages[0].role, "user");
980        let assistant = &request.messages[1];
981        assert_eq!(assistant.role, "assistant");
982        assert!(assistant.content.is_none());
983
984        // Validate tool_calls structure
985        let calls = assistant.tool_calls.as_ref().expect("tool_calls missing");
986        assert_eq!(calls.len(), 1);
987        let call = &calls[0];
988        assert_eq!(call.id, CALL_ID);
989        assert_eq!(call.r#type, "function");
990        assert_eq!(call.function.name, "get_weather");
991        assert_eq!(call.function.arguments, "{\"location\":\"Paris, France\"}");
992    }
993
994    /// Stage-3: the tool responds.  The serialized payload must place the
995    /// assistant message *immediately* before the corresponding tool message
996    /// (which our re-ordering logic guarantees) and preserve IDs.
997    #[test]
998    fn test_stage3_tool_response_serialization() {
999        use crate::model::OpenAi;
1000        use crate::message::{Function, Message, ToolCall};
1001
1002        const CALL_ID: &str = "call_19InQqbLUTQIuc6MlV5QSogY";
1003
1004        let assistant_msg = Message::assistant_with_tool_calls(vec![ToolCall {
1005            id: CALL_ID.to_string(),
1006            tool_type: "function".to_string(),
1007            function: Function {
1008                name: "get_weather".to_string(),
1009                arguments: "{\"location\":\"Paris, France\"}".to_string(),
1010            },
1011        }]);
1012
1013        let tool_msg = Message::tool(CALL_ID, "10C");
1014
1015        let chat = base_chat_with_tool()
1016            .add_message(Message::user("What is the weather like in Paris today?"))
1017            .add_message(assistant_msg)
1018            .add_message(tool_msg);
1019
1020        let provider = OpenAIProvider::new();
1021        let request = provider
1022            .create_request_payload(OpenAi::GPT35Turbo, &chat)
1023            .expect("payload generation failed");
1024
1025        // Expect 3 messages with correct ordering
1026        assert_eq!(request.messages.len(), 3);
1027        assert_eq!(request.messages[0].role, "user");
1028        assert_eq!(request.messages[1].role, "assistant");
1029        assert_eq!(request.messages[2].role, "tool");
1030
1031        // Validate the tool message fields
1032        let tool = &request.messages[2];
1033        assert_eq!(tool.tool_call_id.as_deref(), Some(CALL_ID));
1034        assert_eq!(tool.content.as_deref(), Some("10C"));
1035    }
1036
1037    /// Stage-4: the assistant provides the final answer after the tool call.
1038    /// All 4 turns must serialize in the correct order and structure.
1039    #[test]
1040    fn test_stage4_full_conversation_serialization() {
1041        use crate::model::OpenAi;
1042        use crate::message::{Function, Message, ToolCall};
1043
1044        const CALL_ID: &str = "call_19InQqbLUTQIuc6MlV5QSogY";
1045        let user_msg = Message::user("What is the weather like in Paris today?");
1046
1047        let assistant_call = Message::assistant_with_tool_calls(vec![ToolCall {
1048            id: CALL_ID.to_string(),
1049            tool_type: "function".to_string(),
1050            function: Function {
1051                name: "get_weather".to_string(),
1052                arguments: "{\"location\":\"Paris, France\"}".to_string(),
1053            },
1054        }]);
1055
1056        let tool_msg = Message::tool(CALL_ID, "10C");
1057
1058        let final_assistant = Message::assistant("The weather in Paris today is 10°C. Let me know if you need more details or the forecast for the coming days!");
1059
1060        let chat = base_chat_with_tool()
1061            .add_message(user_msg)
1062            .add_message(assistant_call)
1063            .add_message(tool_msg)
1064            .add_message(final_assistant);
1065
1066        let provider = OpenAIProvider::new();
1067        let request = provider
1068            .create_request_payload(OpenAi::GPT35Turbo, &chat)
1069            .expect("payload generation failed");
1070
1071        // Verify the order and integrity of messages
1072        let roles: Vec<_> = request.messages.iter().map(|m| m.role.as_str()).collect();
1073        assert_eq!(roles, vec!["user", "assistant", "tool", "assistant"]);
1074
1075        let assistant_after_tool = &request.messages[3];
1076        assert_eq!(assistant_after_tool.role, "assistant");
1077        assert_eq!(assistant_after_tool.content.as_deref(), Some("The weather in Paris today is 10°C. Let me know if you need more details or the forecast for the coming days!"));
1078        assert!(assistant_after_tool.tool_calls.is_none());
1079    }
1080
1081    // JSON to test against follows
1082    // {
1083    //   "model": "gpt-4.1",
1084    //   "messages": [
1085    //     {
1086    //       "role": "user",
1087    //       "content": "What is the weather like in Paris today?"
1088    //     },
1089    //     {
1090    //         "role": "assistant",
1091    //         "tool_calls": [
1092    //           {
1093    //             "id": "call_19InQqbLUTQIuc6MlV5QSogY",
1094    //             "type": "function",
1095    //             "function": {
1096    //               "name": "get_weather",
1097    //               "arguments": "{\"location\":\"Paris, France\"}"
1098    //             }
1099    //           }
1100    //         ]
1101    //       }
1102    //     ,
1103    //     {
1104    //         "role": "tool",
1105    //         "tool_call_id": "call_19InQqbLUTQIuc6MlV5QSogY",
1106    //       "content": "10C"
1107    //       },
1108    //     {
1109    //         "role": "assistant",
1110    //         "content": "The weather in Paris today is 10°C. Let me know if you need more details or the forecast for the coming days!",
1111    //         "refusal": null,
1112    //         "annotations": []
1113    //       }
1114    //   ],
1115    //   "tools": [
1116    //     {
1117    //       "type": "function",
1118    //       "function": {
1119    //         "name": "get_weather",
1120    //         "description": "Get current temperature for a given location.",
1121    //         "parameters": {
1122    //           "type": "object",
1123    //           "properties": {
1124    //             "location": {
1125    //               "type": "string",
1126    //               "description": "City and country e.g. Bogotá, Colombia"
1127    //             }
1128    //           },
1129    //           "required": [
1130    //             "location"
1131    //           ],
1132    //           "additionalProperties": false
1133    //         },
1134    //         "strict": true
1135    //       }
1136    //     }
1137    //   ]
1138    // }
1139}