Skip to main content

rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{
6    CompletionsClient as Client,
7    client::{ApiErrorResponse, ApiResponse},
8    streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28/// Serializes user content as a plain string when there's a single text item,
29/// otherwise as an array of content parts.
30fn serialize_user_content<S>(
31    content: &OneOrMany<UserContent>,
32    serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35    S: Serializer,
36{
37    if content.len() == 1
38        && let UserContent::Text { text } = content.first_ref()
39    {
40        return serializer.serialize_str(text);
41    }
42    content.serialize(serializer)
43}
44
45/// `gpt-5.2` completion model
46pub const GPT_5_2: &str = "gpt-5.2";
47
48/// `gpt-5.1` completion model
49pub const GPT_5_1: &str = "gpt-5.1";
50
51/// `gpt-5` completion model
52pub const GPT_5: &str = "gpt-5";
53/// `gpt-5` completion model
54pub const GPT_5_MINI: &str = "gpt-5-mini";
55/// `gpt-5` completion model
56pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58/// `gpt-4.5-preview` completion model
59pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60/// `gpt-4.5-preview-2025-02-27` completion model
61pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
63pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64/// `gpt-4o` completion model
65pub const GPT_4O: &str = "gpt-4o";
66/// `gpt-4o-mini` completion model
67pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68/// `gpt-4o-2024-05-13` completion model
69pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70/// `gpt-4-turbo` completion model
71pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72/// `gpt-4-turbo-2024-04-09` completion model
73pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74/// `gpt-4-turbo-preview` completion model
75pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76/// `gpt-4-0125-preview` completion model
77pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78/// `gpt-4-1106-preview` completion model
79pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80/// `gpt-4-vision-preview` completion model
81pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82/// `gpt-4-1106-vision-preview` completion model
83pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84/// `gpt-4` completion model
85pub const GPT_4: &str = "gpt-4";
86/// `gpt-4-0613` completion model
87pub const GPT_4_0613: &str = "gpt-4-0613";
88/// `gpt-4-32k` completion model
89pub const GPT_4_32K: &str = "gpt-4-32k";
90/// `gpt-4-32k-0613` completion model
91pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93/// `o4-mini-2025-04-16` completion model
94pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95/// `o4-mini` completion model
96pub const O4_MINI: &str = "o4-mini";
97/// `o3` completion model
98pub const O3: &str = "o3";
99/// `o3-mini` completion model
100pub const O3_MINI: &str = "o3-mini";
101/// `o3-mini-2025-01-31` completion model
102pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103/// `o1-pro` completion model
104pub const O1_PRO: &str = "o1-pro";
105/// `o1`` completion model
106pub const O1: &str = "o1";
107/// `o1-2024-12-17` completion model
108pub const O1_2024_12_17: &str = "o1-2024-12-17";
109/// `o1-preview` completion model
110pub const O1_PREVIEW: &str = "o1-preview";
111/// `o1-preview-2024-09-12` completion model
112pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113/// `o1-mini completion model
114pub const O1_MINI: &str = "o1-mini";
115/// `o1-mini-2024-09-12` completion model
116pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118/// `gpt-4.1-mini` completion model
119pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120/// `gpt-4.1-nano` completion model
121pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122/// `gpt-4.1-2025-04-14` completion model
123pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124/// `gpt-4.1` completion model
125pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128    fn from(err: ApiErrorResponse) -> Self {
129        CompletionError::ProviderError(err.message)
130    }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136    #[serde(alias = "developer")]
137    System {
138        #[serde(deserialize_with = "string_or_one_or_many")]
139        content: OneOrMany<SystemContent>,
140        #[serde(skip_serializing_if = "Option::is_none")]
141        name: Option<String>,
142    },
143    User {
144        #[serde(
145            deserialize_with = "string_or_one_or_many",
146            serialize_with = "serialize_user_content"
147        )]
148        content: OneOrMany<UserContent>,
149        #[serde(skip_serializing_if = "Option::is_none")]
150        name: Option<String>,
151    },
152    Assistant {
153        #[serde(
154            default,
155            deserialize_with = "json_utils::string_or_vec",
156            skip_serializing_if = "Vec::is_empty",
157            serialize_with = "serialize_assistant_content_vec"
158        )]
159        content: Vec<AssistantContent>,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        refusal: Option<String>,
162        #[serde(skip_serializing_if = "Option::is_none")]
163        audio: Option<AudioAssistant>,
164        #[serde(skip_serializing_if = "Option::is_none")]
165        name: Option<String>,
166        #[serde(
167            default,
168            deserialize_with = "json_utils::null_or_vec",
169            skip_serializing_if = "Vec::is_empty"
170        )]
171        tool_calls: Vec<ToolCall>,
172    },
173    #[serde(rename = "tool")]
174    ToolResult {
175        tool_call_id: String,
176        content: ToolResultContentValue,
177    },
178}
179
180impl Message {
181    pub fn system(content: &str) -> Self {
182        Message::System {
183            content: OneOrMany::one(content.to_owned().into()),
184            name: None,
185        }
186    }
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190pub struct AudioAssistant {
191    pub id: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct SystemContent {
196    #[serde(default)]
197    pub r#type: SystemContentType,
198    pub text: String,
199}
200
201#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
202#[serde(rename_all = "lowercase")]
203pub enum SystemContentType {
204    #[default]
205    Text,
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum AssistantContent {
211    Text { text: String },
212    Refusal { refusal: String },
213}
214
215impl From<AssistantContent> for completion::AssistantContent {
216    fn from(value: AssistantContent) -> Self {
217        match value {
218            AssistantContent::Text { text } => completion::AssistantContent::text(text),
219            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
220        }
221    }
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum UserContent {
227    Text {
228        text: String,
229    },
230    #[serde(rename = "image_url")]
231    Image {
232        image_url: ImageUrl,
233    },
234    Audio {
235        input_audio: InputAudio,
236    },
237}
238
239#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
240pub struct ImageUrl {
241    pub url: String,
242    #[serde(default)]
243    pub detail: ImageDetail,
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct InputAudio {
248    pub data: String,
249    pub format: AudioMediaType,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct ToolResultContent {
254    #[serde(default)]
255    r#type: ToolResultContentType,
256    pub text: String,
257}
258
259#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolResultContentType {
262    #[default]
263    Text,
264}
265
266impl FromStr for ToolResultContent {
267    type Err = Infallible;
268
269    fn from_str(s: &str) -> Result<Self, Self::Err> {
270        Ok(s.to_owned().into())
271    }
272}
273
274impl From<String> for ToolResultContent {
275    fn from(s: String) -> Self {
276        ToolResultContent {
277            r#type: ToolResultContentType::default(),
278            text: s,
279        }
280    }
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
284#[serde(untagged)]
285pub enum ToolResultContentValue {
286    Array(Vec<ToolResultContent>),
287    String(String),
288}
289
290impl ToolResultContentValue {
291    pub fn from_string(s: String, use_array_format: bool) -> Self {
292        if use_array_format {
293            ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
294        } else {
295            ToolResultContentValue::String(s)
296        }
297    }
298
299    pub fn as_text(&self) -> String {
300        match self {
301            ToolResultContentValue::Array(arr) => arr
302                .iter()
303                .map(|c| c.text.clone())
304                .collect::<Vec<_>>()
305                .join("\n"),
306            ToolResultContentValue::String(s) => s.clone(),
307        }
308    }
309
310    pub fn to_array(&self) -> Self {
311        match self {
312            ToolResultContentValue::Array(_) => self.clone(),
313            ToolResultContentValue::String(s) => {
314                ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
315            }
316        }
317    }
318}
319
320#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
321pub struct ToolCall {
322    pub id: String,
323    #[serde(default)]
324    pub r#type: ToolType,
325    pub function: Function,
326}
327
328#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
329#[serde(rename_all = "lowercase")]
330pub enum ToolType {
331    #[default]
332    Function,
333}
334
335/// Function definition for a tool, with optional strict mode
336#[derive(Debug, Deserialize, Serialize, Clone)]
337pub struct FunctionDefinition {
338    pub name: String,
339    pub description: String,
340    pub parameters: serde_json::Value,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub strict: Option<bool>,
343}
344
345#[derive(Debug, Deserialize, Serialize, Clone)]
346pub struct ToolDefinition {
347    pub r#type: String,
348    pub function: FunctionDefinition,
349}
350
351impl From<completion::ToolDefinition> for ToolDefinition {
352    fn from(tool: completion::ToolDefinition) -> Self {
353        Self {
354            r#type: "function".into(),
355            function: FunctionDefinition {
356                name: tool.name,
357                description: tool.description,
358                parameters: tool.parameters,
359                strict: None,
360            },
361        }
362    }
363}
364
365impl ToolDefinition {
366    /// Apply strict mode to this tool definition.
367    /// This sets `strict: true` and sanitizes the schema to meet OpenAI requirements.
368    pub fn with_strict(mut self) -> Self {
369        self.function.strict = Some(true);
370        super::sanitize_schema(&mut self.function.parameters);
371        self
372    }
373}
374
375#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
376#[serde(rename_all = "snake_case")]
377pub enum ToolChoice {
378    #[default]
379    Auto,
380    None,
381    Required,
382}
383
384impl TryFrom<crate::message::ToolChoice> for ToolChoice {
385    type Error = CompletionError;
386    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
387        let res = match value {
388            message::ToolChoice::Specific { .. } => {
389                return Err(CompletionError::ProviderError(
390                    "Provider doesn't support only using specific tools".to_string(),
391                ));
392            }
393            message::ToolChoice::Auto => Self::Auto,
394            message::ToolChoice::None => Self::None,
395            message::ToolChoice::Required => Self::Required,
396        };
397
398        Ok(res)
399    }
400}
401
402#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
403pub struct Function {
404    pub name: String,
405    #[serde(with = "json_utils::stringified_json")]
406    pub arguments: serde_json::Value,
407}
408
409impl TryFrom<message::ToolResult> for Message {
410    type Error = message::MessageError;
411
412    fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
413        let text = value
414            .content
415            .into_iter()
416            .map(|content| {
417                match content {
418                message::ToolResultContent::Text(message::Text { text }) => Ok(text),
419                message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
420                    "OpenAI does not support images in tool results. Tool results must be text."
421                        .into(),
422                )),
423            }
424            })
425            .collect::<Result<Vec<_>, _>>()?
426            .join("\n");
427
428        Ok(Message::ToolResult {
429            tool_call_id: value.id,
430            content: ToolResultContentValue::String(text),
431        })
432    }
433}
434
435impl TryFrom<message::UserContent> for UserContent {
436    type Error = message::MessageError;
437
438    fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
439        match value {
440            message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
441            message::UserContent::Image(message::Image {
442                data,
443                detail,
444                media_type,
445                ..
446            }) => match data {
447                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
448                    image_url: ImageUrl {
449                        url,
450                        detail: detail.unwrap_or_default(),
451                    },
452                }),
453                DocumentSourceKind::Base64(data) => {
454                    let url = format!(
455                        "data:{};base64,{}",
456                        media_type.map(|i| i.to_mime_type()).ok_or(
457                            message::MessageError::ConversionError(
458                                "OpenAI Image URI must have media type".into()
459                            )
460                        )?,
461                        data
462                    );
463
464                    let detail = detail.ok_or(message::MessageError::ConversionError(
465                        "OpenAI image URI must have image detail".into(),
466                    ))?;
467
468                    Ok(UserContent::Image {
469                        image_url: ImageUrl { url, detail },
470                    })
471                }
472                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
473                    "Raw files not supported, encode as base64 first".into(),
474                )),
475                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
476                    "Document has no body".into(),
477                )),
478                doc => Err(message::MessageError::ConversionError(format!(
479                    "Unsupported document type: {doc:?}"
480                ))),
481            },
482            message::UserContent::Document(message::Document { data, .. }) => {
483                if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
484                    Ok(UserContent::Text { text })
485                } else {
486                    Err(message::MessageError::ConversionError(
487                        "Documents must be base64 or a string".into(),
488                    ))
489                }
490            }
491            message::UserContent::Audio(message::Audio {
492                data, media_type, ..
493            }) => match data {
494                DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
495                    input_audio: InputAudio {
496                        data,
497                        format: match media_type {
498                            Some(media_type) => media_type,
499                            None => AudioMediaType::MP3,
500                        },
501                    },
502                }),
503                DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
504                    "URLs are not supported for audio".into(),
505                )),
506                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
507                    "Raw files are not supported for audio".into(),
508                )),
509                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
510                    "Audio has no body".into(),
511                )),
512                audio => Err(message::MessageError::ConversionError(format!(
513                    "Unsupported audio type: {audio:?}"
514                ))),
515            },
516            message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
517                "Tool result is in unsupported format".into(),
518            )),
519            message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
520                "Video is in unsupported format".into(),
521            )),
522        }
523    }
524}
525
526impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
527    type Error = message::MessageError;
528
529    fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
530        let (tool_results, other_content): (Vec<_>, Vec<_>) = value
531            .into_iter()
532            .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
533
534        // If there are messages with both tool results and user content, openai will only
535        //  handle tool results. It's unlikely that there will be both.
536        if !tool_results.is_empty() {
537            tool_results
538                .into_iter()
539                .map(|content| match content {
540                    message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
541                    _ => unreachable!(),
542                })
543                .collect::<Result<Vec<_>, _>>()
544        } else {
545            let other_content: Vec<UserContent> = other_content
546                .into_iter()
547                .map(|content| content.try_into())
548                .collect::<Result<Vec<_>, _>>()?;
549
550            let other_content = OneOrMany::many(other_content)
551                .expect("There must be other content here if there were no tool result content");
552
553            Ok(vec![Message::User {
554                content: other_content,
555                name: None,
556            }])
557        }
558    }
559}
560
561impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
562    type Error = message::MessageError;
563
564    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
565        let mut text_content = Vec::new();
566        let mut tool_calls = Vec::new();
567
568        for content in value {
569            match content {
570                message::AssistantContent::Text(text) => text_content.push(text),
571                message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
572                message::AssistantContent::Reasoning(_) => {
573                    // OpenAI Chat Completions does not support assistant-history reasoning items.
574                    // Silently skip unsupported reasoning content.
575                }
576                message::AssistantContent::Image(_) => {
577                    panic!(
578                        "The OpenAI Completions API doesn't support image content in assistant messages!"
579                    );
580                }
581            }
582        }
583
584        if text_content.is_empty() && tool_calls.is_empty() {
585            return Ok(vec![]);
586        }
587
588        Ok(vec![Message::Assistant {
589            content: text_content
590                .into_iter()
591                .map(|content| content.text.into())
592                .collect::<Vec<_>>(),
593            refusal: None,
594            audio: None,
595            name: None,
596            tool_calls: tool_calls
597                .into_iter()
598                .map(|tool_call| tool_call.into())
599                .collect::<Vec<_>>(),
600        }])
601    }
602}
603
604impl TryFrom<message::Message> for Vec<Message> {
605    type Error = message::MessageError;
606
607    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
608        match message {
609            message::Message::System { content } => Ok(vec![Message::system(&content)]),
610            message::Message::User { content } => content.try_into(),
611            message::Message::Assistant { content, .. } => content.try_into(),
612        }
613    }
614}
615
616impl From<message::ToolCall> for ToolCall {
617    fn from(tool_call: message::ToolCall) -> Self {
618        Self {
619            id: tool_call.id,
620            r#type: ToolType::default(),
621            function: Function {
622                name: tool_call.function.name,
623                arguments: tool_call.function.arguments,
624            },
625        }
626    }
627}
628
629impl From<ToolCall> for message::ToolCall {
630    fn from(tool_call: ToolCall) -> Self {
631        Self {
632            id: tool_call.id,
633            call_id: None,
634            function: message::ToolFunction {
635                name: tool_call.function.name,
636                arguments: tool_call.function.arguments,
637            },
638            signature: None,
639            additional_params: None,
640        }
641    }
642}
643
644impl TryFrom<Message> for message::Message {
645    type Error = message::MessageError;
646
647    fn try_from(message: Message) -> Result<Self, Self::Error> {
648        Ok(match message {
649            Message::User { content, .. } => message::Message::User {
650                content: content.map(|content| content.into()),
651            },
652            Message::Assistant {
653                content,
654                tool_calls,
655                ..
656            } => {
657                let mut content = content
658                    .into_iter()
659                    .map(|content| match content {
660                        AssistantContent::Text { text } => message::AssistantContent::text(text),
661
662                        // TODO: Currently, refusals are converted into text, but should be
663                        //  investigated for generalization.
664                        AssistantContent::Refusal { refusal } => {
665                            message::AssistantContent::text(refusal)
666                        }
667                    })
668                    .collect::<Vec<_>>();
669
670                content.extend(
671                    tool_calls
672                        .into_iter()
673                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
674                        .collect::<Result<Vec<_>, _>>()?,
675                );
676
677                message::Message::Assistant {
678                    id: None,
679                    content: OneOrMany::many(content).map_err(|_| {
680                        message::MessageError::ConversionError(
681                            "Neither `content` nor `tool_calls` was provided to the Message"
682                                .to_owned(),
683                        )
684                    })?,
685                }
686            }
687
688            Message::ToolResult {
689                tool_call_id,
690                content,
691            } => message::Message::User {
692                content: OneOrMany::one(message::UserContent::tool_result(
693                    tool_call_id,
694                    OneOrMany::one(message::ToolResultContent::text(content.as_text())),
695                )),
696            },
697
698            // System messages should get stripped out when converting messages, this is just a
699            // stop gap to avoid obnoxious error handling or panic occurring.
700            Message::System { content, .. } => message::Message::User {
701                content: content.map(|content| message::UserContent::text(content.text)),
702            },
703        })
704    }
705}
706
707impl From<UserContent> for message::UserContent {
708    fn from(content: UserContent) -> Self {
709        match content {
710            UserContent::Text { text } => message::UserContent::text(text),
711            UserContent::Image { image_url } => {
712                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
713            }
714            UserContent::Audio { input_audio } => {
715                message::UserContent::audio(input_audio.data, Some(input_audio.format))
716            }
717        }
718    }
719}
720
721impl From<String> for UserContent {
722    fn from(s: String) -> Self {
723        UserContent::Text { text: s }
724    }
725}
726
727impl FromStr for UserContent {
728    type Err = Infallible;
729
730    fn from_str(s: &str) -> Result<Self, Self::Err> {
731        Ok(UserContent::Text {
732            text: s.to_string(),
733        })
734    }
735}
736
737impl From<String> for AssistantContent {
738    fn from(s: String) -> Self {
739        AssistantContent::Text { text: s }
740    }
741}
742
743impl FromStr for AssistantContent {
744    type Err = Infallible;
745
746    fn from_str(s: &str) -> Result<Self, Self::Err> {
747        Ok(AssistantContent::Text {
748            text: s.to_string(),
749        })
750    }
751}
752impl From<String> for SystemContent {
753    fn from(s: String) -> Self {
754        SystemContent {
755            r#type: SystemContentType::default(),
756            text: s,
757        }
758    }
759}
760
761impl FromStr for SystemContent {
762    type Err = Infallible;
763
764    fn from_str(s: &str) -> Result<Self, Self::Err> {
765        Ok(SystemContent {
766            r#type: SystemContentType::default(),
767            text: s.to_string(),
768        })
769    }
770}
771
772#[derive(Debug, Deserialize, Serialize)]
773pub struct CompletionResponse {
774    pub id: String,
775    pub object: String,
776    pub created: u64,
777    pub model: String,
778    pub system_fingerprint: Option<String>,
779    pub choices: Vec<Choice>,
780    pub usage: Option<Usage>,
781}
782
783impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
784    type Error = CompletionError;
785
786    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
787        let choice = response.choices.first().ok_or_else(|| {
788            CompletionError::ResponseError("Response contained no choices".to_owned())
789        })?;
790
791        let content = match &choice.message {
792            Message::Assistant {
793                content,
794                tool_calls,
795                ..
796            } => {
797                let mut content = content
798                    .iter()
799                    .filter_map(|c| {
800                        let s = match c {
801                            AssistantContent::Text { text } => text,
802                            AssistantContent::Refusal { refusal } => refusal,
803                        };
804                        if s.is_empty() {
805                            None
806                        } else {
807                            Some(completion::AssistantContent::text(s))
808                        }
809                    })
810                    .collect::<Vec<_>>();
811
812                content.extend(
813                    tool_calls
814                        .iter()
815                        .map(|call| {
816                            completion::AssistantContent::tool_call(
817                                &call.id,
818                                &call.function.name,
819                                call.function.arguments.clone(),
820                            )
821                        })
822                        .collect::<Vec<_>>(),
823                );
824                Ok(content)
825            }
826            _ => Err(CompletionError::ResponseError(
827                "Response did not contain a valid message or tool call".into(),
828            )),
829        }?;
830
831        let choice = OneOrMany::many(content).map_err(|_| {
832            CompletionError::ResponseError(
833                "Response contained no message or tool call (empty)".to_owned(),
834            )
835        })?;
836
837        let usage = response
838            .usage
839            .as_ref()
840            .map(|usage| completion::Usage {
841                input_tokens: usage.prompt_tokens as u64,
842                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
843                total_tokens: usage.total_tokens as u64,
844                cached_input_tokens: usage
845                    .prompt_tokens_details
846                    .as_ref()
847                    .map(|d| d.cached_tokens as u64)
848                    .unwrap_or(0),
849                cache_creation_input_tokens: 0,
850            })
851            .unwrap_or_default();
852
853        Ok(completion::CompletionResponse {
854            choice,
855            usage,
856            raw_response: response,
857            message_id: None,
858        })
859    }
860}
861
862impl ProviderResponseExt for CompletionResponse {
863    type OutputMessage = Choice;
864    type Usage = Usage;
865
866    fn get_response_id(&self) -> Option<String> {
867        Some(self.id.to_owned())
868    }
869
870    fn get_response_model_name(&self) -> Option<String> {
871        Some(self.model.to_owned())
872    }
873
874    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
875        self.choices.clone()
876    }
877
878    fn get_text_response(&self) -> Option<String> {
879        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
880            return None;
881        };
882
883        let UserContent::Text { text } = content.first() else {
884            return None;
885        };
886
887        Some(text)
888    }
889
890    fn get_usage(&self) -> Option<Self::Usage> {
891        self.usage.clone()
892    }
893}
894
895#[derive(Clone, Debug, Serialize, Deserialize)]
896pub struct Choice {
897    pub index: usize,
898    pub message: Message,
899    pub logprobs: Option<serde_json::Value>,
900    pub finish_reason: String,
901}
902
903#[derive(Clone, Debug, Deserialize, Serialize, Default)]
904pub struct PromptTokensDetails {
905    /// Cached tokens from prompt caching
906    #[serde(default)]
907    pub cached_tokens: usize,
908}
909
910#[derive(Clone, Debug, Deserialize, Serialize)]
911pub struct Usage {
912    pub prompt_tokens: usize,
913    pub total_tokens: usize,
914    #[serde(skip_serializing_if = "Option::is_none")]
915    pub prompt_tokens_details: Option<PromptTokensDetails>,
916}
917
918impl Usage {
919    pub fn new() -> Self {
920        Self {
921            prompt_tokens: 0,
922            total_tokens: 0,
923            prompt_tokens_details: None,
924        }
925    }
926}
927
928impl Default for Usage {
929    fn default() -> Self {
930        Self::new()
931    }
932}
933
934impl fmt::Display for Usage {
935    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
936        let Usage {
937            prompt_tokens,
938            total_tokens,
939            ..
940        } = self;
941        write!(
942            f,
943            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
944        )
945    }
946}
947
948impl GetTokenUsage for Usage {
949    fn token_usage(&self) -> Option<crate::completion::Usage> {
950        let mut usage = crate::completion::Usage::new();
951        usage.input_tokens = self.prompt_tokens as u64;
952        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
953        usage.total_tokens = self.total_tokens as u64;
954        usage.cached_input_tokens = self
955            .prompt_tokens_details
956            .as_ref()
957            .map(|d| d.cached_tokens as u64)
958            .unwrap_or(0);
959
960        Some(usage)
961    }
962}
963
964#[derive(Clone)]
965pub struct CompletionModel<T = reqwest::Client> {
966    pub(crate) client: Client<T>,
967    pub model: String,
968    pub strict_tools: bool,
969    pub tool_result_array_content: bool,
970}
971
972impl<T> CompletionModel<T>
973where
974    T: Default + std::fmt::Debug + Clone + 'static,
975{
976    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
977        Self {
978            client,
979            model: model.into(),
980            strict_tools: false,
981            tool_result_array_content: false,
982        }
983    }
984
985    pub fn with_model(client: Client<T>, model: &str) -> Self {
986        Self {
987            client,
988            model: model.into(),
989            strict_tools: false,
990            tool_result_array_content: false,
991        }
992    }
993
994    /// Enable strict mode for tool schemas.
995    ///
996    /// When enabled, tool schemas are automatically sanitized to meet OpenAI's strict mode requirements:
997    /// - `additionalProperties: false` is added to all objects
998    /// - All properties are marked as required
999    /// - `strict: true` is set on each function definition
1000    ///
1001    /// This allows OpenAI to guarantee that the model's tool calls will match the schema exactly.
1002    pub fn with_strict_tools(mut self) -> Self {
1003        self.strict_tools = true;
1004        self
1005    }
1006
1007    pub fn with_tool_result_array_content(mut self) -> Self {
1008        self.tool_result_array_content = true;
1009        self
1010    }
1011}
1012
1013#[derive(Debug, Serialize, Deserialize, Clone)]
1014pub struct CompletionRequest {
1015    model: String,
1016    messages: Vec<Message>,
1017    #[serde(skip_serializing_if = "Vec::is_empty")]
1018    tools: Vec<ToolDefinition>,
1019    #[serde(skip_serializing_if = "Option::is_none")]
1020    tool_choice: Option<ToolChoice>,
1021    #[serde(skip_serializing_if = "Option::is_none")]
1022    temperature: Option<f64>,
1023    #[serde(skip_serializing_if = "Option::is_none")]
1024    max_tokens: Option<u64>,
1025    #[serde(flatten)]
1026    additional_params: Option<serde_json::Value>,
1027}
1028
1029pub struct OpenAIRequestParams {
1030    pub model: String,
1031    pub request: CoreCompletionRequest,
1032    pub strict_tools: bool,
1033    pub tool_result_array_content: bool,
1034}
1035
1036impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1037    type Error = CompletionError;
1038
1039    fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1040        let OpenAIRequestParams {
1041            model,
1042            request: req,
1043            strict_tools,
1044            tool_result_array_content,
1045        } = params;
1046
1047        let mut partial_history = vec![];
1048        if let Some(docs) = req.normalized_documents() {
1049            partial_history.push(docs);
1050        }
1051        let CoreCompletionRequest {
1052            model: request_model,
1053            preamble,
1054            chat_history,
1055            tools,
1056            temperature,
1057            max_tokens,
1058            additional_params,
1059            tool_choice,
1060            output_schema,
1061            ..
1062        } = req;
1063
1064        partial_history.extend(chat_history);
1065
1066        let mut full_history: Vec<Message> =
1067            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1068
1069        full_history.extend(
1070            partial_history
1071                .into_iter()
1072                .map(message::Message::try_into)
1073                .collect::<Result<Vec<Vec<Message>>, _>>()?
1074                .into_iter()
1075                .flatten()
1076                .collect::<Vec<_>>(),
1077        );
1078
1079        if full_history.is_empty() {
1080            return Err(CompletionError::RequestError(
1081                std::io::Error::new(
1082                    std::io::ErrorKind::InvalidInput,
1083                    "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1084                )
1085                .into(),
1086            ));
1087        }
1088
1089        if tool_result_array_content {
1090            for msg in &mut full_history {
1091                if let Message::ToolResult { content, .. } = msg {
1092                    *content = content.to_array();
1093                }
1094            }
1095        }
1096
1097        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1098
1099        let tools: Vec<ToolDefinition> = tools
1100            .into_iter()
1101            .map(|tool| {
1102                let def = ToolDefinition::from(tool);
1103                if strict_tools { def.with_strict() } else { def }
1104            })
1105            .collect();
1106
1107        // Map output_schema to OpenAI's response_format and merge into additional_params
1108        let additional_params = if let Some(schema) = output_schema {
1109            let name = schema
1110                .as_object()
1111                .and_then(|o| o.get("title"))
1112                .and_then(|v| v.as_str())
1113                .unwrap_or("response_schema")
1114                .to_string();
1115            let mut schema_value = schema.to_value();
1116            super::sanitize_schema(&mut schema_value);
1117            let response_format = serde_json::json!({
1118                "response_format": {
1119                    "type": "json_schema",
1120                    "json_schema": {
1121                        "name": name,
1122                        "strict": true,
1123                        "schema": schema_value
1124                    }
1125                }
1126            });
1127            Some(match additional_params {
1128                Some(existing) => json_utils::merge(existing, response_format),
1129                None => response_format,
1130            })
1131        } else {
1132            additional_params
1133        };
1134
1135        let res = Self {
1136            model: request_model.unwrap_or(model),
1137            messages: full_history,
1138            tools,
1139            tool_choice,
1140            temperature,
1141            max_tokens,
1142            additional_params,
1143        };
1144
1145        Ok(res)
1146    }
1147}
1148
1149impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1150    type Error = CompletionError;
1151
1152    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1153        CompletionRequest::try_from(OpenAIRequestParams {
1154            model,
1155            request: req,
1156            strict_tools: false,
1157            tool_result_array_content: false,
1158        })
1159    }
1160}
1161
1162impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1163    type InputMessage = Message;
1164
1165    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1166        self.messages.clone()
1167    }
1168
1169    fn get_system_prompt(&self) -> Option<String> {
1170        let first_message = self.messages.first()?;
1171
1172        let Message::System { ref content, .. } = first_message.clone() else {
1173            return None;
1174        };
1175
1176        let SystemContent { text, .. } = content.first();
1177
1178        Some(text)
1179    }
1180
1181    fn get_prompt(&self) -> Option<String> {
1182        let last_message = self.messages.last()?;
1183
1184        let Message::User { ref content, .. } = last_message.clone() else {
1185            return None;
1186        };
1187
1188        let UserContent::Text { text } = content.first() else {
1189            return None;
1190        };
1191
1192        Some(text)
1193    }
1194
1195    fn get_model_name(&self) -> String {
1196        self.model.clone()
1197    }
1198}
1199
1200impl CompletionModel<reqwest::Client> {
1201    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1202        crate::agent::AgentBuilder::new(self)
1203    }
1204}
1205
1206impl<T> completion::CompletionModel for CompletionModel<T>
1207where
1208    T: HttpClientExt
1209        + Default
1210        + std::fmt::Debug
1211        + Clone
1212        + WasmCompatSend
1213        + WasmCompatSync
1214        + 'static,
1215{
1216    type Response = CompletionResponse;
1217    type StreamingResponse = StreamingCompletionResponse;
1218
1219    type Client = super::CompletionsClient<T>;
1220
1221    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1222        Self::new(client.clone(), model)
1223    }
1224
1225    async fn completion(
1226        &self,
1227        completion_request: CoreCompletionRequest,
1228    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1229        let span = if tracing::Span::current().is_disabled() {
1230            info_span!(
1231                target: "rig::completions",
1232                "chat",
1233                gen_ai.operation.name = "chat",
1234                gen_ai.provider.name = "openai",
1235                gen_ai.request.model = self.model,
1236                gen_ai.system_instructions = &completion_request.preamble,
1237                gen_ai.response.id = tracing::field::Empty,
1238                gen_ai.response.model = tracing::field::Empty,
1239                gen_ai.usage.output_tokens = tracing::field::Empty,
1240                gen_ai.usage.input_tokens = tracing::field::Empty,
1241                gen_ai.usage.cached_tokens = tracing::field::Empty,
1242            )
1243        } else {
1244            tracing::Span::current()
1245        };
1246
1247        let request = CompletionRequest::try_from(OpenAIRequestParams {
1248            model: self.model.to_owned(),
1249            request: completion_request,
1250            strict_tools: self.strict_tools,
1251            tool_result_array_content: self.tool_result_array_content,
1252        })?;
1253
1254        if enabled!(Level::TRACE) {
1255            tracing::trace!(
1256                target: "rig::completions",
1257                "OpenAI Chat Completions completion request: {}",
1258                serde_json::to_string_pretty(&request)?
1259            );
1260        }
1261
1262        let body = serde_json::to_vec(&request)?;
1263
1264        let req = self
1265            .client
1266            .post("/chat/completions")?
1267            .body(body)
1268            .map_err(|e| CompletionError::HttpError(e.into()))?;
1269
1270        async move {
1271            let response = self.client.send(req).await?;
1272
1273            if response.status().is_success() {
1274                let text = http_client::text(response).await?;
1275
1276                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1277                    ApiResponse::Ok(response) => {
1278                        let span = tracing::Span::current();
1279                        span.record_response_metadata(&response);
1280                        span.record_token_usage(&response.usage);
1281
1282                        if enabled!(Level::TRACE) {
1283                            tracing::trace!(
1284                                target: "rig::completions",
1285                                "OpenAI Chat Completions completion response: {}",
1286                                serde_json::to_string_pretty(&response)?
1287                            );
1288                        }
1289
1290                        response.try_into()
1291                    }
1292                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1293                }
1294            } else {
1295                let text = http_client::text(response).await?;
1296                Err(CompletionError::ProviderError(text))
1297            }
1298        }
1299        .instrument(span)
1300        .await
1301    }
1302
1303    async fn stream(
1304        &self,
1305        request: CoreCompletionRequest,
1306    ) -> Result<
1307        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1308        CompletionError,
1309    > {
1310        Self::stream(self, request).await
1311    }
1312}
1313
1314fn serialize_assistant_content_vec<S>(
1315    value: &Vec<AssistantContent>,
1316    serializer: S,
1317) -> Result<S::Ok, S::Error>
1318where
1319    S: Serializer,
1320{
1321    if value.is_empty() {
1322        serializer.serialize_str("")
1323    } else {
1324        value.serialize(serializer)
1325    }
1326}
1327
1328#[cfg(test)]
1329mod tests {
1330    use super::*;
1331
1332    #[test]
1333    fn test_openai_request_uses_request_model_override() {
1334        let request = crate::completion::CompletionRequest {
1335            model: Some("gpt-4.1".to_string()),
1336            preamble: None,
1337            chat_history: crate::OneOrMany::one("Hello".into()),
1338            documents: vec![],
1339            tools: vec![],
1340            temperature: None,
1341            max_tokens: None,
1342            tool_choice: None,
1343            additional_params: None,
1344            output_schema: None,
1345        };
1346
1347        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1348            model: "gpt-4o-mini".to_string(),
1349            request,
1350            strict_tools: false,
1351            tool_result_array_content: false,
1352        })
1353        .expect("request conversion should succeed");
1354        let serialized =
1355            serde_json::to_value(openai_request).expect("serialization should succeed");
1356
1357        assert_eq!(serialized["model"], "gpt-4.1");
1358    }
1359
1360    #[test]
1361    fn test_openai_request_uses_default_model_when_override_unset() {
1362        let request = crate::completion::CompletionRequest {
1363            model: None,
1364            preamble: None,
1365            chat_history: crate::OneOrMany::one("Hello".into()),
1366            documents: vec![],
1367            tools: vec![],
1368            temperature: None,
1369            max_tokens: None,
1370            tool_choice: None,
1371            additional_params: None,
1372            output_schema: None,
1373        };
1374
1375        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1376            model: "gpt-4o-mini".to_string(),
1377            request,
1378            strict_tools: false,
1379            tool_result_array_content: false,
1380        })
1381        .expect("request conversion should succeed");
1382        let serialized =
1383            serde_json::to_value(openai_request).expect("serialization should succeed");
1384
1385        assert_eq!(serialized["model"], "gpt-4o-mini");
1386    }
1387
1388    #[test]
1389    fn assistant_reasoning_is_silently_skipped() {
1390        let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1391
1392        let converted: Vec<Message> = assistant_content
1393            .try_into()
1394            .expect("conversion should work");
1395
1396        assert!(converted.is_empty());
1397    }
1398
1399    #[test]
1400    fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1401        let assistant_content = OneOrMany::many(vec![
1402            message::AssistantContent::reasoning("hidden"),
1403            message::AssistantContent::text("visible"),
1404            message::AssistantContent::tool_call(
1405                "call_1",
1406                "subtract",
1407                serde_json::json!({"x": 2, "y": 1}),
1408            ),
1409        ])
1410        .expect("non-empty assistant content");
1411
1412        let converted: Vec<Message> = assistant_content
1413            .try_into()
1414            .expect("conversion should work");
1415        assert_eq!(converted.len(), 1);
1416
1417        match &converted[0] {
1418            Message::Assistant {
1419                content,
1420                tool_calls,
1421                ..
1422            } => {
1423                assert_eq!(
1424                    content,
1425                    &vec![AssistantContent::Text {
1426                        text: "visible".to_string()
1427                    }]
1428                );
1429                assert_eq!(tool_calls.len(), 1);
1430                assert_eq!(tool_calls[0].id, "call_1");
1431                assert_eq!(tool_calls[0].function.name, "subtract");
1432                assert_eq!(
1433                    tool_calls[0].function.arguments,
1434                    serde_json::json!({"x": 2, "y": 1})
1435                );
1436            }
1437            _ => panic!("expected assistant message"),
1438        }
1439    }
1440
1441    #[test]
1442    fn test_max_tokens_is_forwarded_to_request() {
1443        let request = crate::completion::CompletionRequest {
1444            model: None,
1445            preamble: None,
1446            chat_history: crate::OneOrMany::one("Hello".into()),
1447            documents: vec![],
1448            tools: vec![],
1449            temperature: None,
1450            max_tokens: Some(4096),
1451            tool_choice: None,
1452            additional_params: None,
1453            output_schema: None,
1454        };
1455
1456        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1457            model: "gpt-4o-mini".to_string(),
1458            request,
1459            strict_tools: false,
1460            tool_result_array_content: false,
1461        })
1462        .expect("request conversion should succeed");
1463        let serialized =
1464            serde_json::to_value(openai_request).expect("serialization should succeed");
1465
1466        assert_eq!(serialized["max_tokens"], 4096);
1467    }
1468
1469    #[test]
1470    fn test_max_tokens_omitted_when_none() {
1471        let request = crate::completion::CompletionRequest {
1472            model: None,
1473            preamble: None,
1474            chat_history: crate::OneOrMany::one("Hello".into()),
1475            documents: vec![],
1476            tools: vec![],
1477            temperature: None,
1478            max_tokens: None,
1479            tool_choice: None,
1480            additional_params: None,
1481            output_schema: None,
1482        };
1483
1484        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1485            model: "gpt-4o-mini".to_string(),
1486            request,
1487            strict_tools: false,
1488            tool_result_array_content: false,
1489        })
1490        .expect("request conversion should succeed");
1491        let serialized =
1492            serde_json::to_value(openai_request).expect("serialization should succeed");
1493
1494        assert!(serialized.get("max_tokens").is_none());
1495    }
1496
1497    #[test]
1498    fn request_conversion_errors_when_all_messages_are_filtered() {
1499        let request = CoreCompletionRequest {
1500            model: None,
1501            preamble: None,
1502            chat_history: OneOrMany::one(message::Message::Assistant {
1503                id: None,
1504                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1505            }),
1506            documents: vec![],
1507            tools: vec![],
1508            temperature: None,
1509            max_tokens: None,
1510            tool_choice: None,
1511            additional_params: None,
1512            output_schema: None,
1513        };
1514
1515        let result = CompletionRequest::try_from(OpenAIRequestParams {
1516            model: "gpt-4o-mini".to_string(),
1517            request,
1518            strict_tools: false,
1519            tool_result_array_content: false,
1520        });
1521
1522        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1523    }
1524}