Skip to main content

rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{
6    CompletionsClient as Client,
7    client::{ApiErrorResponse, ApiResponse},
8    streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28/// Serializes user content as a plain string when there's a single text item,
29/// otherwise as an array of content parts.
30fn serialize_user_content<S>(
31    content: &OneOrMany<UserContent>,
32    serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35    S: Serializer,
36{
37    if content.len() == 1
38        && let UserContent::Text { text } = content.first_ref()
39    {
40        return serializer.serialize_str(text);
41    }
42    content.serialize(serializer)
43}
44
45/// `gpt-5.2` completion model
46pub const GPT_5_2: &str = "gpt-5.2";
47
48/// `gpt-5.1` completion model
49pub const GPT_5_1: &str = "gpt-5.1";
50
51/// `gpt-5` completion model
52pub const GPT_5: &str = "gpt-5";
53/// `gpt-5` completion model
54pub const GPT_5_MINI: &str = "gpt-5-mini";
55/// `gpt-5` completion model
56pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58/// `gpt-4.5-preview` completion model
59pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60/// `gpt-4.5-preview-2025-02-27` completion model
61pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
63pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64/// `gpt-4o` completion model
65pub const GPT_4O: &str = "gpt-4o";
66/// `gpt-4o-mini` completion model
67pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68/// `gpt-4o-2024-05-13` completion model
69pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70/// `gpt-4-turbo` completion model
71pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72/// `gpt-4-turbo-2024-04-09` completion model
73pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74/// `gpt-4-turbo-preview` completion model
75pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76/// `gpt-4-0125-preview` completion model
77pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78/// `gpt-4-1106-preview` completion model
79pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80/// `gpt-4-vision-preview` completion model
81pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82/// `gpt-4-1106-vision-preview` completion model
83pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84/// `gpt-4` completion model
85pub const GPT_4: &str = "gpt-4";
86/// `gpt-4-0613` completion model
87pub const GPT_4_0613: &str = "gpt-4-0613";
88/// `gpt-4-32k` completion model
89pub const GPT_4_32K: &str = "gpt-4-32k";
90/// `gpt-4-32k-0613` completion model
91pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93/// `o4-mini-2025-04-16` completion model
94pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95/// `o4-mini` completion model
96pub const O4_MINI: &str = "o4-mini";
97/// `o3` completion model
98pub const O3: &str = "o3";
99/// `o3-mini` completion model
100pub const O3_MINI: &str = "o3-mini";
101/// `o3-mini-2025-01-31` completion model
102pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103/// `o1-pro` completion model
104pub const O1_PRO: &str = "o1-pro";
105/// `o1`` completion model
106pub const O1: &str = "o1";
107/// `o1-2024-12-17` completion model
108pub const O1_2024_12_17: &str = "o1-2024-12-17";
109/// `o1-preview` completion model
110pub const O1_PREVIEW: &str = "o1-preview";
111/// `o1-preview-2024-09-12` completion model
112pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113/// `o1-mini completion model
114pub const O1_MINI: &str = "o1-mini";
115/// `o1-mini-2024-09-12` completion model
116pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118/// `gpt-4.1-mini` completion model
119pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120/// `gpt-4.1-nano` completion model
121pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122/// `gpt-4.1-2025-04-14` completion model
123pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124/// `gpt-4.1` completion model
125pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128    fn from(err: ApiErrorResponse) -> Self {
129        CompletionError::ProviderError(err.message)
130    }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136    #[serde(alias = "developer")]
137    System {
138        #[serde(deserialize_with = "string_or_one_or_many")]
139        content: OneOrMany<SystemContent>,
140        #[serde(skip_serializing_if = "Option::is_none")]
141        name: Option<String>,
142    },
143    User {
144        #[serde(
145            deserialize_with = "string_or_one_or_many",
146            serialize_with = "serialize_user_content"
147        )]
148        content: OneOrMany<UserContent>,
149        #[serde(skip_serializing_if = "Option::is_none")]
150        name: Option<String>,
151    },
152    Assistant {
153        #[serde(
154            default,
155            deserialize_with = "json_utils::string_or_vec",
156            skip_serializing_if = "Vec::is_empty",
157            serialize_with = "serialize_assistant_content_vec"
158        )]
159        content: Vec<AssistantContent>,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        refusal: Option<String>,
162        #[serde(skip_serializing_if = "Option::is_none")]
163        audio: Option<AudioAssistant>,
164        #[serde(skip_serializing_if = "Option::is_none")]
165        name: Option<String>,
166        #[serde(
167            default,
168            deserialize_with = "json_utils::null_or_vec",
169            skip_serializing_if = "Vec::is_empty"
170        )]
171        tool_calls: Vec<ToolCall>,
172    },
173    #[serde(rename = "tool")]
174    ToolResult {
175        tool_call_id: String,
176        content: ToolResultContentValue,
177    },
178}
179
180impl Message {
181    pub fn system(content: &str) -> Self {
182        Message::System {
183            content: OneOrMany::one(content.to_owned().into()),
184            name: None,
185        }
186    }
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190pub struct AudioAssistant {
191    pub id: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct SystemContent {
196    #[serde(default)]
197    pub r#type: SystemContentType,
198    pub text: String,
199}
200
201#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
202#[serde(rename_all = "lowercase")]
203pub enum SystemContentType {
204    #[default]
205    Text,
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum AssistantContent {
211    Text { text: String },
212    Refusal { refusal: String },
213}
214
215impl From<AssistantContent> for completion::AssistantContent {
216    fn from(value: AssistantContent) -> Self {
217        match value {
218            AssistantContent::Text { text } => completion::AssistantContent::text(text),
219            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
220        }
221    }
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum UserContent {
227    Text {
228        text: String,
229    },
230    #[serde(rename = "image_url")]
231    Image {
232        image_url: ImageUrl,
233    },
234    Audio {
235        input_audio: InputAudio,
236    },
237}
238
239#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
240pub struct ImageUrl {
241    pub url: String,
242    #[serde(default)]
243    pub detail: ImageDetail,
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct InputAudio {
248    pub data: String,
249    pub format: AudioMediaType,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct ToolResultContent {
254    #[serde(default)]
255    r#type: ToolResultContentType,
256    pub text: String,
257}
258
259#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolResultContentType {
262    #[default]
263    Text,
264}
265
266impl FromStr for ToolResultContent {
267    type Err = Infallible;
268
269    fn from_str(s: &str) -> Result<Self, Self::Err> {
270        Ok(s.to_owned().into())
271    }
272}
273
274impl From<String> for ToolResultContent {
275    fn from(s: String) -> Self {
276        ToolResultContent {
277            r#type: ToolResultContentType::default(),
278            text: s,
279        }
280    }
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
284#[serde(untagged)]
285pub enum ToolResultContentValue {
286    Array(Vec<ToolResultContent>),
287    String(String),
288}
289
290impl ToolResultContentValue {
291    pub fn from_string(s: String, use_array_format: bool) -> Self {
292        if use_array_format {
293            ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
294        } else {
295            ToolResultContentValue::String(s)
296        }
297    }
298
299    pub fn as_text(&self) -> String {
300        match self {
301            ToolResultContentValue::Array(arr) => arr
302                .iter()
303                .map(|c| c.text.clone())
304                .collect::<Vec<_>>()
305                .join("\n"),
306            ToolResultContentValue::String(s) => s.clone(),
307        }
308    }
309
310    pub fn to_array(&self) -> Self {
311        match self {
312            ToolResultContentValue::Array(_) => self.clone(),
313            ToolResultContentValue::String(s) => {
314                ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
315            }
316        }
317    }
318}
319
320#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
321pub struct ToolCall {
322    pub id: String,
323    #[serde(default)]
324    pub r#type: ToolType,
325    pub function: Function,
326}
327
328#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
329#[serde(rename_all = "lowercase")]
330pub enum ToolType {
331    #[default]
332    Function,
333}
334
335/// Function definition for a tool, with optional strict mode
336#[derive(Debug, Deserialize, Serialize, Clone)]
337pub struct FunctionDefinition {
338    pub name: String,
339    pub description: String,
340    pub parameters: serde_json::Value,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub strict: Option<bool>,
343}
344
345#[derive(Debug, Deserialize, Serialize, Clone)]
346pub struct ToolDefinition {
347    pub r#type: String,
348    pub function: FunctionDefinition,
349}
350
351impl From<completion::ToolDefinition> for ToolDefinition {
352    fn from(tool: completion::ToolDefinition) -> Self {
353        Self {
354            r#type: "function".into(),
355            function: FunctionDefinition {
356                name: tool.name,
357                description: tool.description,
358                parameters: tool.parameters,
359                strict: None,
360            },
361        }
362    }
363}
364
365impl ToolDefinition {
366    /// Apply strict mode to this tool definition.
367    /// This sets `strict: true` and sanitizes the schema to meet OpenAI requirements.
368    pub fn with_strict(mut self) -> Self {
369        self.function.strict = Some(true);
370        super::sanitize_schema(&mut self.function.parameters);
371        self
372    }
373}
374
375#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
376#[serde(rename_all = "snake_case")]
377pub enum ToolChoice {
378    #[default]
379    Auto,
380    None,
381    Required,
382}
383
384impl TryFrom<crate::message::ToolChoice> for ToolChoice {
385    type Error = CompletionError;
386    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
387        let res = match value {
388            message::ToolChoice::Specific { .. } => {
389                return Err(CompletionError::ProviderError(
390                    "Provider doesn't support only using specific tools".to_string(),
391                ));
392            }
393            message::ToolChoice::Auto => Self::Auto,
394            message::ToolChoice::None => Self::None,
395            message::ToolChoice::Required => Self::Required,
396        };
397
398        Ok(res)
399    }
400}
401
402#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
403pub struct Function {
404    pub name: String,
405    #[serde(with = "json_utils::stringified_json")]
406    pub arguments: serde_json::Value,
407}
408
409impl TryFrom<message::ToolResult> for Message {
410    type Error = message::MessageError;
411
412    fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
413        let text = value
414            .content
415            .into_iter()
416            .map(|content| {
417                match content {
418                message::ToolResultContent::Text(message::Text { text }) => Ok(text),
419                message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
420                    "OpenAI does not support images in tool results. Tool results must be text."
421                        .into(),
422                )),
423            }
424            })
425            .collect::<Result<Vec<_>, _>>()?
426            .join("\n");
427
428        Ok(Message::ToolResult {
429            tool_call_id: value.id,
430            content: ToolResultContentValue::String(text),
431        })
432    }
433}
434
435impl TryFrom<message::UserContent> for UserContent {
436    type Error = message::MessageError;
437
438    fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
439        match value {
440            message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
441            message::UserContent::Image(message::Image {
442                data,
443                detail,
444                media_type,
445                ..
446            }) => match data {
447                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
448                    image_url: ImageUrl {
449                        url,
450                        detail: detail.unwrap_or_default(),
451                    },
452                }),
453                DocumentSourceKind::Base64(data) => {
454                    let url = format!(
455                        "data:{};base64,{}",
456                        media_type.map(|i| i.to_mime_type()).ok_or(
457                            message::MessageError::ConversionError(
458                                "OpenAI Image URI must have media type".into()
459                            )
460                        )?,
461                        data
462                    );
463
464                    let detail = detail.ok_or(message::MessageError::ConversionError(
465                        "OpenAI image URI must have image detail".into(),
466                    ))?;
467
468                    Ok(UserContent::Image {
469                        image_url: ImageUrl { url, detail },
470                    })
471                }
472                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
473                    "Raw files not supported, encode as base64 first".into(),
474                )),
475                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
476                    "Document has no body".into(),
477                )),
478                doc => Err(message::MessageError::ConversionError(format!(
479                    "Unsupported document type: {doc:?}"
480                ))),
481            },
482            message::UserContent::Document(message::Document { data, .. }) => {
483                if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
484                    Ok(UserContent::Text { text })
485                } else {
486                    Err(message::MessageError::ConversionError(
487                        "Documents must be base64 or a string".into(),
488                    ))
489                }
490            }
491            message::UserContent::Audio(message::Audio {
492                data, media_type, ..
493            }) => match data {
494                DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
495                    input_audio: InputAudio {
496                        data,
497                        format: match media_type {
498                            Some(media_type) => media_type,
499                            None => AudioMediaType::MP3,
500                        },
501                    },
502                }),
503                DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
504                    "URLs are not supported for audio".into(),
505                )),
506                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
507                    "Raw files are not supported for audio".into(),
508                )),
509                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
510                    "Audio has no body".into(),
511                )),
512                audio => Err(message::MessageError::ConversionError(format!(
513                    "Unsupported audio type: {audio:?}"
514                ))),
515            },
516            message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
517                "Tool result is in unsupported format".into(),
518            )),
519            message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
520                "Video is in unsupported format".into(),
521            )),
522        }
523    }
524}
525
526impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
527    type Error = message::MessageError;
528
529    fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
530        let (tool_results, other_content): (Vec<_>, Vec<_>) = value
531            .into_iter()
532            .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
533
534        // If there are messages with both tool results and user content, openai will only
535        //  handle tool results. It's unlikely that there will be both.
536        if !tool_results.is_empty() {
537            tool_results
538                .into_iter()
539                .map(|content| match content {
540                    message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
541                    _ => unreachable!(),
542                })
543                .collect::<Result<Vec<_>, _>>()
544        } else {
545            let other_content: Vec<UserContent> = other_content
546                .into_iter()
547                .map(|content| content.try_into())
548                .collect::<Result<Vec<_>, _>>()?;
549
550            let other_content = OneOrMany::many(other_content)
551                .expect("There must be other content here if there were no tool result content");
552
553            Ok(vec![Message::User {
554                content: other_content,
555                name: None,
556            }])
557        }
558    }
559}
560
561impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
562    type Error = message::MessageError;
563
564    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
565        let mut text_content = Vec::new();
566        let mut tool_calls = Vec::new();
567
568        for content in value {
569            match content {
570                message::AssistantContent::Text(text) => text_content.push(text),
571                message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
572                message::AssistantContent::Reasoning(_) => {
573                    // OpenAI Chat Completions does not support assistant-history reasoning items.
574                    // Silently skip unsupported reasoning content.
575                }
576                message::AssistantContent::Image(_) => {
577                    panic!(
578                        "The OpenAI Completions API doesn't support image content in assistant messages!"
579                    );
580                }
581            }
582        }
583
584        if text_content.is_empty() && tool_calls.is_empty() {
585            return Ok(vec![]);
586        }
587
588        Ok(vec![Message::Assistant {
589            content: text_content
590                .into_iter()
591                .map(|content| content.text.into())
592                .collect::<Vec<_>>(),
593            refusal: None,
594            audio: None,
595            name: None,
596            tool_calls: tool_calls
597                .into_iter()
598                .map(|tool_call| tool_call.into())
599                .collect::<Vec<_>>(),
600        }])
601    }
602}
603
604impl TryFrom<message::Message> for Vec<Message> {
605    type Error = message::MessageError;
606
607    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
608        match message {
609            message::Message::System { content } => Ok(vec![Message::system(&content)]),
610            message::Message::User { content } => content.try_into(),
611            message::Message::Assistant { content, .. } => content.try_into(),
612        }
613    }
614}
615
616impl From<message::ToolCall> for ToolCall {
617    fn from(tool_call: message::ToolCall) -> Self {
618        Self {
619            id: tool_call.id,
620            r#type: ToolType::default(),
621            function: Function {
622                name: tool_call.function.name,
623                arguments: tool_call.function.arguments,
624            },
625        }
626    }
627}
628
629impl From<ToolCall> for message::ToolCall {
630    fn from(tool_call: ToolCall) -> Self {
631        Self {
632            id: tool_call.id,
633            call_id: None,
634            function: message::ToolFunction {
635                name: tool_call.function.name,
636                arguments: tool_call.function.arguments,
637            },
638            signature: None,
639            additional_params: None,
640        }
641    }
642}
643
644impl TryFrom<Message> for message::Message {
645    type Error = message::MessageError;
646
647    fn try_from(message: Message) -> Result<Self, Self::Error> {
648        Ok(match message {
649            Message::User { content, .. } => message::Message::User {
650                content: content.map(|content| content.into()),
651            },
652            Message::Assistant {
653                content,
654                tool_calls,
655                ..
656            } => {
657                let mut content = content
658                    .into_iter()
659                    .map(|content| match content {
660                        AssistantContent::Text { text } => message::AssistantContent::text(text),
661
662                        // TODO: Currently, refusals are converted into text, but should be
663                        //  investigated for generalization.
664                        AssistantContent::Refusal { refusal } => {
665                            message::AssistantContent::text(refusal)
666                        }
667                    })
668                    .collect::<Vec<_>>();
669
670                content.extend(
671                    tool_calls
672                        .into_iter()
673                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
674                        .collect::<Result<Vec<_>, _>>()?,
675                );
676
677                message::Message::Assistant {
678                    id: None,
679                    content: OneOrMany::many(content).map_err(|_| {
680                        message::MessageError::ConversionError(
681                            "Neither `content` nor `tool_calls` was provided to the Message"
682                                .to_owned(),
683                        )
684                    })?,
685                }
686            }
687
688            Message::ToolResult {
689                tool_call_id,
690                content,
691            } => message::Message::User {
692                content: OneOrMany::one(message::UserContent::tool_result(
693                    tool_call_id,
694                    OneOrMany::one(message::ToolResultContent::text(content.as_text())),
695                )),
696            },
697
698            // System messages should get stripped out when converting messages, this is just a
699            // stop gap to avoid obnoxious error handling or panic occurring.
700            Message::System { content, .. } => message::Message::User {
701                content: content.map(|content| message::UserContent::text(content.text)),
702            },
703        })
704    }
705}
706
707impl From<UserContent> for message::UserContent {
708    fn from(content: UserContent) -> Self {
709        match content {
710            UserContent::Text { text } => message::UserContent::text(text),
711            UserContent::Image { image_url } => {
712                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
713            }
714            UserContent::Audio { input_audio } => {
715                message::UserContent::audio(input_audio.data, Some(input_audio.format))
716            }
717        }
718    }
719}
720
721impl From<String> for UserContent {
722    fn from(s: String) -> Self {
723        UserContent::Text { text: s }
724    }
725}
726
727impl FromStr for UserContent {
728    type Err = Infallible;
729
730    fn from_str(s: &str) -> Result<Self, Self::Err> {
731        Ok(UserContent::Text {
732            text: s.to_string(),
733        })
734    }
735}
736
737impl From<String> for AssistantContent {
738    fn from(s: String) -> Self {
739        AssistantContent::Text { text: s }
740    }
741}
742
743impl FromStr for AssistantContent {
744    type Err = Infallible;
745
746    fn from_str(s: &str) -> Result<Self, Self::Err> {
747        Ok(AssistantContent::Text {
748            text: s.to_string(),
749        })
750    }
751}
752impl From<String> for SystemContent {
753    fn from(s: String) -> Self {
754        SystemContent {
755            r#type: SystemContentType::default(),
756            text: s,
757        }
758    }
759}
760
761impl FromStr for SystemContent {
762    type Err = Infallible;
763
764    fn from_str(s: &str) -> Result<Self, Self::Err> {
765        Ok(SystemContent {
766            r#type: SystemContentType::default(),
767            text: s.to_string(),
768        })
769    }
770}
771
772#[derive(Debug, Deserialize, Serialize)]
773pub struct CompletionResponse {
774    pub id: String,
775    pub object: String,
776    pub created: u64,
777    pub model: String,
778    pub system_fingerprint: Option<String>,
779    pub choices: Vec<Choice>,
780    pub usage: Option<Usage>,
781}
782
783impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
784    type Error = CompletionError;
785
786    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
787        let choice = response.choices.first().ok_or_else(|| {
788            CompletionError::ResponseError("Response contained no choices".to_owned())
789        })?;
790
791        let content = match &choice.message {
792            Message::Assistant {
793                content,
794                tool_calls,
795                ..
796            } => {
797                let mut content = content
798                    .iter()
799                    .filter_map(|c| {
800                        let s = match c {
801                            AssistantContent::Text { text } => text,
802                            AssistantContent::Refusal { refusal } => refusal,
803                        };
804                        if s.is_empty() {
805                            None
806                        } else {
807                            Some(completion::AssistantContent::text(s))
808                        }
809                    })
810                    .collect::<Vec<_>>();
811
812                content.extend(
813                    tool_calls
814                        .iter()
815                        .map(|call| {
816                            completion::AssistantContent::tool_call(
817                                &call.id,
818                                &call.function.name,
819                                call.function.arguments.clone(),
820                            )
821                        })
822                        .collect::<Vec<_>>(),
823                );
824                Ok(content)
825            }
826            _ => Err(CompletionError::ResponseError(
827                "Response did not contain a valid message or tool call".into(),
828            )),
829        }?;
830
831        let choice = OneOrMany::many(content).map_err(|_| {
832            CompletionError::ResponseError(
833                "Response contained no message or tool call (empty)".to_owned(),
834            )
835        })?;
836
837        let usage = response
838            .usage
839            .as_ref()
840            .map(|usage| completion::Usage {
841                input_tokens: usage.prompt_tokens as u64,
842                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
843                total_tokens: usage.total_tokens as u64,
844                cached_input_tokens: usage
845                    .prompt_tokens_details
846                    .as_ref()
847                    .map(|d| d.cached_tokens as u64)
848                    .unwrap_or(0),
849            })
850            .unwrap_or_default();
851
852        Ok(completion::CompletionResponse {
853            choice,
854            usage,
855            raw_response: response,
856            message_id: None,
857        })
858    }
859}
860
861impl ProviderResponseExt for CompletionResponse {
862    type OutputMessage = Choice;
863    type Usage = Usage;
864
865    fn get_response_id(&self) -> Option<String> {
866        Some(self.id.to_owned())
867    }
868
869    fn get_response_model_name(&self) -> Option<String> {
870        Some(self.model.to_owned())
871    }
872
873    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
874        self.choices.clone()
875    }
876
877    fn get_text_response(&self) -> Option<String> {
878        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
879            return None;
880        };
881
882        let UserContent::Text { text } = content.first() else {
883            return None;
884        };
885
886        Some(text)
887    }
888
889    fn get_usage(&self) -> Option<Self::Usage> {
890        self.usage.clone()
891    }
892}
893
894#[derive(Clone, Debug, Serialize, Deserialize)]
895pub struct Choice {
896    pub index: usize,
897    pub message: Message,
898    pub logprobs: Option<serde_json::Value>,
899    pub finish_reason: String,
900}
901
902#[derive(Clone, Debug, Deserialize, Serialize, Default)]
903pub struct PromptTokensDetails {
904    /// Cached tokens from prompt caching
905    #[serde(default)]
906    pub cached_tokens: usize,
907}
908
909#[derive(Clone, Debug, Deserialize, Serialize)]
910pub struct Usage {
911    pub prompt_tokens: usize,
912    pub total_tokens: usize,
913    #[serde(skip_serializing_if = "Option::is_none")]
914    pub prompt_tokens_details: Option<PromptTokensDetails>,
915}
916
917impl Usage {
918    pub fn new() -> Self {
919        Self {
920            prompt_tokens: 0,
921            total_tokens: 0,
922            prompt_tokens_details: None,
923        }
924    }
925}
926
927impl Default for Usage {
928    fn default() -> Self {
929        Self::new()
930    }
931}
932
933impl fmt::Display for Usage {
934    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
935        let Usage {
936            prompt_tokens,
937            total_tokens,
938            ..
939        } = self;
940        write!(
941            f,
942            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
943        )
944    }
945}
946
947impl GetTokenUsage for Usage {
948    fn token_usage(&self) -> Option<crate::completion::Usage> {
949        let mut usage = crate::completion::Usage::new();
950        usage.input_tokens = self.prompt_tokens as u64;
951        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
952        usage.total_tokens = self.total_tokens as u64;
953        usage.cached_input_tokens = self
954            .prompt_tokens_details
955            .as_ref()
956            .map(|d| d.cached_tokens as u64)
957            .unwrap_or(0);
958
959        Some(usage)
960    }
961}
962
963#[derive(Clone)]
964pub struct CompletionModel<T = reqwest::Client> {
965    pub(crate) client: Client<T>,
966    pub model: String,
967    pub strict_tools: bool,
968    pub tool_result_array_content: bool,
969}
970
971impl<T> CompletionModel<T>
972where
973    T: Default + std::fmt::Debug + Clone + 'static,
974{
975    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
976        Self {
977            client,
978            model: model.into(),
979            strict_tools: false,
980            tool_result_array_content: false,
981        }
982    }
983
984    pub fn with_model(client: Client<T>, model: &str) -> Self {
985        Self {
986            client,
987            model: model.into(),
988            strict_tools: false,
989            tool_result_array_content: false,
990        }
991    }
992
993    /// Enable strict mode for tool schemas.
994    ///
995    /// When enabled, tool schemas are automatically sanitized to meet OpenAI's strict mode requirements:
996    /// - `additionalProperties: false` is added to all objects
997    /// - All properties are marked as required
998    /// - `strict: true` is set on each function definition
999    ///
1000    /// This allows OpenAI to guarantee that the model's tool calls will match the schema exactly.
1001    pub fn with_strict_tools(mut self) -> Self {
1002        self.strict_tools = true;
1003        self
1004    }
1005
1006    pub fn with_tool_result_array_content(mut self) -> Self {
1007        self.tool_result_array_content = true;
1008        self
1009    }
1010}
1011
1012#[derive(Debug, Serialize, Deserialize, Clone)]
1013pub struct CompletionRequest {
1014    model: String,
1015    messages: Vec<Message>,
1016    #[serde(skip_serializing_if = "Vec::is_empty")]
1017    tools: Vec<ToolDefinition>,
1018    #[serde(skip_serializing_if = "Option::is_none")]
1019    tool_choice: Option<ToolChoice>,
1020    #[serde(skip_serializing_if = "Option::is_none")]
1021    temperature: Option<f64>,
1022    #[serde(skip_serializing_if = "Option::is_none")]
1023    max_tokens: Option<u64>,
1024    #[serde(flatten)]
1025    additional_params: Option<serde_json::Value>,
1026}
1027
1028pub struct OpenAIRequestParams {
1029    pub model: String,
1030    pub request: CoreCompletionRequest,
1031    pub strict_tools: bool,
1032    pub tool_result_array_content: bool,
1033}
1034
1035impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1036    type Error = CompletionError;
1037
1038    fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1039        let OpenAIRequestParams {
1040            model,
1041            request: req,
1042            strict_tools,
1043            tool_result_array_content,
1044        } = params;
1045
1046        let mut partial_history = vec![];
1047        if let Some(docs) = req.normalized_documents() {
1048            partial_history.push(docs);
1049        }
1050        let CoreCompletionRequest {
1051            model: request_model,
1052            preamble,
1053            chat_history,
1054            tools,
1055            temperature,
1056            max_tokens,
1057            additional_params,
1058            tool_choice,
1059            output_schema,
1060            ..
1061        } = req;
1062
1063        partial_history.extend(chat_history);
1064
1065        let mut full_history: Vec<Message> =
1066            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1067
1068        full_history.extend(
1069            partial_history
1070                .into_iter()
1071                .map(message::Message::try_into)
1072                .collect::<Result<Vec<Vec<Message>>, _>>()?
1073                .into_iter()
1074                .flatten()
1075                .collect::<Vec<_>>(),
1076        );
1077
1078        if full_history.is_empty() {
1079            return Err(CompletionError::RequestError(
1080                std::io::Error::new(
1081                    std::io::ErrorKind::InvalidInput,
1082                    "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1083                )
1084                .into(),
1085            ));
1086        }
1087
1088        if tool_result_array_content {
1089            for msg in &mut full_history {
1090                if let Message::ToolResult { content, .. } = msg {
1091                    *content = content.to_array();
1092                }
1093            }
1094        }
1095
1096        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1097
1098        let tools: Vec<ToolDefinition> = tools
1099            .into_iter()
1100            .map(|tool| {
1101                let def = ToolDefinition::from(tool);
1102                if strict_tools { def.with_strict() } else { def }
1103            })
1104            .collect();
1105
1106        // Map output_schema to OpenAI's response_format and merge into additional_params
1107        let additional_params = if let Some(schema) = output_schema {
1108            let name = schema
1109                .as_object()
1110                .and_then(|o| o.get("title"))
1111                .and_then(|v| v.as_str())
1112                .unwrap_or("response_schema")
1113                .to_string();
1114            let mut schema_value = schema.to_value();
1115            super::sanitize_schema(&mut schema_value);
1116            let response_format = serde_json::json!({
1117                "response_format": {
1118                    "type": "json_schema",
1119                    "json_schema": {
1120                        "name": name,
1121                        "strict": true,
1122                        "schema": schema_value
1123                    }
1124                }
1125            });
1126            Some(match additional_params {
1127                Some(existing) => json_utils::merge(existing, response_format),
1128                None => response_format,
1129            })
1130        } else {
1131            additional_params
1132        };
1133
1134        let res = Self {
1135            model: request_model.unwrap_or(model),
1136            messages: full_history,
1137            tools,
1138            tool_choice,
1139            temperature,
1140            max_tokens,
1141            additional_params,
1142        };
1143
1144        Ok(res)
1145    }
1146}
1147
1148impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1149    type Error = CompletionError;
1150
1151    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1152        CompletionRequest::try_from(OpenAIRequestParams {
1153            model,
1154            request: req,
1155            strict_tools: false,
1156            tool_result_array_content: false,
1157        })
1158    }
1159}
1160
1161impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1162    type InputMessage = Message;
1163
1164    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1165        self.messages.clone()
1166    }
1167
1168    fn get_system_prompt(&self) -> Option<String> {
1169        let first_message = self.messages.first()?;
1170
1171        let Message::System { ref content, .. } = first_message.clone() else {
1172            return None;
1173        };
1174
1175        let SystemContent { text, .. } = content.first();
1176
1177        Some(text)
1178    }
1179
1180    fn get_prompt(&self) -> Option<String> {
1181        let last_message = self.messages.last()?;
1182
1183        let Message::User { ref content, .. } = last_message.clone() else {
1184            return None;
1185        };
1186
1187        let UserContent::Text { text } = content.first() else {
1188            return None;
1189        };
1190
1191        Some(text)
1192    }
1193
1194    fn get_model_name(&self) -> String {
1195        self.model.clone()
1196    }
1197}
1198
1199impl CompletionModel<reqwest::Client> {
1200    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1201        crate::agent::AgentBuilder::new(self)
1202    }
1203}
1204
1205impl<T> completion::CompletionModel for CompletionModel<T>
1206where
1207    T: HttpClientExt
1208        + Default
1209        + std::fmt::Debug
1210        + Clone
1211        + WasmCompatSend
1212        + WasmCompatSync
1213        + 'static,
1214{
1215    type Response = CompletionResponse;
1216    type StreamingResponse = StreamingCompletionResponse;
1217
1218    type Client = super::CompletionsClient<T>;
1219
1220    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1221        Self::new(client.clone(), model)
1222    }
1223
1224    async fn completion(
1225        &self,
1226        completion_request: CoreCompletionRequest,
1227    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1228        let span = if tracing::Span::current().is_disabled() {
1229            info_span!(
1230                target: "rig::completions",
1231                "chat",
1232                gen_ai.operation.name = "chat",
1233                gen_ai.provider.name = "openai",
1234                gen_ai.request.model = self.model,
1235                gen_ai.system_instructions = &completion_request.preamble,
1236                gen_ai.response.id = tracing::field::Empty,
1237                gen_ai.response.model = tracing::field::Empty,
1238                gen_ai.usage.output_tokens = tracing::field::Empty,
1239                gen_ai.usage.input_tokens = tracing::field::Empty,
1240                gen_ai.usage.cached_tokens = tracing::field::Empty,
1241            )
1242        } else {
1243            tracing::Span::current()
1244        };
1245
1246        let request = CompletionRequest::try_from(OpenAIRequestParams {
1247            model: self.model.to_owned(),
1248            request: completion_request,
1249            strict_tools: self.strict_tools,
1250            tool_result_array_content: self.tool_result_array_content,
1251        })?;
1252
1253        if enabled!(Level::TRACE) {
1254            tracing::trace!(
1255                target: "rig::completions",
1256                "OpenAI Chat Completions completion request: {}",
1257                serde_json::to_string_pretty(&request)?
1258            );
1259        }
1260
1261        let body = serde_json::to_vec(&request)?;
1262
1263        let req = self
1264            .client
1265            .post("/chat/completions")?
1266            .body(body)
1267            .map_err(|e| CompletionError::HttpError(e.into()))?;
1268
1269        async move {
1270            let response = self.client.send(req).await?;
1271
1272            if response.status().is_success() {
1273                let text = http_client::text(response).await?;
1274
1275                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1276                    ApiResponse::Ok(response) => {
1277                        let span = tracing::Span::current();
1278                        span.record_response_metadata(&response);
1279                        span.record_token_usage(&response.usage);
1280
1281                        if enabled!(Level::TRACE) {
1282                            tracing::trace!(
1283                                target: "rig::completions",
1284                                "OpenAI Chat Completions completion response: {}",
1285                                serde_json::to_string_pretty(&response)?
1286                            );
1287                        }
1288
1289                        response.try_into()
1290                    }
1291                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1292                }
1293            } else {
1294                let text = http_client::text(response).await?;
1295                Err(CompletionError::ProviderError(text))
1296            }
1297        }
1298        .instrument(span)
1299        .await
1300    }
1301
1302    async fn stream(
1303        &self,
1304        request: CoreCompletionRequest,
1305    ) -> Result<
1306        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1307        CompletionError,
1308    > {
1309        Self::stream(self, request).await
1310    }
1311}
1312
1313fn serialize_assistant_content_vec<S>(
1314    value: &Vec<AssistantContent>,
1315    serializer: S,
1316) -> Result<S::Ok, S::Error>
1317where
1318    S: Serializer,
1319{
1320    if value.is_empty() {
1321        serializer.serialize_str("")
1322    } else {
1323        value.serialize(serializer)
1324    }
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329    use super::*;
1330
1331    #[test]
1332    fn test_openai_request_uses_request_model_override() {
1333        let request = crate::completion::CompletionRequest {
1334            model: Some("gpt-4.1".to_string()),
1335            preamble: None,
1336            chat_history: crate::OneOrMany::one("Hello".into()),
1337            documents: vec![],
1338            tools: vec![],
1339            temperature: None,
1340            max_tokens: None,
1341            tool_choice: None,
1342            additional_params: None,
1343            output_schema: None,
1344        };
1345
1346        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1347            model: "gpt-4o-mini".to_string(),
1348            request,
1349            strict_tools: false,
1350            tool_result_array_content: false,
1351        })
1352        .expect("request conversion should succeed");
1353        let serialized =
1354            serde_json::to_value(openai_request).expect("serialization should succeed");
1355
1356        assert_eq!(serialized["model"], "gpt-4.1");
1357    }
1358
1359    #[test]
1360    fn test_openai_request_uses_default_model_when_override_unset() {
1361        let request = crate::completion::CompletionRequest {
1362            model: None,
1363            preamble: None,
1364            chat_history: crate::OneOrMany::one("Hello".into()),
1365            documents: vec![],
1366            tools: vec![],
1367            temperature: None,
1368            max_tokens: None,
1369            tool_choice: None,
1370            additional_params: None,
1371            output_schema: None,
1372        };
1373
1374        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1375            model: "gpt-4o-mini".to_string(),
1376            request,
1377            strict_tools: false,
1378            tool_result_array_content: false,
1379        })
1380        .expect("request conversion should succeed");
1381        let serialized =
1382            serde_json::to_value(openai_request).expect("serialization should succeed");
1383
1384        assert_eq!(serialized["model"], "gpt-4o-mini");
1385    }
1386
1387    #[test]
1388    fn assistant_reasoning_is_silently_skipped() {
1389        let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1390
1391        let converted: Vec<Message> = assistant_content
1392            .try_into()
1393            .expect("conversion should work");
1394
1395        assert!(converted.is_empty());
1396    }
1397
1398    #[test]
1399    fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1400        let assistant_content = OneOrMany::many(vec![
1401            message::AssistantContent::reasoning("hidden"),
1402            message::AssistantContent::text("visible"),
1403            message::AssistantContent::tool_call(
1404                "call_1",
1405                "subtract",
1406                serde_json::json!({"x": 2, "y": 1}),
1407            ),
1408        ])
1409        .expect("non-empty assistant content");
1410
1411        let converted: Vec<Message> = assistant_content
1412            .try_into()
1413            .expect("conversion should work");
1414        assert_eq!(converted.len(), 1);
1415
1416        match &converted[0] {
1417            Message::Assistant {
1418                content,
1419                tool_calls,
1420                ..
1421            } => {
1422                assert_eq!(
1423                    content,
1424                    &vec![AssistantContent::Text {
1425                        text: "visible".to_string()
1426                    }]
1427                );
1428                assert_eq!(tool_calls.len(), 1);
1429                assert_eq!(tool_calls[0].id, "call_1");
1430                assert_eq!(tool_calls[0].function.name, "subtract");
1431                assert_eq!(
1432                    tool_calls[0].function.arguments,
1433                    serde_json::json!({"x": 2, "y": 1})
1434                );
1435            }
1436            _ => panic!("expected assistant message"),
1437        }
1438    }
1439
1440    #[test]
1441    fn test_max_tokens_is_forwarded_to_request() {
1442        let request = crate::completion::CompletionRequest {
1443            model: None,
1444            preamble: None,
1445            chat_history: crate::OneOrMany::one("Hello".into()),
1446            documents: vec![],
1447            tools: vec![],
1448            temperature: None,
1449            max_tokens: Some(4096),
1450            tool_choice: None,
1451            additional_params: None,
1452            output_schema: None,
1453        };
1454
1455        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1456            model: "gpt-4o-mini".to_string(),
1457            request,
1458            strict_tools: false,
1459            tool_result_array_content: false,
1460        })
1461        .expect("request conversion should succeed");
1462        let serialized =
1463            serde_json::to_value(openai_request).expect("serialization should succeed");
1464
1465        assert_eq!(serialized["max_tokens"], 4096);
1466    }
1467
1468    #[test]
1469    fn test_max_tokens_omitted_when_none() {
1470        let request = crate::completion::CompletionRequest {
1471            model: None,
1472            preamble: None,
1473            chat_history: crate::OneOrMany::one("Hello".into()),
1474            documents: vec![],
1475            tools: vec![],
1476            temperature: None,
1477            max_tokens: None,
1478            tool_choice: None,
1479            additional_params: None,
1480            output_schema: None,
1481        };
1482
1483        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1484            model: "gpt-4o-mini".to_string(),
1485            request,
1486            strict_tools: false,
1487            tool_result_array_content: false,
1488        })
1489        .expect("request conversion should succeed");
1490        let serialized =
1491            serde_json::to_value(openai_request).expect("serialization should succeed");
1492
1493        assert!(serialized.get("max_tokens").is_none());
1494    }
1495
1496    #[test]
1497    fn request_conversion_errors_when_all_messages_are_filtered() {
1498        let request = CoreCompletionRequest {
1499            model: None,
1500            preamble: None,
1501            chat_history: OneOrMany::one(message::Message::Assistant {
1502                id: None,
1503                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1504            }),
1505            documents: vec![],
1506            tools: vec![],
1507            temperature: None,
1508            max_tokens: None,
1509            tool_choice: None,
1510            additional_params: None,
1511            output_schema: None,
1512        };
1513
1514        let result = CompletionRequest::try_from(OpenAIRequestParams {
1515            model: "gpt-4o-mini".to_string(),
1516            request,
1517            strict_tools: false,
1518            tool_result_array_content: false,
1519        });
1520
1521        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1522    }
1523}