rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse};
6use crate::completion::{
7    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
8};
9use crate::http_client::{self, HttpClientExt};
10use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
11use crate::one_or_many::string_or_one_or_many;
12use crate::telemetry::{ProviderResponseExt, SpanCombinator};
13use crate::{OneOrMany, completion, json_utils, message};
14use serde::{Deserialize, Serialize};
15use std::convert::Infallible;
16use std::fmt;
17use tracing::{Instrument, info_span};
18
19use std::str::FromStr;
20
21pub mod streaming;
22
23/// `o4-mini-2025-04-16` completion model
24pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
25/// `o4-mini` completion model
26pub const O4_MINI: &str = "o4-mini";
27/// `o3` completion model
28pub const O3: &str = "o3";
29/// `o3-mini` completion model
30pub const O3_MINI: &str = "o3-mini";
31/// `o3-mini-2025-01-31` completion model
32pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
33/// `o1-pro` completion model
34pub const O1_PRO: &str = "o1-pro";
35/// `o1`` completion model
36pub const O1: &str = "o1";
37/// `o1-2024-12-17` completion model
38pub const O1_2024_12_17: &str = "o1-2024-12-17";
39/// `o1-preview` completion model
40pub const O1_PREVIEW: &str = "o1-preview";
41/// `o1-preview-2024-09-12` completion model
42pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
43/// `o1-mini completion model
44pub const O1_MINI: &str = "o1-mini";
45/// `o1-mini-2024-09-12` completion model
46pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
47
48/// `gpt-4.1-mini` completion model
49pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
50/// `gpt-4.1-nano` completion model
51pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
52/// `gpt-4.1-2025-04-14` completion model
53pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
54/// `gpt-4.1` completion model
55pub const GPT_4_1: &str = "gpt-4.1";
56/// `gpt-4.5-preview` completion model
57pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
58/// `gpt-4.5-preview-2025-02-27` completion model
59pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
60/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
61pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
62/// `gpt-4o` completion model
63pub const GPT_4O: &str = "gpt-4o";
64/// `gpt-4o-mini` completion model
65pub const GPT_4O_MINI: &str = "gpt-4o-mini";
66/// `gpt-4o-2024-05-13` completion model
67pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
68/// `gpt-4-turbo` completion model
69pub const GPT_4_TURBO: &str = "gpt-4-turbo";
70/// `gpt-4-turbo-2024-04-09` completion model
71pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
72/// `gpt-4-turbo-preview` completion model
73pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
74/// `gpt-4-0125-preview` completion model
75pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
76/// `gpt-4-1106-preview` completion model
77pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
78/// `gpt-4-vision-preview` completion model
79pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
80/// `gpt-4-1106-vision-preview` completion model
81pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
82/// `gpt-4` completion model
83pub const GPT_4: &str = "gpt-4";
84/// `gpt-4-0613` completion model
85pub const GPT_4_0613: &str = "gpt-4-0613";
86/// `gpt-4-32k` completion model
87pub const GPT_4_32K: &str = "gpt-4-32k";
88/// `gpt-4-32k-0613` completion model
89pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
90/// `gpt-3.5-turbo` completion model
91pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
92/// `gpt-3.5-turbo-0125` completion model
93pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
94/// `gpt-3.5-turbo-1106` completion model
95pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
96/// `gpt-3.5-turbo-instruct` completion model
97pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
98
99impl From<ApiErrorResponse> for CompletionError {
100    fn from(err: ApiErrorResponse) -> Self {
101        CompletionError::ProviderError(err.message)
102    }
103}
104
105#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
106#[serde(tag = "role", rename_all = "lowercase")]
107pub enum Message {
108    #[serde(alias = "developer")]
109    System {
110        #[serde(deserialize_with = "string_or_one_or_many")]
111        content: OneOrMany<SystemContent>,
112        #[serde(skip_serializing_if = "Option::is_none")]
113        name: Option<String>,
114    },
115    User {
116        #[serde(deserialize_with = "string_or_one_or_many")]
117        content: OneOrMany<UserContent>,
118        #[serde(skip_serializing_if = "Option::is_none")]
119        name: Option<String>,
120    },
121    Assistant {
122        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
123        content: Vec<AssistantContent>,
124        #[serde(skip_serializing_if = "Option::is_none")]
125        refusal: Option<String>,
126        #[serde(skip_serializing_if = "Option::is_none")]
127        audio: Option<AudioAssistant>,
128        #[serde(skip_serializing_if = "Option::is_none")]
129        name: Option<String>,
130        #[serde(
131            default,
132            deserialize_with = "json_utils::null_or_vec",
133            skip_serializing_if = "Vec::is_empty"
134        )]
135        tool_calls: Vec<ToolCall>,
136    },
137    #[serde(rename = "tool")]
138    ToolResult {
139        tool_call_id: String,
140        content: OneOrMany<ToolResultContent>,
141    },
142}
143
144impl Message {
145    pub fn system(content: &str) -> Self {
146        Message::System {
147            content: OneOrMany::one(content.to_owned().into()),
148            name: None,
149        }
150    }
151}
152
153#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
154pub struct AudioAssistant {
155    pub id: String,
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
159pub struct SystemContent {
160    #[serde(default)]
161    pub r#type: SystemContentType,
162    pub text: String,
163}
164
165#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
166#[serde(rename_all = "lowercase")]
167pub enum SystemContentType {
168    #[default]
169    Text,
170}
171
172#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
173#[serde(tag = "type", rename_all = "lowercase")]
174pub enum AssistantContent {
175    Text { text: String },
176    Refusal { refusal: String },
177}
178
179impl From<AssistantContent> for completion::AssistantContent {
180    fn from(value: AssistantContent) -> Self {
181        match value {
182            AssistantContent::Text { text } => completion::AssistantContent::text(text),
183            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
184        }
185    }
186}
187
188#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
189#[serde(tag = "type", rename_all = "lowercase")]
190pub enum UserContent {
191    Text {
192        text: String,
193    },
194    #[serde(rename = "image_url")]
195    Image {
196        image_url: ImageUrl,
197    },
198    Audio {
199        input_audio: InputAudio,
200    },
201}
202
203#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
204pub struct ImageUrl {
205    pub url: String,
206    #[serde(default)]
207    pub detail: ImageDetail,
208}
209
210#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
211pub struct InputAudio {
212    pub data: String,
213    pub format: AudioMediaType,
214}
215
216#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
217pub struct ToolResultContent {
218    #[serde(default)]
219    r#type: ToolResultContentType,
220    pub text: String,
221}
222
223#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
224#[serde(rename_all = "lowercase")]
225pub enum ToolResultContentType {
226    #[default]
227    Text,
228}
229
230impl FromStr for ToolResultContent {
231    type Err = Infallible;
232
233    fn from_str(s: &str) -> Result<Self, Self::Err> {
234        Ok(s.to_owned().into())
235    }
236}
237
238impl From<String> for ToolResultContent {
239    fn from(s: String) -> Self {
240        ToolResultContent {
241            r#type: ToolResultContentType::default(),
242            text: s,
243        }
244    }
245}
246
247#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
248pub struct ToolCall {
249    pub id: String,
250    #[serde(default)]
251    pub r#type: ToolType,
252    pub function: Function,
253}
254
255#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
256#[serde(rename_all = "lowercase")]
257pub enum ToolType {
258    #[default]
259    Function,
260}
261
262#[derive(Debug, Deserialize, Serialize, Clone)]
263pub struct ToolDefinition {
264    pub r#type: String,
265    pub function: completion::ToolDefinition,
266}
267
268impl From<completion::ToolDefinition> for ToolDefinition {
269    fn from(tool: completion::ToolDefinition) -> Self {
270        Self {
271            r#type: "function".into(),
272            function: tool,
273        }
274    }
275}
276
277#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
278#[serde(rename_all = "snake_case")]
279pub enum ToolChoice {
280    #[default]
281    Auto,
282    None,
283    Required,
284}
285
286impl TryFrom<crate::message::ToolChoice> for ToolChoice {
287    type Error = CompletionError;
288    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
289        let res = match value {
290            message::ToolChoice::Specific { .. } => {
291                return Err(CompletionError::ProviderError(
292                    "Provider doesn't support only using specific tools".to_string(),
293                ));
294            }
295            message::ToolChoice::Auto => Self::Auto,
296            message::ToolChoice::None => Self::None,
297            message::ToolChoice::Required => Self::Required,
298        };
299
300        Ok(res)
301    }
302}
303
304#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
305pub struct Function {
306    pub name: String,
307    #[serde(with = "json_utils::stringified_json")]
308    pub arguments: serde_json::Value,
309}
310
311impl TryFrom<message::Message> for Vec<Message> {
312    type Error = message::MessageError;
313
314    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
315        match message {
316            message::Message::User { content } => {
317                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
318                    .into_iter()
319                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
320
321                // If there are messages with both tool results and user content, openai will only
322                //  handle tool results. It's unlikely that there will be both.
323                if !tool_results.is_empty() {
324                    tool_results
325                        .into_iter()
326                        .map(|content| match content {
327                            message::UserContent::ToolResult(message::ToolResult {
328                                id,
329                                content,
330                                ..
331                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
332                                tool_call_id: id,
333                                content: content.try_map(|content| match content {
334                                    message::ToolResultContent::Text(message::Text { text }) => {
335                                        Ok(text.into())
336                                    }
337                                    _ => Err(message::MessageError::ConversionError(
338                                        "Tool result content does not support non-text".into(),
339                                    )),
340                                })?,
341                            }),
342                            _ => unreachable!(),
343                        })
344                        .collect::<Result<Vec<_>, _>>()
345                } else {
346                    let other_content: Vec<UserContent> = other_content
347                        .into_iter()
348                        .map(|content| match content {
349                            message::UserContent::Text(message::Text { text }) => {
350                                Ok(UserContent::Text { text })
351                            }
352                            message::UserContent::Image(message::Image {
353                                data,
354                                detail,
355                                media_type,
356                                ..
357                            }) => match data {
358                                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
359                                    image_url: ImageUrl {
360                                        url,
361                                        detail: detail.unwrap_or_default(),
362                                    },
363                                }),
364                                DocumentSourceKind::Base64(data) => {
365                                    let url = format!(
366                                        "data:{};base64,{}",
367                                        media_type.map(|i| i.to_mime_type()).ok_or(
368                                            message::MessageError::ConversionError(
369                                                "OpenAI Image URI must have media type".into()
370                                            )
371                                        )?,
372                                        data
373                                    );
374
375                                    let detail =
376                                        detail.ok_or(message::MessageError::ConversionError(
377                                            "OpenAI image URI must have image detail".into(),
378                                        ))?;
379
380                                    Ok(UserContent::Image {
381                                        image_url: ImageUrl { url, detail },
382                                    })
383                                }
384                                DocumentSourceKind::Raw(_) => {
385                                    Err(message::MessageError::ConversionError(
386                                        "Raw files not supported, encode as base64 first".into(),
387                                    ))
388                                }
389                                DocumentSourceKind::Unknown => {
390                                    Err(message::MessageError::ConversionError(
391                                        "Document has no body".into(),
392                                    ))
393                                }
394                                doc => Err(message::MessageError::ConversionError(format!(
395                                    "Unsupported document type: {doc:?}"
396                                ))),
397                            },
398                            message::UserContent::Document(message::Document { data, .. }) => {
399                                if let DocumentSourceKind::Base64(text)
400                                | DocumentSourceKind::String(text) = data
401                                {
402                                    Ok(UserContent::Text { text })
403                                } else {
404                                    Err(message::MessageError::ConversionError(
405                                        "Documents must be base64 or a string".into(),
406                                    ))
407                                }
408                            }
409                            message::UserContent::Audio(message::Audio {
410                                data: DocumentSourceKind::Base64(data),
411                                media_type,
412                                ..
413                            }) => Ok(UserContent::Audio {
414                                input_audio: InputAudio {
415                                    data,
416                                    format: match media_type {
417                                        Some(media_type) => media_type,
418                                        None => AudioMediaType::MP3,
419                                    },
420                                },
421                            }),
422                            _ => Err(message::MessageError::ConversionError(
423                                "Tool result is in unsupported format".into(),
424                            )),
425                        })
426                        .collect::<Result<Vec<_>, _>>()?;
427
428                    let other_content = OneOrMany::many(other_content).expect(
429                        "There must be other content here if there were no tool result content",
430                    );
431
432                    Ok(vec![Message::User {
433                        content: other_content,
434                        name: None,
435                    }])
436                }
437            }
438            message::Message::Assistant { content, .. } => {
439                let (text_content, tool_calls) = content.into_iter().fold(
440                    (Vec::new(), Vec::new()),
441                    |(mut texts, mut tools), content| {
442                        match content {
443                            message::AssistantContent::Text(text) => texts.push(text),
444                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
445                            message::AssistantContent::Reasoning(_) => {
446                                unimplemented!(
447                                    "The OpenAI Completions API doesn't support reasoning!"
448                                );
449                            }
450                        }
451                        (texts, tools)
452                    },
453                );
454
455                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
456                //  so either `content` or `tool_calls` will have some content.
457                Ok(vec![Message::Assistant {
458                    content: text_content
459                        .into_iter()
460                        .map(|content| content.text.into())
461                        .collect::<Vec<_>>(),
462                    refusal: None,
463                    audio: None,
464                    name: None,
465                    tool_calls: tool_calls
466                        .into_iter()
467                        .map(|tool_call| tool_call.into())
468                        .collect::<Vec<_>>(),
469                }])
470            }
471        }
472    }
473}
474
475impl From<message::ToolCall> for ToolCall {
476    fn from(tool_call: message::ToolCall) -> Self {
477        Self {
478            id: tool_call.id,
479            r#type: ToolType::default(),
480            function: Function {
481                name: tool_call.function.name,
482                arguments: tool_call.function.arguments,
483            },
484        }
485    }
486}
487
488impl From<ToolCall> for message::ToolCall {
489    fn from(tool_call: ToolCall) -> Self {
490        Self {
491            id: tool_call.id,
492            call_id: None,
493            function: message::ToolFunction {
494                name: tool_call.function.name,
495                arguments: tool_call.function.arguments,
496            },
497        }
498    }
499}
500
501impl TryFrom<Message> for message::Message {
502    type Error = message::MessageError;
503
504    fn try_from(message: Message) -> Result<Self, Self::Error> {
505        Ok(match message {
506            Message::User { content, .. } => message::Message::User {
507                content: content.map(|content| content.into()),
508            },
509            Message::Assistant {
510                content,
511                tool_calls,
512                ..
513            } => {
514                let mut content = content
515                    .into_iter()
516                    .map(|content| match content {
517                        AssistantContent::Text { text } => message::AssistantContent::text(text),
518
519                        // TODO: Currently, refusals are converted into text, but should be
520                        //  investigated for generalization.
521                        AssistantContent::Refusal { refusal } => {
522                            message::AssistantContent::text(refusal)
523                        }
524                    })
525                    .collect::<Vec<_>>();
526
527                content.extend(
528                    tool_calls
529                        .into_iter()
530                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
531                        .collect::<Result<Vec<_>, _>>()?,
532                );
533
534                message::Message::Assistant {
535                    id: None,
536                    content: OneOrMany::many(content).map_err(|_| {
537                        message::MessageError::ConversionError(
538                            "Neither `content` nor `tool_calls` was provided to the Message"
539                                .to_owned(),
540                        )
541                    })?,
542                }
543            }
544
545            Message::ToolResult {
546                tool_call_id,
547                content,
548            } => message::Message::User {
549                content: OneOrMany::one(message::UserContent::tool_result(
550                    tool_call_id,
551                    content.map(|content| message::ToolResultContent::text(content.text)),
552                )),
553            },
554
555            // System messages should get stripped out when converting messages, this is just a
556            // stop gap to avoid obnoxious error handling or panic occurring.
557            Message::System { content, .. } => message::Message::User {
558                content: content.map(|content| message::UserContent::text(content.text)),
559            },
560        })
561    }
562}
563
564impl From<UserContent> for message::UserContent {
565    fn from(content: UserContent) -> Self {
566        match content {
567            UserContent::Text { text } => message::UserContent::text(text),
568            UserContent::Image { image_url } => {
569                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
570            }
571            UserContent::Audio { input_audio } => {
572                message::UserContent::audio(input_audio.data, Some(input_audio.format))
573            }
574        }
575    }
576}
577
578impl From<String> for UserContent {
579    fn from(s: String) -> Self {
580        UserContent::Text { text: s }
581    }
582}
583
584impl FromStr for UserContent {
585    type Err = Infallible;
586
587    fn from_str(s: &str) -> Result<Self, Self::Err> {
588        Ok(UserContent::Text {
589            text: s.to_string(),
590        })
591    }
592}
593
594impl From<String> for AssistantContent {
595    fn from(s: String) -> Self {
596        AssistantContent::Text { text: s }
597    }
598}
599
600impl FromStr for AssistantContent {
601    type Err = Infallible;
602
603    fn from_str(s: &str) -> Result<Self, Self::Err> {
604        Ok(AssistantContent::Text {
605            text: s.to_string(),
606        })
607    }
608}
609impl From<String> for SystemContent {
610    fn from(s: String) -> Self {
611        SystemContent {
612            r#type: SystemContentType::default(),
613            text: s,
614        }
615    }
616}
617
618impl FromStr for SystemContent {
619    type Err = Infallible;
620
621    fn from_str(s: &str) -> Result<Self, Self::Err> {
622        Ok(SystemContent {
623            r#type: SystemContentType::default(),
624            text: s.to_string(),
625        })
626    }
627}
628
629#[derive(Debug, Deserialize, Serialize)]
630pub struct CompletionResponse {
631    pub id: String,
632    pub object: String,
633    pub created: u64,
634    pub model: String,
635    pub system_fingerprint: Option<String>,
636    pub choices: Vec<Choice>,
637    pub usage: Option<Usage>,
638}
639
640impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
641    type Error = CompletionError;
642
643    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
644        let choice = response.choices.first().ok_or_else(|| {
645            CompletionError::ResponseError("Response contained no choices".to_owned())
646        })?;
647
648        let content = match &choice.message {
649            Message::Assistant {
650                content,
651                tool_calls,
652                ..
653            } => {
654                let mut content = content
655                    .iter()
656                    .filter_map(|c| {
657                        let s = match c {
658                            AssistantContent::Text { text } => text,
659                            AssistantContent::Refusal { refusal } => refusal,
660                        };
661                        if s.is_empty() {
662                            None
663                        } else {
664                            Some(completion::AssistantContent::text(s))
665                        }
666                    })
667                    .collect::<Vec<_>>();
668
669                content.extend(
670                    tool_calls
671                        .iter()
672                        .map(|call| {
673                            completion::AssistantContent::tool_call(
674                                &call.id,
675                                &call.function.name,
676                                call.function.arguments.clone(),
677                            )
678                        })
679                        .collect::<Vec<_>>(),
680                );
681                Ok(content)
682            }
683            _ => Err(CompletionError::ResponseError(
684                "Response did not contain a valid message or tool call".into(),
685            )),
686        }?;
687
688        let choice = OneOrMany::many(content).map_err(|_| {
689            CompletionError::ResponseError(
690                "Response contained no message or tool call (empty)".to_owned(),
691            )
692        })?;
693
694        let usage = response
695            .usage
696            .as_ref()
697            .map(|usage| completion::Usage {
698                input_tokens: usage.prompt_tokens as u64,
699                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
700                total_tokens: usage.total_tokens as u64,
701            })
702            .unwrap_or_default();
703
704        Ok(completion::CompletionResponse {
705            choice,
706            usage,
707            raw_response: response,
708        })
709    }
710}
711
712impl ProviderResponseExt for CompletionResponse {
713    type OutputMessage = Choice;
714    type Usage = Usage;
715
716    fn get_response_id(&self) -> Option<String> {
717        Some(self.id.to_owned())
718    }
719
720    fn get_response_model_name(&self) -> Option<String> {
721        Some(self.model.to_owned())
722    }
723
724    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
725        self.choices.clone()
726    }
727
728    fn get_text_response(&self) -> Option<String> {
729        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
730            return None;
731        };
732
733        let UserContent::Text { text } = content.first() else {
734            return None;
735        };
736
737        Some(text)
738    }
739
740    fn get_usage(&self) -> Option<Self::Usage> {
741        self.usage.clone()
742    }
743}
744
745#[derive(Clone, Debug, Serialize, Deserialize)]
746pub struct Choice {
747    pub index: usize,
748    pub message: Message,
749    pub logprobs: Option<serde_json::Value>,
750    pub finish_reason: String,
751}
752
753#[derive(Clone, Debug, Deserialize, Serialize)]
754pub struct Usage {
755    pub prompt_tokens: usize,
756    pub total_tokens: usize,
757}
758
759impl Usage {
760    pub fn new() -> Self {
761        Self {
762            prompt_tokens: 0,
763            total_tokens: 0,
764        }
765    }
766}
767
768impl Default for Usage {
769    fn default() -> Self {
770        Self::new()
771    }
772}
773
774impl fmt::Display for Usage {
775    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
776        let Usage {
777            prompt_tokens,
778            total_tokens,
779        } = self;
780        write!(
781            f,
782            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
783        )
784    }
785}
786
787impl GetTokenUsage for Usage {
788    fn token_usage(&self) -> Option<crate::completion::Usage> {
789        let mut usage = crate::completion::Usage::new();
790        usage.input_tokens = self.prompt_tokens as u64;
791        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
792        usage.total_tokens = self.total_tokens as u64;
793
794        Some(usage)
795    }
796}
797
798#[derive(Clone)]
799pub struct CompletionModel<T = reqwest::Client> {
800    pub(crate) client: Client<T>,
801    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
802    pub model: String,
803}
804
805impl<T> CompletionModel<T>
806where
807    T: HttpClientExt + Default + std::fmt::Debug + Clone + 'static,
808{
809    pub fn new(client: Client<T>, model: &str) -> Self {
810        Self {
811            client,
812            model: model.to_string(),
813        }
814    }
815}
816
817#[derive(Debug, Serialize, Deserialize, Clone)]
818pub struct CompletionRequest {
819    model: String,
820    messages: Vec<Message>,
821    #[serde(skip_serializing_if = "Vec::is_empty")]
822    tools: Vec<ToolDefinition>,
823    #[serde(skip_serializing_if = "Option::is_none")]
824    tool_choice: Option<ToolChoice>,
825    #[serde(skip_serializing_if = "Option::is_none")]
826    temperature: Option<f64>,
827    #[serde(flatten)]
828    additional_params: Option<serde_json::Value>,
829}
830
831impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
832    type Error = CompletionError;
833
834    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
835        let mut partial_history = vec![];
836        if let Some(docs) = req.normalized_documents() {
837            partial_history.push(docs);
838        }
839        let CoreCompletionRequest {
840            preamble,
841            chat_history,
842            tools,
843            temperature,
844            additional_params,
845            tool_choice,
846            ..
847        } = req;
848
849        partial_history.extend(chat_history);
850
851        let mut full_history: Vec<Message> =
852            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
853
854        // Convert and extend the rest of the history
855        full_history.extend(
856            partial_history
857                .into_iter()
858                .map(message::Message::try_into)
859                .collect::<Result<Vec<Vec<Message>>, _>>()?
860                .into_iter()
861                .flatten()
862                .collect::<Vec<_>>(),
863        );
864
865        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
866
867        let res = Self {
868            model,
869            messages: full_history,
870            tools: tools
871                .into_iter()
872                .map(ToolDefinition::from)
873                .collect::<Vec<_>>(),
874            tool_choice,
875            temperature,
876            additional_params,
877        };
878
879        Ok(res)
880    }
881}
882
883impl crate::telemetry::ProviderRequestExt for CompletionRequest {
884    type InputMessage = Message;
885
886    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
887        self.messages.clone()
888    }
889
890    fn get_system_prompt(&self) -> Option<String> {
891        let first_message = self.messages.first()?;
892
893        let Message::System { ref content, .. } = first_message.clone() else {
894            return None;
895        };
896
897        let SystemContent { text, .. } = content.first();
898
899        Some(text)
900    }
901
902    fn get_prompt(&self) -> Option<String> {
903        let last_message = self.messages.last()?;
904
905        let Message::User { ref content, .. } = last_message.clone() else {
906            return None;
907        };
908
909        let UserContent::Text { text } = content.first() else {
910            return None;
911        };
912
913        Some(text)
914    }
915
916    fn get_model_name(&self) -> String {
917        self.model.clone()
918    }
919}
920
921impl CompletionModel<reqwest::Client> {
922    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
923        crate::agent::AgentBuilder::new(self)
924    }
925}
926
927impl completion::CompletionModel for CompletionModel<reqwest::Client> {
928    type Response = CompletionResponse;
929    type StreamingResponse = StreamingCompletionResponse;
930
931    #[cfg_attr(feature = "worker", worker::send)]
932    async fn completion(
933        &self,
934        completion_request: CoreCompletionRequest,
935    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
936        let span = if tracing::Span::current().is_disabled() {
937            info_span!(
938                target: "rig::completions",
939                "chat",
940                gen_ai.operation.name = "chat",
941                gen_ai.provider.name = "openai",
942                gen_ai.request.model = self.model,
943                gen_ai.system_instructions = &completion_request.preamble,
944                gen_ai.response.id = tracing::field::Empty,
945                gen_ai.response.model = tracing::field::Empty,
946                gen_ai.usage.output_tokens = tracing::field::Empty,
947                gen_ai.usage.input_tokens = tracing::field::Empty,
948                gen_ai.input.messages = tracing::field::Empty,
949                gen_ai.output.messages = tracing::field::Empty,
950            )
951        } else {
952            tracing::Span::current()
953        };
954
955        let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?;
956
957        span.record_model_input(&request.messages);
958
959        let body = serde_json::to_vec(&request)?;
960
961        let req = self
962            .client
963            .post("/chat/completions")?
964            .header("Content-Type", "application/json")
965            .body(body)
966            .map_err(|e| CompletionError::HttpError(e.into()))?;
967
968        async move {
969            let response = self.client.send(req).await?;
970
971            if response.status().is_success() {
972                let text = http_client::text(response).await?;
973
974                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
975                    ApiResponse::Ok(response) => {
976                        let span = tracing::Span::current();
977                        span.record_model_output(&response.choices);
978                        span.record_response_metadata(&response);
979                        span.record_token_usage(&response.usage);
980                        tracing::debug!("OpenAI response: {response:?}");
981                        response.try_into()
982                    }
983                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
984                }
985            } else {
986                let text = http_client::text(response).await?;
987                Err(CompletionError::ProviderError(text))
988            }
989        }
990        .instrument(span)
991        .await
992    }
993
994    #[cfg_attr(feature = "worker", worker::send)]
995    async fn stream(
996        &self,
997        request: CoreCompletionRequest,
998    ) -> Result<
999        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1000        CompletionError,
1001    > {
1002        CompletionModel::stream(self, request).await
1003    }
1004}