rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
12
13#[cfg(feature = "audio")]
14use crate::client::AudioGenerationClient;
15#[cfg(feature = "image")]
16use crate::client::ImageGenerationClient;
17
18use serde::Deserialize;
19
20// ================================================================
21// Main OpenAI Client
22// ================================================================
23const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
24
25#[derive(Clone)]
26pub struct Client {
27    base_url: String,
28    api_key: String,
29    http_client: reqwest::Client,
30}
31
32impl std::fmt::Debug for Client {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("Client")
35            .field("base_url", &self.base_url)
36            .field("http_client", &self.http_client)
37            .field("api_key", &"<REDACTED>")
38            .finish()
39    }
40}
41
42impl Client {
43    /// Create a new OpenAI client with the given API key.
44    pub fn new(api_key: &str) -> Self {
45        Self::from_url(api_key, OPENAI_API_BASE_URL)
46    }
47
48    /// Create a new OpenAI client with the given API key and base API URL.
49    pub fn from_url(api_key: &str, base_url: &str) -> Self {
50        Self {
51            base_url: base_url.to_string(),
52            api_key: api_key.to_string(),
53            http_client: reqwest::Client::builder()
54                .build()
55                .expect("OpenAI reqwest client should build"),
56        }
57    }
58
59    /// Use your own `reqwest::Client`.
60    /// The required headers will be automatically attached upon trying to make a request.
61    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
62        self.http_client = client;
63
64        self
65    }
66
67    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
68        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69        self.http_client.post(url).bearer_auth(&self.api_key)
70    }
71}
72
73impl ProviderClient for Client {
74    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
75    /// Panics if the environment variable is not set.
76    fn from_env() -> Self {
77        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
78        Self::new(&api_key)
79    }
80
81    fn from_val(input: crate::client::ProviderValue) -> Self {
82        let crate::client::ProviderValue::Simple(api_key) = input else {
83            panic!("Incorrect provider value type")
84        };
85        Self::new(&api_key)
86    }
87}
88
89impl CompletionClient for Client {
90    type CompletionModel = super::responses_api::ResponsesCompletionModel;
91    /// Create a completion model with the given name.
92    ///
93    /// # Example
94    /// ```
95    /// use rig::providers::openai::{Client, self};
96    ///
97    /// // Initialize the OpenAI client
98    /// let openai = Client::new("your-open-ai-api-key");
99    ///
100    /// let gpt4 = openai.completion_model(openai::GPT_4);
101    /// ```
102    fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
103        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
104    }
105}
106
107impl EmbeddingsClient for Client {
108    type EmbeddingModel = EmbeddingModel;
109    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
110        let ndims = match model {
111            TEXT_EMBEDDING_3_LARGE => 3072,
112            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
113            _ => 0,
114        };
115        EmbeddingModel::new(self.clone(), model, ndims)
116    }
117
118    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
119        EmbeddingModel::new(self.clone(), model, ndims)
120    }
121}
122
123impl TranscriptionClient for Client {
124    type TranscriptionModel = TranscriptionModel;
125    /// Create a transcription model with the given name.
126    ///
127    /// # Example
128    /// ```
129    /// use rig::providers::openai::{Client, self};
130    ///
131    /// // Initialize the OpenAI client
132    /// let openai = Client::new("your-open-ai-api-key");
133    ///
134    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
135    /// ```
136    fn transcription_model(&self, model: &str) -> TranscriptionModel {
137        TranscriptionModel::new(self.clone(), model)
138    }
139}
140
141#[cfg(feature = "image")]
142impl ImageGenerationClient for Client {
143    type ImageGenerationModel = ImageGenerationModel;
144    /// Create an image generation model with the given name.
145    ///
146    /// # Example
147    /// ```
148    /// use rig::providers::openai::{Client, self};
149    ///
150    /// // Initialize the OpenAI client
151    /// let openai = Client::new("your-open-ai-api-key");
152    ///
153    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
154    /// ```
155    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
156        ImageGenerationModel::new(self.clone(), model)
157    }
158}
159
160#[cfg(feature = "audio")]
161impl AudioGenerationClient for Client {
162    type AudioGenerationModel = AudioGenerationModel;
163    /// Create an audio generation model with the given name.
164    ///
165    /// # Example
166    /// ```
167    /// use rig::providers::openai::{Client, self};
168    ///
169    /// // Initialize the OpenAI client
170    /// let openai = Client::new("your-open-ai-api-key");
171    ///
172    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
173    /// ```
174    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
175        AudioGenerationModel::new(self.clone(), model)
176    }
177}
178
179#[derive(Debug, Deserialize)]
180pub struct ApiErrorResponse {
181    pub(crate) message: String,
182}
183
184#[derive(Debug, Deserialize)]
185#[serde(untagged)]
186pub(crate) enum ApiResponse<T> {
187    Ok(T),
188    Err(ApiErrorResponse),
189}
190
191#[cfg(test)]
192mod tests {
193    use crate::message::ImageDetail;
194    use crate::providers::openai::{
195        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
196    };
197    use crate::{OneOrMany, message};
198    use serde_path_to_error::deserialize;
199
200    #[test]
201    fn test_deserialize_message() {
202        let assistant_message_json = r#"
203        {
204            "role": "assistant",
205            "content": "\n\nHello there, how may I assist you today?"
206        }
207        "#;
208
209        let assistant_message_json2 = r#"
210        {
211            "role": "assistant",
212            "content": [
213                {
214                    "type": "text",
215                    "text": "\n\nHello there, how may I assist you today?"
216                }
217            ],
218            "tool_calls": null
219        }
220        "#;
221
222        let assistant_message_json3 = r#"
223        {
224            "role": "assistant",
225            "tool_calls": [
226                {
227                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
228                    "type": "function",
229                    "function": {
230                        "name": "subtract",
231                        "arguments": "{\"x\": 2, \"y\": 5}"
232                    }
233                }
234            ],
235            "content": null,
236            "refusal": null
237        }
238        "#;
239
240        let user_message_json = r#"
241        {
242            "role": "user",
243            "content": [
244                {
245                    "type": "text",
246                    "text": "What's in this image?"
247                },
248                {
249                    "type": "image_url",
250                    "image_url": {
251                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
252                    }
253                },
254                {
255                    "type": "audio",
256                    "input_audio": {
257                        "data": "...",
258                        "format": "mp3"
259                    }
260                }
261            ]
262        }
263        "#;
264
265        let assistant_message: Message = {
266            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
267            deserialize(jd).unwrap_or_else(|err| {
268                panic!(
269                    "Deserialization error at {} ({}:{}): {}",
270                    err.path(),
271                    err.inner().line(),
272                    err.inner().column(),
273                    err
274                );
275            })
276        };
277
278        let assistant_message2: Message = {
279            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
280            deserialize(jd).unwrap_or_else(|err| {
281                panic!(
282                    "Deserialization error at {} ({}:{}): {}",
283                    err.path(),
284                    err.inner().line(),
285                    err.inner().column(),
286                    err
287                );
288            })
289        };
290
291        let assistant_message3: Message = {
292            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
293                &mut serde_json::Deserializer::from_str(assistant_message_json3);
294            deserialize(jd).unwrap_or_else(|err| {
295                panic!(
296                    "Deserialization error at {} ({}:{}): {}",
297                    err.path(),
298                    err.inner().line(),
299                    err.inner().column(),
300                    err
301                );
302            })
303        };
304
305        let user_message: Message = {
306            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
307            deserialize(jd).unwrap_or_else(|err| {
308                panic!(
309                    "Deserialization error at {} ({}:{}): {}",
310                    err.path(),
311                    err.inner().line(),
312                    err.inner().column(),
313                    err
314                );
315            })
316        };
317
318        match assistant_message {
319            Message::Assistant { content, .. } => {
320                assert_eq!(
321                    content[0],
322                    AssistantContent::Text {
323                        text: "\n\nHello there, how may I assist you today?".to_string()
324                    }
325                );
326            }
327            _ => panic!("Expected assistant message"),
328        }
329
330        match assistant_message2 {
331            Message::Assistant {
332                content,
333                tool_calls,
334                ..
335            } => {
336                assert_eq!(
337                    content[0],
338                    AssistantContent::Text {
339                        text: "\n\nHello there, how may I assist you today?".to_string()
340                    }
341                );
342
343                assert_eq!(tool_calls, vec![]);
344            }
345            _ => panic!("Expected assistant message"),
346        }
347
348        match assistant_message3 {
349            Message::Assistant {
350                content,
351                tool_calls,
352                refusal,
353                ..
354            } => {
355                assert!(content.is_empty());
356                assert!(refusal.is_none());
357                assert_eq!(
358                    tool_calls[0],
359                    ToolCall {
360                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
361                        r#type: ToolType::Function,
362                        function: Function {
363                            name: "subtract".to_string(),
364                            arguments: serde_json::json!({"x": 2, "y": 5}),
365                        },
366                    }
367                );
368            }
369            _ => panic!("Expected assistant message"),
370        }
371
372        match user_message {
373            Message::User { content, .. } => {
374                let (first, second) = {
375                    let mut iter = content.into_iter();
376                    (iter.next().unwrap(), iter.next().unwrap())
377                };
378                assert_eq!(
379                    first,
380                    UserContent::Text {
381                        text: "What's in this image?".to_string()
382                    }
383                );
384                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
385            }
386            _ => panic!("Expected user message"),
387        }
388    }
389
390    #[test]
391    fn test_message_to_message_conversion() {
392        let user_message = message::Message::User {
393            content: OneOrMany::one(message::UserContent::text("Hello")),
394        };
395
396        let assistant_message = message::Message::Assistant {
397            id: None,
398            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
399        };
400
401        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
402        let converted_assistant_message: Vec<Message> =
403            assistant_message.clone().try_into().unwrap();
404
405        match converted_user_message[0].clone() {
406            Message::User { content, .. } => {
407                assert_eq!(
408                    content.first(),
409                    UserContent::Text {
410                        text: "Hello".to_string()
411                    }
412                );
413            }
414            _ => panic!("Expected user message"),
415        }
416
417        match converted_assistant_message[0].clone() {
418            Message::Assistant { content, .. } => {
419                assert_eq!(
420                    content[0].clone(),
421                    AssistantContent::Text {
422                        text: "Hi there!".to_string()
423                    }
424                );
425            }
426            _ => panic!("Expected assistant message"),
427        }
428
429        let original_user_message: message::Message =
430            converted_user_message[0].clone().try_into().unwrap();
431        let original_assistant_message: message::Message =
432            converted_assistant_message[0].clone().try_into().unwrap();
433
434        assert_eq!(original_user_message, user_message);
435        assert_eq!(original_assistant_message, assistant_message);
436    }
437
438    #[test]
439    fn test_message_from_message_conversion() {
440        let user_message = Message::User {
441            content: OneOrMany::one(UserContent::Text {
442                text: "Hello".to_string(),
443            }),
444            name: None,
445        };
446
447        let assistant_message = Message::Assistant {
448            content: vec![AssistantContent::Text {
449                text: "Hi there!".to_string(),
450            }],
451            refusal: None,
452            audio: None,
453            name: None,
454            tool_calls: vec![],
455        };
456
457        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
458        let converted_assistant_message: message::Message =
459            assistant_message.clone().try_into().unwrap();
460
461        match converted_user_message.clone() {
462            message::Message::User { content } => {
463                assert_eq!(content.first(), message::UserContent::text("Hello"));
464            }
465            _ => panic!("Expected user message"),
466        }
467
468        match converted_assistant_message.clone() {
469            message::Message::Assistant { content, .. } => {
470                assert_eq!(
471                    content.first(),
472                    message::AssistantContent::text("Hi there!")
473                );
474            }
475            _ => panic!("Expected assistant message"),
476        }
477
478        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
479        let original_assistant_message: Vec<Message> =
480            converted_assistant_message.try_into().unwrap();
481
482        assert_eq!(original_user_message[0], user_message);
483        assert_eq!(original_assistant_message[0], assistant_message);
484    }
485}