rig/providers/
openai.rs

1//! OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::openai;
6//!
7//! let client = openai::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(openai::GPT_4O);
10//! ```
11use crate::json_utils::merge;
12use crate::streaming::StreamingResult;
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError, CompletionRequest},
16    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17    extractor::ExtractorBuilder,
18    json_utils,
19    message::{self, AudioMediaType, ImageDetail},
20    one_or_many::string_or_one_or_many,
21    streaming,
22    streaming::StreamingCompletionModel,
23    transcription::{self, TranscriptionError},
24    Embed, OneOrMany,
25};
26use async_stream::stream;
27use futures::StreamExt;
28use reqwest::multipart::Part;
29use reqwest::RequestBuilder;
30use schemars::JsonSchema;
31use serde::{Deserialize, Serialize};
32use serde_json::{json, Value};
33use std::collections::HashMap;
34use std::{convert::Infallible, str::FromStr};
35
36// ================================================================
37// Main OpenAI Client
38// ================================================================
39const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
40
41#[derive(Clone)]
42pub struct Client {
43    base_url: String,
44    http_client: reqwest::Client,
45}
46
47impl Client {
48    /// Create a new OpenAI client with the given API key.
49    pub fn new(api_key: &str) -> Self {
50        Self::from_url(api_key, OPENAI_API_BASE_URL)
51    }
52
53    /// Create a new OpenAI client with the given API key and base API URL.
54    pub fn from_url(api_key: &str, base_url: &str) -> Self {
55        Self {
56            base_url: base_url.to_string(),
57            http_client: reqwest::Client::builder()
58                .default_headers({
59                    let mut headers = reqwest::header::HeaderMap::new();
60                    headers.insert(
61                        "Authorization",
62                        format!("Bearer {}", api_key)
63                            .parse()
64                            .expect("Bearer token should parse"),
65                    );
66                    headers
67                })
68                .build()
69                .expect("OpenAI reqwest client should build"),
70        }
71    }
72
73    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
74    /// Panics if the environment variable is not set.
75    pub fn from_env() -> Self {
76        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
77        Self::new(&api_key)
78    }
79
80    fn post(&self, path: &str) -> reqwest::RequestBuilder {
81        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
82        self.http_client.post(url)
83    }
84
85    /// Create an embedding model with the given name.
86    /// Note: default embedding dimension of 0 will be used if model is not known.
87    /// If this is the case, it's better to use function `embedding_model_with_ndims`
88    ///
89    /// # Example
90    /// ```
91    /// use rig::providers::openai::{Client, self};
92    ///
93    /// // Initialize the OpenAI client
94    /// let openai = Client::new("your-open-ai-api-key");
95    ///
96    /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
97    /// ```
98    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
99        let ndims = match model {
100            TEXT_EMBEDDING_3_LARGE => 3072,
101            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
102            _ => 0,
103        };
104        EmbeddingModel::new(self.clone(), model, ndims)
105    }
106
107    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
108    ///
109    /// # Example
110    /// ```
111    /// use rig::providers::openai::{Client, self};
112    ///
113    /// // Initialize the OpenAI client
114    /// let openai = Client::new("your-open-ai-api-key");
115    ///
116    /// let embedding_model = openai.embedding_model("model-unknown-to-rig", 3072);
117    /// ```
118    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
119        EmbeddingModel::new(self.clone(), model, ndims)
120    }
121
122    /// Create an embedding builder with the given embedding model.
123    ///
124    /// # Example
125    /// ```
126    /// use rig::providers::openai::{Client, self};
127    ///
128    /// // Initialize the OpenAI client
129    /// let openai = Client::new("your-open-ai-api-key");
130    ///
131    /// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
132    ///     .simple_document("doc0", "Hello, world!")
133    ///     .simple_document("doc1", "Goodbye, world!")
134    ///     .build()
135    ///     .await
136    ///     .expect("Failed to embed documents");
137    /// ```
138    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
139        EmbeddingsBuilder::new(self.embedding_model(model))
140    }
141
142    /// Create a completion model with the given name.
143    ///
144    /// # Example
145    /// ```
146    /// use rig::providers::openai::{Client, self};
147    ///
148    /// // Initialize the OpenAI client
149    /// let openai = Client::new("your-open-ai-api-key");
150    ///
151    /// let gpt4 = openai.completion_model(openai::GPT_4);
152    /// ```
153    pub fn completion_model(&self, model: &str) -> CompletionModel {
154        CompletionModel::new(self.clone(), model)
155    }
156
157    /// Create an agent builder with the given completion model.
158    ///
159    /// # Example
160    /// ```
161    /// use rig::providers::openai::{Client, self};
162    ///
163    /// // Initialize the OpenAI client
164    /// let openai = Client::new("your-open-ai-api-key");
165    ///
166    /// let agent = openai.agent(openai::GPT_4)
167    ///    .preamble("You are comedian AI with a mission to make people laugh.")
168    ///    .temperature(0.0)
169    ///    .build();
170    /// ```
171    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
172        AgentBuilder::new(self.completion_model(model))
173    }
174
175    /// Create an extractor builder with the given completion model.
176    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
177        &self,
178        model: &str,
179    ) -> ExtractorBuilder<T, CompletionModel> {
180        ExtractorBuilder::new(self.completion_model(model))
181    }
182
183    /// Create a completion model with the given name.
184    ///
185    /// # Example
186    /// ```
187    /// use rig::providers::openai::{Client, self};
188    ///
189    /// // Initialize the OpenAI client
190    /// let openai = Client::new("your-open-ai-api-key");
191    ///
192    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
193    /// ```
194    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
195        TranscriptionModel::new(self.clone(), model)
196    }
197}
198
199#[derive(Debug, Deserialize)]
200struct ApiErrorResponse {
201    message: String,
202}
203
204#[derive(Debug, Deserialize)]
205#[serde(untagged)]
206enum ApiResponse<T> {
207    Ok(T),
208    Err(ApiErrorResponse),
209}
210
211// ================================================================
212// OpenAI Embedding API
213// ================================================================
214/// `text-embedding-3-large` embedding model
215pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
216/// `text-embedding-3-small` embedding model
217pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
218/// `text-embedding-ada-002` embedding model
219pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
220
221#[derive(Debug, Deserialize)]
222pub struct EmbeddingResponse {
223    pub object: String,
224    pub data: Vec<EmbeddingData>,
225    pub model: String,
226    pub usage: Usage,
227}
228
229impl From<ApiErrorResponse> for EmbeddingError {
230    fn from(err: ApiErrorResponse) -> Self {
231        EmbeddingError::ProviderError(err.message)
232    }
233}
234
235impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
236    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
237        match value {
238            ApiResponse::Ok(response) => Ok(response),
239            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
240        }
241    }
242}
243
244#[derive(Debug, Deserialize)]
245pub struct EmbeddingData {
246    pub object: String,
247    pub embedding: Vec<f64>,
248    pub index: usize,
249}
250
251#[derive(Clone, Debug, Deserialize)]
252pub struct Usage {
253    pub prompt_tokens: usize,
254    pub total_tokens: usize,
255}
256
257impl std::fmt::Display for Usage {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        write!(
260            f,
261            "Prompt tokens: {} Total tokens: {}",
262            self.prompt_tokens, self.total_tokens
263        )
264    }
265}
266
267#[derive(Clone)]
268pub struct EmbeddingModel {
269    client: Client,
270    pub model: String,
271    ndims: usize,
272}
273
274impl embeddings::EmbeddingModel for EmbeddingModel {
275    const MAX_DOCUMENTS: usize = 1024;
276
277    fn ndims(&self) -> usize {
278        self.ndims
279    }
280
281    #[cfg_attr(feature = "worker", worker::send)]
282    async fn embed_texts(
283        &self,
284        documents: impl IntoIterator<Item = String>,
285    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
286        let documents = documents.into_iter().collect::<Vec<_>>();
287
288        let response = self
289            .client
290            .post("/embeddings")
291            .json(&json!({
292                "model": self.model,
293                "input": documents,
294            }))
295            .send()
296            .await?;
297
298        if response.status().is_success() {
299            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
300                ApiResponse::Ok(response) => {
301                    tracing::info!(target: "rig",
302                        "OpenAI embedding token usage: {}",
303                        response.usage
304                    );
305
306                    if response.data.len() != documents.len() {
307                        return Err(EmbeddingError::ResponseError(
308                            "Response data length does not match input length".into(),
309                        ));
310                    }
311
312                    Ok(response
313                        .data
314                        .into_iter()
315                        .zip(documents.into_iter())
316                        .map(|(embedding, document)| embeddings::Embedding {
317                            document,
318                            vec: embedding.embedding,
319                        })
320                        .collect())
321                }
322                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
323            }
324        } else {
325            Err(EmbeddingError::ProviderError(response.text().await?))
326        }
327    }
328}
329
330impl EmbeddingModel {
331    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
332        Self {
333            client,
334            model: model.to_string(),
335            ndims,
336        }
337    }
338}
339
340// ================================================================
341// OpenAI Completion API
342// ================================================================
343
344/// `o3-mini` completion model
345pub const O3_MINI: &str = "o3-mini";
346/// `o3-mini-2025-01-31` completion model
347pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
348/// 'o1' completion model
349pub const O1: &str = "o1";
350/// `o1-2024-12-17` completion model
351pub const O1_2024_12_17: &str = "o1-2024-12-17";
352/// `o1-preview` completion model
353pub const O1_PREVIEW: &str = "o1-preview";
354/// `o1-preview-2024-09-12` completion model
355pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
356/// `o1-mini completion model
357pub const O1_MINI: &str = "o1-mini";
358/// `o1-mini-2024-09-12` completion model
359pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
360/// `gpt-4.5-preview` completion model
361pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
362/// `gpt-4.5-preview-2025-02-27` completion model
363pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
364/// `gpt-4o` completion model
365pub const GPT_4O: &str = "gpt-4o";
366/// `gpt-4o-mini` completion model
367pub const GPT_4O_MINI: &str = "gpt-4o-mini";
368/// `gpt-4o-2024-05-13` completion model
369pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
370/// `gpt-4-turbo` completion model
371pub const GPT_4_TURBO: &str = "gpt-4-turbo";
372/// `gpt-4-turbo-2024-04-09` completion model
373pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
374/// `gpt-4-turbo-preview` completion model
375pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
376/// `gpt-4-0125-preview` completion model
377pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
378/// `gpt-4-1106-preview` completion model
379pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
380/// `gpt-4-vision-preview` completion model
381pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
382/// `gpt-4-1106-vision-preview` completion model
383pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
384/// `gpt-4` completion model
385pub const GPT_4: &str = "gpt-4";
386/// `gpt-4-0613` completion model
387pub const GPT_4_0613: &str = "gpt-4-0613";
388/// `gpt-4-32k` completion model
389pub const GPT_4_32K: &str = "gpt-4-32k";
390/// `gpt-4-32k-0613` completion model
391pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
392/// `gpt-3.5-turbo` completion model
393pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
394/// `gpt-3.5-turbo-0125` completion model
395pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
396/// `gpt-3.5-turbo-1106` completion model
397pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
398/// `gpt-3.5-turbo-instruct` completion model
399pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
400
401#[derive(Debug, Deserialize)]
402pub struct CompletionResponse {
403    pub id: String,
404    pub object: String,
405    pub created: u64,
406    pub model: String,
407    pub system_fingerprint: Option<String>,
408    pub choices: Vec<Choice>,
409    pub usage: Option<Usage>,
410}
411
412impl From<ApiErrorResponse> for CompletionError {
413    fn from(err: ApiErrorResponse) -> Self {
414        CompletionError::ProviderError(err.message)
415    }
416}
417
418impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
419    type Error = CompletionError;
420
421    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
422        let choice = response.choices.first().ok_or_else(|| {
423            CompletionError::ResponseError("Response contained no choices".to_owned())
424        })?;
425
426        let content = match &choice.message {
427            Message::Assistant {
428                content,
429                tool_calls,
430                ..
431            } => {
432                let mut content = content
433                    .iter()
434                    .filter_map(|c| {
435                        let s = match c {
436                            AssistantContent::Text { text } => text,
437                            AssistantContent::Refusal { refusal } => refusal,
438                        };
439                        if s.is_empty() {
440                            None
441                        } else {
442                            Some(completion::AssistantContent::text(s))
443                        }
444                    })
445                    .collect::<Vec<_>>();
446
447                content.extend(
448                    tool_calls
449                        .iter()
450                        .map(|call| {
451                            completion::AssistantContent::tool_call(
452                                &call.id,
453                                &call.function.name,
454                                call.function.arguments.clone(),
455                            )
456                        })
457                        .collect::<Vec<_>>(),
458                );
459                Ok(content)
460            }
461            _ => Err(CompletionError::ResponseError(
462                "Response did not contain a valid message or tool call".into(),
463            )),
464        }?;
465
466        let choice = OneOrMany::many(content).map_err(|_| {
467            CompletionError::ResponseError(
468                "Response contained no message or tool call (empty)".to_owned(),
469            )
470        })?;
471
472        Ok(completion::CompletionResponse {
473            choice,
474            raw_response: response,
475        })
476    }
477}
478
479#[derive(Debug, Serialize, Deserialize)]
480pub struct Choice {
481    pub index: usize,
482    pub message: Message,
483    pub logprobs: Option<serde_json::Value>,
484    pub finish_reason: String,
485}
486
487#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
488#[serde(tag = "role", rename_all = "lowercase")]
489pub enum Message {
490    System {
491        #[serde(deserialize_with = "string_or_one_or_many")]
492        content: OneOrMany<SystemContent>,
493        #[serde(skip_serializing_if = "Option::is_none")]
494        name: Option<String>,
495    },
496    User {
497        #[serde(deserialize_with = "string_or_one_or_many")]
498        content: OneOrMany<UserContent>,
499        #[serde(skip_serializing_if = "Option::is_none")]
500        name: Option<String>,
501    },
502    Assistant {
503        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
504        content: Vec<AssistantContent>,
505        #[serde(skip_serializing_if = "Option::is_none")]
506        refusal: Option<String>,
507        #[serde(skip_serializing_if = "Option::is_none")]
508        audio: Option<AudioAssistant>,
509        #[serde(skip_serializing_if = "Option::is_none")]
510        name: Option<String>,
511        #[serde(
512            default,
513            deserialize_with = "json_utils::null_or_vec",
514            skip_serializing_if = "Vec::is_empty"
515        )]
516        tool_calls: Vec<ToolCall>,
517    },
518    #[serde(rename = "tool")]
519    ToolResult {
520        tool_call_id: String,
521        content: OneOrMany<ToolResultContent>,
522    },
523}
524
525impl Message {
526    pub fn system(content: &str) -> Self {
527        Message::System {
528            content: OneOrMany::one(content.to_owned().into()),
529            name: None,
530        }
531    }
532}
533
534#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
535pub struct AudioAssistant {
536    id: String,
537}
538
539#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
540pub struct SystemContent {
541    #[serde(default)]
542    r#type: SystemContentType,
543    text: String,
544}
545
546#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
547#[serde(rename_all = "lowercase")]
548pub enum SystemContentType {
549    #[default]
550    Text,
551}
552
553#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
554#[serde(tag = "type", rename_all = "lowercase")]
555pub enum AssistantContent {
556    Text { text: String },
557    Refusal { refusal: String },
558}
559
560#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
561#[serde(tag = "type", rename_all = "lowercase")]
562pub enum UserContent {
563    Text { text: String },
564    Image { image_url: ImageUrl },
565    Audio { input_audio: InputAudio },
566}
567
568#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
569pub struct ImageUrl {
570    pub url: String,
571    #[serde(default)]
572    pub detail: ImageDetail,
573}
574
575#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
576pub struct InputAudio {
577    pub data: String,
578    pub format: AudioMediaType,
579}
580
581#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
582pub struct ToolResultContent {
583    #[serde(default)]
584    r#type: ToolResultContentType,
585    text: String,
586}
587
588#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
589#[serde(rename_all = "lowercase")]
590pub enum ToolResultContentType {
591    #[default]
592    Text,
593}
594
595impl FromStr for ToolResultContent {
596    type Err = Infallible;
597
598    fn from_str(s: &str) -> Result<Self, Self::Err> {
599        Ok(s.to_owned().into())
600    }
601}
602
603impl From<String> for ToolResultContent {
604    fn from(s: String) -> Self {
605        ToolResultContent {
606            r#type: ToolResultContentType::default(),
607            text: s,
608        }
609    }
610}
611
612#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
613pub struct ToolCall {
614    pub id: String,
615    #[serde(default)]
616    pub r#type: ToolType,
617    pub function: Function,
618}
619
620#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
621#[serde(rename_all = "lowercase")]
622pub enum ToolType {
623    #[default]
624    Function,
625}
626
627#[derive(Debug, Deserialize, Serialize, Clone)]
628pub struct ToolDefinition {
629    pub r#type: String,
630    pub function: completion::ToolDefinition,
631}
632
633impl From<completion::ToolDefinition> for ToolDefinition {
634    fn from(tool: completion::ToolDefinition) -> Self {
635        Self {
636            r#type: "function".into(),
637            function: tool,
638        }
639    }
640}
641
642#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
643pub struct Function {
644    pub name: String,
645    #[serde(with = "json_utils::stringified_json")]
646    pub arguments: serde_json::Value,
647}
648
649impl TryFrom<message::Message> for Vec<Message> {
650    type Error = message::MessageError;
651
652    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
653        match message {
654            message::Message::User { content } => {
655                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
656                    .into_iter()
657                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
658
659                // If there are messages with both tool results and user content, openai will only
660                //  handle tool results. It's unlikely that there will be both.
661                if !tool_results.is_empty() {
662                    tool_results
663                        .into_iter()
664                        .map(|content| match content {
665                            message::UserContent::ToolResult(message::ToolResult {
666                                id,
667                                content,
668                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
669                                tool_call_id: id,
670                                content: content.try_map(|content| match content {
671                                    message::ToolResultContent::Text(message::Text { text }) => {
672                                        Ok(text.into())
673                                    }
674                                    _ => Err(message::MessageError::ConversionError(
675                                        "Tool result content does not support non-text".into(),
676                                    )),
677                                })?,
678                            }),
679                            _ => unreachable!(),
680                        })
681                        .collect::<Result<Vec<_>, _>>()
682                } else {
683                    let other_content = OneOrMany::many(other_content).expect(
684                        "There must be other content here if there were no tool result content",
685                    );
686
687                    Ok(vec![Message::User {
688                        content: other_content.map(|content| match content {
689                            message::UserContent::Text(message::Text { text }) => {
690                                UserContent::Text { text }
691                            }
692                            message::UserContent::Image(message::Image {
693                                data, detail, ..
694                            }) => UserContent::Image {
695                                image_url: ImageUrl {
696                                    url: data,
697                                    detail: detail.unwrap_or_default(),
698                                },
699                            },
700                            message::UserContent::Document(message::Document { data, .. }) => {
701                                UserContent::Text { text: data }
702                            }
703                            message::UserContent::Audio(message::Audio {
704                                data,
705                                media_type,
706                                ..
707                            }) => UserContent::Audio {
708                                input_audio: InputAudio {
709                                    data,
710                                    format: match media_type {
711                                        Some(media_type) => media_type,
712                                        None => AudioMediaType::MP3,
713                                    },
714                                },
715                            },
716                            _ => unreachable!(),
717                        }),
718                        name: None,
719                    }])
720                }
721            }
722            message::Message::Assistant { content } => {
723                let (text_content, tool_calls) = content.into_iter().fold(
724                    (Vec::new(), Vec::new()),
725                    |(mut texts, mut tools), content| {
726                        match content {
727                            message::AssistantContent::Text(text) => texts.push(text),
728                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
729                        }
730                        (texts, tools)
731                    },
732                );
733
734                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
735                //  so either `content` or `tool_calls` will have some content.
736                Ok(vec![Message::Assistant {
737                    content: text_content
738                        .into_iter()
739                        .map(|content| content.text.into())
740                        .collect::<Vec<_>>(),
741                    refusal: None,
742                    audio: None,
743                    name: None,
744                    tool_calls: tool_calls
745                        .into_iter()
746                        .map(|tool_call| tool_call.into())
747                        .collect::<Vec<_>>(),
748                }])
749            }
750        }
751    }
752}
753
754impl From<message::ToolCall> for ToolCall {
755    fn from(tool_call: message::ToolCall) -> Self {
756        Self {
757            id: tool_call.id,
758            r#type: ToolType::default(),
759            function: Function {
760                name: tool_call.function.name,
761                arguments: tool_call.function.arguments,
762            },
763        }
764    }
765}
766
767impl From<ToolCall> for message::ToolCall {
768    fn from(tool_call: ToolCall) -> Self {
769        Self {
770            id: tool_call.id,
771            function: message::ToolFunction {
772                name: tool_call.function.name,
773                arguments: tool_call.function.arguments,
774            },
775        }
776    }
777}
778
779impl TryFrom<Message> for message::Message {
780    type Error = message::MessageError;
781
782    fn try_from(message: Message) -> Result<Self, Self::Error> {
783        Ok(match message {
784            Message::User { content, .. } => message::Message::User {
785                content: content.map(|content| content.into()),
786            },
787            Message::Assistant {
788                content,
789                tool_calls,
790                ..
791            } => {
792                let mut content = content
793                    .into_iter()
794                    .map(|content| match content {
795                        AssistantContent::Text { text } => message::AssistantContent::text(text),
796
797                        // TODO: Currently, refusals are converted into text, but should be
798                        //  investigated for generalization.
799                        AssistantContent::Refusal { refusal } => {
800                            message::AssistantContent::text(refusal)
801                        }
802                    })
803                    .collect::<Vec<_>>();
804
805                content.extend(
806                    tool_calls
807                        .into_iter()
808                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
809                        .collect::<Result<Vec<_>, _>>()?,
810                );
811
812                message::Message::Assistant {
813                    content: OneOrMany::many(content).map_err(|_| {
814                        message::MessageError::ConversionError(
815                            "Neither `content` nor `tool_calls` was provided to the Message"
816                                .to_owned(),
817                        )
818                    })?,
819                }
820            }
821
822            Message::ToolResult {
823                tool_call_id,
824                content,
825            } => message::Message::User {
826                content: OneOrMany::one(message::UserContent::tool_result(
827                    tool_call_id,
828                    content.map(|content| message::ToolResultContent::text(content.text)),
829                )),
830            },
831
832            // System messages should get stripped out when converting message's, this is just a
833            // stop gap to avoid obnoxious error handling or panic occuring.
834            Message::System { content, .. } => message::Message::User {
835                content: content.map(|content| message::UserContent::text(content.text)),
836            },
837        })
838    }
839}
840
841impl From<UserContent> for message::UserContent {
842    fn from(content: UserContent) -> Self {
843        match content {
844            UserContent::Text { text } => message::UserContent::text(text),
845            UserContent::Image { image_url } => message::UserContent::image(
846                image_url.url,
847                Some(message::ContentFormat::default()),
848                None,
849                Some(image_url.detail),
850            ),
851            UserContent::Audio { input_audio } => message::UserContent::audio(
852                input_audio.data,
853                Some(message::ContentFormat::default()),
854                Some(input_audio.format),
855            ),
856        }
857    }
858}
859
860impl From<String> for UserContent {
861    fn from(s: String) -> Self {
862        UserContent::Text { text: s }
863    }
864}
865
866impl FromStr for UserContent {
867    type Err = Infallible;
868
869    fn from_str(s: &str) -> Result<Self, Self::Err> {
870        Ok(UserContent::Text {
871            text: s.to_string(),
872        })
873    }
874}
875
876impl From<String> for AssistantContent {
877    fn from(s: String) -> Self {
878        AssistantContent::Text { text: s }
879    }
880}
881
882impl FromStr for AssistantContent {
883    type Err = Infallible;
884
885    fn from_str(s: &str) -> Result<Self, Self::Err> {
886        Ok(AssistantContent::Text {
887            text: s.to_string(),
888        })
889    }
890}
891impl From<String> for SystemContent {
892    fn from(s: String) -> Self {
893        SystemContent {
894            r#type: SystemContentType::default(),
895            text: s,
896        }
897    }
898}
899
900impl FromStr for SystemContent {
901    type Err = Infallible;
902
903    fn from_str(s: &str) -> Result<Self, Self::Err> {
904        Ok(SystemContent {
905            r#type: SystemContentType::default(),
906            text: s.to_string(),
907        })
908    }
909}
910
911#[derive(Clone)]
912pub struct CompletionModel {
913    client: Client,
914    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
915    pub model: String,
916}
917
918impl CompletionModel {
919    pub fn new(client: Client, model: &str) -> Self {
920        Self {
921            client,
922            model: model.to_string(),
923        }
924    }
925
926    fn create_completion_request(
927        &self,
928        completion_request: CompletionRequest,
929    ) -> Result<Value, CompletionError> {
930        // Add preamble to chat history (if available)
931        let mut full_history: Vec<Message> = match &completion_request.preamble {
932            Some(preamble) => vec![Message::system(preamble)],
933            None => vec![],
934        };
935
936        // Convert prompt to user message
937        let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
938
939        // Convert existing chat history
940        let chat_history: Vec<Message> = completion_request
941            .chat_history
942            .into_iter()
943            .map(|message| message.try_into())
944            .collect::<Result<Vec<Vec<Message>>, _>>()?
945            .into_iter()
946            .flatten()
947            .collect();
948
949        // Combine all messages into a single history
950        full_history.extend(chat_history);
951        full_history.extend(prompt);
952
953        let request = if completion_request.tools.is_empty() {
954            json!({
955                "model": self.model,
956                "messages": full_history,
957
958            })
959        } else {
960            json!({
961                "model": self.model,
962                "messages": full_history,
963                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
964                "tool_choice": "auto",
965            })
966        };
967
968        // only include temperature if it exists
969        // because some models don't support temperature
970        let request = if let Some(temperature) = completion_request.temperature {
971            json_utils::merge(
972                request,
973                json!({
974                    "temperature": temperature,
975                }),
976            )
977        } else {
978            request
979        };
980
981        let request = if let Some(params) = completion_request.additional_params {
982            json_utils::merge(request, params)
983        } else {
984            request
985        };
986
987        Ok(request)
988    }
989}
990
991impl completion::CompletionModel for CompletionModel {
992    type Response = CompletionResponse;
993
994    #[cfg_attr(feature = "worker", worker::send)]
995    async fn completion(
996        &self,
997        completion_request: CompletionRequest,
998    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
999        let request = self.create_completion_request(completion_request)?;
1000
1001        let response = self
1002            .client
1003            .post("/chat/completions")
1004            .json(&request)
1005            .send()
1006            .await?;
1007
1008        if response.status().is_success() {
1009            let t = response.text().await?;
1010            tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
1011
1012            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
1013                ApiResponse::Ok(response) => {
1014                    tracing::info!(target: "rig",
1015                        "OpenAI completion token usage: {:?}",
1016                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
1017                    );
1018                    response.try_into()
1019                }
1020                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1021            }
1022        } else {
1023            Err(CompletionError::ProviderError(response.text().await?))
1024        }
1025    }
1026}
1027
1028// ================================================================
1029// OpenAI Transcription API
1030// ================================================================
1031pub const WHISPER_1: &str = "whisper-1";
1032
1033#[derive(Debug, Deserialize)]
1034pub struct TranscriptionResponse {
1035    pub text: String,
1036}
1037
1038impl TryFrom<TranscriptionResponse>
1039    for transcription::TranscriptionResponse<TranscriptionResponse>
1040{
1041    type Error = TranscriptionError;
1042
1043    fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
1044        Ok(transcription::TranscriptionResponse {
1045            text: value.text.clone(),
1046            response: value,
1047        })
1048    }
1049}
1050
1051#[derive(Clone)]
1052pub struct TranscriptionModel {
1053    client: Client,
1054    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
1055    pub model: String,
1056}
1057
1058impl TranscriptionModel {
1059    pub fn new(client: Client, model: &str) -> Self {
1060        Self {
1061            client,
1062            model: model.to_string(),
1063        }
1064    }
1065}
1066
1067impl transcription::TranscriptionModel for TranscriptionModel {
1068    type Response = TranscriptionResponse;
1069
1070    #[cfg_attr(feature = "worker", worker::send)]
1071    async fn transcription(
1072        &self,
1073        request: transcription::TranscriptionRequest,
1074    ) -> Result<
1075        transcription::TranscriptionResponse<Self::Response>,
1076        transcription::TranscriptionError,
1077    > {
1078        let data = request.data;
1079
1080        let mut body = reqwest::multipart::Form::new()
1081            .text("model", self.model.clone())
1082            .text("language", request.language)
1083            .part(
1084                "file",
1085                Part::bytes(data).file_name(request.filename.clone()),
1086            );
1087
1088        if let Some(prompt) = request.prompt {
1089            body = body.text("prompt", prompt.clone());
1090        }
1091
1092        if let Some(ref temperature) = request.temperature {
1093            body = body.text("temperature", temperature.to_string());
1094        }
1095
1096        if let Some(ref additional_params) = request.additional_params {
1097            for (key, value) in additional_params
1098                .as_object()
1099                .expect("Additional Parameters to OpenAI Transcription should be a map")
1100            {
1101                body = body.text(key.to_owned(), value.to_string());
1102            }
1103        }
1104
1105        let response = self
1106            .client
1107            .post("audio/transcriptions")
1108            .multipart(body)
1109            .send()
1110            .await?;
1111
1112        if response.status().is_success() {
1113            match response
1114                .json::<ApiResponse<TranscriptionResponse>>()
1115                .await?
1116            {
1117                ApiResponse::Ok(response) => response.try_into(),
1118                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
1119                    api_error_response.message,
1120                )),
1121            }
1122        } else {
1123            Err(TranscriptionError::ProviderError(response.text().await?))
1124        }
1125    }
1126}
1127
1128// ================================================================
1129// OpenAI Completion Streaming API
1130// ================================================================
1131#[derive(Debug, Serialize, Deserialize, Clone)]
1132pub struct StreamingFunction {
1133    #[serde(default)]
1134    name: Option<String>,
1135    #[serde(default)]
1136    arguments: String,
1137}
1138
1139#[derive(Debug, Serialize, Deserialize, Clone)]
1140pub struct StreamingToolCall {
1141    pub index: usize,
1142    pub function: StreamingFunction,
1143}
1144
1145#[derive(Deserialize)]
1146struct StreamingDelta {
1147    #[serde(default)]
1148    content: Option<String>,
1149    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
1150    tool_calls: Vec<StreamingToolCall>,
1151}
1152
1153#[derive(Deserialize)]
1154struct StreamingChoice {
1155    delta: StreamingDelta,
1156}
1157
1158#[derive(Deserialize)]
1159struct StreamingCompletionResponse {
1160    choices: Vec<StreamingChoice>,
1161}
1162
1163impl StreamingCompletionModel for CompletionModel {
1164    async fn stream(
1165        &self,
1166        completion_request: CompletionRequest,
1167    ) -> Result<StreamingResult, CompletionError> {
1168        let mut request = self.create_completion_request(completion_request)?;
1169        request = merge(request, json!({"stream": true}));
1170
1171        let builder = self.client.post("/chat/completions").json(&request);
1172        send_compatible_streaming_request(builder).await
1173    }
1174}
1175
1176pub async fn send_compatible_streaming_request(
1177    request_builder: RequestBuilder,
1178) -> Result<StreamingResult, CompletionError> {
1179    let response = request_builder.send().await?;
1180
1181    if !response.status().is_success() {
1182        return Err(CompletionError::ProviderError(format!(
1183            "{}: {}",
1184            response.status(),
1185            response.text().await?
1186        )));
1187    }
1188
1189    // Handle OpenAI Compatible SSE chunks
1190    Ok(Box::pin(stream! {
1191        let mut stream = response.bytes_stream();
1192
1193        let mut partial_data = None;
1194        let mut calls: HashMap<usize, (String, String)> = HashMap::new();
1195
1196        while let Some(chunk_result) = stream.next().await {
1197            let chunk = match chunk_result {
1198                Ok(c) => c,
1199                Err(e) => {
1200                    yield Err(CompletionError::from(e));
1201                    break;
1202                }
1203            };
1204
1205            let text = match String::from_utf8(chunk.to_vec()) {
1206                Ok(t) => t,
1207                Err(e) => {
1208                    yield Err(CompletionError::ResponseError(e.to_string()));
1209                    break;
1210                }
1211            };
1212
1213
1214            for line in text.lines() {
1215                let mut line = line.to_string();
1216
1217
1218
1219                // If there was a remaining part, concat with current line
1220                if partial_data.is_some() {
1221                    line = format!("{}{}", partial_data.unwrap(), line);
1222                    partial_data = None;
1223                }
1224                // Otherwise full data line
1225                else {
1226                    let Some(data) = line.strip_prefix("data: ") else {
1227                        continue;
1228                    };
1229
1230                    // Partial data, split somewhere in the middle
1231                    if !line.ends_with("}") {
1232                        partial_data = Some(data.to_string());
1233                    } else {
1234                        line = data.to_string();
1235                    }
1236                }
1237
1238                let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
1239
1240                let Ok(data) = data else {
1241                    continue;
1242                };
1243
1244                let choice = data.choices.first().expect("Should have at least one choice");
1245
1246                let delta = &choice.delta;
1247
1248                if !delta.tool_calls.is_empty() {
1249                    for tool_call in &delta.tool_calls {
1250                        let function = tool_call.function.clone();
1251
1252                        // Start of tool call
1253                        // name: Some(String)
1254                        // arguments: None
1255                        if function.name.is_some() && function.arguments.is_empty() {
1256                            calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
1257                        }
1258                        // Part of tool call
1259                        // name: None
1260                        // arguments: Some(String)
1261                        else if function.name.is_none() && !function.arguments.is_empty() {
1262                            let Some((name, arguments)) = calls.get(&tool_call.index) else {
1263                                continue;
1264                            };
1265
1266                            let new_arguments = &tool_call.function.arguments;
1267                            let arguments = format!("{}{}", arguments, new_arguments);
1268
1269                            calls.insert(tool_call.index, (name.clone(), arguments));
1270                        }
1271                        // Entire tool call
1272                        else {
1273                            let name = function.name.unwrap();
1274                            let arguments = function.arguments;
1275                            let Ok(arguments) = serde_json::from_str(&arguments) else {
1276                                continue;
1277                            };
1278
1279                            yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
1280                        }
1281                    }
1282                }
1283
1284                if let Some(content) = &choice.delta.content {
1285                    yield Ok(streaming::StreamingChoice::Message(content.clone()))
1286                }
1287            }
1288        }
1289
1290        for (_, (name, arguments)) in calls {
1291            let Ok(arguments) = serde_json::from_str(&arguments) else {
1292                continue;
1293            };
1294
1295            yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
1296        }
1297    }))
1298}
1299
1300#[cfg(test)]
1301mod tests {
1302    use super::*;
1303    use serde_path_to_error::deserialize;
1304
1305    #[test]
1306    fn test_deserialize_message() {
1307        let assistant_message_json = r#"
1308        {
1309            "role": "assistant",
1310            "content": "\n\nHello there, how may I assist you today?"
1311        }
1312        "#;
1313
1314        let assistant_message_json2 = r#"
1315        {
1316            "role": "assistant",
1317            "content": [
1318                {
1319                    "type": "text",
1320                    "text": "\n\nHello there, how may I assist you today?"
1321                }
1322            ],
1323            "tool_calls": null
1324        }
1325        "#;
1326
1327        let assistant_message_json3 = r#"
1328        {
1329            "role": "assistant",
1330            "tool_calls": [
1331                {
1332                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
1333                    "type": "function",
1334                    "function": {
1335                        "name": "subtract",
1336                        "arguments": "{\"x\": 2, \"y\": 5}"
1337                    }
1338                }
1339            ],
1340            "content": null,
1341            "refusal": null
1342        }
1343        "#;
1344
1345        let user_message_json = r#"
1346        {
1347            "role": "user",
1348            "content": [
1349                {
1350                    "type": "text",
1351                    "text": "What's in this image?"
1352                },
1353                {
1354                    "type": "image",
1355                    "image_url": {
1356                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
1357                    }
1358                },
1359                {
1360                    "type": "audio",
1361                    "input_audio": {
1362                        "data": "...",
1363                        "format": "mp3"
1364                    }
1365                }
1366            ]
1367        }
1368        "#;
1369
1370        let assistant_message: Message = {
1371            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1372            deserialize(jd).unwrap_or_else(|err| {
1373                panic!(
1374                    "Deserialization error at {} ({}:{}): {}",
1375                    err.path(),
1376                    err.inner().line(),
1377                    err.inner().column(),
1378                    err
1379                );
1380            })
1381        };
1382
1383        let assistant_message2: Message = {
1384            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1385            deserialize(jd).unwrap_or_else(|err| {
1386                panic!(
1387                    "Deserialization error at {} ({}:{}): {}",
1388                    err.path(),
1389                    err.inner().line(),
1390                    err.inner().column(),
1391                    err
1392                );
1393            })
1394        };
1395
1396        let assistant_message3: Message = {
1397            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
1398                &mut serde_json::Deserializer::from_str(assistant_message_json3);
1399            deserialize(jd).unwrap_or_else(|err| {
1400                panic!(
1401                    "Deserialization error at {} ({}:{}): {}",
1402                    err.path(),
1403                    err.inner().line(),
1404                    err.inner().column(),
1405                    err
1406                );
1407            })
1408        };
1409
1410        let user_message: Message = {
1411            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1412            deserialize(jd).unwrap_or_else(|err| {
1413                panic!(
1414                    "Deserialization error at {} ({}:{}): {}",
1415                    err.path(),
1416                    err.inner().line(),
1417                    err.inner().column(),
1418                    err
1419                );
1420            })
1421        };
1422
1423        match assistant_message {
1424            Message::Assistant { content, .. } => {
1425                assert_eq!(
1426                    content[0],
1427                    AssistantContent::Text {
1428                        text: "\n\nHello there, how may I assist you today?".to_string()
1429                    }
1430                );
1431            }
1432            _ => panic!("Expected assistant message"),
1433        }
1434
1435        match assistant_message2 {
1436            Message::Assistant {
1437                content,
1438                tool_calls,
1439                ..
1440            } => {
1441                assert_eq!(
1442                    content[0],
1443                    AssistantContent::Text {
1444                        text: "\n\nHello there, how may I assist you today?".to_string()
1445                    }
1446                );
1447
1448                assert_eq!(tool_calls, vec![]);
1449            }
1450            _ => panic!("Expected assistant message"),
1451        }
1452
1453        match assistant_message3 {
1454            Message::Assistant {
1455                content,
1456                tool_calls,
1457                refusal,
1458                ..
1459            } => {
1460                assert!(content.is_empty());
1461                assert!(refusal.is_none());
1462                assert_eq!(
1463                    tool_calls[0],
1464                    ToolCall {
1465                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1466                        r#type: ToolType::Function,
1467                        function: Function {
1468                            name: "subtract".to_string(),
1469                            arguments: serde_json::json!({"x": 2, "y": 5}),
1470                        },
1471                    }
1472                );
1473            }
1474            _ => panic!("Expected assistant message"),
1475        }
1476
1477        match user_message {
1478            Message::User { content, .. } => {
1479                let (first, second) = {
1480                    let mut iter = content.into_iter();
1481                    (iter.next().unwrap(), iter.next().unwrap())
1482                };
1483                assert_eq!(
1484                    first,
1485                    UserContent::Text {
1486                        text: "What's in this image?".to_string()
1487                    }
1488                );
1489                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
1490            }
1491            _ => panic!("Expected user message"),
1492        }
1493    }
1494
1495    #[test]
1496    fn test_message_to_message_conversion() {
1497        let user_message = message::Message::User {
1498            content: OneOrMany::one(message::UserContent::text("Hello")),
1499        };
1500
1501        let assistant_message = message::Message::Assistant {
1502            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1503        };
1504
1505        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1506        let converted_assistant_message: Vec<Message> =
1507            assistant_message.clone().try_into().unwrap();
1508
1509        match converted_user_message[0].clone() {
1510            Message::User { content, .. } => {
1511                assert_eq!(
1512                    content.first(),
1513                    UserContent::Text {
1514                        text: "Hello".to_string()
1515                    }
1516                );
1517            }
1518            _ => panic!("Expected user message"),
1519        }
1520
1521        match converted_assistant_message[0].clone() {
1522            Message::Assistant { content, .. } => {
1523                assert_eq!(
1524                    content[0].clone(),
1525                    AssistantContent::Text {
1526                        text: "Hi there!".to_string()
1527                    }
1528                );
1529            }
1530            _ => panic!("Expected assistant message"),
1531        }
1532
1533        let original_user_message: message::Message =
1534            converted_user_message[0].clone().try_into().unwrap();
1535        let original_assistant_message: message::Message =
1536            converted_assistant_message[0].clone().try_into().unwrap();
1537
1538        assert_eq!(original_user_message, user_message);
1539        assert_eq!(original_assistant_message, assistant_message);
1540    }
1541
1542    #[test]
1543    fn test_message_from_message_conversion() {
1544        let user_message = Message::User {
1545            content: OneOrMany::one(UserContent::Text {
1546                text: "Hello".to_string(),
1547            }),
1548            name: None,
1549        };
1550
1551        let assistant_message = Message::Assistant {
1552            content: vec![AssistantContent::Text {
1553                text: "Hi there!".to_string(),
1554            }],
1555            refusal: None,
1556            audio: None,
1557            name: None,
1558            tool_calls: vec![],
1559        };
1560
1561        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1562        let converted_assistant_message: message::Message =
1563            assistant_message.clone().try_into().unwrap();
1564
1565        match converted_user_message.clone() {
1566            message::Message::User { content } => {
1567                assert_eq!(content.first(), message::UserContent::text("Hello"));
1568            }
1569            _ => panic!("Expected user message"),
1570        }
1571
1572        match converted_assistant_message.clone() {
1573            message::Message::Assistant { content } => {
1574                assert_eq!(
1575                    content.first(),
1576                    message::AssistantContent::text("Hi there!")
1577                );
1578            }
1579            _ => panic!("Expected assistant message"),
1580        }
1581
1582        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1583        let original_assistant_message: Vec<Message> =
1584            converted_assistant_message.try_into().unwrap();
1585
1586        assert_eq!(original_user_message[0], user_message);
1587        assert_eq!(original_assistant_message[0], assistant_message);
1588    }
1589}