rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{
6    CompletionsClient as Client,
7    client::{ApiErrorResponse, ApiResponse},
8    streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, Level, enabled, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28/// `gpt-5.1` completion model
29pub const GPT_5_1: &str = "gpt-5.1";
30
31/// `gpt-5` completion model
32pub const GPT_5: &str = "gpt-5";
33/// `gpt-5` completion model
34pub const GPT_5_MINI: &str = "gpt-5-mini";
35/// `gpt-5` completion model
36pub const GPT_5_NANO: &str = "gpt-5-nano";
37
38/// `gpt-4.5-preview` completion model
39pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
40/// `gpt-4.5-preview-2025-02-27` completion model
41pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
42/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
43pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
44/// `gpt-4o` completion model
45pub const GPT_4O: &str = "gpt-4o";
46/// `gpt-4o-mini` completion model
47pub const GPT_4O_MINI: &str = "gpt-4o-mini";
48/// `gpt-4o-2024-05-13` completion model
49pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
50/// `gpt-4-turbo` completion model
51pub const GPT_4_TURBO: &str = "gpt-4-turbo";
52/// `gpt-4-turbo-2024-04-09` completion model
53pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
54/// `gpt-4-turbo-preview` completion model
55pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
56/// `gpt-4-0125-preview` completion model
57pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
58/// `gpt-4-1106-preview` completion model
59pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
60/// `gpt-4-vision-preview` completion model
61pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
62/// `gpt-4-1106-vision-preview` completion model
63pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
64/// `gpt-4` completion model
65pub const GPT_4: &str = "gpt-4";
66/// `gpt-4-0613` completion model
67pub const GPT_4_0613: &str = "gpt-4-0613";
68/// `gpt-4-32k` completion model
69pub const GPT_4_32K: &str = "gpt-4-32k";
70/// `gpt-4-32k-0613` completion model
71pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
72
73/// `o4-mini-2025-04-16` completion model
74pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
75/// `o4-mini` completion model
76pub const O4_MINI: &str = "o4-mini";
77/// `o3` completion model
78pub const O3: &str = "o3";
79/// `o3-mini` completion model
80pub const O3_MINI: &str = "o3-mini";
81/// `o3-mini-2025-01-31` completion model
82pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
83/// `o1-pro` completion model
84pub const O1_PRO: &str = "o1-pro";
85/// `o1`` completion model
86pub const O1: &str = "o1";
87/// `o1-2024-12-17` completion model
88pub const O1_2024_12_17: &str = "o1-2024-12-17";
89/// `o1-preview` completion model
90pub const O1_PREVIEW: &str = "o1-preview";
91/// `o1-preview-2024-09-12` completion model
92pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
93/// `o1-mini completion model
94pub const O1_MINI: &str = "o1-mini";
95/// `o1-mini-2024-09-12` completion model
96pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
97
98/// `gpt-4.1-mini` completion model
99pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
100/// `gpt-4.1-nano` completion model
101pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
102/// `gpt-4.1-2025-04-14` completion model
103pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
104/// `gpt-4.1` completion model
105pub const GPT_4_1: &str = "gpt-4.1";
106
107impl From<ApiErrorResponse> for CompletionError {
108    fn from(err: ApiErrorResponse) -> Self {
109        CompletionError::ProviderError(err.message)
110    }
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
114#[serde(tag = "role", rename_all = "lowercase")]
115pub enum Message {
116    #[serde(alias = "developer")]
117    System {
118        #[serde(deserialize_with = "string_or_one_or_many")]
119        content: OneOrMany<SystemContent>,
120        #[serde(skip_serializing_if = "Option::is_none")]
121        name: Option<String>,
122    },
123    User {
124        #[serde(deserialize_with = "string_or_one_or_many")]
125        content: OneOrMany<UserContent>,
126        #[serde(skip_serializing_if = "Option::is_none")]
127        name: Option<String>,
128    },
129    Assistant {
130        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
131        content: Vec<AssistantContent>,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        refusal: Option<String>,
134        #[serde(skip_serializing_if = "Option::is_none")]
135        audio: Option<AudioAssistant>,
136        #[serde(skip_serializing_if = "Option::is_none")]
137        name: Option<String>,
138        #[serde(
139            default,
140            deserialize_with = "json_utils::null_or_vec",
141            skip_serializing_if = "Vec::is_empty"
142        )]
143        tool_calls: Vec<ToolCall>,
144    },
145    #[serde(rename = "tool")]
146    ToolResult {
147        tool_call_id: String,
148        content: ToolResultContentValue,
149    },
150}
151
152impl Message {
153    pub fn system(content: &str) -> Self {
154        Message::System {
155            content: OneOrMany::one(content.to_owned().into()),
156            name: None,
157        }
158    }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct AudioAssistant {
163    pub id: String,
164}
165
166#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
167pub struct SystemContent {
168    #[serde(default)]
169    pub r#type: SystemContentType,
170    pub text: String,
171}
172
173#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
174#[serde(rename_all = "lowercase")]
175pub enum SystemContentType {
176    #[default]
177    Text,
178}
179
180#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum AssistantContent {
183    Text { text: String },
184    Refusal { refusal: String },
185}
186
187impl From<AssistantContent> for completion::AssistantContent {
188    fn from(value: AssistantContent) -> Self {
189        match value {
190            AssistantContent::Text { text } => completion::AssistantContent::text(text),
191            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
192        }
193    }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum UserContent {
199    Text {
200        text: String,
201    },
202    #[serde(rename = "image_url")]
203    Image {
204        image_url: ImageUrl,
205    },
206    Audio {
207        input_audio: InputAudio,
208    },
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ImageUrl {
213    pub url: String,
214    #[serde(default)]
215    pub detail: ImageDetail,
216}
217
218#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
219pub struct InputAudio {
220    pub data: String,
221    pub format: AudioMediaType,
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225pub struct ToolResultContent {
226    #[serde(default)]
227    r#type: ToolResultContentType,
228    pub text: String,
229}
230
231#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(rename_all = "lowercase")]
233pub enum ToolResultContentType {
234    #[default]
235    Text,
236}
237
238impl FromStr for ToolResultContent {
239    type Err = Infallible;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Ok(s.to_owned().into())
243    }
244}
245
246impl From<String> for ToolResultContent {
247    fn from(s: String) -> Self {
248        ToolResultContent {
249            r#type: ToolResultContentType::default(),
250            text: s,
251        }
252    }
253}
254
255#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
256#[serde(untagged)]
257pub enum ToolResultContentValue {
258    Array(Vec<ToolResultContent>),
259    String(String),
260}
261
262impl ToolResultContentValue {
263    pub fn from_string(s: String, use_array_format: bool) -> Self {
264        if use_array_format {
265            ToolResultContentValue::Array(vec![ToolResultContent::from(s)])
266        } else {
267            ToolResultContentValue::String(s)
268        }
269    }
270
271    pub fn as_text(&self) -> String {
272        match self {
273            ToolResultContentValue::Array(arr) => arr
274                .iter()
275                .map(|c| c.text.clone())
276                .collect::<Vec<_>>()
277                .join("\n"),
278            ToolResultContentValue::String(s) => s.clone(),
279        }
280    }
281
282    pub fn to_array(&self) -> Self {
283        match self {
284            ToolResultContentValue::Array(_) => self.clone(),
285            ToolResultContentValue::String(s) => {
286                ToolResultContentValue::Array(vec![ToolResultContent::from(s.clone())])
287            }
288        }
289    }
290}
291
292#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
293pub struct ToolCall {
294    pub id: String,
295    #[serde(default)]
296    pub r#type: ToolType,
297    pub function: Function,
298}
299
300#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
301#[serde(rename_all = "lowercase")]
302pub enum ToolType {
303    #[default]
304    Function,
305}
306
307/// Function definition for a tool, with optional strict mode
308#[derive(Debug, Deserialize, Serialize, Clone)]
309pub struct FunctionDefinition {
310    pub name: String,
311    pub description: String,
312    pub parameters: serde_json::Value,
313    #[serde(skip_serializing_if = "Option::is_none")]
314    pub strict: Option<bool>,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone)]
318pub struct ToolDefinition {
319    pub r#type: String,
320    pub function: FunctionDefinition,
321}
322
323impl From<completion::ToolDefinition> for ToolDefinition {
324    fn from(tool: completion::ToolDefinition) -> Self {
325        Self {
326            r#type: "function".into(),
327            function: FunctionDefinition {
328                name: tool.name,
329                description: tool.description,
330                parameters: tool.parameters,
331                strict: None,
332            },
333        }
334    }
335}
336
337impl ToolDefinition {
338    /// Apply strict mode to this tool definition.
339    /// This sets `strict: true` and sanitizes the schema to meet OpenAI requirements.
340    pub fn with_strict(mut self) -> Self {
341        self.function.strict = Some(true);
342        super::sanitize_schema(&mut self.function.parameters);
343        self
344    }
345}
346
347#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
348#[serde(rename_all = "snake_case")]
349pub enum ToolChoice {
350    #[default]
351    Auto,
352    None,
353    Required,
354}
355
356impl TryFrom<crate::message::ToolChoice> for ToolChoice {
357    type Error = CompletionError;
358    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
359        let res = match value {
360            message::ToolChoice::Specific { .. } => {
361                return Err(CompletionError::ProviderError(
362                    "Provider doesn't support only using specific tools".to_string(),
363                ));
364            }
365            message::ToolChoice::Auto => Self::Auto,
366            message::ToolChoice::None => Self::None,
367            message::ToolChoice::Required => Self::Required,
368        };
369
370        Ok(res)
371    }
372}
373
374#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
375pub struct Function {
376    pub name: String,
377    #[serde(with = "json_utils::stringified_json")]
378    pub arguments: serde_json::Value,
379}
380
381impl TryFrom<message::ToolResult> for Message {
382    type Error = message::MessageError;
383
384    fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
385        let text = value
386            .content
387            .into_iter()
388            .map(|content| match content {
389                message::ToolResultContent::Text(message::Text { text }) => Ok(text),
390                _ => Err(message::MessageError::ConversionError(
391                    "Tool result content does not support non-text".into(),
392                )),
393            })
394            .collect::<Result<Vec<_>, _>>()?
395            .join("\n");
396
397        Ok(Message::ToolResult {
398            tool_call_id: value.id,
399            content: ToolResultContentValue::String(text),
400        })
401    }
402}
403
404impl TryFrom<message::UserContent> for UserContent {
405    type Error = message::MessageError;
406
407    fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
408        match value {
409            message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
410            message::UserContent::Image(message::Image {
411                data,
412                detail,
413                media_type,
414                ..
415            }) => match data {
416                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
417                    image_url: ImageUrl {
418                        url,
419                        detail: detail.unwrap_or_default(),
420                    },
421                }),
422                DocumentSourceKind::Base64(data) => {
423                    let url = format!(
424                        "data:{};base64,{}",
425                        media_type.map(|i| i.to_mime_type()).ok_or(
426                            message::MessageError::ConversionError(
427                                "OpenAI Image URI must have media type".into()
428                            )
429                        )?,
430                        data
431                    );
432
433                    let detail = detail.ok_or(message::MessageError::ConversionError(
434                        "OpenAI image URI must have image detail".into(),
435                    ))?;
436
437                    Ok(UserContent::Image {
438                        image_url: ImageUrl { url, detail },
439                    })
440                }
441                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
442                    "Raw files not supported, encode as base64 first".into(),
443                )),
444                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
445                    "Document has no body".into(),
446                )),
447                doc => Err(message::MessageError::ConversionError(format!(
448                    "Unsupported document type: {doc:?}"
449                ))),
450            },
451            message::UserContent::Document(message::Document { data, .. }) => {
452                if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
453                    Ok(UserContent::Text { text })
454                } else {
455                    Err(message::MessageError::ConversionError(
456                        "Documents must be base64 or a string".into(),
457                    ))
458                }
459            }
460            message::UserContent::Audio(message::Audio {
461                data, media_type, ..
462            }) => match data {
463                DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
464                    input_audio: InputAudio {
465                        data,
466                        format: match media_type {
467                            Some(media_type) => media_type,
468                            None => AudioMediaType::MP3,
469                        },
470                    },
471                }),
472                DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
473                    "URLs are not supported for audio".into(),
474                )),
475                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
476                    "Raw files are not supported for audio".into(),
477                )),
478                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
479                    "Audio has no body".into(),
480                )),
481                audio => Err(message::MessageError::ConversionError(format!(
482                    "Unsupported audio type: {audio:?}"
483                ))),
484            },
485            message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
486                "Tool result is in unsupported format".into(),
487            )),
488            message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
489                "Video is in unsupported format".into(),
490            )),
491        }
492    }
493}
494
495impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
496    type Error = message::MessageError;
497
498    fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
499        let (tool_results, other_content): (Vec<_>, Vec<_>) = value
500            .into_iter()
501            .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
502
503        // If there are messages with both tool results and user content, openai will only
504        //  handle tool results. It's unlikely that there will be both.
505        if !tool_results.is_empty() {
506            tool_results
507                .into_iter()
508                .map(|content| match content {
509                    message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
510                    _ => unreachable!(),
511                })
512                .collect::<Result<Vec<_>, _>>()
513        } else {
514            let other_content: Vec<UserContent> = other_content
515                .into_iter()
516                .map(|content| content.try_into())
517                .collect::<Result<Vec<_>, _>>()?;
518
519            let other_content = OneOrMany::many(other_content)
520                .expect("There must be other content here if there were no tool result content");
521
522            Ok(vec![Message::User {
523                content: other_content,
524                name: None,
525            }])
526        }
527    }
528}
529
530impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
531    type Error = message::MessageError;
532
533    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
534        let (text_content, tool_calls) = value.into_iter().fold(
535            (Vec::new(), Vec::new()),
536            |(mut texts, mut tools), content| {
537                match content {
538                    message::AssistantContent::Text(text) => texts.push(text),
539                    message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
540                    message::AssistantContent::Reasoning(_) => {
541                        panic!("The OpenAI Completions API doesn't support reasoning!");
542                    }
543                    message::AssistantContent::Image(_) => {
544                        panic!(
545                            "The OpenAI Completions API doesn't support image content in assistant messages!"
546                        );
547                    }
548                }
549                (texts, tools)
550            },
551        );
552
553        // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
554        //  so either `content` or `tool_calls` will have some content.
555        Ok(vec![Message::Assistant {
556            content: text_content
557                .into_iter()
558                .map(|content| content.text.into())
559                .collect::<Vec<_>>(),
560            refusal: None,
561            audio: None,
562            name: None,
563            tool_calls: tool_calls
564                .into_iter()
565                .map(|tool_call| tool_call.into())
566                .collect::<Vec<_>>(),
567        }])
568    }
569}
570
571impl TryFrom<message::Message> for Vec<Message> {
572    type Error = message::MessageError;
573
574    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
575        match message {
576            message::Message::User { content } => content.try_into(),
577            message::Message::Assistant { content, .. } => content.try_into(),
578        }
579    }
580}
581
582impl From<message::ToolCall> for ToolCall {
583    fn from(tool_call: message::ToolCall) -> Self {
584        Self {
585            id: tool_call.id,
586            r#type: ToolType::default(),
587            function: Function {
588                name: tool_call.function.name,
589                arguments: tool_call.function.arguments,
590            },
591        }
592    }
593}
594
595impl From<ToolCall> for message::ToolCall {
596    fn from(tool_call: ToolCall) -> Self {
597        Self {
598            id: tool_call.id,
599            call_id: None,
600            function: message::ToolFunction {
601                name: tool_call.function.name,
602                arguments: tool_call.function.arguments,
603            },
604            signature: None,
605            additional_params: None,
606        }
607    }
608}
609
610impl TryFrom<Message> for message::Message {
611    type Error = message::MessageError;
612
613    fn try_from(message: Message) -> Result<Self, Self::Error> {
614        Ok(match message {
615            Message::User { content, .. } => message::Message::User {
616                content: content.map(|content| content.into()),
617            },
618            Message::Assistant {
619                content,
620                tool_calls,
621                ..
622            } => {
623                let mut content = content
624                    .into_iter()
625                    .map(|content| match content {
626                        AssistantContent::Text { text } => message::AssistantContent::text(text),
627
628                        // TODO: Currently, refusals are converted into text, but should be
629                        //  investigated for generalization.
630                        AssistantContent::Refusal { refusal } => {
631                            message::AssistantContent::text(refusal)
632                        }
633                    })
634                    .collect::<Vec<_>>();
635
636                content.extend(
637                    tool_calls
638                        .into_iter()
639                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
640                        .collect::<Result<Vec<_>, _>>()?,
641                );
642
643                message::Message::Assistant {
644                    id: None,
645                    content: OneOrMany::many(content).map_err(|_| {
646                        message::MessageError::ConversionError(
647                            "Neither `content` nor `tool_calls` was provided to the Message"
648                                .to_owned(),
649                        )
650                    })?,
651                }
652            }
653
654            Message::ToolResult {
655                tool_call_id,
656                content,
657            } => message::Message::User {
658                content: OneOrMany::one(message::UserContent::tool_result(
659                    tool_call_id,
660                    OneOrMany::one(message::ToolResultContent::text(content.as_text())),
661                )),
662            },
663
664            // System messages should get stripped out when converting messages, this is just a
665            // stop gap to avoid obnoxious error handling or panic occurring.
666            Message::System { content, .. } => message::Message::User {
667                content: content.map(|content| message::UserContent::text(content.text)),
668            },
669        })
670    }
671}
672
673impl From<UserContent> for message::UserContent {
674    fn from(content: UserContent) -> Self {
675        match content {
676            UserContent::Text { text } => message::UserContent::text(text),
677            UserContent::Image { image_url } => {
678                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
679            }
680            UserContent::Audio { input_audio } => {
681                message::UserContent::audio(input_audio.data, Some(input_audio.format))
682            }
683        }
684    }
685}
686
687impl From<String> for UserContent {
688    fn from(s: String) -> Self {
689        UserContent::Text { text: s }
690    }
691}
692
693impl FromStr for UserContent {
694    type Err = Infallible;
695
696    fn from_str(s: &str) -> Result<Self, Self::Err> {
697        Ok(UserContent::Text {
698            text: s.to_string(),
699        })
700    }
701}
702
703impl From<String> for AssistantContent {
704    fn from(s: String) -> Self {
705        AssistantContent::Text { text: s }
706    }
707}
708
709impl FromStr for AssistantContent {
710    type Err = Infallible;
711
712    fn from_str(s: &str) -> Result<Self, Self::Err> {
713        Ok(AssistantContent::Text {
714            text: s.to_string(),
715        })
716    }
717}
718impl From<String> for SystemContent {
719    fn from(s: String) -> Self {
720        SystemContent {
721            r#type: SystemContentType::default(),
722            text: s,
723        }
724    }
725}
726
727impl FromStr for SystemContent {
728    type Err = Infallible;
729
730    fn from_str(s: &str) -> Result<Self, Self::Err> {
731        Ok(SystemContent {
732            r#type: SystemContentType::default(),
733            text: s.to_string(),
734        })
735    }
736}
737
738#[derive(Debug, Deserialize, Serialize)]
739pub struct CompletionResponse {
740    pub id: String,
741    pub object: String,
742    pub created: u64,
743    pub model: String,
744    pub system_fingerprint: Option<String>,
745    pub choices: Vec<Choice>,
746    pub usage: Option<Usage>,
747}
748
749impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
750    type Error = CompletionError;
751
752    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
753        let choice = response.choices.first().ok_or_else(|| {
754            CompletionError::ResponseError("Response contained no choices".to_owned())
755        })?;
756
757        let content = match &choice.message {
758            Message::Assistant {
759                content,
760                tool_calls,
761                ..
762            } => {
763                let mut content = content
764                    .iter()
765                    .filter_map(|c| {
766                        let s = match c {
767                            AssistantContent::Text { text } => text,
768                            AssistantContent::Refusal { refusal } => refusal,
769                        };
770                        if s.is_empty() {
771                            None
772                        } else {
773                            Some(completion::AssistantContent::text(s))
774                        }
775                    })
776                    .collect::<Vec<_>>();
777
778                content.extend(
779                    tool_calls
780                        .iter()
781                        .map(|call| {
782                            completion::AssistantContent::tool_call(
783                                &call.id,
784                                &call.function.name,
785                                call.function.arguments.clone(),
786                            )
787                        })
788                        .collect::<Vec<_>>(),
789                );
790                Ok(content)
791            }
792            _ => Err(CompletionError::ResponseError(
793                "Response did not contain a valid message or tool call".into(),
794            )),
795        }?;
796
797        let choice = OneOrMany::many(content).map_err(|_| {
798            CompletionError::ResponseError(
799                "Response contained no message or tool call (empty)".to_owned(),
800            )
801        })?;
802
803        let usage = response
804            .usage
805            .as_ref()
806            .map(|usage| completion::Usage {
807                input_tokens: usage.prompt_tokens as u64,
808                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
809                total_tokens: usage.total_tokens as u64,
810            })
811            .unwrap_or_default();
812
813        Ok(completion::CompletionResponse {
814            choice,
815            usage,
816            raw_response: response,
817        })
818    }
819}
820
821impl ProviderResponseExt for CompletionResponse {
822    type OutputMessage = Choice;
823    type Usage = Usage;
824
825    fn get_response_id(&self) -> Option<String> {
826        Some(self.id.to_owned())
827    }
828
829    fn get_response_model_name(&self) -> Option<String> {
830        Some(self.model.to_owned())
831    }
832
833    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
834        self.choices.clone()
835    }
836
837    fn get_text_response(&self) -> Option<String> {
838        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
839            return None;
840        };
841
842        let UserContent::Text { text } = content.first() else {
843            return None;
844        };
845
846        Some(text)
847    }
848
849    fn get_usage(&self) -> Option<Self::Usage> {
850        self.usage.clone()
851    }
852}
853
854#[derive(Clone, Debug, Serialize, Deserialize)]
855pub struct Choice {
856    pub index: usize,
857    pub message: Message,
858    pub logprobs: Option<serde_json::Value>,
859    pub finish_reason: String,
860}
861
862#[derive(Clone, Debug, Deserialize, Serialize)]
863pub struct Usage {
864    pub prompt_tokens: usize,
865    pub total_tokens: usize,
866}
867
868impl Usage {
869    pub fn new() -> Self {
870        Self {
871            prompt_tokens: 0,
872            total_tokens: 0,
873        }
874    }
875}
876
877impl Default for Usage {
878    fn default() -> Self {
879        Self::new()
880    }
881}
882
883impl fmt::Display for Usage {
884    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
885        let Usage {
886            prompt_tokens,
887            total_tokens,
888        } = self;
889        write!(
890            f,
891            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
892        )
893    }
894}
895
896impl GetTokenUsage for Usage {
897    fn token_usage(&self) -> Option<crate::completion::Usage> {
898        let mut usage = crate::completion::Usage::new();
899        usage.input_tokens = self.prompt_tokens as u64;
900        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
901        usage.total_tokens = self.total_tokens as u64;
902
903        Some(usage)
904    }
905}
906
907#[derive(Clone)]
908pub struct CompletionModel<T = reqwest::Client> {
909    pub(crate) client: Client<T>,
910    pub model: String,
911    pub strict_tools: bool,
912    pub tool_result_array_content: bool,
913}
914
915impl<T> CompletionModel<T>
916where
917    T: Default + std::fmt::Debug + Clone + 'static,
918{
919    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
920        Self {
921            client,
922            model: model.into(),
923            strict_tools: false,
924            tool_result_array_content: false,
925        }
926    }
927
928    pub fn with_model(client: Client<T>, model: &str) -> Self {
929        Self {
930            client,
931            model: model.into(),
932            strict_tools: false,
933            tool_result_array_content: false,
934        }
935    }
936
937    /// Enable strict mode for tool schemas.
938    ///
939    /// When enabled, tool schemas are automatically sanitized to meet OpenAI's strict mode requirements:
940    /// - `additionalProperties: false` is added to all objects
941    /// - All properties are marked as required
942    /// - `strict: true` is set on each function definition
943    ///
944    /// This allows OpenAI to guarantee that the model's tool calls will match the schema exactly.
945    pub fn with_strict_tools(mut self) -> Self {
946        self.strict_tools = true;
947        self
948    }
949
950    pub fn with_tool_result_array_content(mut self) -> Self {
951        self.tool_result_array_content = true;
952        self
953    }
954}
955
956#[derive(Debug, Serialize, Deserialize, Clone)]
957pub struct CompletionRequest {
958    model: String,
959    messages: Vec<Message>,
960    #[serde(skip_serializing_if = "Vec::is_empty")]
961    tools: Vec<ToolDefinition>,
962    #[serde(skip_serializing_if = "Option::is_none")]
963    tool_choice: Option<ToolChoice>,
964    #[serde(skip_serializing_if = "Option::is_none")]
965    temperature: Option<f64>,
966    #[serde(flatten)]
967    additional_params: Option<serde_json::Value>,
968}
969
970pub struct OpenAIRequestParams {
971    pub model: String,
972    pub request: CoreCompletionRequest,
973    pub strict_tools: bool,
974    pub tool_result_array_content: bool,
975}
976
977impl TryFrom<OpenAIRequestParams> for CompletionRequest {
978    type Error = CompletionError;
979
980    fn try_from(params: OpenAIRequestParams) -> Result<Self, Self::Error> {
981        let OpenAIRequestParams {
982            model,
983            request: req,
984            strict_tools,
985            tool_result_array_content,
986        } = params;
987
988        let mut partial_history = vec![];
989        if let Some(docs) = req.normalized_documents() {
990            partial_history.push(docs);
991        }
992        let CoreCompletionRequest {
993            preamble,
994            chat_history,
995            tools,
996            temperature,
997            additional_params,
998            tool_choice,
999            ..
1000        } = req;
1001
1002        partial_history.extend(chat_history);
1003
1004        let mut full_history: Vec<Message> =
1005            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
1006
1007        full_history.extend(
1008            partial_history
1009                .into_iter()
1010                .map(message::Message::try_into)
1011                .collect::<Result<Vec<Vec<Message>>, _>>()?
1012                .into_iter()
1013                .flatten()
1014                .collect::<Vec<_>>(),
1015        );
1016
1017        if tool_result_array_content {
1018            for msg in &mut full_history {
1019                if let Message::ToolResult { content, .. } = msg {
1020                    *content = content.to_array();
1021                }
1022            }
1023        }
1024
1025        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
1026
1027        let tools: Vec<ToolDefinition> = tools
1028            .into_iter()
1029            .map(|tool| {
1030                let def = ToolDefinition::from(tool);
1031                if strict_tools { def.with_strict() } else { def }
1032            })
1033            .collect();
1034
1035        let res = Self {
1036            model,
1037            messages: full_history,
1038            tools,
1039            tool_choice,
1040            temperature,
1041            additional_params,
1042        };
1043
1044        Ok(res)
1045    }
1046}
1047
1048impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
1049    type Error = CompletionError;
1050
1051    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
1052        CompletionRequest::try_from(OpenAIRequestParams {
1053            model,
1054            request: req,
1055            strict_tools: false,
1056            tool_result_array_content: false,
1057        })
1058    }
1059}
1060
1061impl crate::telemetry::ProviderRequestExt for CompletionRequest {
1062    type InputMessage = Message;
1063
1064    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
1065        self.messages.clone()
1066    }
1067
1068    fn get_system_prompt(&self) -> Option<String> {
1069        let first_message = self.messages.first()?;
1070
1071        let Message::System { ref content, .. } = first_message.clone() else {
1072            return None;
1073        };
1074
1075        let SystemContent { text, .. } = content.first();
1076
1077        Some(text)
1078    }
1079
1080    fn get_prompt(&self) -> Option<String> {
1081        let last_message = self.messages.last()?;
1082
1083        let Message::User { ref content, .. } = last_message.clone() else {
1084            return None;
1085        };
1086
1087        let UserContent::Text { text } = content.first() else {
1088            return None;
1089        };
1090
1091        Some(text)
1092    }
1093
1094    fn get_model_name(&self) -> String {
1095        self.model.clone()
1096    }
1097}
1098
1099impl CompletionModel<reqwest::Client> {
1100    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
1101        crate::agent::AgentBuilder::new(self)
1102    }
1103}
1104
1105impl<T> completion::CompletionModel for CompletionModel<T>
1106where
1107    T: HttpClientExt
1108        + Default
1109        + std::fmt::Debug
1110        + Clone
1111        + WasmCompatSend
1112        + WasmCompatSync
1113        + 'static,
1114{
1115    type Response = CompletionResponse;
1116    type StreamingResponse = StreamingCompletionResponse;
1117
1118    type Client = super::CompletionsClient<T>;
1119
1120    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1121        Self::new(client.clone(), model)
1122    }
1123
1124    async fn completion(
1125        &self,
1126        completion_request: CoreCompletionRequest,
1127    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1128        let span = if tracing::Span::current().is_disabled() {
1129            info_span!(
1130                target: "rig::completions",
1131                "chat",
1132                gen_ai.operation.name = "chat",
1133                gen_ai.provider.name = "openai",
1134                gen_ai.request.model = self.model,
1135                gen_ai.system_instructions = &completion_request.preamble,
1136                gen_ai.response.id = tracing::field::Empty,
1137                gen_ai.response.model = tracing::field::Empty,
1138                gen_ai.usage.output_tokens = tracing::field::Empty,
1139                gen_ai.usage.input_tokens = tracing::field::Empty,
1140            )
1141        } else {
1142            tracing::Span::current()
1143        };
1144
1145        let request = CompletionRequest::try_from(OpenAIRequestParams {
1146            model: self.model.to_owned(),
1147            request: completion_request,
1148            strict_tools: self.strict_tools,
1149            tool_result_array_content: self.tool_result_array_content,
1150        })?;
1151
1152        if enabled!(Level::TRACE) {
1153            tracing::trace!(
1154                target: "rig::completions",
1155                "OpenAI Chat Completions completion request: {}",
1156                serde_json::to_string_pretty(&request)?
1157            );
1158        }
1159
1160        let body = serde_json::to_vec(&request)?;
1161
1162        let req = self
1163            .client
1164            .post("/chat/completions")?
1165            .body(body)
1166            .map_err(|e| CompletionError::HttpError(e.into()))?;
1167
1168        async move {
1169            let response = self.client.send(req).await?;
1170
1171            if response.status().is_success() {
1172                let text = http_client::text(response).await?;
1173
1174                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1175                    ApiResponse::Ok(response) => {
1176                        let span = tracing::Span::current();
1177                        span.record_response_metadata(&response);
1178                        span.record_token_usage(&response.usage);
1179
1180                        if enabled!(Level::TRACE) {
1181                            tracing::trace!(
1182                                target: "rig::completions",
1183                                "OpenAI Chat Completions completion response: {}",
1184                                serde_json::to_string_pretty(&response)?
1185                            );
1186                        }
1187
1188                        response.try_into()
1189                    }
1190                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1191                }
1192            } else {
1193                let text = http_client::text(response).await?;
1194                Err(CompletionError::ProviderError(text))
1195            }
1196        }
1197        .instrument(span)
1198        .await
1199    }
1200
1201    async fn stream(
1202        &self,
1203        request: CoreCompletionRequest,
1204    ) -> Result<
1205        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1206        CompletionError,
1207    > {
1208        Self::stream(self, request).await
1209    }
1210}