rig/providers/openai/completion/
mod.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse};
6use crate::completion::{
7    CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage,
8};
9use crate::http_client::{self, HttpClientExt};
10use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType};
11use crate::one_or_many::string_or_one_or_many;
12use crate::telemetry::{ProviderResponseExt, SpanCombinator};
13use crate::{OneOrMany, completion, json_utils, message};
14use serde::{Deserialize, Serialize};
15use std::convert::Infallible;
16use std::fmt;
17use tracing::{Instrument, info_span};
18
19use std::str::FromStr;
20
21pub mod streaming;
22
23/// `o4-mini-2025-04-16` completion model
24pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
25/// `o4-mini` completion model
26pub const O4_MINI: &str = "o4-mini";
27/// `o3` completion model
28pub const O3: &str = "o3";
29/// `o3-mini` completion model
30pub const O3_MINI: &str = "o3-mini";
31/// `o3-mini-2025-01-31` completion model
32pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
33/// `o1-pro` completion model
34pub const O1_PRO: &str = "o1-pro";
35/// `o1`` completion model
36pub const O1: &str = "o1";
37/// `o1-2024-12-17` completion model
38pub const O1_2024_12_17: &str = "o1-2024-12-17";
39/// `o1-preview` completion model
40pub const O1_PREVIEW: &str = "o1-preview";
41/// `o1-preview-2024-09-12` completion model
42pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
43/// `o1-mini completion model
44pub const O1_MINI: &str = "o1-mini";
45/// `o1-mini-2024-09-12` completion model
46pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
47
48/// `gpt-4.1-mini` completion model
49pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
50/// `gpt-4.1-nano` completion model
51pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
52/// `gpt-4.1-2025-04-14` completion model
53pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
54/// `gpt-4.1` completion model
55pub const GPT_4_1: &str = "gpt-4.1";
56/// `gpt-4.5-preview` completion model
57pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
58/// `gpt-4.5-preview-2025-02-27` completion model
59pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
60/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
61pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
62/// `gpt-4o` completion model
63pub const GPT_4O: &str = "gpt-4o";
64/// `gpt-4o-mini` completion model
65pub const GPT_4O_MINI: &str = "gpt-4o-mini";
66/// `gpt-4o-2024-05-13` completion model
67pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
68/// `gpt-4-turbo` completion model
69pub const GPT_4_TURBO: &str = "gpt-4-turbo";
70/// `gpt-4-turbo-2024-04-09` completion model
71pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
72/// `gpt-4-turbo-preview` completion model
73pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
74/// `gpt-4-0125-preview` completion model
75pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
76/// `gpt-4-1106-preview` completion model
77pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
78/// `gpt-4-vision-preview` completion model
79pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
80/// `gpt-4-1106-vision-preview` completion model
81pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
82/// `gpt-4` completion model
83pub const GPT_4: &str = "gpt-4";
84/// `gpt-4-0613` completion model
85pub const GPT_4_0613: &str = "gpt-4-0613";
86/// `gpt-4-32k` completion model
87pub const GPT_4_32K: &str = "gpt-4-32k";
88/// `gpt-4-32k-0613` completion model
89pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
90/// `gpt-3.5-turbo` completion model
91pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
92/// `gpt-3.5-turbo-0125` completion model
93pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
94/// `gpt-3.5-turbo-1106` completion model
95pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
96/// `gpt-3.5-turbo-instruct` completion model
97pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
98
99impl From<ApiErrorResponse> for CompletionError {
100    fn from(err: ApiErrorResponse) -> Self {
101        CompletionError::ProviderError(err.message)
102    }
103}
104
105#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
106#[serde(tag = "role", rename_all = "lowercase")]
107pub enum Message {
108    #[serde(alias = "developer")]
109    System {
110        #[serde(deserialize_with = "string_or_one_or_many")]
111        content: OneOrMany<SystemContent>,
112        #[serde(skip_serializing_if = "Option::is_none")]
113        name: Option<String>,
114    },
115    User {
116        #[serde(deserialize_with = "string_or_one_or_many")]
117        content: OneOrMany<UserContent>,
118        #[serde(skip_serializing_if = "Option::is_none")]
119        name: Option<String>,
120    },
121    Assistant {
122        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
123        content: Vec<AssistantContent>,
124        #[serde(skip_serializing_if = "Option::is_none")]
125        refusal: Option<String>,
126        #[serde(skip_serializing_if = "Option::is_none")]
127        audio: Option<AudioAssistant>,
128        #[serde(skip_serializing_if = "Option::is_none")]
129        name: Option<String>,
130        #[serde(
131            default,
132            deserialize_with = "json_utils::null_or_vec",
133            skip_serializing_if = "Vec::is_empty"
134        )]
135        tool_calls: Vec<ToolCall>,
136    },
137    #[serde(rename = "tool")]
138    ToolResult {
139        tool_call_id: String,
140        content: OneOrMany<ToolResultContent>,
141    },
142}
143
144impl Message {
145    pub fn system(content: &str) -> Self {
146        Message::System {
147            content: OneOrMany::one(content.to_owned().into()),
148            name: None,
149        }
150    }
151}
152
153#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
154pub struct AudioAssistant {
155    pub id: String,
156}
157
158#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
159pub struct SystemContent {
160    #[serde(default)]
161    pub r#type: SystemContentType,
162    pub text: String,
163}
164
165#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
166#[serde(rename_all = "lowercase")]
167pub enum SystemContentType {
168    #[default]
169    Text,
170}
171
172#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
173#[serde(tag = "type", rename_all = "lowercase")]
174pub enum AssistantContent {
175    Text { text: String },
176    Refusal { refusal: String },
177}
178
179impl From<AssistantContent> for completion::AssistantContent {
180    fn from(value: AssistantContent) -> Self {
181        match value {
182            AssistantContent::Text { text } => completion::AssistantContent::text(text),
183            AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
184        }
185    }
186}
187
188#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
189#[serde(tag = "type", rename_all = "lowercase")]
190pub enum UserContent {
191    Text {
192        text: String,
193    },
194    #[serde(rename = "image_url")]
195    Image {
196        image_url: ImageUrl,
197    },
198    Audio {
199        input_audio: InputAudio,
200    },
201}
202
203#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
204pub struct ImageUrl {
205    pub url: String,
206    #[serde(default)]
207    pub detail: ImageDetail,
208}
209
210#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
211pub struct InputAudio {
212    pub data: String,
213    pub format: AudioMediaType,
214}
215
216#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
217pub struct ToolResultContent {
218    #[serde(default)]
219    r#type: ToolResultContentType,
220    pub text: String,
221}
222
223#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
224#[serde(rename_all = "lowercase")]
225pub enum ToolResultContentType {
226    #[default]
227    Text,
228}
229
230impl FromStr for ToolResultContent {
231    type Err = Infallible;
232
233    fn from_str(s: &str) -> Result<Self, Self::Err> {
234        Ok(s.to_owned().into())
235    }
236}
237
238impl From<String> for ToolResultContent {
239    fn from(s: String) -> Self {
240        ToolResultContent {
241            r#type: ToolResultContentType::default(),
242            text: s,
243        }
244    }
245}
246
247#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
248pub struct ToolCall {
249    pub id: String,
250    #[serde(default)]
251    pub r#type: ToolType,
252    pub function: Function,
253}
254
255#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
256#[serde(rename_all = "lowercase")]
257pub enum ToolType {
258    #[default]
259    Function,
260}
261
262#[derive(Debug, Deserialize, Serialize, Clone)]
263pub struct ToolDefinition {
264    pub r#type: String,
265    pub function: completion::ToolDefinition,
266}
267
268impl From<completion::ToolDefinition> for ToolDefinition {
269    fn from(tool: completion::ToolDefinition) -> Self {
270        Self {
271            r#type: "function".into(),
272            function: tool,
273        }
274    }
275}
276
277#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
278#[serde(rename_all = "snake_case")]
279pub enum ToolChoice {
280    #[default]
281    Auto,
282    None,
283    Required,
284}
285
286impl TryFrom<crate::message::ToolChoice> for ToolChoice {
287    type Error = CompletionError;
288    fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
289        let res = match value {
290            message::ToolChoice::Specific { .. } => {
291                return Err(CompletionError::ProviderError(
292                    "Provider doesn't support only using specific tools".to_string(),
293                ));
294            }
295            message::ToolChoice::Auto => Self::Auto,
296            message::ToolChoice::None => Self::None,
297            message::ToolChoice::Required => Self::Required,
298        };
299
300        Ok(res)
301    }
302}
303
304#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
305pub struct Function {
306    pub name: String,
307    #[serde(with = "json_utils::stringified_json")]
308    pub arguments: serde_json::Value,
309}
310
311impl TryFrom<message::Message> for Vec<Message> {
312    type Error = message::MessageError;
313
314    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
315        match message {
316            message::Message::User { content } => {
317                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
318                    .into_iter()
319                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
320
321                // If there are messages with both tool results and user content, openai will only
322                //  handle tool results. It's unlikely that there will be both.
323                if !tool_results.is_empty() {
324                    tool_results
325                        .into_iter()
326                        .map(|content| match content {
327                            message::UserContent::ToolResult(message::ToolResult {
328                                id,
329                                content,
330                                ..
331                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
332                                tool_call_id: id,
333                                content: content.try_map(|content| match content {
334                                    message::ToolResultContent::Text(message::Text { text }) => {
335                                        Ok(text.into())
336                                    }
337                                    _ => Err(message::MessageError::ConversionError(
338                                        "Tool result content does not support non-text".into(),
339                                    )),
340                                })?,
341                            }),
342                            _ => unreachable!(),
343                        })
344                        .collect::<Result<Vec<_>, _>>()
345                } else {
346                    let other_content: Vec<UserContent> = other_content
347                        .into_iter()
348                        .map(|content| match content {
349                            message::UserContent::Text(message::Text { text }) => {
350                                Ok(UserContent::Text { text })
351                            }
352                            message::UserContent::Image(message::Image {
353                                data,
354                                detail,
355                                media_type,
356                                ..
357                            }) => match data {
358                                DocumentSourceKind::Url(url) => Ok(UserContent::Image {
359                                    image_url: ImageUrl {
360                                        url,
361                                        detail: detail.unwrap_or_default(),
362                                    },
363                                }),
364                                DocumentSourceKind::Base64(data) => {
365                                    let url = format!(
366                                        "data:{};base64,{}",
367                                        media_type.map(|i| i.to_mime_type()).ok_or(
368                                            message::MessageError::ConversionError(
369                                                "OpenAI Image URI must have media type".into()
370                                            )
371                                        )?,
372                                        data
373                                    );
374
375                                    let detail =
376                                        detail.ok_or(message::MessageError::ConversionError(
377                                            "OpenAI image URI must have image detail".into(),
378                                        ))?;
379
380                                    Ok(UserContent::Image {
381                                        image_url: ImageUrl { url, detail },
382                                    })
383                                }
384                                DocumentSourceKind::Raw(_) => {
385                                    Err(message::MessageError::ConversionError(
386                                        "Raw files not supported, encode as base64 first".into(),
387                                    ))
388                                }
389                                DocumentSourceKind::Unknown => {
390                                    Err(message::MessageError::ConversionError(
391                                        "Document has no body".into(),
392                                    ))
393                                }
394                                doc => Err(message::MessageError::ConversionError(format!(
395                                    "Unsupported document type: {doc:?}"
396                                ))),
397                            },
398                            message::UserContent::Document(message::Document { data, .. }) => {
399                                if let DocumentSourceKind::Base64(text) = data {
400                                    Ok(UserContent::Text { text })
401                                } else {
402                                    Err(message::MessageError::ConversionError(
403                                        "Documents must be base64".into(),
404                                    ))
405                                }
406                            }
407                            message::UserContent::Audio(message::Audio {
408                                data: DocumentSourceKind::Base64(data),
409                                media_type,
410                                ..
411                            }) => Ok(UserContent::Audio {
412                                input_audio: InputAudio {
413                                    data,
414                                    format: match media_type {
415                                        Some(media_type) => media_type,
416                                        None => AudioMediaType::MP3,
417                                    },
418                                },
419                            }),
420                            _ => Err(message::MessageError::ConversionError(
421                                "Tool result is in unsupported format".into(),
422                            )),
423                        })
424                        .collect::<Result<Vec<_>, _>>()?;
425
426                    let other_content = OneOrMany::many(other_content).expect(
427                        "There must be other content here if there were no tool result content",
428                    );
429
430                    Ok(vec![Message::User {
431                        content: other_content,
432                        name: None,
433                    }])
434                }
435            }
436            message::Message::Assistant { content, .. } => {
437                let (text_content, tool_calls) = content.into_iter().fold(
438                    (Vec::new(), Vec::new()),
439                    |(mut texts, mut tools), content| {
440                        match content {
441                            message::AssistantContent::Text(text) => texts.push(text),
442                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
443                            message::AssistantContent::Reasoning(_) => {
444                                unimplemented!(
445                                    "The OpenAI Completions API doesn't support reasoning!"
446                                );
447                            }
448                        }
449                        (texts, tools)
450                    },
451                );
452
453                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
454                //  so either `content` or `tool_calls` will have some content.
455                Ok(vec![Message::Assistant {
456                    content: text_content
457                        .into_iter()
458                        .map(|content| content.text.into())
459                        .collect::<Vec<_>>(),
460                    refusal: None,
461                    audio: None,
462                    name: None,
463                    tool_calls: tool_calls
464                        .into_iter()
465                        .map(|tool_call| tool_call.into())
466                        .collect::<Vec<_>>(),
467                }])
468            }
469        }
470    }
471}
472
473impl From<message::ToolCall> for ToolCall {
474    fn from(tool_call: message::ToolCall) -> Self {
475        Self {
476            id: tool_call.id,
477            r#type: ToolType::default(),
478            function: Function {
479                name: tool_call.function.name,
480                arguments: tool_call.function.arguments,
481            },
482        }
483    }
484}
485
486impl From<ToolCall> for message::ToolCall {
487    fn from(tool_call: ToolCall) -> Self {
488        Self {
489            id: tool_call.id,
490            call_id: None,
491            function: message::ToolFunction {
492                name: tool_call.function.name,
493                arguments: tool_call.function.arguments,
494            },
495        }
496    }
497}
498
499impl TryFrom<Message> for message::Message {
500    type Error = message::MessageError;
501
502    fn try_from(message: Message) -> Result<Self, Self::Error> {
503        Ok(match message {
504            Message::User { content, .. } => message::Message::User {
505                content: content.map(|content| content.into()),
506            },
507            Message::Assistant {
508                content,
509                tool_calls,
510                ..
511            } => {
512                let mut content = content
513                    .into_iter()
514                    .map(|content| match content {
515                        AssistantContent::Text { text } => message::AssistantContent::text(text),
516
517                        // TODO: Currently, refusals are converted into text, but should be
518                        //  investigated for generalization.
519                        AssistantContent::Refusal { refusal } => {
520                            message::AssistantContent::text(refusal)
521                        }
522                    })
523                    .collect::<Vec<_>>();
524
525                content.extend(
526                    tool_calls
527                        .into_iter()
528                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
529                        .collect::<Result<Vec<_>, _>>()?,
530                );
531
532                message::Message::Assistant {
533                    id: None,
534                    content: OneOrMany::many(content).map_err(|_| {
535                        message::MessageError::ConversionError(
536                            "Neither `content` nor `tool_calls` was provided to the Message"
537                                .to_owned(),
538                        )
539                    })?,
540                }
541            }
542
543            Message::ToolResult {
544                tool_call_id,
545                content,
546            } => message::Message::User {
547                content: OneOrMany::one(message::UserContent::tool_result(
548                    tool_call_id,
549                    content.map(|content| message::ToolResultContent::text(content.text)),
550                )),
551            },
552
553            // System messages should get stripped out when converting messages, this is just a
554            // stop gap to avoid obnoxious error handling or panic occurring.
555            Message::System { content, .. } => message::Message::User {
556                content: content.map(|content| message::UserContent::text(content.text)),
557            },
558        })
559    }
560}
561
562impl From<UserContent> for message::UserContent {
563    fn from(content: UserContent) -> Self {
564        match content {
565            UserContent::Text { text } => message::UserContent::text(text),
566            UserContent::Image { image_url } => {
567                message::UserContent::image_url(image_url.url, None, Some(image_url.detail))
568            }
569            UserContent::Audio { input_audio } => {
570                message::UserContent::audio(input_audio.data, Some(input_audio.format))
571            }
572        }
573    }
574}
575
576impl From<String> for UserContent {
577    fn from(s: String) -> Self {
578        UserContent::Text { text: s }
579    }
580}
581
582impl FromStr for UserContent {
583    type Err = Infallible;
584
585    fn from_str(s: &str) -> Result<Self, Self::Err> {
586        Ok(UserContent::Text {
587            text: s.to_string(),
588        })
589    }
590}
591
592impl From<String> for AssistantContent {
593    fn from(s: String) -> Self {
594        AssistantContent::Text { text: s }
595    }
596}
597
598impl FromStr for AssistantContent {
599    type Err = Infallible;
600
601    fn from_str(s: &str) -> Result<Self, Self::Err> {
602        Ok(AssistantContent::Text {
603            text: s.to_string(),
604        })
605    }
606}
607impl From<String> for SystemContent {
608    fn from(s: String) -> Self {
609        SystemContent {
610            r#type: SystemContentType::default(),
611            text: s,
612        }
613    }
614}
615
616impl FromStr for SystemContent {
617    type Err = Infallible;
618
619    fn from_str(s: &str) -> Result<Self, Self::Err> {
620        Ok(SystemContent {
621            r#type: SystemContentType::default(),
622            text: s.to_string(),
623        })
624    }
625}
626
627#[derive(Debug, Deserialize, Serialize)]
628pub struct CompletionResponse {
629    pub id: String,
630    pub object: String,
631    pub created: u64,
632    pub model: String,
633    pub system_fingerprint: Option<String>,
634    pub choices: Vec<Choice>,
635    pub usage: Option<Usage>,
636}
637
638impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
639    type Error = CompletionError;
640
641    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
642        let choice = response.choices.first().ok_or_else(|| {
643            CompletionError::ResponseError("Response contained no choices".to_owned())
644        })?;
645
646        let content = match &choice.message {
647            Message::Assistant {
648                content,
649                tool_calls,
650                ..
651            } => {
652                let mut content = content
653                    .iter()
654                    .filter_map(|c| {
655                        let s = match c {
656                            AssistantContent::Text { text } => text,
657                            AssistantContent::Refusal { refusal } => refusal,
658                        };
659                        if s.is_empty() {
660                            None
661                        } else {
662                            Some(completion::AssistantContent::text(s))
663                        }
664                    })
665                    .collect::<Vec<_>>();
666
667                content.extend(
668                    tool_calls
669                        .iter()
670                        .map(|call| {
671                            completion::AssistantContent::tool_call(
672                                &call.id,
673                                &call.function.name,
674                                call.function.arguments.clone(),
675                            )
676                        })
677                        .collect::<Vec<_>>(),
678                );
679                Ok(content)
680            }
681            _ => Err(CompletionError::ResponseError(
682                "Response did not contain a valid message or tool call".into(),
683            )),
684        }?;
685
686        let choice = OneOrMany::many(content).map_err(|_| {
687            CompletionError::ResponseError(
688                "Response contained no message or tool call (empty)".to_owned(),
689            )
690        })?;
691
692        let usage = response
693            .usage
694            .as_ref()
695            .map(|usage| completion::Usage {
696                input_tokens: usage.prompt_tokens as u64,
697                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
698                total_tokens: usage.total_tokens as u64,
699            })
700            .unwrap_or_default();
701
702        Ok(completion::CompletionResponse {
703            choice,
704            usage,
705            raw_response: response,
706        })
707    }
708}
709
710impl ProviderResponseExt for CompletionResponse {
711    type OutputMessage = Choice;
712    type Usage = Usage;
713
714    fn get_response_id(&self) -> Option<String> {
715        Some(self.id.to_owned())
716    }
717
718    fn get_response_model_name(&self) -> Option<String> {
719        Some(self.model.to_owned())
720    }
721
722    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
723        self.choices.clone()
724    }
725
726    fn get_text_response(&self) -> Option<String> {
727        let Message::User { ref content, .. } = self.choices.last()?.message.clone() else {
728            return None;
729        };
730
731        let UserContent::Text { text } = content.first() else {
732            return None;
733        };
734
735        Some(text)
736    }
737
738    fn get_usage(&self) -> Option<Self::Usage> {
739        self.usage.clone()
740    }
741}
742
743#[derive(Clone, Debug, Serialize, Deserialize)]
744pub struct Choice {
745    pub index: usize,
746    pub message: Message,
747    pub logprobs: Option<serde_json::Value>,
748    pub finish_reason: String,
749}
750
751#[derive(Clone, Debug, Deserialize, Serialize)]
752pub struct Usage {
753    pub prompt_tokens: usize,
754    pub total_tokens: usize,
755}
756
757impl Usage {
758    pub fn new() -> Self {
759        Self {
760            prompt_tokens: 0,
761            total_tokens: 0,
762        }
763    }
764}
765
766impl Default for Usage {
767    fn default() -> Self {
768        Self::new()
769    }
770}
771
772impl fmt::Display for Usage {
773    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774        let Usage {
775            prompt_tokens,
776            total_tokens,
777        } = self;
778        write!(
779            f,
780            "Prompt tokens: {prompt_tokens} Total tokens: {total_tokens}"
781        )
782    }
783}
784
785impl GetTokenUsage for Usage {
786    fn token_usage(&self) -> Option<crate::completion::Usage> {
787        let mut usage = crate::completion::Usage::new();
788        usage.input_tokens = self.prompt_tokens as u64;
789        usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64;
790        usage.total_tokens = self.total_tokens as u64;
791
792        Some(usage)
793    }
794}
795
796#[derive(Clone)]
797pub struct CompletionModel<T = reqwest::Client> {
798    pub(crate) client: Client<T>,
799    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
800    pub model: String,
801}
802
803impl<T> CompletionModel<T>
804where
805    T: HttpClientExt + Default + std::fmt::Debug + Clone + 'static,
806{
807    pub fn new(client: Client<T>, model: &str) -> Self {
808        Self {
809            client,
810            model: model.to_string(),
811        }
812    }
813}
814
815#[derive(Debug, Serialize, Deserialize, Clone)]
816pub struct CompletionRequest {
817    model: String,
818    messages: Vec<Message>,
819    tools: Vec<ToolDefinition>,
820    tool_choice: Option<ToolChoice>,
821    temperature: Option<f64>,
822    #[serde(flatten)]
823    additional_params: Option<serde_json::Value>,
824}
825
826impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest {
827    type Error = CompletionError;
828
829    fn try_from((model, req): (String, CoreCompletionRequest)) -> Result<Self, Self::Error> {
830        let mut partial_history = vec![];
831        if let Some(docs) = req.normalized_documents() {
832            partial_history.push(docs);
833        }
834        let CoreCompletionRequest {
835            preamble,
836            chat_history,
837            tools,
838            temperature,
839            additional_params,
840            tool_choice,
841            ..
842        } = req;
843
844        partial_history.extend(chat_history);
845
846        let mut full_history: Vec<Message> =
847            preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
848
849        // Convert and extend the rest of the history
850        full_history.extend(
851            partial_history
852                .into_iter()
853                .map(message::Message::try_into)
854                .collect::<Result<Vec<Vec<Message>>, _>>()?
855                .into_iter()
856                .flatten()
857                .collect::<Vec<_>>(),
858        );
859
860        let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?;
861
862        let res = Self {
863            model,
864            messages: full_history,
865            tools: tools
866                .into_iter()
867                .map(ToolDefinition::from)
868                .collect::<Vec<_>>(),
869            tool_choice,
870            temperature,
871            additional_params,
872        };
873
874        Ok(res)
875    }
876}
877
878impl crate::telemetry::ProviderRequestExt for CompletionRequest {
879    type InputMessage = Message;
880
881    fn get_input_messages(&self) -> Vec<Self::InputMessage> {
882        self.messages.clone()
883    }
884
885    fn get_system_prompt(&self) -> Option<String> {
886        let first_message = self.messages.first()?;
887
888        let Message::System { ref content, .. } = first_message.clone() else {
889            return None;
890        };
891
892        let SystemContent { text, .. } = content.first();
893
894        Some(text)
895    }
896
897    fn get_prompt(&self) -> Option<String> {
898        let last_message = self.messages.last()?;
899
900        let Message::User { ref content, .. } = last_message.clone() else {
901            return None;
902        };
903
904        let UserContent::Text { text } = content.first() else {
905            return None;
906        };
907
908        Some(text)
909    }
910
911    fn get_model_name(&self) -> String {
912        self.model.clone()
913    }
914}
915
916impl CompletionModel<reqwest::Client> {
917    pub fn into_agent_builder(self) -> crate::agent::AgentBuilder<Self> {
918        crate::agent::AgentBuilder::new(self)
919    }
920}
921
922impl completion::CompletionModel for CompletionModel<reqwest::Client> {
923    type Response = CompletionResponse;
924    type StreamingResponse = StreamingCompletionResponse;
925
926    #[cfg_attr(feature = "worker", worker::send)]
927    async fn completion(
928        &self,
929        completion_request: CoreCompletionRequest,
930    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
931        let span = if tracing::Span::current().is_disabled() {
932            info_span!(
933                target: "rig::completions",
934                "chat",
935                gen_ai.operation.name = "chat",
936                gen_ai.provider.name = "openai",
937                gen_ai.request.model = self.model,
938                gen_ai.system_instructions = &completion_request.preamble,
939                gen_ai.response.id = tracing::field::Empty,
940                gen_ai.response.model = tracing::field::Empty,
941                gen_ai.usage.output_tokens = tracing::field::Empty,
942                gen_ai.usage.input_tokens = tracing::field::Empty,
943                gen_ai.input.messages = tracing::field::Empty,
944                gen_ai.output.messages = tracing::field::Empty,
945            )
946        } else {
947            tracing::Span::current()
948        };
949
950        let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?;
951
952        span.record_model_input(&request.messages);
953
954        let body = serde_json::to_vec(&request)?;
955
956        let req = self
957            .client
958            .post("/chat/completions")?
959            .header("Content-Type", "application/json")
960            .body(body)
961            .map_err(|e| CompletionError::HttpError(e.into()))?;
962
963        async move {
964            let response = self.client.send(req).await?;
965
966            if response.status().is_success() {
967                let text = http_client::text(response).await?;
968
969                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
970                    ApiResponse::Ok(response) => {
971                        let span = tracing::Span::current();
972                        span.record_model_output(&response.choices);
973                        span.record_response_metadata(&response);
974                        span.record_token_usage(&response.usage);
975                        tracing::debug!("OpenAI response: {response:?}");
976                        response.try_into()
977                    }
978                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
979                }
980            } else {
981                let text = http_client::text(response).await?;
982                Err(CompletionError::ProviderError(text))
983            }
984        }
985        .instrument(span)
986        .await
987    }
988
989    #[cfg_attr(feature = "worker", worker::send)]
990    async fn stream(
991        &self,
992        request: CoreCompletionRequest,
993    ) -> Result<
994        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
995        CompletionError,
996    > {
997        CompletionModel::stream(self, request).await
998    }
999}