rig/providers/
openai.rs

1//! OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::openai;
6//!
7//! let client = openai::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(openai::GPT_4O);
10//! ```
11use std::{convert::Infallible, str::FromStr};
12
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError, CompletionRequest},
16    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17    extractor::ExtractorBuilder,
18    json_utils,
19    message::{self, AudioMediaType, ImageDetail},
20    one_or_many::string_or_one_or_many,
21    transcription::{self, TranscriptionError},
22    Embed, OneOrMany,
23};
24use reqwest::multipart::Part;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::json;
28
29// ================================================================
30// Main OpenAI Client
31// ================================================================
32const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
33
34#[derive(Clone)]
35pub struct Client {
36    base_url: String,
37    http_client: reqwest::Client,
38}
39
40impl Client {
41    /// Create a new OpenAI client with the given API key.
42    pub fn new(api_key: &str) -> Self {
43        Self::from_url(api_key, OPENAI_API_BASE_URL)
44    }
45
46    /// Create a new OpenAI client with the given API key and base API URL.
47    pub fn from_url(api_key: &str, base_url: &str) -> Self {
48        Self {
49            base_url: base_url.to_string(),
50            http_client: reqwest::Client::builder()
51                .default_headers({
52                    let mut headers = reqwest::header::HeaderMap::new();
53                    headers.insert(
54                        "Authorization",
55                        format!("Bearer {}", api_key)
56                            .parse()
57                            .expect("Bearer token should parse"),
58                    );
59                    headers
60                })
61                .build()
62                .expect("OpenAI reqwest client should build"),
63        }
64    }
65
66    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
67    /// Panics if the environment variable is not set.
68    pub fn from_env() -> Self {
69        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
70        Self::new(&api_key)
71    }
72
73    fn post(&self, path: &str) -> reqwest::RequestBuilder {
74        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
75        self.http_client.post(url)
76    }
77
78    /// Create an embedding model with the given name.
79    /// Note: default embedding dimension of 0 will be used if model is not known.
80    /// If this is the case, it's better to use function `embedding_model_with_ndims`
81    ///
82    /// # Example
83    /// ```
84    /// use rig::providers::openai::{Client, self};
85    ///
86    /// // Initialize the OpenAI client
87    /// let openai = Client::new("your-open-ai-api-key");
88    ///
89    /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
90    /// ```
91    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
92        let ndims = match model {
93            TEXT_EMBEDDING_3_LARGE => 3072,
94            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
95            _ => 0,
96        };
97        EmbeddingModel::new(self.clone(), model, ndims)
98    }
99
100    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
101    ///
102    /// # Example
103    /// ```
104    /// use rig::providers::openai::{Client, self};
105    ///
106    /// // Initialize the OpenAI client
107    /// let openai = Client::new("your-open-ai-api-key");
108    ///
109    /// let embedding_model = openai.embedding_model("model-unknown-to-rig", 3072);
110    /// ```
111    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
112        EmbeddingModel::new(self.clone(), model, ndims)
113    }
114
115    /// Create an embedding builder with the given embedding model.
116    ///
117    /// # Example
118    /// ```
119    /// use rig::providers::openai::{Client, self};
120    ///
121    /// // Initialize the OpenAI client
122    /// let openai = Client::new("your-open-ai-api-key");
123    ///
124    /// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
125    ///     .simple_document("doc0", "Hello, world!")
126    ///     .simple_document("doc1", "Goodbye, world!")
127    ///     .build()
128    ///     .await
129    ///     .expect("Failed to embed documents");
130    /// ```
131    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
132        EmbeddingsBuilder::new(self.embedding_model(model))
133    }
134
135    /// Create a completion model with the given name.
136    ///
137    /// # Example
138    /// ```
139    /// use rig::providers::openai::{Client, self};
140    ///
141    /// // Initialize the OpenAI client
142    /// let openai = Client::new("your-open-ai-api-key");
143    ///
144    /// let gpt4 = openai.completion_model(openai::GPT_4);
145    /// ```
146    pub fn completion_model(&self, model: &str) -> CompletionModel {
147        CompletionModel::new(self.clone(), model)
148    }
149
150    /// Create an agent builder with the given completion model.
151    ///
152    /// # Example
153    /// ```
154    /// use rig::providers::openai::{Client, self};
155    ///
156    /// // Initialize the OpenAI client
157    /// let openai = Client::new("your-open-ai-api-key");
158    ///
159    /// let agent = openai.agent(openai::GPT_4)
160    ///    .preamble("You are comedian AI with a mission to make people laugh.")
161    ///    .temperature(0.0)
162    ///    .build();
163    /// ```
164    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
165        AgentBuilder::new(self.completion_model(model))
166    }
167
168    /// Create an extractor builder with the given completion model.
169    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
170        &self,
171        model: &str,
172    ) -> ExtractorBuilder<T, CompletionModel> {
173        ExtractorBuilder::new(self.completion_model(model))
174    }
175
176    /// Create a completion model with the given name.
177    ///
178    /// # Example
179    /// ```
180    /// use rig::providers::openai::{Client, self};
181    ///
182    /// // Initialize the OpenAI client
183    /// let openai = Client::new("your-open-ai-api-key");
184    ///
185    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
186    /// ```
187    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
188        TranscriptionModel::new(self.clone(), model)
189    }
190}
191
192#[derive(Debug, Deserialize)]
193struct ApiErrorResponse {
194    message: String,
195}
196
197#[derive(Debug, Deserialize)]
198#[serde(untagged)]
199enum ApiResponse<T> {
200    Ok(T),
201    Err(ApiErrorResponse),
202}
203
204// ================================================================
205// OpenAI Embedding API
206// ================================================================
207/// `text-embedding-3-large` embedding model
208pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
209/// `text-embedding-3-small` embedding model
210pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
211/// `text-embedding-ada-002` embedding model
212pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
213
214#[derive(Debug, Deserialize)]
215pub struct EmbeddingResponse {
216    pub object: String,
217    pub data: Vec<EmbeddingData>,
218    pub model: String,
219    pub usage: Usage,
220}
221
222impl From<ApiErrorResponse> for EmbeddingError {
223    fn from(err: ApiErrorResponse) -> Self {
224        EmbeddingError::ProviderError(err.message)
225    }
226}
227
228impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
229    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
230        match value {
231            ApiResponse::Ok(response) => Ok(response),
232            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
233        }
234    }
235}
236
237#[derive(Debug, Deserialize)]
238pub struct EmbeddingData {
239    pub object: String,
240    pub embedding: Vec<f64>,
241    pub index: usize,
242}
243
244#[derive(Clone, Debug, Deserialize)]
245pub struct Usage {
246    pub prompt_tokens: usize,
247    pub total_tokens: usize,
248}
249
250impl std::fmt::Display for Usage {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        write!(
253            f,
254            "Prompt tokens: {} Total tokens: {}",
255            self.prompt_tokens, self.total_tokens
256        )
257    }
258}
259
260#[derive(Clone)]
261pub struct EmbeddingModel {
262    client: Client,
263    pub model: String,
264    ndims: usize,
265}
266
267impl embeddings::EmbeddingModel for EmbeddingModel {
268    const MAX_DOCUMENTS: usize = 1024;
269
270    fn ndims(&self) -> usize {
271        self.ndims
272    }
273
274    #[cfg_attr(feature = "worker", worker::send)]
275    async fn embed_texts(
276        &self,
277        documents: impl IntoIterator<Item = String>,
278    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
279        let documents = documents.into_iter().collect::<Vec<_>>();
280
281        let response = self
282            .client
283            .post("/embeddings")
284            .json(&json!({
285                "model": self.model,
286                "input": documents,
287            }))
288            .send()
289            .await?;
290
291        if response.status().is_success() {
292            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
293                ApiResponse::Ok(response) => {
294                    tracing::info!(target: "rig",
295                        "OpenAI embedding token usage: {}",
296                        response.usage
297                    );
298
299                    if response.data.len() != documents.len() {
300                        return Err(EmbeddingError::ResponseError(
301                            "Response data length does not match input length".into(),
302                        ));
303                    }
304
305                    Ok(response
306                        .data
307                        .into_iter()
308                        .zip(documents.into_iter())
309                        .map(|(embedding, document)| embeddings::Embedding {
310                            document,
311                            vec: embedding.embedding,
312                        })
313                        .collect())
314                }
315                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
316            }
317        } else {
318            Err(EmbeddingError::ProviderError(response.text().await?))
319        }
320    }
321}
322
323impl EmbeddingModel {
324    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
325        Self {
326            client,
327            model: model.to_string(),
328            ndims,
329        }
330    }
331}
332
333// ================================================================
334// OpenAI Completion API
335// ================================================================
336/// `o3-mini` completion model
337pub const O3_MINI: &str = "o3-mini";
338/// `o3-mini-2025-01-31` completion model
339pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
340/// 'o1' completion model
341pub const O1: &str = "o1";
342/// `o1-2024-12-17` completion model
343pub const O1_2024_12_17: &str = "o1-2024-12-17";
344/// `o1-preview` completion model
345pub const O1_PREVIEW: &str = "o1-preview";
346/// `o1-preview-2024-09-12` completion model
347pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
348/// `o1-mini completion model
349pub const O1_MINI: &str = "o1-mini";
350/// `o1-mini-2024-09-12` completion model
351pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
352/// `gpt-4o` completion model
353pub const GPT_4O: &str = "gpt-4o";
354/// `gpt-4o-mini` completion model
355pub const GPT_4O_MINI: &str = "gpt-4o-mini";
356/// `gpt-4o-2024-05-13` completion model
357pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
358/// `gpt-4-turbo` completion model
359pub const GPT_4_TURBO: &str = "gpt-4-turbo";
360/// `gpt-4-turbo-2024-04-09` completion model
361pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
362/// `gpt-4-turbo-preview` completion model
363pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
364/// `gpt-4-0125-preview` completion model
365pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
366/// `gpt-4-1106-preview` completion model
367pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
368/// `gpt-4-vision-preview` completion model
369pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
370/// `gpt-4-1106-vision-preview` completion model
371pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
372/// `gpt-4` completion model
373pub const GPT_4: &str = "gpt-4";
374/// `gpt-4-0613` completion model
375pub const GPT_4_0613: &str = "gpt-4-0613";
376/// `gpt-4-32k` completion model
377pub const GPT_4_32K: &str = "gpt-4-32k";
378/// `gpt-4-32k-0613` completion model
379pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
380/// `gpt-3.5-turbo` completion model
381pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
382/// `gpt-3.5-turbo-0125` completion model
383pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
384/// `gpt-3.5-turbo-1106` completion model
385pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
386/// `gpt-3.5-turbo-instruct` completion model
387pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
388
389#[derive(Debug, Deserialize)]
390pub struct CompletionResponse {
391    pub id: String,
392    pub object: String,
393    pub created: u64,
394    pub model: String,
395    pub system_fingerprint: Option<String>,
396    pub choices: Vec<Choice>,
397    pub usage: Option<Usage>,
398}
399
400impl From<ApiErrorResponse> for CompletionError {
401    fn from(err: ApiErrorResponse) -> Self {
402        CompletionError::ProviderError(err.message)
403    }
404}
405
406impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
407    type Error = CompletionError;
408
409    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
410        let choice = response.choices.first().ok_or_else(|| {
411            CompletionError::ResponseError("Response contained no choices".to_owned())
412        })?;
413
414        let content = match &choice.message {
415            Message::Assistant {
416                content,
417                tool_calls,
418                ..
419            } => {
420                let mut content = content
421                    .iter()
422                    .map(|c| match c {
423                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
424                        AssistantContent::Refusal { refusal } => {
425                            completion::AssistantContent::text(refusal)
426                        }
427                    })
428                    .collect::<Vec<_>>();
429
430                content.extend(
431                    tool_calls
432                        .iter()
433                        .map(|call| {
434                            completion::AssistantContent::tool_call(
435                                &call.id,
436                                &call.function.name,
437                                call.function.arguments.clone(),
438                            )
439                        })
440                        .collect::<Vec<_>>(),
441                );
442                Ok(content)
443            }
444            _ => Err(CompletionError::ResponseError(
445                "Response did not contain a valid message or tool call".into(),
446            )),
447        }?;
448
449        let choice = OneOrMany::many(content).map_err(|_| {
450            CompletionError::ResponseError(
451                "Response contained no message or tool call (empty)".to_owned(),
452            )
453        })?;
454
455        Ok(completion::CompletionResponse {
456            choice,
457            raw_response: response,
458        })
459    }
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463pub struct Choice {
464    pub index: usize,
465    pub message: Message,
466    pub logprobs: Option<serde_json::Value>,
467    pub finish_reason: String,
468}
469
470#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
471#[serde(tag = "role", rename_all = "lowercase")]
472pub enum Message {
473    System {
474        #[serde(deserialize_with = "string_or_one_or_many")]
475        content: OneOrMany<SystemContent>,
476        #[serde(skip_serializing_if = "Option::is_none")]
477        name: Option<String>,
478    },
479    User {
480        #[serde(deserialize_with = "string_or_one_or_many")]
481        content: OneOrMany<UserContent>,
482        #[serde(skip_serializing_if = "Option::is_none")]
483        name: Option<String>,
484    },
485    Assistant {
486        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
487        content: Vec<AssistantContent>,
488        #[serde(skip_serializing_if = "Option::is_none")]
489        refusal: Option<String>,
490        #[serde(skip_serializing_if = "Option::is_none")]
491        audio: Option<AudioAssistant>,
492        #[serde(skip_serializing_if = "Option::is_none")]
493        name: Option<String>,
494        #[serde(
495            default,
496            deserialize_with = "json_utils::null_or_vec",
497            skip_serializing_if = "Vec::is_empty"
498        )]
499        tool_calls: Vec<ToolCall>,
500    },
501    #[serde(rename = "tool")]
502    ToolResult {
503        tool_call_id: String,
504        content: OneOrMany<ToolResultContent>,
505    },
506}
507
508impl Message {
509    pub fn system(content: &str) -> Self {
510        Message::System {
511            content: OneOrMany::one(content.to_owned().into()),
512            name: None,
513        }
514    }
515}
516
517#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
518pub struct AudioAssistant {
519    id: String,
520}
521
522#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
523pub struct SystemContent {
524    #[serde(default)]
525    r#type: SystemContentType,
526    text: String,
527}
528
529#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
530#[serde(rename_all = "lowercase")]
531pub enum SystemContentType {
532    #[default]
533    Text,
534}
535
536#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
537#[serde(tag = "type", rename_all = "lowercase")]
538pub enum AssistantContent {
539    Text { text: String },
540    Refusal { refusal: String },
541}
542
543#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
544#[serde(tag = "type", rename_all = "lowercase")]
545pub enum UserContent {
546    Text { text: String },
547    Image { image_url: ImageUrl },
548    Audio { input_audio: InputAudio },
549}
550
551#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
552pub struct ImageUrl {
553    pub url: String,
554    #[serde(default)]
555    pub detail: ImageDetail,
556}
557
558#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
559pub struct InputAudio {
560    pub data: String,
561    pub format: AudioMediaType,
562}
563
564#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
565pub struct ToolResultContent {
566    #[serde(default)]
567    r#type: ToolResultContentType,
568    text: String,
569}
570
571#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
572#[serde(rename_all = "lowercase")]
573pub enum ToolResultContentType {
574    #[default]
575    Text,
576}
577
578impl FromStr for ToolResultContent {
579    type Err = Infallible;
580
581    fn from_str(s: &str) -> Result<Self, Self::Err> {
582        Ok(s.to_owned().into())
583    }
584}
585
586impl From<String> for ToolResultContent {
587    fn from(s: String) -> Self {
588        ToolResultContent {
589            r#type: ToolResultContentType::default(),
590            text: s,
591        }
592    }
593}
594
595#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
596pub struct ToolCall {
597    pub id: String,
598    #[serde(default)]
599    pub r#type: ToolType,
600    pub function: Function,
601}
602
603#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
604#[serde(rename_all = "lowercase")]
605pub enum ToolType {
606    #[default]
607    Function,
608}
609
610#[derive(Debug, Deserialize, Serialize, Clone)]
611pub struct ToolDefinition {
612    pub r#type: String,
613    pub function: completion::ToolDefinition,
614}
615
616impl From<completion::ToolDefinition> for ToolDefinition {
617    fn from(tool: completion::ToolDefinition) -> Self {
618        Self {
619            r#type: "function".into(),
620            function: tool,
621        }
622    }
623}
624
625#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
626pub struct Function {
627    pub name: String,
628    #[serde(with = "json_utils::stringified_json")]
629    pub arguments: serde_json::Value,
630}
631
632impl TryFrom<message::Message> for Vec<Message> {
633    type Error = message::MessageError;
634
635    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
636        match message {
637            message::Message::User { content } => {
638                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
639                    .into_iter()
640                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
641
642                // If there are messages with both tool results and user content, openai will only
643                //  handle tool results. It's unlikely that there will be both.
644                if !tool_results.is_empty() {
645                    tool_results
646                        .into_iter()
647                        .map(|content| match content {
648                            message::UserContent::ToolResult(message::ToolResult {
649                                id,
650                                content,
651                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
652                                tool_call_id: id,
653                                content: content.try_map(|content| match content {
654                                    message::ToolResultContent::Text(message::Text { text }) => {
655                                        Ok(text.into())
656                                    }
657                                    _ => Err(message::MessageError::ConversionError(
658                                        "Tool result content does not support non-text".into(),
659                                    )),
660                                })?,
661                            }),
662                            _ => unreachable!(),
663                        })
664                        .collect::<Result<Vec<_>, _>>()
665                } else {
666                    let other_content = OneOrMany::many(other_content).expect(
667                        "There must be other content here if there were no tool result content",
668                    );
669
670                    Ok(vec![Message::User {
671                        content: other_content.map(|content| match content {
672                            message::UserContent::Text(message::Text { text }) => {
673                                UserContent::Text { text }
674                            }
675                            message::UserContent::Image(message::Image {
676                                data, detail, ..
677                            }) => UserContent::Image {
678                                image_url: ImageUrl {
679                                    url: data,
680                                    detail: detail.unwrap_or_default(),
681                                },
682                            },
683                            message::UserContent::Document(message::Document { data, .. }) => {
684                                UserContent::Text { text: data }
685                            }
686                            message::UserContent::Audio(message::Audio {
687                                data,
688                                media_type,
689                                ..
690                            }) => UserContent::Audio {
691                                input_audio: InputAudio {
692                                    data,
693                                    format: match media_type {
694                                        Some(media_type) => media_type,
695                                        None => AudioMediaType::MP3,
696                                    },
697                                },
698                            },
699                            _ => unreachable!(),
700                        }),
701                        name: None,
702                    }])
703                }
704            }
705            message::Message::Assistant { content } => {
706                let (text_content, tool_calls) = content.into_iter().fold(
707                    (Vec::new(), Vec::new()),
708                    |(mut texts, mut tools), content| {
709                        match content {
710                            message::AssistantContent::Text(text) => texts.push(text),
711                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
712                        }
713                        (texts, tools)
714                    },
715                );
716
717                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
718                //  so either `content` or `tool_calls` will have some content.
719                Ok(vec![Message::Assistant {
720                    content: text_content
721                        .into_iter()
722                        .map(|content| content.text.into())
723                        .collect::<Vec<_>>(),
724                    refusal: None,
725                    audio: None,
726                    name: None,
727                    tool_calls: tool_calls
728                        .into_iter()
729                        .map(|tool_call| tool_call.into())
730                        .collect::<Vec<_>>(),
731                }])
732            }
733        }
734    }
735}
736
737impl From<message::ToolCall> for ToolCall {
738    fn from(tool_call: message::ToolCall) -> Self {
739        Self {
740            id: tool_call.id,
741            r#type: ToolType::default(),
742            function: Function {
743                name: tool_call.function.name,
744                arguments: tool_call.function.arguments,
745            },
746        }
747    }
748}
749
750impl From<ToolCall> for message::ToolCall {
751    fn from(tool_call: ToolCall) -> Self {
752        Self {
753            id: tool_call.id,
754            function: message::ToolFunction {
755                name: tool_call.function.name,
756                arguments: tool_call.function.arguments,
757            },
758        }
759    }
760}
761
762impl TryFrom<Message> for message::Message {
763    type Error = message::MessageError;
764
765    fn try_from(message: Message) -> Result<Self, Self::Error> {
766        Ok(match message {
767            Message::User { content, .. } => message::Message::User {
768                content: content.map(|content| content.into()),
769            },
770            Message::Assistant {
771                content,
772                tool_calls,
773                ..
774            } => {
775                let mut content = content
776                    .into_iter()
777                    .map(|content| match content {
778                        AssistantContent::Text { text } => message::AssistantContent::text(text),
779
780                        // TODO: Currently, refusals are converted into text, but should be
781                        //  investigated for generalization.
782                        AssistantContent::Refusal { refusal } => {
783                            message::AssistantContent::text(refusal)
784                        }
785                    })
786                    .collect::<Vec<_>>();
787
788                content.extend(
789                    tool_calls
790                        .into_iter()
791                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
792                        .collect::<Result<Vec<_>, _>>()?,
793                );
794
795                message::Message::Assistant {
796                    content: OneOrMany::many(content).map_err(|_| {
797                        message::MessageError::ConversionError(
798                            "Neither `content` nor `tool_calls` was provided to the Message"
799                                .to_owned(),
800                        )
801                    })?,
802                }
803            }
804
805            Message::ToolResult {
806                tool_call_id,
807                content,
808            } => message::Message::User {
809                content: OneOrMany::one(message::UserContent::tool_result(
810                    tool_call_id,
811                    content.map(|content| message::ToolResultContent::text(content.text)),
812                )),
813            },
814
815            // System messages should get stripped out when converting message's, this is just a
816            // stop gap to avoid obnoxious error handling or panic occuring.
817            Message::System { content, .. } => message::Message::User {
818                content: content.map(|content| message::UserContent::text(content.text)),
819            },
820        })
821    }
822}
823
824impl From<UserContent> for message::UserContent {
825    fn from(content: UserContent) -> Self {
826        match content {
827            UserContent::Text { text } => message::UserContent::text(text),
828            UserContent::Image { image_url } => message::UserContent::image(
829                image_url.url,
830                Some(message::ContentFormat::default()),
831                None,
832                Some(image_url.detail),
833            ),
834            UserContent::Audio { input_audio } => message::UserContent::audio(
835                input_audio.data,
836                Some(message::ContentFormat::default()),
837                Some(input_audio.format),
838            ),
839        }
840    }
841}
842
843impl From<String> for UserContent {
844    fn from(s: String) -> Self {
845        UserContent::Text { text: s }
846    }
847}
848
849impl FromStr for UserContent {
850    type Err = Infallible;
851
852    fn from_str(s: &str) -> Result<Self, Self::Err> {
853        Ok(UserContent::Text {
854            text: s.to_string(),
855        })
856    }
857}
858
859impl From<String> for AssistantContent {
860    fn from(s: String) -> Self {
861        AssistantContent::Text { text: s }
862    }
863}
864
865impl FromStr for AssistantContent {
866    type Err = Infallible;
867
868    fn from_str(s: &str) -> Result<Self, Self::Err> {
869        Ok(AssistantContent::Text {
870            text: s.to_string(),
871        })
872    }
873}
874impl From<String> for SystemContent {
875    fn from(s: String) -> Self {
876        SystemContent {
877            r#type: SystemContentType::default(),
878            text: s,
879        }
880    }
881}
882
883impl FromStr for SystemContent {
884    type Err = Infallible;
885
886    fn from_str(s: &str) -> Result<Self, Self::Err> {
887        Ok(SystemContent {
888            r#type: SystemContentType::default(),
889            text: s.to_string(),
890        })
891    }
892}
893
894#[derive(Clone)]
895pub struct CompletionModel {
896    client: Client,
897    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
898    pub model: String,
899}
900
901impl CompletionModel {
902    pub fn new(client: Client, model: &str) -> Self {
903        Self {
904            client,
905            model: model.to_string(),
906        }
907    }
908}
909
910impl completion::CompletionModel for CompletionModel {
911    type Response = CompletionResponse;
912
913    #[cfg_attr(feature = "worker", worker::send)]
914    async fn completion(
915        &self,
916        completion_request: CompletionRequest,
917    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
918        // Add preamble to chat history (if available)
919        let mut full_history: Vec<Message> = match &completion_request.preamble {
920            Some(preamble) => vec![Message::system(preamble)],
921            None => vec![],
922        };
923
924        // Convert prompt to user message
925        let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
926
927        // Convert existing chat history
928        let chat_history: Vec<Message> = completion_request
929            .chat_history
930            .into_iter()
931            .map(|message| message.try_into())
932            .collect::<Result<Vec<Vec<Message>>, _>>()?
933            .into_iter()
934            .flatten()
935            .collect();
936
937        // Combine all messages into a single history
938        full_history.extend(chat_history);
939        full_history.extend(prompt);
940
941        let request = if completion_request.tools.is_empty() {
942            json!({
943                "model": self.model,
944                "messages": full_history,
945
946            })
947        } else {
948            json!({
949                "model": self.model,
950                "messages": full_history,
951                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
952                "tool_choice": "auto",
953            })
954        };
955
956        // only include temperature if it exists
957        // because some models don't support temperature
958        let request = if let Some(temperature) = completion_request.temperature {
959            json_utils::merge(
960                request,
961                json!({
962                    "temperature": temperature,
963                }),
964            )
965        } else {
966            request
967        };
968
969        let response = self
970            .client
971            .post("/chat/completions")
972            .json(
973                &if let Some(params) = completion_request.additional_params {
974                    json_utils::merge(request, params)
975                } else {
976                    request
977                },
978            )
979            .send()
980            .await?;
981
982        if response.status().is_success() {
983            let t = response.text().await?;
984            tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
985
986            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
987                ApiResponse::Ok(response) => {
988                    tracing::info!(target: "rig",
989                        "OpenAI completion token usage: {:?}",
990                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
991                    );
992                    response.try_into()
993                }
994                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
995            }
996        } else {
997            Err(CompletionError::ProviderError(response.text().await?))
998        }
999    }
1000}
1001
1002// ================================================================
1003// OpenAI Transcription API
1004// ================================================================
1005pub const WHISPER_1: &str = "whisper-1";
1006
1007#[derive(Debug, Deserialize)]
1008pub struct TranscriptionResponse {
1009    pub text: String,
1010}
1011
1012impl TryFrom<TranscriptionResponse>
1013    for transcription::TranscriptionResponse<TranscriptionResponse>
1014{
1015    type Error = TranscriptionError;
1016
1017    fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
1018        Ok(transcription::TranscriptionResponse {
1019            text: value.text.clone(),
1020            response: value,
1021        })
1022    }
1023}
1024
1025#[derive(Clone)]
1026pub struct TranscriptionModel {
1027    client: Client,
1028    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
1029    pub model: String,
1030}
1031
1032impl TranscriptionModel {
1033    pub fn new(client: Client, model: &str) -> Self {
1034        Self {
1035            client,
1036            model: model.to_string(),
1037        }
1038    }
1039}
1040impl transcription::TranscriptionModel for TranscriptionModel {
1041    type Response = TranscriptionResponse;
1042
1043    #[cfg_attr(feature = "worker", worker::send)]
1044    async fn transcription(
1045        &self,
1046        request: transcription::TranscriptionRequest,
1047    ) -> Result<
1048        transcription::TranscriptionResponse<Self::Response>,
1049        transcription::TranscriptionError,
1050    > {
1051        let data = request.data;
1052
1053        let mut body = reqwest::multipart::Form::new()
1054            .text("model", self.model.clone())
1055            .text("language", request.language)
1056            .part(
1057                "file",
1058                Part::bytes(data).file_name(request.filename.clone()),
1059            );
1060
1061        if let Some(prompt) = request.prompt {
1062            body = body.text("prompt", prompt.clone());
1063        }
1064
1065        if let Some(ref temperature) = request.temperature {
1066            body = body.text("temperature", temperature.to_string());
1067        }
1068
1069        if let Some(ref additional_params) = request.additional_params {
1070            for (key, value) in additional_params
1071                .as_object()
1072                .expect("Additional Parameters to OpenAI Transcription should be a map")
1073            {
1074                body = body.text(key.to_owned(), value.to_string());
1075            }
1076        }
1077
1078        let response = self
1079            .client
1080            .post("audio/transcriptions")
1081            .multipart(body)
1082            .send()
1083            .await?;
1084
1085        if response.status().is_success() {
1086            match response
1087                .json::<ApiResponse<TranscriptionResponse>>()
1088                .await?
1089            {
1090                ApiResponse::Ok(response) => response.try_into(),
1091                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
1092                    api_error_response.message,
1093                )),
1094            }
1095        } else {
1096            Err(TranscriptionError::ProviderError(response.text().await?))
1097        }
1098    }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103    use super::*;
1104    use serde_path_to_error::deserialize;
1105
1106    #[test]
1107    fn test_deserialize_message() {
1108        let assistant_message_json = r#"
1109        {
1110            "role": "assistant",
1111            "content": "\n\nHello there, how may I assist you today?"
1112        }
1113        "#;
1114
1115        let assistant_message_json2 = r#"
1116        {
1117            "role": "assistant",
1118            "content": [
1119                {
1120                    "type": "text",
1121                    "text": "\n\nHello there, how may I assist you today?"
1122                }
1123            ],
1124            "tool_calls": null
1125        }
1126        "#;
1127
1128        let assistant_message_json3 = r#"
1129        {
1130            "role": "assistant",
1131            "tool_calls": [
1132                {
1133                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
1134                    "type": "function",
1135                    "function": {
1136                        "name": "subtract",
1137                        "arguments": "{\"x\": 2, \"y\": 5}"
1138                    }
1139                }
1140            ],
1141            "content": null,
1142            "refusal": null
1143        }
1144        "#;
1145
1146        let user_message_json = r#"
1147        {
1148            "role": "user",
1149            "content": [
1150                {
1151                    "type": "text",
1152                    "text": "What's in this image?"
1153                },
1154                {
1155                    "type": "image",
1156                    "image_url": {
1157                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
1158                    }
1159                },
1160                {
1161                    "type": "audio",
1162                    "input_audio": {
1163                        "data": "...",
1164                        "format": "mp3"
1165                    }
1166                }
1167            ]
1168        }
1169        "#;
1170
1171        let assistant_message: Message = {
1172            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1173            deserialize(jd).unwrap_or_else(|err| {
1174                panic!(
1175                    "Deserialization error at {} ({}:{}): {}",
1176                    err.path(),
1177                    err.inner().line(),
1178                    err.inner().column(),
1179                    err
1180                );
1181            })
1182        };
1183
1184        let assistant_message2: Message = {
1185            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1186            deserialize(jd).unwrap_or_else(|err| {
1187                panic!(
1188                    "Deserialization error at {} ({}:{}): {}",
1189                    err.path(),
1190                    err.inner().line(),
1191                    err.inner().column(),
1192                    err
1193                );
1194            })
1195        };
1196
1197        let assistant_message3: Message = {
1198            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
1199                &mut serde_json::Deserializer::from_str(assistant_message_json3);
1200            deserialize(jd).unwrap_or_else(|err| {
1201                panic!(
1202                    "Deserialization error at {} ({}:{}): {}",
1203                    err.path(),
1204                    err.inner().line(),
1205                    err.inner().column(),
1206                    err
1207                );
1208            })
1209        };
1210
1211        let user_message: Message = {
1212            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1213            deserialize(jd).unwrap_or_else(|err| {
1214                panic!(
1215                    "Deserialization error at {} ({}:{}): {}",
1216                    err.path(),
1217                    err.inner().line(),
1218                    err.inner().column(),
1219                    err
1220                );
1221            })
1222        };
1223
1224        match assistant_message {
1225            Message::Assistant { content, .. } => {
1226                assert_eq!(
1227                    content[0],
1228                    AssistantContent::Text {
1229                        text: "\n\nHello there, how may I assist you today?".to_string()
1230                    }
1231                );
1232            }
1233            _ => panic!("Expected assistant message"),
1234        }
1235
1236        match assistant_message2 {
1237            Message::Assistant {
1238                content,
1239                tool_calls,
1240                ..
1241            } => {
1242                assert_eq!(
1243                    content[0],
1244                    AssistantContent::Text {
1245                        text: "\n\nHello there, how may I assist you today?".to_string()
1246                    }
1247                );
1248
1249                assert_eq!(tool_calls, vec![]);
1250            }
1251            _ => panic!("Expected assistant message"),
1252        }
1253
1254        match assistant_message3 {
1255            Message::Assistant {
1256                content,
1257                tool_calls,
1258                refusal,
1259                ..
1260            } => {
1261                assert!(content.is_empty());
1262                assert!(refusal.is_none());
1263                assert_eq!(
1264                    tool_calls[0],
1265                    ToolCall {
1266                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1267                        r#type: ToolType::Function,
1268                        function: Function {
1269                            name: "subtract".to_string(),
1270                            arguments: serde_json::json!({"x": 2, "y": 5}),
1271                        },
1272                    }
1273                );
1274            }
1275            _ => panic!("Expected assistant message"),
1276        }
1277
1278        match user_message {
1279            Message::User { content, .. } => {
1280                let (first, second) = {
1281                    let mut iter = content.into_iter();
1282                    (iter.next().unwrap(), iter.next().unwrap())
1283                };
1284                assert_eq!(
1285                    first,
1286                    UserContent::Text {
1287                        text: "What's in this image?".to_string()
1288                    }
1289                );
1290                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
1291            }
1292            _ => panic!("Expected user message"),
1293        }
1294    }
1295
1296    #[test]
1297    fn test_message_to_message_conversion() {
1298        let user_message = message::Message::User {
1299            content: OneOrMany::one(message::UserContent::text("Hello")),
1300        };
1301
1302        let assistant_message = message::Message::Assistant {
1303            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1304        };
1305
1306        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1307        let converted_assistant_message: Vec<Message> =
1308            assistant_message.clone().try_into().unwrap();
1309
1310        match converted_user_message[0].clone() {
1311            Message::User { content, .. } => {
1312                assert_eq!(
1313                    content.first(),
1314                    UserContent::Text {
1315                        text: "Hello".to_string()
1316                    }
1317                );
1318            }
1319            _ => panic!("Expected user message"),
1320        }
1321
1322        match converted_assistant_message[0].clone() {
1323            Message::Assistant { content, .. } => {
1324                assert_eq!(
1325                    content[0].clone(),
1326                    AssistantContent::Text {
1327                        text: "Hi there!".to_string()
1328                    }
1329                );
1330            }
1331            _ => panic!("Expected assistant message"),
1332        }
1333
1334        let original_user_message: message::Message =
1335            converted_user_message[0].clone().try_into().unwrap();
1336        let original_assistant_message: message::Message =
1337            converted_assistant_message[0].clone().try_into().unwrap();
1338
1339        assert_eq!(original_user_message, user_message);
1340        assert_eq!(original_assistant_message, assistant_message);
1341    }
1342
1343    #[test]
1344    fn test_message_from_message_conversion() {
1345        let user_message = Message::User {
1346            content: OneOrMany::one(UserContent::Text {
1347                text: "Hello".to_string(),
1348            }),
1349            name: None,
1350        };
1351
1352        let assistant_message = Message::Assistant {
1353            content: vec![AssistantContent::Text {
1354                text: "Hi there!".to_string(),
1355            }],
1356            refusal: None,
1357            audio: None,
1358            name: None,
1359            tool_calls: vec![],
1360        };
1361
1362        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1363        let converted_assistant_message: message::Message =
1364            assistant_message.clone().try_into().unwrap();
1365
1366        match converted_user_message.clone() {
1367            message::Message::User { content } => {
1368                assert_eq!(content.first(), message::UserContent::text("Hello"));
1369            }
1370            _ => panic!("Expected user message"),
1371        }
1372
1373        match converted_assistant_message.clone() {
1374            message::Message::Assistant { content } => {
1375                assert_eq!(
1376                    content.first(),
1377                    message::AssistantContent::text("Hi there!")
1378                );
1379            }
1380            _ => panic!("Expected assistant message"),
1381        }
1382
1383        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1384        let original_assistant_message: Vec<Message> =
1385            converted_assistant_message.try_into().unwrap();
1386
1387        assert_eq!(original_user_message[0], user_message);
1388        assert_eq!(original_assistant_message[0], assistant_message);
1389    }
1390}