rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{
12    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
13    VerifyClient, VerifyError,
14};
15
16#[cfg(feature = "audio")]
17use crate::client::AudioGenerationClient;
18#[cfg(feature = "image")]
19use crate::client::ImageGenerationClient;
20
21use serde::Deserialize;
22
23// ================================================================
24// Main OpenAI Client
25// ================================================================
26const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
27
28pub struct ClientBuilder<'a> {
29    api_key: &'a str,
30    base_url: &'a str,
31    http_client: Option<reqwest::Client>,
32}
33
34impl<'a> ClientBuilder<'a> {
35    pub fn new(api_key: &'a str) -> Self {
36        Self {
37            api_key,
38            base_url: OPENAI_API_BASE_URL,
39            http_client: None,
40        }
41    }
42
43    pub fn base_url(mut self, base_url: &'a str) -> Self {
44        self.base_url = base_url;
45        self
46    }
47
48    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
49        self.http_client = Some(client);
50        self
51    }
52
53    pub fn build(self) -> Result<Client, ClientBuilderError> {
54        let http_client = if let Some(http_client) = self.http_client {
55            http_client
56        } else {
57            reqwest::Client::builder().build()?
58        };
59
60        Ok(Client {
61            base_url: self.base_url.to_string(),
62            api_key: self.api_key.to_string(),
63            http_client,
64        })
65    }
66}
67
68#[derive(Clone)]
69pub struct Client {
70    base_url: String,
71    api_key: String,
72    http_client: reqwest::Client,
73}
74
75impl std::fmt::Debug for Client {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("Client")
78            .field("base_url", &self.base_url)
79            .field("http_client", &self.http_client)
80            .field("api_key", &"<REDACTED>")
81            .finish()
82    }
83}
84
85impl Client {
86    /// Create a new OpenAI client builder.
87    ///
88    /// # Example
89    /// ```
90    /// use rig::providers::openai::{ClientBuilder, self};
91    ///
92    /// // Initialize the OpenAI client
93    /// let openai_client = Client::builder("your-open-ai-api-key")
94    ///    .build()
95    /// ```
96    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
97        ClientBuilder::new(api_key)
98    }
99
100    /// Create a new OpenAI client. For more control, use the `builder` method.
101    ///
102    /// # Panics
103    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
104    pub fn new(api_key: &str) -> Self {
105        Self::builder(api_key)
106            .build()
107            .expect("OpenAI client should build")
108    }
109
110    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
111        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
112        self.http_client.post(url).bearer_auth(&self.api_key)
113    }
114
115    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
116        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
117        self.http_client.get(url).bearer_auth(&self.api_key)
118    }
119}
120
121impl ProviderClient for Client {
122    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
123    /// Panics if the environment variable is not set.
124    fn from_env() -> Self {
125        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
126        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
127
128        match base_url {
129            Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(),
130            None => Self::new(&api_key),
131        }
132    }
133
134    fn from_val(input: crate::client::ProviderValue) -> Self {
135        let crate::client::ProviderValue::Simple(api_key) = input else {
136            panic!("Incorrect provider value type")
137        };
138        Self::new(&api_key)
139    }
140}
141
142impl CompletionClient for Client {
143    type CompletionModel = super::responses_api::ResponsesCompletionModel;
144    /// Create a completion model with the given name.
145    ///
146    /// # Example
147    /// ```
148    /// use rig::providers::openai::{Client, self};
149    ///
150    /// // Initialize the OpenAI client
151    /// let openai = Client::new("your-open-ai-api-key");
152    ///
153    /// let gpt4 = openai.completion_model(openai::GPT_4);
154    /// ```
155    fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
156        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
157    }
158}
159
160impl EmbeddingsClient for Client {
161    type EmbeddingModel = EmbeddingModel;
162    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
163        let ndims = match model {
164            TEXT_EMBEDDING_3_LARGE => 3072,
165            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
166            _ => 0,
167        };
168        EmbeddingModel::new(self.clone(), model, ndims)
169    }
170
171    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
172        EmbeddingModel::new(self.clone(), model, ndims)
173    }
174}
175
176impl TranscriptionClient for Client {
177    type TranscriptionModel = TranscriptionModel;
178    /// Create a transcription model with the given name.
179    ///
180    /// # Example
181    /// ```
182    /// use rig::providers::openai::{Client, self};
183    ///
184    /// // Initialize the OpenAI client
185    /// let openai = Client::new("your-open-ai-api-key");
186    ///
187    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
188    /// ```
189    fn transcription_model(&self, model: &str) -> TranscriptionModel {
190        TranscriptionModel::new(self.clone(), model)
191    }
192}
193
194#[cfg(feature = "image")]
195impl ImageGenerationClient for Client {
196    type ImageGenerationModel = ImageGenerationModel;
197    /// Create an image generation model with the given name.
198    ///
199    /// # Example
200    /// ```
201    /// use rig::providers::openai::{Client, self};
202    ///
203    /// // Initialize the OpenAI client
204    /// let openai = Client::new("your-open-ai-api-key");
205    ///
206    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
207    /// ```
208    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
209        ImageGenerationModel::new(self.clone(), model)
210    }
211}
212
213#[cfg(feature = "audio")]
214impl AudioGenerationClient for Client {
215    type AudioGenerationModel = AudioGenerationModel;
216    /// Create an audio generation model with the given name.
217    ///
218    /// # Example
219    /// ```
220    /// use rig::providers::openai::{Client, self};
221    ///
222    /// // Initialize the OpenAI client
223    /// let openai = Client::new("your-open-ai-api-key");
224    ///
225    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
226    /// ```
227    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
228        AudioGenerationModel::new(self.clone(), model)
229    }
230}
231
232impl VerifyClient for Client {
233    #[cfg_attr(feature = "worker", worker::send)]
234    async fn verify(&self) -> Result<(), VerifyError> {
235        let response = self.get("/models").send().await?;
236        match response.status() {
237            reqwest::StatusCode::OK => Ok(()),
238            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
239            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
240                Err(VerifyError::ProviderError(response.text().await?))
241            }
242            _ => {
243                response.error_for_status()?;
244                Ok(())
245            }
246        }
247    }
248}
249
250#[derive(Debug, Deserialize)]
251pub struct ApiErrorResponse {
252    pub(crate) message: String,
253}
254
255#[derive(Debug, Deserialize)]
256#[serde(untagged)]
257pub(crate) enum ApiResponse<T> {
258    Ok(T),
259    Err(ApiErrorResponse),
260}
261
262#[cfg(test)]
263mod tests {
264    use crate::message::ImageDetail;
265    use crate::providers::openai::{
266        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
267    };
268    use crate::{OneOrMany, message};
269    use serde_path_to_error::deserialize;
270
271    #[test]
272    fn test_deserialize_message() {
273        let assistant_message_json = r#"
274        {
275            "role": "assistant",
276            "content": "\n\nHello there, how may I assist you today?"
277        }
278        "#;
279
280        let assistant_message_json2 = r#"
281        {
282            "role": "assistant",
283            "content": [
284                {
285                    "type": "text",
286                    "text": "\n\nHello there, how may I assist you today?"
287                }
288            ],
289            "tool_calls": null
290        }
291        "#;
292
293        let assistant_message_json3 = r#"
294        {
295            "role": "assistant",
296            "tool_calls": [
297                {
298                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
299                    "type": "function",
300                    "function": {
301                        "name": "subtract",
302                        "arguments": "{\"x\": 2, \"y\": 5}"
303                    }
304                }
305            ],
306            "content": null,
307            "refusal": null
308        }
309        "#;
310
311        let user_message_json = r#"
312        {
313            "role": "user",
314            "content": [
315                {
316                    "type": "text",
317                    "text": "What's in this image?"
318                },
319                {
320                    "type": "image_url",
321                    "image_url": {
322                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
323                    }
324                },
325                {
326                    "type": "audio",
327                    "input_audio": {
328                        "data": "...",
329                        "format": "mp3"
330                    }
331                }
332            ]
333        }
334        "#;
335
336        let assistant_message: Message = {
337            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
338            deserialize(jd).unwrap_or_else(|err| {
339                panic!(
340                    "Deserialization error at {} ({}:{}): {}",
341                    err.path(),
342                    err.inner().line(),
343                    err.inner().column(),
344                    err
345                );
346            })
347        };
348
349        let assistant_message2: Message = {
350            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
351            deserialize(jd).unwrap_or_else(|err| {
352                panic!(
353                    "Deserialization error at {} ({}:{}): {}",
354                    err.path(),
355                    err.inner().line(),
356                    err.inner().column(),
357                    err
358                );
359            })
360        };
361
362        let assistant_message3: Message = {
363            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
364                &mut serde_json::Deserializer::from_str(assistant_message_json3);
365            deserialize(jd).unwrap_or_else(|err| {
366                panic!(
367                    "Deserialization error at {} ({}:{}): {}",
368                    err.path(),
369                    err.inner().line(),
370                    err.inner().column(),
371                    err
372                );
373            })
374        };
375
376        let user_message: Message = {
377            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
378            deserialize(jd).unwrap_or_else(|err| {
379                panic!(
380                    "Deserialization error at {} ({}:{}): {}",
381                    err.path(),
382                    err.inner().line(),
383                    err.inner().column(),
384                    err
385                );
386            })
387        };
388
389        match assistant_message {
390            Message::Assistant { content, .. } => {
391                assert_eq!(
392                    content[0],
393                    AssistantContent::Text {
394                        text: "\n\nHello there, how may I assist you today?".to_string()
395                    }
396                );
397            }
398            _ => panic!("Expected assistant message"),
399        }
400
401        match assistant_message2 {
402            Message::Assistant {
403                content,
404                tool_calls,
405                ..
406            } => {
407                assert_eq!(
408                    content[0],
409                    AssistantContent::Text {
410                        text: "\n\nHello there, how may I assist you today?".to_string()
411                    }
412                );
413
414                assert_eq!(tool_calls, vec![]);
415            }
416            _ => panic!("Expected assistant message"),
417        }
418
419        match assistant_message3 {
420            Message::Assistant {
421                content,
422                tool_calls,
423                refusal,
424                ..
425            } => {
426                assert!(content.is_empty());
427                assert!(refusal.is_none());
428                assert_eq!(
429                    tool_calls[0],
430                    ToolCall {
431                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
432                        r#type: ToolType::Function,
433                        function: Function {
434                            name: "subtract".to_string(),
435                            arguments: serde_json::json!({"x": 2, "y": 5}),
436                        },
437                    }
438                );
439            }
440            _ => panic!("Expected assistant message"),
441        }
442
443        match user_message {
444            Message::User { content, .. } => {
445                let (first, second) = {
446                    let mut iter = content.into_iter();
447                    (iter.next().unwrap(), iter.next().unwrap())
448                };
449                assert_eq!(
450                    first,
451                    UserContent::Text {
452                        text: "What's in this image?".to_string()
453                    }
454                );
455                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
456            }
457            _ => panic!("Expected user message"),
458        }
459    }
460
461    #[test]
462    fn test_message_to_message_conversion() {
463        let user_message = message::Message::User {
464            content: OneOrMany::one(message::UserContent::text("Hello")),
465        };
466
467        let assistant_message = message::Message::Assistant {
468            id: None,
469            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
470        };
471
472        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
473        let converted_assistant_message: Vec<Message> =
474            assistant_message.clone().try_into().unwrap();
475
476        match converted_user_message[0].clone() {
477            Message::User { content, .. } => {
478                assert_eq!(
479                    content.first(),
480                    UserContent::Text {
481                        text: "Hello".to_string()
482                    }
483                );
484            }
485            _ => panic!("Expected user message"),
486        }
487
488        match converted_assistant_message[0].clone() {
489            Message::Assistant { content, .. } => {
490                assert_eq!(
491                    content[0].clone(),
492                    AssistantContent::Text {
493                        text: "Hi there!".to_string()
494                    }
495                );
496            }
497            _ => panic!("Expected assistant message"),
498        }
499
500        let original_user_message: message::Message =
501            converted_user_message[0].clone().try_into().unwrap();
502        let original_assistant_message: message::Message =
503            converted_assistant_message[0].clone().try_into().unwrap();
504
505        assert_eq!(original_user_message, user_message);
506        assert_eq!(original_assistant_message, assistant_message);
507    }
508
509    #[test]
510    fn test_message_from_message_conversion() {
511        let user_message = Message::User {
512            content: OneOrMany::one(UserContent::Text {
513                text: "Hello".to_string(),
514            }),
515            name: None,
516        };
517
518        let assistant_message = Message::Assistant {
519            content: vec![AssistantContent::Text {
520                text: "Hi there!".to_string(),
521            }],
522            refusal: None,
523            audio: None,
524            name: None,
525            tool_calls: vec![],
526        };
527
528        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
529        let converted_assistant_message: message::Message =
530            assistant_message.clone().try_into().unwrap();
531
532        match converted_user_message.clone() {
533            message::Message::User { content } => {
534                assert_eq!(content.first(), message::UserContent::text("Hello"));
535            }
536            _ => panic!("Expected user message"),
537        }
538
539        match converted_assistant_message.clone() {
540            message::Message::Assistant { content, .. } => {
541                assert_eq!(
542                    content.first(),
543                    message::AssistantContent::text("Hi there!")
544                );
545            }
546            _ => panic!("Expected assistant message"),
547        }
548
549        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
550        let original_assistant_message: Vec<Message> =
551            converted_assistant_message.try_into().unwrap();
552
553        assert_eq!(original_user_message[0], user_message);
554        assert_eq!(original_assistant_message[0], assistant_message);
555    }
556}