rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{
12    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
13};
14
15#[cfg(feature = "audio")]
16use crate::client::AudioGenerationClient;
17#[cfg(feature = "image")]
18use crate::client::ImageGenerationClient;
19
20use serde::Deserialize;
21
22// ================================================================
23// Main OpenAI Client
24// ================================================================
25const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
26
27pub struct ClientBuilder<'a> {
28    api_key: &'a str,
29    base_url: &'a str,
30    http_client: Option<reqwest::Client>,
31}
32
33impl<'a> ClientBuilder<'a> {
34    pub fn new(api_key: &'a str) -> Self {
35        Self {
36            api_key,
37            base_url: OPENAI_API_BASE_URL,
38            http_client: None,
39        }
40    }
41
42    pub fn base_url(mut self, base_url: &'a str) -> Self {
43        self.base_url = base_url;
44        self
45    }
46
47    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
48        self.http_client = Some(client);
49        self
50    }
51
52    pub fn build(self) -> Result<Client, ClientBuilderError> {
53        let http_client = if let Some(http_client) = self.http_client {
54            http_client
55        } else {
56            reqwest::Client::builder().build()?
57        };
58
59        Ok(Client {
60            base_url: self.base_url.to_string(),
61            api_key: self.api_key.to_string(),
62            http_client,
63        })
64    }
65}
66
67#[derive(Clone)]
68pub struct Client {
69    base_url: String,
70    api_key: String,
71    http_client: reqwest::Client,
72}
73
74impl std::fmt::Debug for Client {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("Client")
77            .field("base_url", &self.base_url)
78            .field("http_client", &self.http_client)
79            .field("api_key", &"<REDACTED>")
80            .finish()
81    }
82}
83
84impl Client {
85    /// Create a new OpenAI client builder.
86    ///
87    /// # Example
88    /// ```
89    /// use rig::providers::openai::{ClientBuilder, self};
90    ///
91    /// // Initialize the OpenAI client
92    /// let openai_client = Client::builder("your-open-ai-api-key")
93    ///    .build()
94    /// ```
95    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
96        ClientBuilder::new(api_key)
97    }
98
99    /// Create a new OpenAI client. For more control, use the `builder` method.
100    ///
101    /// # Panics
102    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
103    pub fn new(api_key: &str) -> Self {
104        Self::builder(api_key)
105            .build()
106            .expect("OpenAI client should build")
107    }
108
109    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
110        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
111        self.http_client.post(url).bearer_auth(&self.api_key)
112    }
113}
114
115impl ProviderClient for Client {
116    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
117    /// Panics if the environment variable is not set.
118    fn from_env() -> Self {
119        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
120        Self::new(&api_key)
121    }
122
123    fn from_val(input: crate::client::ProviderValue) -> Self {
124        let crate::client::ProviderValue::Simple(api_key) = input else {
125            panic!("Incorrect provider value type")
126        };
127        Self::new(&api_key)
128    }
129}
130
131impl CompletionClient for Client {
132    type CompletionModel = super::responses_api::ResponsesCompletionModel;
133    /// Create a completion model with the given name.
134    ///
135    /// # Example
136    /// ```
137    /// use rig::providers::openai::{Client, self};
138    ///
139    /// // Initialize the OpenAI client
140    /// let openai = Client::new("your-open-ai-api-key");
141    ///
142    /// let gpt4 = openai.completion_model(openai::GPT_4);
143    /// ```
144    fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
145        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
146    }
147}
148
149impl EmbeddingsClient for Client {
150    type EmbeddingModel = EmbeddingModel;
151    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
152        let ndims = match model {
153            TEXT_EMBEDDING_3_LARGE => 3072,
154            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
155            _ => 0,
156        };
157        EmbeddingModel::new(self.clone(), model, ndims)
158    }
159
160    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
161        EmbeddingModel::new(self.clone(), model, ndims)
162    }
163}
164
165impl TranscriptionClient for Client {
166    type TranscriptionModel = TranscriptionModel;
167    /// Create a transcription model with the given name.
168    ///
169    /// # Example
170    /// ```
171    /// use rig::providers::openai::{Client, self};
172    ///
173    /// // Initialize the OpenAI client
174    /// let openai = Client::new("your-open-ai-api-key");
175    ///
176    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
177    /// ```
178    fn transcription_model(&self, model: &str) -> TranscriptionModel {
179        TranscriptionModel::new(self.clone(), model)
180    }
181}
182
183#[cfg(feature = "image")]
184impl ImageGenerationClient for Client {
185    type ImageGenerationModel = ImageGenerationModel;
186    /// Create an image generation model with the given name.
187    ///
188    /// # Example
189    /// ```
190    /// use rig::providers::openai::{Client, self};
191    ///
192    /// // Initialize the OpenAI client
193    /// let openai = Client::new("your-open-ai-api-key");
194    ///
195    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
196    /// ```
197    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
198        ImageGenerationModel::new(self.clone(), model)
199    }
200}
201
202#[cfg(feature = "audio")]
203impl AudioGenerationClient for Client {
204    type AudioGenerationModel = AudioGenerationModel;
205    /// Create an audio generation model with the given name.
206    ///
207    /// # Example
208    /// ```
209    /// use rig::providers::openai::{Client, self};
210    ///
211    /// // Initialize the OpenAI client
212    /// let openai = Client::new("your-open-ai-api-key");
213    ///
214    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
215    /// ```
216    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
217        AudioGenerationModel::new(self.clone(), model)
218    }
219}
220
221#[derive(Debug, Deserialize)]
222pub struct ApiErrorResponse {
223    pub(crate) message: String,
224}
225
226#[derive(Debug, Deserialize)]
227#[serde(untagged)]
228pub(crate) enum ApiResponse<T> {
229    Ok(T),
230    Err(ApiErrorResponse),
231}
232
233#[cfg(test)]
234mod tests {
235    use crate::message::ImageDetail;
236    use crate::providers::openai::{
237        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
238    };
239    use crate::{OneOrMany, message};
240    use serde_path_to_error::deserialize;
241
242    #[test]
243    fn test_deserialize_message() {
244        let assistant_message_json = r#"
245        {
246            "role": "assistant",
247            "content": "\n\nHello there, how may I assist you today?"
248        }
249        "#;
250
251        let assistant_message_json2 = r#"
252        {
253            "role": "assistant",
254            "content": [
255                {
256                    "type": "text",
257                    "text": "\n\nHello there, how may I assist you today?"
258                }
259            ],
260            "tool_calls": null
261        }
262        "#;
263
264        let assistant_message_json3 = r#"
265        {
266            "role": "assistant",
267            "tool_calls": [
268                {
269                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
270                    "type": "function",
271                    "function": {
272                        "name": "subtract",
273                        "arguments": "{\"x\": 2, \"y\": 5}"
274                    }
275                }
276            ],
277            "content": null,
278            "refusal": null
279        }
280        "#;
281
282        let user_message_json = r#"
283        {
284            "role": "user",
285            "content": [
286                {
287                    "type": "text",
288                    "text": "What's in this image?"
289                },
290                {
291                    "type": "image_url",
292                    "image_url": {
293                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
294                    }
295                },
296                {
297                    "type": "audio",
298                    "input_audio": {
299                        "data": "...",
300                        "format": "mp3"
301                    }
302                }
303            ]
304        }
305        "#;
306
307        let assistant_message: Message = {
308            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
309            deserialize(jd).unwrap_or_else(|err| {
310                panic!(
311                    "Deserialization error at {} ({}:{}): {}",
312                    err.path(),
313                    err.inner().line(),
314                    err.inner().column(),
315                    err
316                );
317            })
318        };
319
320        let assistant_message2: Message = {
321            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
322            deserialize(jd).unwrap_or_else(|err| {
323                panic!(
324                    "Deserialization error at {} ({}:{}): {}",
325                    err.path(),
326                    err.inner().line(),
327                    err.inner().column(),
328                    err
329                );
330            })
331        };
332
333        let assistant_message3: Message = {
334            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
335                &mut serde_json::Deserializer::from_str(assistant_message_json3);
336            deserialize(jd).unwrap_or_else(|err| {
337                panic!(
338                    "Deserialization error at {} ({}:{}): {}",
339                    err.path(),
340                    err.inner().line(),
341                    err.inner().column(),
342                    err
343                );
344            })
345        };
346
347        let user_message: Message = {
348            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
349            deserialize(jd).unwrap_or_else(|err| {
350                panic!(
351                    "Deserialization error at {} ({}:{}): {}",
352                    err.path(),
353                    err.inner().line(),
354                    err.inner().column(),
355                    err
356                );
357            })
358        };
359
360        match assistant_message {
361            Message::Assistant { content, .. } => {
362                assert_eq!(
363                    content[0],
364                    AssistantContent::Text {
365                        text: "\n\nHello there, how may I assist you today?".to_string()
366                    }
367                );
368            }
369            _ => panic!("Expected assistant message"),
370        }
371
372        match assistant_message2 {
373            Message::Assistant {
374                content,
375                tool_calls,
376                ..
377            } => {
378                assert_eq!(
379                    content[0],
380                    AssistantContent::Text {
381                        text: "\n\nHello there, how may I assist you today?".to_string()
382                    }
383                );
384
385                assert_eq!(tool_calls, vec![]);
386            }
387            _ => panic!("Expected assistant message"),
388        }
389
390        match assistant_message3 {
391            Message::Assistant {
392                content,
393                tool_calls,
394                refusal,
395                ..
396            } => {
397                assert!(content.is_empty());
398                assert!(refusal.is_none());
399                assert_eq!(
400                    tool_calls[0],
401                    ToolCall {
402                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
403                        r#type: ToolType::Function,
404                        function: Function {
405                            name: "subtract".to_string(),
406                            arguments: serde_json::json!({"x": 2, "y": 5}),
407                        },
408                    }
409                );
410            }
411            _ => panic!("Expected assistant message"),
412        }
413
414        match user_message {
415            Message::User { content, .. } => {
416                let (first, second) = {
417                    let mut iter = content.into_iter();
418                    (iter.next().unwrap(), iter.next().unwrap())
419                };
420                assert_eq!(
421                    first,
422                    UserContent::Text {
423                        text: "What's in this image?".to_string()
424                    }
425                );
426                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
427            }
428            _ => panic!("Expected user message"),
429        }
430    }
431
432    #[test]
433    fn test_message_to_message_conversion() {
434        let user_message = message::Message::User {
435            content: OneOrMany::one(message::UserContent::text("Hello")),
436        };
437
438        let assistant_message = message::Message::Assistant {
439            id: None,
440            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
441        };
442
443        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
444        let converted_assistant_message: Vec<Message> =
445            assistant_message.clone().try_into().unwrap();
446
447        match converted_user_message[0].clone() {
448            Message::User { content, .. } => {
449                assert_eq!(
450                    content.first(),
451                    UserContent::Text {
452                        text: "Hello".to_string()
453                    }
454                );
455            }
456            _ => panic!("Expected user message"),
457        }
458
459        match converted_assistant_message[0].clone() {
460            Message::Assistant { content, .. } => {
461                assert_eq!(
462                    content[0].clone(),
463                    AssistantContent::Text {
464                        text: "Hi there!".to_string()
465                    }
466                );
467            }
468            _ => panic!("Expected assistant message"),
469        }
470
471        let original_user_message: message::Message =
472            converted_user_message[0].clone().try_into().unwrap();
473        let original_assistant_message: message::Message =
474            converted_assistant_message[0].clone().try_into().unwrap();
475
476        assert_eq!(original_user_message, user_message);
477        assert_eq!(original_assistant_message, assistant_message);
478    }
479
480    #[test]
481    fn test_message_from_message_conversion() {
482        let user_message = Message::User {
483            content: OneOrMany::one(UserContent::Text {
484                text: "Hello".to_string(),
485            }),
486            name: None,
487        };
488
489        let assistant_message = Message::Assistant {
490            content: vec![AssistantContent::Text {
491                text: "Hi there!".to_string(),
492            }],
493            refusal: None,
494            audio: None,
495            name: None,
496            tool_calls: vec![],
497        };
498
499        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
500        let converted_assistant_message: message::Message =
501            assistant_message.clone().try_into().unwrap();
502
503        match converted_user_message.clone() {
504            message::Message::User { content } => {
505                assert_eq!(content.first(), message::UserContent::text("Hello"));
506            }
507            _ => panic!("Expected user message"),
508        }
509
510        match converted_assistant_message.clone() {
511            message::Message::Assistant { content, .. } => {
512                assert_eq!(
513                    content.first(),
514                    message::AssistantContent::text("Hi there!")
515                );
516            }
517            _ => panic!("Expected assistant message"),
518        }
519
520        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
521        let original_assistant_message: Vec<Message> =
522            converted_assistant_message.try_into().unwrap();
523
524        assert_eq!(original_user_message[0], user_message);
525        assert_eq!(original_assistant_message[0], assistant_message);
526    }
527}