Skip to main content

rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{
6    CompletionsClient as Client,
7    client::{ApiErrorResponse, ApiResponse},
8    streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize, Serializer};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28/// Serializes user content as a plain string when there's a single text item,
29/// otherwise as an array of content parts.
30fn serialize_user_content<S>(
31    content: &OneOrMany<UserContent>,
32    serializer: S,
33) -> Result<S::Ok, S::Error>
34where
35    S: Serializer,
36{
37    if content.len() == 1
38        && let UserContent::Text { text } = content.first_ref()
39    {
40        return serializer.serialize_str(text);
41    }
42    content.serialize(serializer)
43}
44
45/// `gpt-5.2` completion model
46pub const GPT_5_2: &str = "gpt-5.2";
47
48/// `gpt-5.1` completion model
49pub const GPT_5_1: &str = "gpt-5.1";
50
51/// `gpt-5` completion model
52pub const GPT_5: &str = "gpt-5";
53/// `gpt-5` completion model
54pub const GPT_5_MINI: &str = "gpt-5-mini";
55/// `gpt-5` completion model
56pub const GPT_5_NANO: &str = "gpt-5-nano";
57
58/// `gpt-4.5-preview` completion model
59pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
60/// `gpt-4.5-preview-2025-02-27` completion model
61pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
62/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
63pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
64/// `gpt-4o` completion model
65pub const GPT_4O: &str = "gpt-4o";
66/// `gpt-4o-mini` completion model
67pub const GPT_4O_MINI: &str = "gpt-4o-mini";
68/// `gpt-4o-2024-05-13` completion model
69pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
70/// `gpt-4-turbo` completion model
71pub const GPT_4_TURBO: &str = "gpt-4-turbo";
72/// `gpt-4-turbo-2024-04-09` completion model
73pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
74/// `gpt-4-turbo-preview` completion model
75pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
76/// `gpt-4-0125-preview` completion model
77pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
78/// `gpt-4-1106-preview` completion model
79pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
80/// `gpt-4-vision-preview` completion model
81pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
82/// `gpt-4-1106-vision-preview` completion model
83pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
84/// `gpt-4` completion model
85pub const GPT_4: &str = "gpt-4";
86/// `gpt-4-0613` completion model
87pub const GPT_4_0613: &str = "gpt-4-0613";
88/// `gpt-4-32k` completion model
89pub const GPT_4_32K: &str = "gpt-4-32k";
90/// `gpt-4-32k-0613` completion model
91pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
92
93/// `o4-mini-2025-04-16` completion model
94pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
95/// `o4-mini` completion model
96pub const O4_MINI: &str = "o4-mini";
97/// `o3` completion model
98pub const O3: &str = "o3";
99/// `o3-mini` completion model
100pub const O3_MINI: &str = "o3-mini";
101/// `o3-mini-2025-01-31` completion model
102pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
103/// `o1-pro` completion model
104pub const O1_PRO: &str = "o1-pro";
105/// `o1`` completion model
106pub const O1: &str = "o1";
107/// `o1-2024-12-17` completion model
108pub const O1_2024_12_17: &str = "o1-2024-12-17";
109/// `o1-preview` completion model
110pub const O1_PREVIEW: &str = "o1-preview";
111/// `o1-preview-2024-09-12` completion model
112pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
113/// `o1-mini completion model
114pub const O1_MINI: &str = "o1-mini";
115/// `o1-mini-2024-09-12` completion model
116pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
117
118/// `gpt-4.1-mini` completion model
119pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
120/// `gpt-4.1-nano` completion model
121pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
122/// `gpt-4.1-2025-04-14` completion model
123pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
124/// `gpt-4.1` completion model
125pub const GPT_4_1: &str = "gpt-4.1";
126
127impl From<ApiErrorResponse> for CompletionError {
128    fn from(err: ApiErrorResponse) -> Self {
129        CompletionError::ProviderError(err.message)
130    }
131}
132
133#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
134#[serde(tag = "role", rename_all = "lowercase")]
135pub enum Message {
136    #[serde(alias = "developer")]
137    System {
138        #[serde(deserialize_with = "string_or_one_or_many")]
139        content: OneOrMany<SystemContent>,
140        #[serde(skip_serializing_if = "Option::is_none")]
141        name: Option<String>,
142    },
143    User {
144        #[serde(
145            deserialize_with = "string_or_one_or_many",
146            serialize_with = "serialize_user_content"
147        )]
148        content: OneOrMany<UserContent>,
149        #[serde(skip_serializing_if = "Option::is_none")]
150        name: Option<String>,
151    },
152    Assistant {
153        #[serde(
154            default,
155            deserialize_with = "json_utils::string_or_vec",
156            skip_serializing_if = "Vec::is_empty",
157            serialize_with = "serialize_assistant_content_vec"
158        )]
159        content: Vec<AssistantContent>,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        refusal: Option<String>,
162        #[serde(skip_serializing_if = "Option::is_none")]
163        audio: Option<AudioAssistant>,
164        #[serde(skip_serializing_if = "Option::is_none")]
165        name: Option<String>,
166        #[serde(
167            default,
168            deserialize_with = "json_utils::null_or_vec",
169            skip_serializing_if = "Vec::is_empty"
170        )]
171        tool_calls: Vec<ToolCall>,
172    },
173    #[serde(rename = "tool")]
174    ToolResult {
175        tool_call_id: String,
176        content: ToolResultContentValue,
177    },
178}
179
180impl Message {
181    pub fn system(content: &str) -> Self {
182        Message::System {
183            content: OneOrMany::one(content.to_owned().into()),
184            name: None,
185        }
186    }
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190pub struct AudioAssistant {
191    pub id: String,
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct SystemContent {
196    #[serde(default)]
197    pub r#type: SystemContentType,
198    pub text: String,
199}
200
201#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
202#[serde(rename_all = "lowercase")]
203pub enum SystemContentType {
204    #[default]
205    Text,
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum AssistantContent {
211    Text { text: String },
212    Refusal { refusal: String },
213}
214
215impl From<AssistantContent> for completion::AssistantContent {
216    fn from(value: AssistantContent) -> Self {
217        match value {
218            AssistantContent::Text { text } => completion::AssistantContent::text(text),
219            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
220        }
221    }
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum UserContent {
227    Text {
228        text: String,
229    },
230    #[serde(rename = "image_url")]
231    Image {
232        image_url: ImageUrl,
233    },
234    Audio {
235        input_audio: InputAudio,
236    },
237}
238
239#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
240pub struct ImageUrl {
241    pub url: String,
242    #[serde(default)]
243    pub detail: ImageDetail,
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct InputAudio {
248    pub data: String,
249    pub format: AudioMediaType,
250}
251
252#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
253pub struct ToolResultContent {
254    #[serde(default)]
255    r#type: ToolResultContentType,
256    pub text: String,
257}
258
259#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(rename_all = "lowercase")]
261pub enum ToolResultContentType {
262    #[default]
263    Text,
264}
265
266impl FromStr for ToolResultContent {
267    type Err = Infallible;
268
269    fn from_str(s: &str) -> Result<Self, Self::Err> {
270        Ok(s.to_owned().into())
271    }
272}
273
274impl From<String> for ToolResultContent {
275    fn from(s: String) -> Self {
276        ToolResultContent {
277            r#type: ToolResultContentType::default(),
278            text: s,
279        }
280    }
281}
282
283#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
284#[serde(untagged)]
285pub enum ToolResultContentValue {
286    Array(Vec<ToolResultContent>),
287    String(String),
288}
289
290impl ToolResultContentValue {
291    pub fn from_string(s: String, use_array_format: bool) -> Self {
292        if use_array_format {
293            ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
294        } else {
295            ToolResultContentValue::String(s)
296        }
297    }
298
299    pub fn as_text(&self) -> String {
300        match self {
301            ToolResultContentValue::Array(arr) => arr
302                .iter()
303                .map(|c| c.text.clone())
304                .collect::<Vec<_>>()
305                .join("\n"),
306            ToolResultContentValue::String(s) => s.clone(),
307        }
308    }
309
310    pub fn to_array(&self) -> Self {
311        match self {
312            ToolResultContentValue::Array(_) => self.clone(),
313            ToolResultContentValue::String(s) => {
314                ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
315            }
316        }
317    }
318}
319
320#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
321pub struct ToolCall {
322    pub id: String,
323    #[serde(default)]
324    pub r#type: ToolType,
325    pub function: Function,
326}
327
328#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
329#[serde(rename_all = "lowercase")]
330pub enum ToolType {
331    #[default]
332    Function,
333}
334
335/// Function definition for a tool, with optional strict mode
336#[derive(Debug, Deserialize, Serialize, Clone)]
337pub struct FunctionDefinition {
338    pub name: String,
339    pub description: String,
340    pub parameters: serde_json::Value,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub strict: Option<bool>,
343}
344
345#[derive(Debug, Deserialize, Serialize, Clone)]
346pub struct ToolDefinition {
347    pub r#type: String,
348    pub function: FunctionDefinition,
349}
350
351impl From<completion::ToolDefinition> for ToolDefinition {
352    fn from(tool: completion::ToolDefinition) -> Self {
353        Self {
354            r#type: "function".into(),
355            function: FunctionDefinition {
356                name: tool.name,
357                description: tool.description,
358                parameters: tool.parameters,
359                strict: None,
360            },
361        }
362    }
363}
364
365impl ToolDefinition {
366    /// Apply strict mode to this tool definition.
367    /// This sets `strict: true` and sanitizes the schema to meet OpenAI requirements.
368    pub fn with_strict(mut self) -> Self {
369        self.function.strict = Some(true);
370        super::sanitize_schema(&mut self.function.parameters);
371        self
372    }
373}
374
375#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
376#[serde(rename_all = "snake_case")]
377pub enum ToolChoice {
378    #[default]
379    Auto,
380    None,
381    Required,
382}
383
384impl TryFrom<crate::message::ToolChoice> for ToolChoice {
385    type Error = CompletionError;
386    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
387        let res = match value {
388            message::ToolChoice::Specific { .. } => {
389                return Err(CompletionError::ProviderError(
390                    "Provider doesn't support only using specific tools".to_string(),
391                ));
392            }
393            message::ToolChoice::Auto => Self::Auto,
394            message::ToolChoice::None => Self::None,
395            message::ToolChoice::Required => Self::Required,
396        };
397
398        Ok(res)
399    }
400}
401
402#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
403pub struct Function {
404    pub name: String,
405    #[serde(with = "json_utils::stringified_json")]
406    pub arguments: serde_json::Value,
407}
408
409impl TryFrom<message::ToolResult> for Message {
410    type Error = message::MessageError;
411
412    fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
413        let text = value
414            .content
415            .into_iter()
416            .map(|content| {
417                match content {
418                message::ToolResultContent::Text(message::Text { text }) => Ok(text),
419                message::ToolResultContent::Image(_) => Err(message::MessageError::ConversionError(
420                    "OpenAI does not support images in tool results. Tool results must be text."
421                        .into(),
422                )),
423            }
424            })
425            .collect::<Result<Vec<_>, _>>()?
426            .join("\n");
427
428        Ok(Message::ToolResult {
429            tool_call_id: value.id,
430            content: ToolResultContentValue::String(text),
431        })
432    }
433}
434
435impl TryFrom<message::UserContent> for UserContent {
436    type Error = message::MessageError;
437
438    fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
439        match value {
440            message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
441            message::UserContent::Image(message::Image {
442                data,
443                detail,
444                media_type,
445                ..
446            }) => match data {
447                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
448                    image_url: ImageUrl {
449                        url,
450                        detail: detail.unwrap_or_default(),
451                    },
452                }),
453                DocumentSourceKind::Base64(data) => {
454                    let url = format!(
455                        "data:{};base64,{}",
456                        media_type.map(|i| i.to_mime_type()).ok_or(
457                            message::MessageError::ConversionError(
458                                "OpenAI Image URI must have media type".into()
459                            )
460                        )?,
461                        data
462                    );
463
464                    let detail = detail.ok_or(message::MessageError::ConversionError(
465                        "OpenAI image URI must have image detail".into(),
466                    ))?;
467
468                    Ok(UserContent::Image {
469                        image_url: ImageUrl { url, detail },
470                    })
471                }
472                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
473                    "Raw files not supported, encode as base64 first".into(),
474                )),
475                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
476                    "Document has no body".into(),
477                )),
478                doc => Err(message::MessageError::ConversionError(format!(
479                    "Unsupported document type: {doc:?}"
480                ))),
481            },
482            message::UserContent::Document(message::Document { data, .. }) => {
483                if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
484                    Ok(UserContent::Text { text })
485                } else {
486                    Err(message::MessageError::ConversionError(
487                        "Documents must be base64 or a string".into(),
488                    ))
489                }
490            }
491            message::UserContent::Audio(message::Audio {
492                data, media_type, ..
493            }) => match data {
494                DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
495                    input_audio: InputAudio {
496                        data,
497                        format: match media_type {
498                            Some(media_type) => media_type,
499                            None => AudioMediaType::MP3,
500                        },
501                    },
502                }),
503                DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
504                    "URLs are not supported for audio".into(),
505                )),
506                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
507                    "Raw files are not supported for audio".into(),
508                )),
509                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
510                    "Audio has no body".into(),
511                )),
512                audio => Err(message::MessageError::ConversionError(format!(
513                    "Unsupported audio type: {audio:?}"
514                ))),
515            },
516            message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
517                "Tool result is in unsupported format".into(),
518            )),
519            message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
520                "Video is in unsupported format".into(),
521            )),
522        }
523    }
524}
525
526impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
527    type Error = message::MessageError;
528
529    fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
530        let (tool_results, other_content): (Vec<_>, Vec<_>) = value
531            .into_iter()
532            .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
533
534        // If there are messages with both tool results and user content, openai will only
535        //  handle tool results. It's unlikely that there will be both.
536        if !tool_results.is_empty() {
537            tool_results
538                .into_iter()
539                .map(|content| match content {
540                    message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
541                    _ => unreachable!(),
542                })
543                .collect::<Result<Vec<_>, _>>()
544        } else {
545            let other_content: Vec<UserContent> = other_content
546                .into_iter()
547                .map(|content| content.try_into())
548                .collect::<Result<Vec<_>, _>>()?;
549
550            let other_content = OneOrMany::many(other_content)
551                .expect("There must be other content here if there were no tool result content");
552
553            Ok(vec![Message::User {
554                content: other_content,
555                name: None,
556            }])
557        }
558    }
559}
560
561impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
562    type Error = message::MessageError;
563
564    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
565        let mut text_content = Vec::new();
566        let mut tool_calls = Vec::new();
567
568        for content in value {
569            match content {
570                message::AssistantContent::Text(text) => text_content.push(text),
571                message::AssistantContent::ToolCall(tool_call) => tool_calls.push(tool_call),
572                message::AssistantContent::Reasoning(_) => {
573                    // OpenAI Chat Completions does not support assistant-history reasoning items.
574                    // Silently skip unsupported reasoning content.
575                }
576                message::AssistantContent::Image(_) => {
577                    panic!(
578                        "The OpenAI Completions API doesn't support image content in assistant messages!"
579                    );
580                }
581            }
582        }
583
584        if text_content.is_empty() && tool_calls.is_empty() {
585            return Ok(vec![]);
586        }
587
588        Ok(vec![Message::Assistant {
589            content: text_content
590                .into_iter()
591                .map(|content| content.text.into())
592                .collect::<Vec<_>>(),
593            refusal: None,
594            audio: None,
595            name: None,
596            tool_calls: tool_calls
597                .into_iter()
598                .map(|tool_call| tool_call.into())
599                .collect::<Vec<_>>(),
600        }])
601    }
602}
603
604impl TryFrom<message::Message> for Vec<Message> {
605    type Error = message::MessageError;
606
607    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
608        match message {
609            message::Message::User { content } => content.try_into(),
610            message::Message::Assistant { content, .. } => content.try_into(),
611        }
612    }
613}
614
615impl From<message::ToolCall> for ToolCall {
616    fn from(tool_call: message::ToolCall) -> Self {
617        Self {
618            id: tool_call.id,
619            r#type: ToolType::default(),
620            function: Function {
621                name: tool_call.function.name,
622                arguments: tool_call.function.arguments,
623            },
624        }
625    }
626}
627
628impl From<ToolCall> for message::ToolCall {
629    fn from(tool_call: ToolCall) -> Self {
630        Self {
631            id: tool_call.id,
632            call_id: None,
633            function: message::ToolFunction {
634                name: tool_call.function.name,
635                arguments: tool_call.function.arguments,
636            },
637            signature: None,
638            additional_params: None,
639        }
640    }
641}
642
643impl TryFrom<Message> for message::Message {
644    type Error = message::MessageError;
645
646    fn try_from(message: Message) -> Result<Self, Self::Error> {
647        Ok(match message {
648            Message::User { content, .. } => message::Message::User {
649                content: content.map(|content| content.into()),
650            },
651            Message::Assistant {
652                content,
653                tool_calls,
654                ..
655            } => {
656                let mut content = content
657                    .into_iter()
658                    .map(|content| match content {
659                        AssistantContent::Text { text } => message::AssistantContent::text(text),
660
661                        // TODO: Currently, refusals are converted into text, but should be
662                        //  investigated for generalization.
663                        AssistantContent::Refusal { refusal } => {
664                            message::AssistantContent::text(refusal)
665                        }
666                    })
667                    .collect::<Vec<_>>();
668
669                content.extend(
670                    tool_calls
671                        .into_iter()
672                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
673                        .collect::<Result<Vec<_>, _>>()?,
674                );
675
676                message::Message::Assistant {
677                    id: None,
678                    content: OneOrMany::many(content).map_err(|_| {
679                        message::MessageError::ConversionError(
680                            "Neither `content` nor `tool_calls` was provided to the Message"
681                                .to_owned(),
682                        )
683                    })?,
684                }
685            }
686
687            Message::ToolResult {
688                tool_call_id,
689                content,
690            } => message::Message::User {
691                content: OneOrMany::one(message::UserContent::tool_result(
692                    tool_call_id,
693                    OneOrMany::one(message::ToolResultContent::text(content.as_text())),
694                )),
695            },
696
697            // System messages should get stripped out when converting messages, this is just a
698            // stop gap to avoid obnoxious error handling or panic occurring.
699            Message::System { content, .. } => message::Message::User {
700                content: content.map(|content| message::UserContent::text(content.text)),
701            },
702        })
703    }
704}
705
706impl From<UserContent> for message::UserContent {
707    fn from(content: UserContent) -> Self {
708        match content {
709            UserContent::Text { text } => message::UserContent::text(text),
710            UserContent::Image { image_url } => {
711                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
712            }
713            UserContent::Audio { input_audio } => {
714                message::UserContent::audio(input_audio.data, Some(input_audio.format))
715            }
716        }
717    }
718}
719
720impl From<String> for UserContent {
721    fn from(s: String) -> Self {
722        UserContent::Text { text: s }
723    }
724}
725
726impl FromStr for UserContent {
727    type Err = Infallible;
728
729    fn from_str(s: &str) -> Result<Self, Self::Err> {
730        Ok(UserContent::Text {
731            text: s.to_string(),
732        })
733    }
734}
735
736impl From<String> for AssistantContent {
737    fn from(s: String) -> Self {
738        AssistantContent::Text { text: s }
739    }
740}
741
742impl FromStr for AssistantContent {
743    type Err = Infallible;
744
745    fn from_str(s: &str) -> Result<Self, Self::Err> {
746        Ok(AssistantContent::Text {
747            text: s.to_string(),
748        })
749    }
750}
751impl From<String> for SystemContent {
752    fn from(s: String) -> Self {
753        SystemContent {
754            r#type: SystemContentType::default(),
755            text: s,
756        }
757    }
758}
759
760impl FromStr for SystemContent {
761    type Err = Infallible;
762
763    fn from_str(s: &str) -> Result<Self, Self::Err> {
764        Ok(SystemContent {
765            r#type: SystemContentType::default(),
766            text: s.to_string(),
767        })
768    }
769}
770
771#[derive(Debug, Deserialize, Serialize)]
772pub struct CompletionResponse {
773    pub id: String,
774    pub object: String,
775    pub created: u64,
776    pub model: String,
777    pub system_fingerprint: Option<String>,
778    pub choices: Vec<Choice>,
779    pub usage: Option<Usage>,
780}
781
782impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
783    type Error = CompletionError;
784
785    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
786        let choice = response.choices.first().ok_or_else(|| {
787            CompletionError::ResponseError("Response contained no choices".to_owned())
788        })?;
789
790        let content = match &choice.message {
791            Message::Assistant {
792                content,
793                tool_calls,
794                ..
795            } => {
796                let mut content = content
797                    .iter()
798                    .filter_map(|c| {
799                        let s = match c {
800                            AssistantContent::Text { text } => text,
801                            AssistantContent::Refusal { refusal } => refusal,
802                        };
803                        if s.is_empty() {
804                            None
805                        } else {
806                            Some(completion::AssistantContent::text(s))
807                        }
808                    })
809                    .collect::<Vec<_>>();
810
811                content.extend(
812                    tool_calls
813                        .iter()
814                        .map(|call| {
815                            completion::AssistantContent::tool_call(
816                                &call.id,
817                                &call.function.name,
818                                call.function.arguments.clone(),
819                            )
820                        })
821                        .collect::<Vec<_>>(),
822                );
823                Ok(content)
824            }
825            _ => Err(CompletionError::ResponseError(
826                "Response did not contain a valid message or tool call".into(),
827            )),
828        }?;
829
830        let choice = OneOrMany::many(content).map_err(|_| {
831            CompletionError::ResponseError(
832                "Response contained no message or tool call (empty)".to_owned(),
833            )
834        })?;
835
836        let usage = response
837            .usage
838            .as_ref()
839            .map(|usage| completion::Usage {
840                input_tokens: usage.prompt_tokens as u64,
841                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
842                total_tokens: usage.total_tokens as u64,
843                cached_input_tokens: usage
844                    .prompt_tokens_details
845                    .as_ref()
846                    .map(|d| d.cached_tokens as u64)
847                    .unwrap_or(0),
848            })
849            .unwrap_or_default();
850
851        Ok(completion::CompletionResponse {
852            choice,
853            usage,
854            raw_response: response,
855            message_id: None,
856        })
857    }
858}
859
860impl ProviderResponseExt for CompletionResponse {
861    type OutputMessage = Choice;
862    type Usage = Usage;
863
864    fn get_response_id(&self) -> Option<String> {
865        Some(self.id.to_owned())
866    }
867
868    fn get_response_model_name(&self) -> Option<String> {
869        Some(self.model.to_owned())
870    }
871
872    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
873        self.choices.clone()
874    }
875
876    fn get_text_response(&self) -> Option<String> {
877        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
878            return None;
879        };
880
881        let UserContent::Text { text } = content.first() else {
882            return None;
883        };
884
885        Some(text)
886    }
887
888    fn get_usage(&self) -> Option<Self::Usage> {
889        self.usage.clone()
890    }
891}
892
893#[derive(Clone, Debug, Serialize, Deserialize)]
894pub struct Choice {
895    pub index: usize,
896    pub message: Message,
897    pub logprobs: Option<serde_json::Value>,
898    pub finish_reason: String,
899}
900
901#[derive(Clone, Debug, Deserialize, Serialize, Default)]
902pub struct PromptTokensDetails {
903    /// Cached tokens from prompt caching
904    #[serde(default)]
905    pub cached_tokens: usize,
906}
907
908#[derive(Clone, Debug, Deserialize, Serialize)]
909pub struct Usage {
910    pub prompt_tokens: usize,
911    pub total_tokens: usize,
912    #[serde(skip_serializing_if = "Option::is_none")]
913    pub prompt_tokens_details: Option<PromptTokensDetails>,
914}
915
916impl Usage {
917    pub fn new() -> Self {
918        Self {
919            prompt_tokens: 0,
920            total_tokens: 0,
921            prompt_tokens_details: None,
922        }
923    }
924}
925
926impl Default for Usage {
927    fn default() -> Self {
928        Self::new()
929    }
930}
931
932impl fmt::Display for Usage {
933    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
934        let Usage {
935            prompt_tokens,
936            total_tokens,
937            ..
938        } = self;
939        write!(
940            f,
941            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
942        )
943    }
944}
945
946impl GetTokenUsage for Usage {
947    fn token_usage(&self) -> Option<crate::completion::Usage> {
948        let mut usage = crate::completion::Usage::new();
949        usage.input_tokens = self.prompt_tokens as u64;
950        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
951        usage.total_tokens = self.total_tokens as u64;
952        usage.cached_input_tokens = self
953            .prompt_tokens_details
954            .as_ref()
955            .map(|d| d.cached_tokens as u64)
956            .unwrap_or(0);
957
958        Some(usage)
959    }
960}
961
962#[derive(Clone)]
963pub struct CompletionModel<T = reqwest::Client> {
964    pub(crate) client: Client<T>,
965    pub model: String,
966    pub strict_tools: bool,
967    pub tool_result_array_content: bool,
968}
969
970impl<T> CompletionModel<T>
971where
972    T: Default + std::fmt::Debug + Clone + 'static,
973{
974    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
975        Self {
976            client,
977            model: model.into(),
978            strict_tools: false,
979            tool_result_array_content: false,
980        }
981    }
982
983    pub fn with_model(client: Client<T>, model: &str) -> Self {
984        Self {
985            client,
986            model: model.into(),
987            strict_tools: false,
988            tool_result_array_content: false,
989        }
990    }
991
992    /// Enable strict mode for tool schemas.
993    ///
994    /// When enabled, tool schemas are automatically sanitized to meet OpenAI's strict mode requirements:
995    /// - `additionalProperties: false` is added to all objects
996    /// - All properties are marked as required
997    /// - `strict: true` is set on each function definition
998    ///
999    /// This allows OpenAI to guarantee that the model's tool calls will match the schema exactly.
1000    pub fn with_strict_tools(mut self) -> Self {
1001        self.strict_tools = true;
1002        self
1003    }
1004
1005    pub fn with_tool_result_array_content(mut self) -> Self {
1006        self.tool_result_array_content = true;
1007        self
1008    }
1009}
1010
1011#[derive(Debug, Serialize, Deserialize, Clone)]
1012pub struct CompletionRequest {
1013    model: String,
1014    messages: Vec<Message>,
1015    #[serde(skip_serializing_if = "Vec::is_empty")]
1016    tools: Vec<ToolDefinition>,
1017    #[serde(skip_serializing_if = "Option::is_none")]
1018    tool_choice: Option<ToolChoice>,
1019    #[serde(skip_serializing_if = "Option::is_none")]
1020    temperature: Option<f64>,
1021    #[serde(flatten)]
1022    additional_params: Option<serde_json::Value>,
1023}
1024
1025pub struct OpenAIRequestParams {
1026    pub model: String,
1027    pub request: CoreCompletionRequest,
1028    pub strict_tools: bool,
1029    pub tool_result_array_content: bool,
1030}
1031
1032impl TryFrom<OpenAIRequestParams> for CompletionRequest {
1033    type Error = CompletionError;
1034
1035    fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
1036        let OpenAIRequestParams {
1037            model,
1038            request: req,
1039            strict_tools,
1040            tool_result_array_content,
1041        } = params;
1042
1043        let mut partial_history = vec![];
1044        if let Some(docs) = req.normalized_documents() {
1045            partial_history.push(docs);
1046        }
1047        let CoreCompletionRequest {
1048            model: request_model,
1049            preamble,
1050            chat_history,
1051            tools,
1052            temperature,
1053            additional_params,
1054            tool_choice,
1055            output_schema,
1056            ..
1057        } = req;
1058
1059        partial_history.extend(chat_history);
1060
1061        let mut full_history: Vec<Message> =
1062            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1063
1064        full_history.extend(
1065            partial_history
1066                .into_iter()
1067                .map(message::Message::try_into)
1068                .collect::<Result<Vec<Vec<Message>>, _>>()?
1069                .into_iter()
1070                .flatten()
1071                .collect::<Vec<_>>(),
1072        );
1073
1074        if full_history.is_empty() {
1075            return Err(CompletionError::RequestError(
1076                std::io::Error::new(
1077                    std::io::ErrorKind::InvalidInput,
1078                    "OpenAI Chat Completions request has no provider-compatible messages after conversion",
1079                )
1080                .into(),
1081            ));
1082        }
1083
1084        if tool_result_array_content {
1085            for msg in &mut full_history {
1086                if let Message::ToolResult { content, .. } = msg {
1087                    *content = content.to_array();
1088                }
1089            }
1090        }
1091
1092        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1093
1094        let tools: Vec<ToolDefinition> = tools
1095            .into_iter()
1096            .map(|tool| {
1097                let def = ToolDefinition::from(tool);
1098                if strict_tools { def.with_strict() } else { def }
1099            })
1100            .collect();
1101
1102        // Map output_schema to OpenAI's response_format and merge into additional_params
1103        let additional_params = if let Some(schema) = output_schema {
1104            let name = schema
1105                .as_object()
1106                .and_then(|o| o.get("title"))
1107                .and_then(|v| v.as_str())
1108                .unwrap_or("response_schema")
1109                .to_string();
1110            let mut schema_value = schema.to_value();
1111            super::sanitize_schema(&mut schema_value);
1112            let response_format = serde_json::json!({
1113                "response_format": {
1114                    "type": "json_schema",
1115                    "json_schema": {
1116                        "name": name,
1117                        "strict": true,
1118                        "schema": schema_value
1119                    }
1120                }
1121            });
1122            Some(match additional_params {
1123                Some(existing) => json_utils::merge(existing, response_format),
1124                None => response_format,
1125            })
1126        } else {
1127            additional_params
1128        };
1129
1130        let res = Self {
1131            model: request_model.unwrap_or(model),
1132            messages: full_history,
1133            tools,
1134            tool_choice,
1135            temperature,
1136            additional_params,
1137        };
1138
1139        Ok(res)
1140    }
1141}
1142
1143impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1144    type Error = CompletionError;
1145
1146    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1147        CompletionRequest::try_from(OpenAIRequestParams {
1148            model,
1149            request: req,
1150            strict_tools: false,
1151            tool_result_array_content: false,
1152        })
1153    }
1154}
1155
1156impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1157    type InputMessage = Message;
1158
1159    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1160        self.messages.clone()
1161    }
1162
1163    fn get_system_prompt(&self) -> Option<String> {
1164        let first_message = self.messages.first()?;
1165
1166        let Message::System { ref content, .. } = first_message.clone() else {
1167            return None;
1168        };
1169
1170        let SystemContent { text, .. } = content.first();
1171
1172        Some(text)
1173    }
1174
1175    fn get_prompt(&self) -> Option<String> {
1176        let last_message = self.messages.last()?;
1177
1178        let Message::User { ref content, .. } = last_message.clone() else {
1179            return None;
1180        };
1181
1182        let UserContent::Text { text } = content.first() else {
1183            return None;
1184        };
1185
1186        Some(text)
1187    }
1188
1189    fn get_model_name(&self) -> String {
1190        self.model.clone()
1191    }
1192}
1193
1194impl CompletionModel<reqwest::Client> {
1195    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1196        crate::agent::AgentBuilder::new(self)
1197    }
1198}
1199
1200impl<T> completion::CompletionModel for CompletionModel<T>
1201where
1202    T: HttpClientExt
1203        + Default
1204        + std::fmt::Debug
1205        + Clone
1206        + WasmCompatSend
1207        + WasmCompatSync
1208        + 'static,
1209{
1210    type Response = CompletionResponse;
1211    type StreamingResponse = StreamingCompletionResponse;
1212
1213    type Client = super::CompletionsClient<T>;
1214
1215    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1216        Self::new(client.clone(), model)
1217    }
1218
1219    async fn completion(
1220        &self,
1221        completion_request: CoreCompletionRequest,
1222    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1223        let span = if tracing::Span::current().is_disabled() {
1224            info_span!(
1225                target: "rig::completions",
1226                "chat",
1227                gen_ai.operation.name = "chat",
1228                gen_ai.provider.name = "openai",
1229                gen_ai.request.model = self.model,
1230                gen_ai.system_instructions = &completion_request.preamble,
1231                gen_ai.response.id = tracing::field::Empty,
1232                gen_ai.response.model = tracing::field::Empty,
1233                gen_ai.usage.output_tokens = tracing::field::Empty,
1234                gen_ai.usage.input_tokens = tracing::field::Empty,
1235            )
1236        } else {
1237            tracing::Span::current()
1238        };
1239
1240        let request = CompletionRequest::try_from(OpenAIRequestParams {
1241            model: self.model.to_owned(),
1242            request: completion_request,
1243            strict_tools: self.strict_tools,
1244            tool_result_array_content: self.tool_result_array_content,
1245        })?;
1246
1247        if enabled!(Level::TRACE) {
1248            tracing::trace!(
1249                target: "rig::completions",
1250                "OpenAI Chat Completions completion request: {}",
1251                serde_json::to_string_pretty(&request)?
1252            );
1253        }
1254
1255        let body = serde_json::to_vec(&request)?;
1256
1257        let req = self
1258            .client
1259            .post("/chat/completions")?
1260            .body(body)
1261            .map_err(|e| CompletionError::HttpError(e.into()))?;
1262
1263        async move {
1264            let response = self.client.send(req).await?;
1265
1266            if response.status().is_success() {
1267                let text = http_client::text(response).await?;
1268
1269                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1270                    ApiResponse::Ok(response) => {
1271                        let span = tracing::Span::current();
1272                        span.record_response_metadata(&response);
1273                        span.record_token_usage(&response.usage);
1274
1275                        if enabled!(Level::TRACE) {
1276                            tracing::trace!(
1277                                target: "rig::completions",
1278                                "OpenAI Chat Completions completion response: {}",
1279                                serde_json::to_string_pretty(&response)?
1280                            );
1281                        }
1282
1283                        response.try_into()
1284                    }
1285                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1286                }
1287            } else {
1288                let text = http_client::text(response).await?;
1289                Err(CompletionError::ProviderError(text))
1290            }
1291        }
1292        .instrument(span)
1293        .await
1294    }
1295
1296    async fn stream(
1297        &self,
1298        request: CoreCompletionRequest,
1299    ) -> Result<
1300        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1301        CompletionError,
1302    > {
1303        Self::stream(self, request).await
1304    }
1305}
1306
1307fn serialize_assistant_content_vec<S>(
1308    value: &Vec<AssistantContent>,
1309    serializer: S,
1310) -> Result<S::Ok, S::Error>
1311where
1312    S: Serializer,
1313{
1314    if value.is_empty() {
1315        serializer.serialize_str("")
1316    } else {
1317        value.serialize(serializer)
1318    }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323    use super::*;
1324
1325    #[test]
1326    fn test_openai_request_uses_request_model_override() {
1327        let request = crate::completion::CompletionRequest {
1328            model: Some("gpt-4.1".to_string()),
1329            preamble: None,
1330            chat_history: crate::OneOrMany::one("Hello".into()),
1331            documents: vec![],
1332            tools: vec![],
1333            temperature: None,
1334            max_tokens: None,
1335            tool_choice: None,
1336            additional_params: None,
1337            output_schema: None,
1338        };
1339
1340        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1341            model: "gpt-4o-mini".to_string(),
1342            request,
1343            strict_tools: false,
1344            tool_result_array_content: false,
1345        })
1346        .expect("request conversion should succeed");
1347        let serialized =
1348            serde_json::to_value(openai_request).expect("serialization should succeed");
1349
1350        assert_eq!(serialized["model"], "gpt-4.1");
1351    }
1352
1353    #[test]
1354    fn test_openai_request_uses_default_model_when_override_unset() {
1355        let request = crate::completion::CompletionRequest {
1356            model: None,
1357            preamble: None,
1358            chat_history: crate::OneOrMany::one("Hello".into()),
1359            documents: vec![],
1360            tools: vec![],
1361            temperature: None,
1362            max_tokens: None,
1363            tool_choice: None,
1364            additional_params: None,
1365            output_schema: None,
1366        };
1367
1368        let openai_request = CompletionRequest::try_from(OpenAIRequestParams {
1369            model: "gpt-4o-mini".to_string(),
1370            request,
1371            strict_tools: false,
1372            tool_result_array_content: false,
1373        })
1374        .expect("request conversion should succeed");
1375        let serialized =
1376            serde_json::to_value(openai_request).expect("serialization should succeed");
1377
1378        assert_eq!(serialized["model"], "gpt-4o-mini");
1379    }
1380
1381    #[test]
1382    fn assistant_reasoning_is_silently_skipped() {
1383        let assistant_content = OneOrMany::one(message::AssistantContent::reasoning("hidden"));
1384
1385        let converted: Vec<Message> = assistant_content
1386            .try_into()
1387            .expect("conversion should work");
1388
1389        assert!(converted.is_empty());
1390    }
1391
1392    #[test]
1393    fn assistant_text_and_tool_call_are_preserved_when_reasoning_is_present() {
1394        let assistant_content = OneOrMany::many(vec![
1395            message::AssistantContent::reasoning("hidden"),
1396            message::AssistantContent::text("visible"),
1397            message::AssistantContent::tool_call(
1398                "call_1",
1399                "subtract",
1400                serde_json::json!({"x": 2, "y": 1}),
1401            ),
1402        ])
1403        .expect("non-empty assistant content");
1404
1405        let converted: Vec<Message> = assistant_content
1406            .try_into()
1407            .expect("conversion should work");
1408        assert_eq!(converted.len(), 1);
1409
1410        match &converted[0] {
1411            Message::Assistant {
1412                content,
1413                tool_calls,
1414                ..
1415            } => {
1416                assert_eq!(
1417                    content,
1418                    &vec![AssistantContent::Text {
1419                        text: "visible".to_string()
1420                    }]
1421                );
1422                assert_eq!(tool_calls.len(), 1);
1423                assert_eq!(tool_calls[0].id, "call_1");
1424                assert_eq!(tool_calls[0].function.name, "subtract");
1425                assert_eq!(
1426                    tool_calls[0].function.arguments,
1427                    serde_json::json!({"x": 2, "y": 1})
1428                );
1429            }
1430            _ => panic!("expected assistant message"),
1431        }
1432    }
1433
1434    #[test]
1435    fn request_conversion_errors_when_all_messages_are_filtered() {
1436        let request = CoreCompletionRequest {
1437            model: None,
1438            preamble: None,
1439            chat_history: OneOrMany::one(message::Message::Assistant {
1440                id: None,
1441                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1442            }),
1443            documents: vec![],
1444            tools: vec![],
1445            temperature: None,
1446            max_tokens: None,
1447            tool_choice: None,
1448            additional_params: None,
1449            output_schema: None,
1450        };
1451
1452        let result = CompletionRequest::try_from(OpenAIRequestParams {
1453            model: "gpt-4o-mini".to_string(),
1454            request,
1455            strict_tools: false,
1456            tool_result_array_content: false,
1457        });
1458
1459        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1460    }
1461}