rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{
6    CompletionsClient as Client,
7    client::{ApiErrorResponse, ApiResponse},
8    streaming::StreamingCompletionResponse,
9};
10use crate::completion::{
11    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
12};
13use crate::http_client::{self, HttpClientExt};
14use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
15use crate::one_or_many::string_or_one_or_many;
16use crate::telemetry::{ProviderResponseExt, SpanCombinator};
17use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
18use crate::{OneOrMany, completion, json_utils, message};
19use serde::{Deserialize, Serialize};
20use std::convert::Infallible;
21use std::fmt;
22use tracing::{Instrument, info_span};
23
24use std::str::FromStr;
25
26pub mod streaming;
27
28/// `gpt-5.1` completion model
29pub const GPT_5_1: &str = "gpt-5.1";
30
31/// `gpt-5` completion model
32pub const GPT_5: &str = "gpt-5";
33/// `gpt-5` completion model
34pub const GPT_5_MINI: &str = "gpt-5-mini";
35/// `gpt-5` completion model
36pub const GPT_5_NANO: &str = "gpt-5-nano";
37
38/// `gpt-4.5-preview` completion model
39pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
40/// `gpt-4.5-preview-2025-02-27` completion model
41pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
42/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
43pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
44/// `gpt-4o` completion model
45pub const GPT_4O: &str = "gpt-4o";
46/// `gpt-4o-mini` completion model
47pub const GPT_4O_MINI: &str = "gpt-4o-mini";
48/// `gpt-4o-2024-05-13` completion model
49pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
50/// `gpt-4-turbo` completion model
51pub const GPT_4_TURBO: &str = "gpt-4-turbo";
52/// `gpt-4-turbo-2024-04-09` completion model
53pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
54/// `gpt-4-turbo-preview` completion model
55pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
56/// `gpt-4-0125-preview` completion model
57pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
58/// `gpt-4-1106-preview` completion model
59pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
60/// `gpt-4-vision-preview` completion model
61pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
62/// `gpt-4-1106-vision-preview` completion model
63pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
64/// `gpt-4` completion model
65pub const GPT_4: &str = "gpt-4";
66/// `gpt-4-0613` completion model
67pub const GPT_4_0613: &str = "gpt-4-0613";
68/// `gpt-4-32k` completion model
69pub const GPT_4_32K: &str = "gpt-4-32k";
70/// `gpt-4-32k-0613` completion model
71pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
72
73/// `o4-mini-2025-04-16` completion model
74pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
75/// `o4-mini` completion model
76pub const O4_MINI: &str = "o4-mini";
77/// `o3` completion model
78pub const O3: &str = "o3";
79/// `o3-mini` completion model
80pub const O3_MINI: &str = "o3-mini";
81/// `o3-mini-2025-01-31` completion model
82pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
83/// `o1-pro` completion model
84pub const O1_PRO: &str = "o1-pro";
85/// `o1`` completion model
86pub const O1: &str = "o1";
87/// `o1-2024-12-17` completion model
88pub const O1_2024_12_17: &str = "o1-2024-12-17";
89/// `o1-preview` completion model
90pub const O1_PREVIEW: &str = "o1-preview";
91/// `o1-preview-2024-09-12` completion model
92pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
93/// `o1-mini completion model
94pub const O1_MINI: &str = "o1-mini";
95/// `o1-mini-2024-09-12` completion model
96pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
97
98/// `gpt-4.1-mini` completion model
99pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
100/// `gpt-4.1-nano` completion model
101pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
102/// `gpt-4.1-2025-04-14` completion model
103pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
104/// `gpt-4.1` completion model
105pub const GPT_4_1: &str = "gpt-4.1";
106
107impl From<ApiErrorResponse> for CompletionError {
108    fn from(err: ApiErrorResponse) -> Self {
109        CompletionError::ProviderError(err.message)
110    }
111}
112
113#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
114#[serde(tag = "role", rename_all = "lowercase")]
115pub enum Message {
116    #[serde(alias = "developer")]
117    System {
118        #[serde(deserialize_with = "string_or_one_or_many")]
119        content: OneOrMany<SystemContent>,
120        #[serde(skip_serializing_if = "Option::is_none")]
121        name: Option<String>,
122    },
123    User {
124        #[serde(deserialize_with = "string_or_one_or_many")]
125        content: OneOrMany<UserContent>,
126        #[serde(skip_serializing_if = "Option::is_none")]
127        name: Option<String>,
128    },
129    Assistant {
130        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
131        content: Vec<AssistantContent>,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        refusal: Option<String>,
134        #[serde(skip_serializing_if = "Option::is_none")]
135        audio: Option<AudioAssistant>,
136        #[serde(skip_serializing_if = "Option::is_none")]
137        name: Option<String>,
138        #[serde(
139            default,
140            deserialize_with = "json_utils::null_or_vec",
141            skip_serializing_if = "Vec::is_empty"
142        )]
143        tool_calls: Vec<ToolCall>,
144    },
145    #[serde(rename = "tool")]
146    ToolResult {
147        tool_call_id: String,
148        content: OneOrMany<ToolResultContent>,
149    },
150}
151
152impl Message {
153    pub fn system(content: &str) -> Self {
154        Message::System {
155            content: OneOrMany::one(content.to_owned().into()),
156            name: None,
157        }
158    }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct AudioAssistant {
163    pub id: String,
164}
165
166#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
167pub struct SystemContent {
168    #[serde(default)]
169    pub r#type: SystemContentType,
170    pub text: String,
171}
172
173#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
174#[serde(rename_all = "lowercase")]
175pub enum SystemContentType {
176    #[default]
177    Text,
178}
179
180#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum AssistantContent {
183    Text { text: String },
184    Refusal { refusal: String },
185}
186
187impl From<AssistantContent> for completion::AssistantContent {
188    fn from(value: AssistantContent) -> Self {
189        match value {
190            AssistantContent::Text { text } => completion::AssistantContent::text(text),
191            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
192        }
193    }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum UserContent {
199    Text {
200        text: String,
201    },
202    #[serde(rename = "image_url")]
203    Image {
204        image_url: ImageUrl,
205    },
206    Audio {
207        input_audio: InputAudio,
208    },
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ImageUrl {
213    pub url: String,
214    #[serde(default)]
215    pub detail: ImageDetail,
216}
217
218#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
219pub struct InputAudio {
220    pub data: String,
221    pub format: AudioMediaType,
222}
223
224#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
225pub struct ToolResultContent {
226    #[serde(default)]
227    r#type: ToolResultContentType,
228    pub text: String,
229}
230
231#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(rename_all = "lowercase")]
233pub enum ToolResultContentType {
234    #[default]
235    Text,
236}
237
238impl FromStr for ToolResultContent {
239    type Err = Infallible;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Ok(s.to_owned().into())
243    }
244}
245
246impl From<String> for ToolResultContent {
247    fn from(s: String) -> Self {
248        ToolResultContent {
249            r#type: ToolResultContentType::default(),
250            text: s,
251        }
252    }
253}
254
255#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
256pub struct ToolCall {
257    pub id: String,
258    #[serde(default)]
259    pub r#type: ToolType,
260    pub function: Function,
261}
262
263#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
264#[serde(rename_all = "lowercase")]
265pub enum ToolType {
266    #[default]
267    Function,
268}
269
270#[derive(Debug, Deserialize, Serialize, Clone)]
271pub struct ToolDefinition {
272    pub r#type: String,
273    pub function: completion::ToolDefinition,
274}
275
276impl From<completion::ToolDefinition> for ToolDefinition {
277    fn from(tool: completion::ToolDefinition) -> Self {
278        Self {
279            r#type: "function".into(),
280            function: tool,
281        }
282    }
283}
284
285#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
286#[serde(rename_all = "snake_case")]
287pub enum ToolChoice {
288    #[default]
289    Auto,
290    None,
291    Required,
292}
293
294impl TryFrom<crate::message::ToolChoice> for ToolChoice {
295    type Error = CompletionError;
296    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
297        let res = match value {
298            message::ToolChoice::Specific { .. } => {
299                return Err(CompletionError::ProviderError(
300                    "Provider doesn't support only using specific tools".to_string(),
301                ));
302            }
303            message::ToolChoice::Auto => Self::Auto,
304            message::ToolChoice::None => Self::None,
305            message::ToolChoice::Required => Self::Required,
306        };
307
308        Ok(res)
309    }
310}
311
312#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
313pub struct Function {
314    pub name: String,
315    #[serde(with = "json_utils::stringified_json")]
316    pub arguments: serde_json::Value,
317}
318
319impl TryFrom<message::ToolResult> for Message {
320    type Error = message::MessageError;
321
322    fn try_from(value: message::ToolResult) -> Result<Self, Self::Error> {
323        Ok(Message::ToolResult {
324            tool_call_id: value.id,
325            content: value.content.try_map(|content| match content {
326                message::ToolResultContent::Text(message::Text { text }) => Ok(text.into()),
327                _ => Err(message::MessageError::ConversionError(
328                    "Tool result content does not support non-text".into(),
329                )),
330            })?,
331        })
332    }
333}
334
335impl TryFrom<message::UserContent> for UserContent {
336    type Error = message::MessageError;
337
338    fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
339        match value {
340            message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
341            message::UserContent::Image(message::Image {
342                data,
343                detail,
344                media_type,
345                ..
346            }) => match data {
347                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
348                    image_url: ImageUrl {
349                        url,
350                        detail: detail.unwrap_or_default(),
351                    },
352                }),
353                DocumentSourceKind::Base64(data) => {
354                    let url = format!(
355                        "data:{};base64,{}",
356                        media_type.map(|i| i.to_mime_type()).ok_or(
357                            message::MessageError::ConversionError(
358                                "OpenAI Image URI must have media type".into()
359                            )
360                        )?,
361                        data
362                    );
363
364                    let detail = detail.ok_or(message::MessageError::ConversionError(
365                        "OpenAI image URI must have image detail".into(),
366                    ))?;
367
368                    Ok(UserContent::Image {
369                        image_url: ImageUrl { url, detail },
370                    })
371                }
372                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
373                    "Raw files not supported, encode as base64 first".into(),
374                )),
375                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
376                    "Document has no body".into(),
377                )),
378                doc => Err(message::MessageError::ConversionError(format!(
379                    "Unsupported document type: {doc:?}"
380                ))),
381            },
382            message::UserContent::Document(message::Document { data, .. }) => {
383                if let DocumentSourceKind::Base64(text) | DocumentSourceKind::String(text) = data {
384                    Ok(UserContent::Text { text })
385                } else {
386                    Err(message::MessageError::ConversionError(
387                        "Documents must be base64 or a string".into(),
388                    ))
389                }
390            }
391            message::UserContent::Audio(message::Audio {
392                data, media_type, ..
393            }) => match data {
394                DocumentSourceKind::Base64(data) => Ok(UserContent::Audio {
395                    input_audio: InputAudio {
396                        data,
397                        format: match media_type {
398                            Some(media_type) => media_type,
399                            None => AudioMediaType::MP3,
400                        },
401                    },
402                }),
403                DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
404                    "URLs are not supported for audio".into(),
405                )),
406                DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
407                    "Raw files are not supported for audio".into(),
408                )),
409                DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
410                    "Audio has no body".into(),
411                )),
412                audio => Err(message::MessageError::ConversionError(format!(
413                    "Unsupported audio type: {audio:?}"
414                ))),
415            },
416            message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
417                "Tool result is in unsupported format".into(),
418            )),
419            message::UserContent::Video(_) => Err(message::MessageError::ConversionError(
420                "Video is in unsupported format".into(),
421            )),
422        }
423    }
424}
425
426impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
427    type Error = message::MessageError;
428
429    fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
430        let (tool_results, other_content): (Vec<_>, Vec<_>) = value
431            .into_iter()
432            .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
433
434        // If there are messages with both tool results and user content, openai will only
435        //  handle tool results. It's unlikely that there will be both.
436        if !tool_results.is_empty() {
437            tool_results
438                .into_iter()
439                .map(|content| match content {
440                    message::UserContent::ToolResult(tool_result) => tool_result.try_into(),
441                    _ => unreachable!(),
442                })
443                .collect::<Result<Vec<_>, _>>()
444        } else {
445            let other_content: Vec<UserContent> = other_content
446                .into_iter()
447                .map(|content| content.try_into())
448                .collect::<Result<Vec<_>, _>>()?;
449
450            let other_content = OneOrMany::many(other_content)
451                .expect("There must be other content here if there were no tool result content");
452
453            Ok(vec![Message::User {
454                content: other_content,
455                name: None,
456            }])
457        }
458    }
459}
460
461impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
462    type Error = message::MessageError;
463
464    fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
465        let (text_content, tool_calls) = value.into_iter().fold(
466            (Vec::new(), Vec::new()),
467            |(mut texts, mut tools), content| {
468                match content {
469                    message::AssistantContent::Text(text) => texts.push(text),
470                    message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
471                    message::AssistantContent::Reasoning(_) => {
472                        unimplemented!("The OpenAI Completions API doesn't support reasoning!");
473                    }
474                    message::AssistantContent::Image(_) => {
475                        unimplemented!(
476                            "The OpenAI Completions API doesn't support image content in assistant messages!"
477                        );
478                    }
479                }
480                (texts, tools)
481            },
482        );
483
484        // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
485        //  so either `content` or `tool_calls` will have some content.
486        Ok(vec![Message::Assistant {
487            content: text_content
488                .into_iter()
489                .map(|content| content.text.into())
490                .collect::<Vec<_>>(),
491            refusal: None,
492            audio: None,
493            name: None,
494            tool_calls: tool_calls
495                .into_iter()
496                .map(|tool_call| tool_call.into())
497                .collect::<Vec<_>>(),
498        }])
499    }
500}
501
502impl TryFrom<message::Message> for Vec<Message> {
503    type Error = message::MessageError;
504
505    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
506        match message {
507            message::Message::User { content } => content.try_into(),
508            message::Message::Assistant { content, .. } => content.try_into(),
509        }
510    }
511}
512
513impl From<message::ToolCall> for ToolCall {
514    fn from(tool_call: message::ToolCall) -> Self {
515        Self {
516            id: tool_call.id,
517            r#type: ToolType::default(),
518            function: Function {
519                name: tool_call.function.name,
520                arguments: tool_call.function.arguments,
521            },
522        }
523    }
524}
525
526impl From<ToolCall> for message::ToolCall {
527    fn from(tool_call: ToolCall) -> Self {
528        Self {
529            id: tool_call.id,
530            call_id: None,
531            function: message::ToolFunction {
532                name: tool_call.function.name,
533                arguments: tool_call.function.arguments,
534            },
535        }
536    }
537}
538
539impl TryFrom<Message> for message::Message {
540    type Error = message::MessageError;
541
542    fn try_from(message: Message) -> Result<Self, Self::Error> {
543        Ok(match message {
544            Message::User { content, .. } => message::Message::User {
545                content: content.map(|content| content.into()),
546            },
547            Message::Assistant {
548                content,
549                tool_calls,
550                ..
551            } => {
552                let mut content = content
553                    .into_iter()
554                    .map(|content| match content {
555                        AssistantContent::Text { text } => message::AssistantContent::text(text),
556
557                        // TODO: Currently, refusals are converted into text, but should be
558                        //  investigated for generalization.
559                        AssistantContent::Refusal { refusal } => {
560                            message::AssistantContent::text(refusal)
561                        }
562                    })
563                    .collect::<Vec<_>>();
564
565                content.extend(
566                    tool_calls
567                        .into_iter()
568                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
569                        .collect::<Result<Vec<_>, _>>()?,
570                );
571
572                message::Message::Assistant {
573                    id: None,
574                    content: OneOrMany::many(content).map_err(|_| {
575                        message::MessageError::ConversionError(
576                            "Neither `content` nor `tool_calls` was provided to the Message"
577                                .to_owned(),
578                        )
579                    })?,
580                }
581            }
582
583            Message::ToolResult {
584                tool_call_id,
585                content,
586            } => message::Message::User {
587                content: OneOrMany::one(message::UserContent::tool_result(
588                    tool_call_id,
589                    content.map(|content| message::ToolResultContent::text(content.text)),
590                )),
591            },
592
593            // System messages should get stripped out when converting messages, this is just a
594            // stop gap to avoid obnoxious error handling or panic occurring.
595            Message::System { content, .. } => message::Message::User {
596                content: content.map(|content| message::UserContent::text(content.text)),
597            },
598        })
599    }
600}
601
602impl From<UserContent> for message::UserContent {
603    fn from(content: UserContent) -> Self {
604        match content {
605            UserContent::Text { text } => message::UserContent::text(text),
606            UserContent::Image { image_url } => {
607                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
608            }
609            UserContent::Audio { input_audio } => {
610                message::UserContent::audio(input_audio.data, Some(input_audio.format))
611            }
612        }
613    }
614}
615
616impl From<String> for UserContent {
617    fn from(s: String) -> Self {
618        UserContent::Text { text: s }
619    }
620}
621
622impl FromStr for UserContent {
623    type Err = Infallible;
624
625    fn from_str(s: &str) -> Result<Self, Self::Err> {
626        Ok(UserContent::Text {
627            text: s.to_string(),
628        })
629    }
630}
631
632impl From<String> for AssistantContent {
633    fn from(s: String) -> Self {
634        AssistantContent::Text { text: s }
635    }
636}
637
638impl FromStr for AssistantContent {
639    type Err = Infallible;
640
641    fn from_str(s: &str) -> Result<Self, Self::Err> {
642        Ok(AssistantContent::Text {
643            text: s.to_string(),
644        })
645    }
646}
647impl From<String> for SystemContent {
648    fn from(s: String) -> Self {
649        SystemContent {
650            r#type: SystemContentType::default(),
651            text: s,
652        }
653    }
654}
655
656impl FromStr for SystemContent {
657    type Err = Infallible;
658
659    fn from_str(s: &str) -> Result<Self, Self::Err> {
660        Ok(SystemContent {
661            r#type: SystemContentType::default(),
662            text: s.to_string(),
663        })
664    }
665}
666
667#[derive(Debug, Deserialize, Serialize)]
668pub struct CompletionResponse {
669    pub id: String,
670    pub object: String,
671    pub created: u64,
672    pub model: String,
673    pub system_fingerprint: Option<String>,
674    pub choices: Vec<Choice>,
675    pub usage: Option<Usage>,
676}
677
678impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
679    type Error = CompletionError;
680
681    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
682        let choice = response.choices.first().ok_or_else(|| {
683            CompletionError::ResponseError("Response contained no choices".to_owned())
684        })?;
685
686        let content = match &choice.message {
687            Message::Assistant {
688                content,
689                tool_calls,
690                ..
691            } => {
692                let mut content = content
693                    .iter()
694                    .filter_map(|c| {
695                        let s = match c {
696                            AssistantContent::Text { text } => text,
697                            AssistantContent::Refusal { refusal } => refusal,
698                        };
699                        if s.is_empty() {
700                            None
701                        } else {
702                            Some(completion::AssistantContent::text(s))
703                        }
704                    })
705                    .collect::<Vec<_>>();
706
707                content.extend(
708                    tool_calls
709                        .iter()
710                        .map(|call| {
711                            completion::AssistantContent::tool_call(
712                                &call.id,
713                                &call.function.name,
714                                call.function.arguments.clone(),
715                            )
716                        })
717                        .collect::<Vec<_>>(),
718                );
719                Ok(content)
720            }
721            _ => Err(CompletionError::ResponseError(
722                "Response did not contain a valid message or tool call".into(),
723            )),
724        }?;
725
726        let choice = OneOrMany::many(content).map_err(|_| {
727            CompletionError::ResponseError(
728                "Response contained no message or tool call (empty)".to_owned(),
729            )
730        })?;
731
732        let usage = response
733            .usage
734            .as_ref()
735            .map(|usage| completion::Usage {
736                input_tokens: usage.prompt_tokens as u64,
737                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
738                total_tokens: usage.total_tokens as u64,
739            })
740            .unwrap_or_default();
741
742        Ok(completion::CompletionResponse {
743            choice,
744            usage,
745            raw_response: response,
746        })
747    }
748}
749
750impl ProviderResponseExt for CompletionResponse {
751    type OutputMessage = Choice;
752    type Usage = Usage;
753
754    fn get_response_id(&self) -> Option<String> {
755        Some(self.id.to_owned())
756    }
757
758    fn get_response_model_name(&self) -> Option<String> {
759        Some(self.model.to_owned())
760    }
761
762    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
763        self.choices.clone()
764    }
765
766    fn get_text_response(&self) -> Option<String> {
767        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
768            return None;
769        };
770
771        let UserContent::Text { text } = content.first() else {
772            return None;
773        };
774
775        Some(text)
776    }
777
778    fn get_usage(&self) -> Option<Self::Usage> {
779        self.usage.clone()
780    }
781}
782
783#[derive(Clone, Debug, Serialize, Deserialize)]
784pub struct Choice {
785    pub index: usize,
786    pub message: Message,
787    pub logprobs: Option<serde_json::Value>,
788    pub finish_reason: String,
789}
790
791#[derive(Clone, Debug, Deserialize, Serialize)]
792pub struct Usage {
793    pub prompt_tokens: usize,
794    pub total_tokens: usize,
795}
796
797impl Usage {
798    pub fn new() -> Self {
799        Self {
800            prompt_tokens: 0,
801            total_tokens: 0,
802        }
803    }
804}
805
806impl Default for Usage {
807    fn default() -> Self {
808        Self::new()
809    }
810}
811
812impl fmt::Display for Usage {
813    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
814        let Usage {
815            prompt_tokens,
816            total_tokens,
817        } = self;
818        write!(
819            f,
820            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
821        )
822    }
823}
824
825impl GetTokenUsage for Usage {
826    fn token_usage(&self) -> Option<crate::completion::Usage> {
827        let mut usage = crate::completion::Usage::new();
828        usage.input_tokens = self.prompt_tokens as u64;
829        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
830        usage.total_tokens = self.total_tokens as u64;
831
832        Some(usage)
833    }
834}
835
836#[derive(Clone)]
837pub struct CompletionModel<T = reqwest::Client> {
838    pub(crate) client: Client<T>,
839    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
840    pub model: String,
841}
842
843impl<T> CompletionModel<T>
844where
845    T: Default + std::fmt::Debug + Clone + 'static,
846{
847    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
848        Self {
849            client,
850            model: model.into(),
851        }
852    }
853
854    pub fn with_model(client: Client<T>, model: &str) -> Self {
855        Self {
856            client,
857            model: model.into(),
858        }
859    }
860}
861
862#[derive(Debug, Serialize, Deserialize, Clone)]
863pub struct CompletionRequest {
864    model: String,
865    messages: Vec<Message>,
866    #[serde(skip_serializing_if = "Vec::is_empty")]
867    tools: Vec<ToolDefinition>,
868    #[serde(skip_serializing_if = "Option::is_none")]
869    tool_choice: Option<ToolChoice>,
870    #[serde(skip_serializing_if = "Option::is_none")]
871    temperature: Option<f64>,
872    #[serde(flatten)]
873    additional_params: Option<serde_json::Value>,
874}
875
876impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
877    type Error = CompletionError;
878
879    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
880        let mut partial_history = vec![];
881        if let Some(docs) = req.normalized_documents() {
882            partial_history.push(docs);
883        }
884        let CoreCompletionRequest {
885            preamble,
886            chat_history,
887            tools,
888            temperature,
889            additional_params,
890            tool_choice,
891            ..
892        } = req;
893
894        partial_history.extend(chat_history);
895
896        let mut full_history: Vec<Message> =
897            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
898
899        // Convert and extend the rest of the history
900        full_history.extend(
901            partial_history
902                .into_iter()
903                .map(message::Message::try_into)
904                .collect::<Result<Vec<Vec<Message>>, _>>()?
905                .into_iter()
906                .flatten()
907                .collect::<Vec<_>>(),
908        );
909
910        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
911
912        let res = Self {
913            model,
914            messages: full_history,
915            tools: tools
916                .into_iter()
917                .map(ToolDefinition::from)
918                .collect::<Vec<_>>(),
919            tool_choice,
920            temperature,
921            additional_params,
922        };
923
924        Ok(res)
925    }
926}
927
928impl crate::telemetry::ProviderRequestExt for CompletionRequest {
929    type InputMessage = Message;
930
931    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
932        self.messages.clone()
933    }
934
935    fn get_system_prompt(&self) -> Option<String> {
936        let first_message = self.messages.first()?;
937
938        let Message::System { ref content, .. } = first_message.clone() else {
939            return None;
940        };
941
942        let SystemContent { text, .. } = content.first();
943
944        Some(text)
945    }
946
947    fn get_prompt(&self) -> Option<String> {
948        let last_message = self.messages.last()?;
949
950        let Message::User { ref content, .. } = last_message.clone() else {
951            return None;
952        };
953
954        let UserContent::Text { text } = content.first() else {
955            return None;
956        };
957
958        Some(text)
959    }
960
961    fn get_model_name(&self) -> String {
962        self.model.clone()
963    }
964}
965
966impl CompletionModel<reqwest::Client> {
967    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
968        crate::agent::AgentBuilder::new(self)
969    }
970}
971
972impl<T> completion::CompletionModel for CompletionModel<T>
973where
974    T: HttpClientExt
975        + Default
976        + std::fmt::Debug
977        + Clone
978        + WasmCompatSend
979        + WasmCompatSync
980        + 'static,
981{
982    type Response = CompletionResponse;
983    type StreamingResponse = StreamingCompletionResponse;
984
985    type Client = super::CompletionsClient<T>;
986
987    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
988        Self::new(client.clone(), model)
989    }
990
991    async fn completion(
992        &self,
993        completion_request: CoreCompletionRequest,
994    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
995        let span = if tracing::Span::current().is_disabled() {
996            info_span!(
997                target: "rig::completions",
998                "chat",
999                gen_ai.operation.name = "chat",
1000                gen_ai.provider.name = "openai",
1001                gen_ai.request.model = self.model,
1002                gen_ai.system_instructions = &completion_request.preamble,
1003                gen_ai.response.id = tracing::field::Empty,
1004                gen_ai.response.model = tracing::field::Empty,
1005                gen_ai.usage.output_tokens = tracing::field::Empty,
1006                gen_ai.usage.input_tokens = tracing::field::Empty,
1007                gen_ai.input.messages = tracing::field::Empty,
1008                gen_ai.output.messages = tracing::field::Empty,
1009            )
1010        } else {
1011            tracing::Span::current()
1012        };
1013
1014        let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?;
1015
1016        span.record_model_input(&request.messages);
1017
1018        let body = serde_json::to_vec(&request)?;
1019
1020        let req = self
1021            .client
1022            .post("/chat/completions")?
1023            .body(body)
1024            .map_err(|e| CompletionError::HttpError(e.into()))?;
1025
1026        async move {
1027            let response = self.client.send(req).await?;
1028
1029            if response.status().is_success() {
1030                let text = http_client::text(response).await?;
1031
1032                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
1033                    ApiResponse::Ok(response) => {
1034                        let span = tracing::Span::current();
1035                        span.record_model_output(&response.choices);
1036                        span.record_response_metadata(&response);
1037                        span.record_token_usage(&response.usage);
1038                        tracing::debug!("OpenAI response: {response:?}");
1039                        response.try_into()
1040                    }
1041                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1042                }
1043            } else {
1044                let text = http_client::text(response).await?;
1045                Err(CompletionError::ProviderError(text))
1046            }
1047        }
1048        .instrument(span)
1049        .await
1050    }
1051
1052    async fn stream(
1053        &self,
1054        request: CoreCompletionRequest,
1055    ) -> Result<
1056        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1057        CompletionError,
1058    > {
1059        Self::stream(self, request).await
1060    }
1061}