rig/providers/openai/
completion.rs

1// ================================================================
2// OpenAI Completion API
3// ================================================================
4
5use super::{ApiErrorResponse, ApiResponse, Client, Usage};
6use crate::completion::{CompletionError, CompletionRequest};
7use crate::message::{AudioMediaType, ImageDetail};
8use crate::one_or_many::string_or_one_or_many;
9use crate::{completion, json_utils, message, OneOrMany};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::convert::Infallible;
13use std::str::FromStr;
14
15/// `o4-mini-2025-04-16` completion model
16pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16";
17/// `o4-mini` completion model
18pub const O4_MINI: &str = "o4-mini";
19/// `o3` completion model
20pub const O3: &str = "o3";
21/// `o3-mini` completion model
22pub const O3_MINI: &str = "o3-mini";
23/// `o3-mini-2025-01-31` completion model
24pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
25/// `o1-pro` completion model
26pub const O1_PRO: &str = "o1-pro";
27/// `o1`` completion model
28pub const O1: &str = "o1";
29/// `o1-2024-12-17` completion model
30pub const O1_2024_12_17: &str = "o1-2024-12-17";
31/// `o1-preview` completion model
32pub const O1_PREVIEW: &str = "o1-preview";
33/// `o1-preview-2024-09-12` completion model
34pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
35/// `o1-mini completion model
36pub const O1_MINI: &str = "o1-mini";
37/// `o1-mini-2024-09-12` completion model
38pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
39
40/// `gpt-4.1-mini` completion model
41pub const GPT_4_1_MINI: &str = "gpt-4.1-mini";
42/// `gpt-4.1-nano` completion model
43pub const GPT_4_1_NANO: &str = "gpt-4.1-nano";
44/// `gpt-4.1-2025-04-14` completion model
45pub const GPT_4_1_2025_04_14: &str = "gpt-4.1-2025-04-14";
46/// `gpt-4.1` completion model
47pub const GPT_4_1: &str = "gpt-4.1";
48/// `gpt-4.5-preview` completion model
49pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
50/// `gpt-4.5-preview-2025-02-27` completion model
51pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
52/// `gpt-4o-2024-11-20` completion model (this is newer than 4o)
53pub const GPT_4O_2024_11_20: &str = "gpt-4o-2024-11-20";
54/// `gpt-4o` completion model
55pub const GPT_4O: &str = "gpt-4o";
56/// `gpt-4o-mini` completion model
57pub const GPT_4O_MINI: &str = "gpt-4o-mini";
58/// `gpt-4o-2024-05-13` completion model
59pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
60/// `gpt-4-turbo` completion model
61pub const GPT_4_TURBO: &str = "gpt-4-turbo";
62/// `gpt-4-turbo-2024-04-09` completion model
63pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
64/// `gpt-4-turbo-preview` completion model
65pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
66/// `gpt-4-0125-preview` completion model
67pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
68/// `gpt-4-1106-preview` completion model
69pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
70/// `gpt-4-vision-preview` completion model
71pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
72/// `gpt-4-1106-vision-preview` completion model
73pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
74/// `gpt-4` completion model
75pub const GPT_4: &str = "gpt-4";
76/// `gpt-4-0613` completion model
77pub const GPT_4_0613: &str = "gpt-4-0613";
78/// `gpt-4-32k` completion model
79pub const GPT_4_32K: &str = "gpt-4-32k";
80/// `gpt-4-32k-0613` completion model
81pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
82/// `gpt-3.5-turbo` completion model
83pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
84/// `gpt-3.5-turbo-0125` completion model
85pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
86/// `gpt-3.5-turbo-1106` completion model
87pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
88/// `gpt-3.5-turbo-instruct` completion model
89pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
90
91#[derive(Debug, Deserialize)]
92pub struct CompletionResponse {
93    pub id: String,
94    pub object: String,
95    pub created: u64,
96    pub model: String,
97    pub system_fingerprint: Option<String>,
98    pub choices: Vec<Choice>,
99    pub usage: Option<Usage>,
100}
101
102impl From<ApiErrorResponse> for CompletionError {
103    fn from(err: ApiErrorResponse) -> Self {
104        CompletionError::ProviderError(err.message)
105    }
106}
107
108impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
109    type Error = CompletionError;
110
111    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
112        let choice = response.choices.first().ok_or_else(|| {
113            CompletionError::ResponseError("Response contained no choices".to_owned())
114        })?;
115
116        let content = match &choice.message {
117            Message::Assistant {
118                content,
119                tool_calls,
120                ..
121            } => {
122                let mut content = content
123                    .iter()
124                    .filter_map(|c| {
125                        let s = match c {
126                            AssistantContent::Text { text } => text,
127                            AssistantContent::Refusal { refusal } => refusal,
128                        };
129                        if s.is_empty() {
130                            None
131                        } else {
132                            Some(completion::AssistantContent::text(s))
133                        }
134                    })
135                    .collect::<Vec<_>>();
136
137                content.extend(
138                    tool_calls
139                        .iter()
140                        .map(|call| {
141                            completion::AssistantContent::tool_call(
142                                &call.id,
143                                &call.function.name,
144                                call.function.arguments.clone(),
145                            )
146                        })
147                        .collect::<Vec<_>>(),
148                );
149                Ok(content)
150            }
151            _ => Err(CompletionError::ResponseError(
152                "Response did not contain a valid message or tool call".into(),
153            )),
154        }?;
155
156        let choice = OneOrMany::many(content).map_err(|_| {
157            CompletionError::ResponseError(
158                "Response contained no message or tool call (empty)".to_owned(),
159            )
160        })?;
161
162        Ok(completion::CompletionResponse {
163            choice,
164            raw_response: response,
165        })
166    }
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170pub struct Choice {
171    pub index: usize,
172    pub message: Message,
173    pub logprobs: Option<serde_json::Value>,
174    pub finish_reason: String,
175}
176
177#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
178#[serde(tag = "role", rename_all = "lowercase")]
179pub enum Message {
180    #[serde(alias = "developer")]
181    System {
182        #[serde(deserialize_with = "string_or_one_or_many")]
183        content: OneOrMany<SystemContent>,
184        #[serde(skip_serializing_if = "Option::is_none")]
185        name: Option<String>,
186    },
187    User {
188        #[serde(deserialize_with = "string_or_one_or_many")]
189        content: OneOrMany<UserContent>,
190        #[serde(skip_serializing_if = "Option::is_none")]
191        name: Option<String>,
192    },
193    Assistant {
194        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
195        content: Vec<AssistantContent>,
196        #[serde(skip_serializing_if = "Option::is_none")]
197        refusal: Option<String>,
198        #[serde(skip_serializing_if = "Option::is_none")]
199        audio: Option<AudioAssistant>,
200        #[serde(skip_serializing_if = "Option::is_none")]
201        name: Option<String>,
202        #[serde(
203            default,
204            deserialize_with = "json_utils::null_or_vec",
205            skip_serializing_if = "Vec::is_empty"
206        )]
207        tool_calls: Vec<ToolCall>,
208    },
209    #[serde(rename = "tool")]
210    ToolResult {
211        tool_call_id: String,
212        content: OneOrMany<ToolResultContent>,
213    },
214}
215
216impl Message {
217    pub fn system(content: &str) -> Self {
218        Message::System {
219            content: OneOrMany::one(content.to_owned().into()),
220            name: None,
221        }
222    }
223}
224
225#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
226pub struct AudioAssistant {
227    id: String,
228}
229
230#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
231pub struct SystemContent {
232    #[serde(default)]
233    r#type: SystemContentType,
234    text: String,
235}
236
237#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
238#[serde(rename_all = "lowercase")]
239pub enum SystemContentType {
240    #[default]
241    Text,
242}
243
244#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
245#[serde(tag = "type", rename_all = "lowercase")]
246pub enum AssistantContent {
247    Text { text: String },
248    Refusal { refusal: String },
249}
250
251#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
252#[serde(tag = "type", rename_all = "lowercase")]
253pub enum UserContent {
254    Text {
255        text: String,
256    },
257    #[serde(rename = "image_url")]
258    Image {
259        image_url: ImageUrl,
260    },
261    Audio {
262        input_audio: InputAudio,
263    },
264}
265
266#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
267pub struct ImageUrl {
268    pub url: String,
269    #[serde(default)]
270    pub detail: ImageDetail,
271}
272
273#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
274pub struct InputAudio {
275    pub data: String,
276    pub format: AudioMediaType,
277}
278
279#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
280pub struct ToolResultContent {
281    #[serde(default)]
282    r#type: ToolResultContentType,
283    text: String,
284}
285
286#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
287#[serde(rename_all = "lowercase")]
288pub enum ToolResultContentType {
289    #[default]
290    Text,
291}
292
293impl FromStr for ToolResultContent {
294    type Err = Infallible;
295
296    fn from_str(s: &str) -> Result<Self, Self::Err> {
297        Ok(s.to_owned().into())
298    }
299}
300
301impl From<String> for ToolResultContent {
302    fn from(s: String) -> Self {
303        ToolResultContent {
304            r#type: ToolResultContentType::default(),
305            text: s,
306        }
307    }
308}
309
310#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
311pub struct ToolCall {
312    pub id: String,
313    #[serde(default)]
314    pub r#type: ToolType,
315    pub function: Function,
316}
317
318#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
319#[serde(rename_all = "lowercase")]
320pub enum ToolType {
321    #[default]
322    Function,
323}
324
325#[derive(Debug, Deserialize, Serialize, Clone)]
326pub struct ToolDefinition {
327    pub r#type: String,
328    pub function: completion::ToolDefinition,
329}
330
331impl From<completion::ToolDefinition> for ToolDefinition {
332    fn from(tool: completion::ToolDefinition) -> Self {
333        Self {
334            r#type: "function".into(),
335            function: tool,
336        }
337    }
338}
339
340#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
341pub struct Function {
342    pub name: String,
343    #[serde(with = "json_utils::stringified_json")]
344    pub arguments: serde_json::Value,
345}
346
347impl TryFrom<message::Message> for Vec<Message> {
348    type Error = message::MessageError;
349
350    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
351        match message {
352            message::Message::User { content } => {
353                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
354                    .into_iter()
355                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
356
357                // If there are messages with both tool results and user content, openai will only
358                //  handle tool results. It's unlikely that there will be both.
359                if !tool_results.is_empty() {
360                    tool_results
361                        .into_iter()
362                        .map(|content| match content {
363                            message::UserContent::ToolResult(message::ToolResult {
364                                id,
365                                content,
366                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
367                                tool_call_id: id,
368                                content: content.try_map(|content| match content {
369                                    message::ToolResultContent::Text(message::Text { text }) => {
370                                        Ok(text.into())
371                                    }
372                                    _ => Err(message::MessageError::ConversionError(
373                                        "Tool result content does not support non-text".into(),
374                                    )),
375                                })?,
376                            }),
377                            _ => unreachable!(),
378                        })
379                        .collect::<Result<Vec<_>, _>>()
380                } else {
381                    let other_content = OneOrMany::many(other_content).expect(
382                        "There must be other content here if there were no tool result content",
383                    );
384
385                    Ok(vec![Message::User {
386                        content: other_content.map(|content| match content {
387                            message::UserContent::Text(message::Text { text }) => {
388                                UserContent::Text { text }
389                            }
390                            message::UserContent::Image(message::Image {
391                                data, detail, ..
392                            }) => UserContent::Image {
393                                image_url: ImageUrl {
394                                    url: data,
395                                    detail: detail.unwrap_or_default(),
396                                },
397                            },
398                            message::UserContent::Document(message::Document { data, .. }) => {
399                                UserContent::Text { text: data }
400                            }
401                            message::UserContent::Audio(message::Audio {
402                                data,
403                                media_type,
404                                ..
405                            }) => UserContent::Audio {
406                                input_audio: InputAudio {
407                                    data,
408                                    format: match media_type {
409                                        Some(media_type) => media_type,
410                                        None => AudioMediaType::MP3,
411                                    },
412                                },
413                            },
414                            _ => unreachable!(),
415                        }),
416                        name: None,
417                    }])
418                }
419            }
420            message::Message::Assistant { content } => {
421                let (text_content, tool_calls) = content.into_iter().fold(
422                    (Vec::new(), Vec::new()),
423                    |(mut texts, mut tools), content| {
424                        match content {
425                            message::AssistantContent::Text(text) => texts.push(text),
426                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
427                        }
428                        (texts, tools)
429                    },
430                );
431
432                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
433                //  so either `content` or `tool_calls` will have some content.
434                Ok(vec![Message::Assistant {
435                    content: text_content
436                        .into_iter()
437                        .map(|content| content.text.into())
438                        .collect::<Vec<_>>(),
439                    refusal: None,
440                    audio: None,
441                    name: None,
442                    tool_calls: tool_calls
443                        .into_iter()
444                        .map(|tool_call| tool_call.into())
445                        .collect::<Vec<_>>(),
446                }])
447            }
448        }
449    }
450}
451
452impl From<message::ToolCall> for ToolCall {
453    fn from(tool_call: message::ToolCall) -> Self {
454        Self {
455            id: tool_call.id,
456            r#type: ToolType::default(),
457            function: Function {
458                name: tool_call.function.name,
459                arguments: tool_call.function.arguments,
460            },
461        }
462    }
463}
464
465impl From<ToolCall> for message::ToolCall {
466    fn from(tool_call: ToolCall) -> Self {
467        Self {
468            id: tool_call.id,
469            function: message::ToolFunction {
470                name: tool_call.function.name,
471                arguments: tool_call.function.arguments,
472            },
473        }
474    }
475}
476
477impl TryFrom<Message> for message::Message {
478    type Error = message::MessageError;
479
480    fn try_from(message: Message) -> Result<Self, Self::Error> {
481        Ok(match message {
482            Message::User { content, .. } => message::Message::User {
483                content: content.map(|content| content.into()),
484            },
485            Message::Assistant {
486                content,
487                tool_calls,
488                ..
489            } => {
490                let mut content = content
491                    .into_iter()
492                    .map(|content| match content {
493                        AssistantContent::Text { text } => message::AssistantContent::text(text),
494
495                        // TODO: Currently, refusals are converted into text, but should be
496                        //  investigated for generalization.
497                        AssistantContent::Refusal { refusal } => {
498                            message::AssistantContent::text(refusal)
499                        }
500                    })
501                    .collect::<Vec<_>>();
502
503                content.extend(
504                    tool_calls
505                        .into_iter()
506                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
507                        .collect::<Result<Vec<_>, _>>()?,
508                );
509
510                message::Message::Assistant {
511                    content: OneOrMany::many(content).map_err(|_| {
512                        message::MessageError::ConversionError(
513                            "Neither `content` nor `tool_calls` was provided to the Message"
514                                .to_owned(),
515                        )
516                    })?,
517                }
518            }
519
520            Message::ToolResult {
521                tool_call_id,
522                content,
523            } => message::Message::User {
524                content: OneOrMany::one(message::UserContent::tool_result(
525                    tool_call_id,
526                    content.map(|content| message::ToolResultContent::text(content.text)),
527                )),
528            },
529
530            // System messages should get stripped out when converting message's, this is just a
531            // stop gap to avoid obnoxious error handling or panic occuring.
532            Message::System { content, .. } => message::Message::User {
533                content: content.map(|content| message::UserContent::text(content.text)),
534            },
535        })
536    }
537}
538
539impl From<UserContent> for message::UserContent {
540    fn from(content: UserContent) -> Self {
541        match content {
542            UserContent::Text { text } => message::UserContent::text(text),
543            UserContent::Image { image_url } => message::UserContent::image(
544                image_url.url,
545                Some(message::ContentFormat::default()),
546                None,
547                Some(image_url.detail),
548            ),
549            UserContent::Audio { input_audio } => message::UserContent::audio(
550                input_audio.data,
551                Some(message::ContentFormat::default()),
552                Some(input_audio.format),
553            ),
554        }
555    }
556}
557
558impl From<String> for UserContent {
559    fn from(s: String) -> Self {
560        UserContent::Text { text: s }
561    }
562}
563
564impl FromStr for UserContent {
565    type Err = Infallible;
566
567    fn from_str(s: &str) -> Result<Self, Self::Err> {
568        Ok(UserContent::Text {
569            text: s.to_string(),
570        })
571    }
572}
573
574impl From<String> for AssistantContent {
575    fn from(s: String) -> Self {
576        AssistantContent::Text { text: s }
577    }
578}
579
580impl FromStr for AssistantContent {
581    type Err = Infallible;
582
583    fn from_str(s: &str) -> Result<Self, Self::Err> {
584        Ok(AssistantContent::Text {
585            text: s.to_string(),
586        })
587    }
588}
589impl From<String> for SystemContent {
590    fn from(s: String) -> Self {
591        SystemContent {
592            r#type: SystemContentType::default(),
593            text: s,
594        }
595    }
596}
597
598impl FromStr for SystemContent {
599    type Err = Infallible;
600
601    fn from_str(s: &str) -> Result<Self, Self::Err> {
602        Ok(SystemContent {
603            r#type: SystemContentType::default(),
604            text: s.to_string(),
605        })
606    }
607}
608
609#[derive(Clone)]
610pub struct CompletionModel {
611    pub(crate) client: Client,
612    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
613    pub model: String,
614}
615
616impl CompletionModel {
617    pub fn new(client: Client, model: &str) -> Self {
618        Self {
619            client,
620            model: model.to_string(),
621        }
622    }
623
624    pub(crate) fn create_completion_request(
625        &self,
626        completion_request: CompletionRequest,
627    ) -> Result<Value, CompletionError> {
628        // Build up the order of messages (context, chat_history)
629        let mut partial_history = vec![];
630        if let Some(docs) = completion_request.normalized_documents() {
631            partial_history.push(docs);
632        }
633        partial_history.extend(completion_request.chat_history);
634
635        // Initialize full history with preamble (or empty if non-existent)
636        let mut full_history: Vec<Message> = completion_request
637            .preamble
638            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
639
640        // Convert and extend the rest of the history
641        full_history.extend(
642            partial_history
643                .into_iter()
644                .map(message::Message::try_into)
645                .collect::<Result<Vec<Vec<Message>>, _>>()?
646                .into_iter()
647                .flatten()
648                .collect::<Vec<_>>(),
649        );
650
651        let request = if completion_request.tools.is_empty() {
652            json!({
653                "model": self.model,
654                "messages": full_history,
655
656            })
657        } else {
658            json!({
659                "model": self.model,
660                "messages": full_history,
661                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
662                "tool_choice": "auto",
663            })
664        };
665
666        // only include temperature if it exists
667        // because some models don't support temperature
668        let request = if let Some(temperature) = completion_request.temperature {
669            json_utils::merge(
670                request,
671                json!({
672                    "temperature": temperature,
673                }),
674            )
675        } else {
676            request
677        };
678
679        let request = if let Some(params) = completion_request.additional_params {
680            json_utils::merge(request, params)
681        } else {
682            request
683        };
684
685        Ok(request)
686    }
687}
688
689impl completion::CompletionModel for CompletionModel {
690    type Response = CompletionResponse;
691
692    #[cfg_attr(feature = "worker", worker::send)]
693    async fn completion(
694        &self,
695        completion_request: CompletionRequest,
696    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
697        let request = self.create_completion_request(completion_request)?;
698
699        let response = self
700            .client
701            .post("/chat/completions")
702            .json(&request)
703            .send()
704            .await?;
705
706        if response.status().is_success() {
707            let t = response.text().await?;
708            tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
709
710            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
711                ApiResponse::Ok(response) => {
712                    tracing::info!(target: "rig",
713                        "OpenAI completion token usage: {:?}",
714                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
715                    );
716                    response.try_into()
717                }
718                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
719            }
720        } else {
721            Err(CompletionError::ProviderError(response.text().await?))
722        }
723    }
724}