rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{
12    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
13};
14
15#[cfg(feature = "audio")]
16use crate::client::AudioGenerationClient;
17#[cfg(feature = "image")]
18use crate::client::ImageGenerationClient;
19
20use serde::Deserialize;
21
22// ================================================================
23// Main OpenAI Client
24// ================================================================
25const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
26
27pub struct ClientBuilder<'a> {
28    api_key: &'a str,
29    base_url: &'a str,
30    http_client: Option<reqwest::Client>,
31}
32
33impl<'a> ClientBuilder<'a> {
34    pub fn new(api_key: &'a str) -> Self {
35        Self {
36            api_key,
37            base_url: OPENAI_API_BASE_URL,
38            http_client: None,
39        }
40    }
41
42    pub fn base_url(mut self, base_url: &'a str) -> Self {
43        self.base_url = base_url;
44        self
45    }
46
47    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
48        self.http_client = Some(client);
49        self
50    }
51
52    pub fn build(self) -> Result<Client, ClientBuilderError> {
53        let http_client = if let Some(http_client) = self.http_client {
54            http_client
55        } else {
56            reqwest::Client::builder().build()?
57        };
58
59        Ok(Client {
60            base_url: self.base_url.to_string(),
61            api_key: self.api_key.to_string(),
62            http_client,
63        })
64    }
65}
66
67#[derive(Clone)]
68pub struct Client {
69    base_url: String,
70    api_key: String,
71    http_client: reqwest::Client,
72}
73
74impl std::fmt::Debug for Client {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("Client")
77            .field("base_url", &self.base_url)
78            .field("http_client", &self.http_client)
79            .field("api_key", &"<REDACTED>")
80            .finish()
81    }
82}
83
84impl Client {
85    /// Create a new OpenAI client builder.
86    ///
87    /// # Example
88    /// ```
89    /// use rig::providers::openai::{ClientBuilder, self};
90    ///
91    /// // Initialize the OpenAI client
92    /// let openai_client = Client::builder("your-open-ai-api-key")
93    ///    .build()
94    /// ```
95    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
96        ClientBuilder::new(api_key)
97    }
98
99    /// Create a new OpenAI client. For more control, use the `builder` method.
100    ///
101    /// # Panics
102    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
103    pub fn new(api_key: &str) -> Self {
104        Self::builder(api_key)
105            .build()
106            .expect("OpenAI client should build")
107    }
108
109    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
110        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
111        self.http_client.post(url).bearer_auth(&self.api_key)
112    }
113}
114
115impl ProviderClient for Client {
116    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
117    /// Panics if the environment variable is not set.
118    fn from_env() -> Self {
119        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
120        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
121
122        match base_url {
123            Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(),
124            None => Self::new(&api_key),
125        }
126    }
127
128    fn from_val(input: crate::client::ProviderValue) -> Self {
129        let crate::client::ProviderValue::Simple(api_key) = input else {
130            panic!("Incorrect provider value type")
131        };
132        Self::new(&api_key)
133    }
134}
135
136impl CompletionClient for Client {
137    type CompletionModel = super::responses_api::ResponsesCompletionModel;
138    /// Create a completion model with the given name.
139    ///
140    /// # Example
141    /// ```
142    /// use rig::providers::openai::{Client, self};
143    ///
144    /// // Initialize the OpenAI client
145    /// let openai = Client::new("your-open-ai-api-key");
146    ///
147    /// let gpt4 = openai.completion_model(openai::GPT_4);
148    /// ```
149    fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
150        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
151    }
152}
153
154impl EmbeddingsClient for Client {
155    type EmbeddingModel = EmbeddingModel;
156    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
157        let ndims = match model {
158            TEXT_EMBEDDING_3_LARGE => 3072,
159            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
160            _ => 0,
161        };
162        EmbeddingModel::new(self.clone(), model, ndims)
163    }
164
165    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
166        EmbeddingModel::new(self.clone(), model, ndims)
167    }
168}
169
170impl TranscriptionClient for Client {
171    type TranscriptionModel = TranscriptionModel;
172    /// Create a transcription model with the given name.
173    ///
174    /// # Example
175    /// ```
176    /// use rig::providers::openai::{Client, self};
177    ///
178    /// // Initialize the OpenAI client
179    /// let openai = Client::new("your-open-ai-api-key");
180    ///
181    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
182    /// ```
183    fn transcription_model(&self, model: &str) -> TranscriptionModel {
184        TranscriptionModel::new(self.clone(), model)
185    }
186}
187
188#[cfg(feature = "image")]
189impl ImageGenerationClient for Client {
190    type ImageGenerationModel = ImageGenerationModel;
191    /// Create an image generation model with the given name.
192    ///
193    /// # Example
194    /// ```
195    /// use rig::providers::openai::{Client, self};
196    ///
197    /// // Initialize the OpenAI client
198    /// let openai = Client::new("your-open-ai-api-key");
199    ///
200    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
201    /// ```
202    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
203        ImageGenerationModel::new(self.clone(), model)
204    }
205}
206
207#[cfg(feature = "audio")]
208impl AudioGenerationClient for Client {
209    type AudioGenerationModel = AudioGenerationModel;
210    /// Create an audio generation model with the given name.
211    ///
212    /// # Example
213    /// ```
214    /// use rig::providers::openai::{Client, self};
215    ///
216    /// // Initialize the OpenAI client
217    /// let openai = Client::new("your-open-ai-api-key");
218    ///
219    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
220    /// ```
221    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
222        AudioGenerationModel::new(self.clone(), model)
223    }
224}
225
226#[derive(Debug, Deserialize)]
227pub struct ApiErrorResponse {
228    pub(crate) message: String,
229}
230
231#[derive(Debug, Deserialize)]
232#[serde(untagged)]
233pub(crate) enum ApiResponse<T> {
234    Ok(T),
235    Err(ApiErrorResponse),
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::message::ImageDetail;
241    use crate::providers::openai::{
242        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
243    };
244    use crate::{OneOrMany, message};
245    use serde_path_to_error::deserialize;
246
247    #[test]
248    fn test_deserialize_message() {
249        let assistant_message_json = r#"
250        {
251            "role": "assistant",
252            "content": "\n\nHello there, how may I assist you today?"
253        }
254        "#;
255
256        let assistant_message_json2 = r#"
257        {
258            "role": "assistant",
259            "content": [
260                {
261                    "type": "text",
262                    "text": "\n\nHello there, how may I assist you today?"
263                }
264            ],
265            "tool_calls": null
266        }
267        "#;
268
269        let assistant_message_json3 = r#"
270        {
271            "role": "assistant",
272            "tool_calls": [
273                {
274                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
275                    "type": "function",
276                    "function": {
277                        "name": "subtract",
278                        "arguments": "{\"x\": 2, \"y\": 5}"
279                    }
280                }
281            ],
282            "content": null,
283            "refusal": null
284        }
285        "#;
286
287        let user_message_json = r#"
288        {
289            "role": "user",
290            "content": [
291                {
292                    "type": "text",
293                    "text": "What's in this image?"
294                },
295                {
296                    "type": "image_url",
297                    "image_url": {
298                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
299                    }
300                },
301                {
302                    "type": "audio",
303                    "input_audio": {
304                        "data": "...",
305                        "format": "mp3"
306                    }
307                }
308            ]
309        }
310        "#;
311
312        let assistant_message: Message = {
313            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
314            deserialize(jd).unwrap_or_else(|err| {
315                panic!(
316                    "Deserialization error at {} ({}:{}): {}",
317                    err.path(),
318                    err.inner().line(),
319                    err.inner().column(),
320                    err
321                );
322            })
323        };
324
325        let assistant_message2: Message = {
326            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
327            deserialize(jd).unwrap_or_else(|err| {
328                panic!(
329                    "Deserialization error at {} ({}:{}): {}",
330                    err.path(),
331                    err.inner().line(),
332                    err.inner().column(),
333                    err
334                );
335            })
336        };
337
338        let assistant_message3: Message = {
339            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
340                &mut serde_json::Deserializer::from_str(assistant_message_json3);
341            deserialize(jd).unwrap_or_else(|err| {
342                panic!(
343                    "Deserialization error at {} ({}:{}): {}",
344                    err.path(),
345                    err.inner().line(),
346                    err.inner().column(),
347                    err
348                );
349            })
350        };
351
352        let user_message: Message = {
353            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
354            deserialize(jd).unwrap_or_else(|err| {
355                panic!(
356                    "Deserialization error at {} ({}:{}): {}",
357                    err.path(),
358                    err.inner().line(),
359                    err.inner().column(),
360                    err
361                );
362            })
363        };
364
365        match assistant_message {
366            Message::Assistant { content, .. } => {
367                assert_eq!(
368                    content[0],
369                    AssistantContent::Text {
370                        text: "\n\nHello there, how may I assist you today?".to_string()
371                    }
372                );
373            }
374            _ => panic!("Expected assistant message"),
375        }
376
377        match assistant_message2 {
378            Message::Assistant {
379                content,
380                tool_calls,
381                ..
382            } => {
383                assert_eq!(
384                    content[0],
385                    AssistantContent::Text {
386                        text: "\n\nHello there, how may I assist you today?".to_string()
387                    }
388                );
389
390                assert_eq!(tool_calls, vec![]);
391            }
392            _ => panic!("Expected assistant message"),
393        }
394
395        match assistant_message3 {
396            Message::Assistant {
397                content,
398                tool_calls,
399                refusal,
400                ..
401            } => {
402                assert!(content.is_empty());
403                assert!(refusal.is_none());
404                assert_eq!(
405                    tool_calls[0],
406                    ToolCall {
407                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
408                        r#type: ToolType::Function,
409                        function: Function {
410                            name: "subtract".to_string(),
411                            arguments: serde_json::json!({"x": 2, "y": 5}),
412                        },
413                    }
414                );
415            }
416            _ => panic!("Expected assistant message"),
417        }
418
419        match user_message {
420            Message::User { content, .. } => {
421                let (first, second) = {
422                    let mut iter = content.into_iter();
423                    (iter.next().unwrap(), iter.next().unwrap())
424                };
425                assert_eq!(
426                    first,
427                    UserContent::Text {
428                        text: "What's in this image?".to_string()
429                    }
430                );
431                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
432            }
433            _ => panic!("Expected user message"),
434        }
435    }
436
437    #[test]
438    fn test_message_to_message_conversion() {
439        let user_message = message::Message::User {
440            content: OneOrMany::one(message::UserContent::text("Hello")),
441        };
442
443        let assistant_message = message::Message::Assistant {
444            id: None,
445            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
446        };
447
448        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
449        let converted_assistant_message: Vec<Message> =
450            assistant_message.clone().try_into().unwrap();
451
452        match converted_user_message[0].clone() {
453            Message::User { content, .. } => {
454                assert_eq!(
455                    content.first(),
456                    UserContent::Text {
457                        text: "Hello".to_string()
458                    }
459                );
460            }
461            _ => panic!("Expected user message"),
462        }
463
464        match converted_assistant_message[0].clone() {
465            Message::Assistant { content, .. } => {
466                assert_eq!(
467                    content[0].clone(),
468                    AssistantContent::Text {
469                        text: "Hi there!".to_string()
470                    }
471                );
472            }
473            _ => panic!("Expected assistant message"),
474        }
475
476        let original_user_message: message::Message =
477            converted_user_message[0].clone().try_into().unwrap();
478        let original_assistant_message: message::Message =
479            converted_assistant_message[0].clone().try_into().unwrap();
480
481        assert_eq!(original_user_message, user_message);
482        assert_eq!(original_assistant_message, assistant_message);
483    }
484
485    #[test]
486    fn test_message_from_message_conversion() {
487        let user_message = Message::User {
488            content: OneOrMany::one(UserContent::Text {
489                text: "Hello".to_string(),
490            }),
491            name: None,
492        };
493
494        let assistant_message = Message::Assistant {
495            content: vec![AssistantContent::Text {
496                text: "Hi there!".to_string(),
497            }],
498            refusal: None,
499            audio: None,
500            name: None,
501            tool_calls: vec![],
502        };
503
504        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
505        let converted_assistant_message: message::Message =
506            assistant_message.clone().try_into().unwrap();
507
508        match converted_user_message.clone() {
509            message::Message::User { content } => {
510                assert_eq!(content.first(), message::UserContent::text("Hello"));
511            }
512            _ => panic!("Expected user message"),
513        }
514
515        match converted_assistant_message.clone() {
516            message::Message::Assistant { content, .. } => {
517                assert_eq!(
518                    content.first(),
519                    message::AssistantContent::text("Hi there!")
520                );
521            }
522            _ => panic!("Expected assistant message"),
523        }
524
525        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
526        let original_assistant_message: Vec<Message> =
527            converted_assistant_message.try_into().unwrap();
528
529        assert_eq!(original_user_message[0], user_message);
530        assert_eq!(original_assistant_message[0], assistant_message);
531    }
532}