rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
12
13#[cfg(feature = "audio")]
14use crate::client::AudioGenerationClient;
15#[cfg(feature = "image")]
16use crate::client::ImageGenerationClient;
17
18use serde::Deserialize;
19
20// ================================================================
21// Main OpenAI Client
22// ================================================================
23const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
24
25#[derive(Clone)]
26pub struct Client {
27    base_url: String,
28    api_key: String,
29    http_client: reqwest::Client,
30}
31
32impl std::fmt::Debug for Client {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("Client")
35            .field("base_url", &self.base_url)
36            .field("http_client", &self.http_client)
37            .field("api_key", &"<REDACTED>")
38            .finish()
39    }
40}
41
42impl Client {
43    /// Create a new OpenAI client with the given API key.
44    pub fn new(api_key: &str) -> Self {
45        Self::from_url(api_key, OPENAI_API_BASE_URL)
46    }
47
48    /// Create a new OpenAI client with the given API key and base API URL.
49    pub fn from_url(api_key: &str, base_url: &str) -> Self {
50        Self {
51            base_url: base_url.to_string(),
52            api_key: api_key.to_string(),
53            http_client: reqwest::Client::builder()
54                .build()
55                .expect("OpenAI reqwest client should build"),
56        }
57    }
58
59    /// Use your own `reqwest::Client`.
60    /// The required headers will be automatically attached upon trying to make a request.
61    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
62        self.http_client = client;
63
64        self
65    }
66
67    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
68        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69        self.http_client.post(url).bearer_auth(&self.api_key)
70    }
71}
72
73impl ProviderClient for Client {
74    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
75    /// Panics if the environment variable is not set.
76    fn from_env() -> Self {
77        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
78        Self::new(&api_key)
79    }
80}
81
82impl CompletionClient for Client {
83    type CompletionModel = super::responses_api::ResponsesCompletionModel;
84    /// Create a completion model with the given name.
85    ///
86    /// # Example
87    /// ```
88    /// use rig::providers::openai::{Client, self};
89    ///
90    /// // Initialize the OpenAI client
91    /// let openai = Client::new("your-open-ai-api-key");
92    ///
93    /// let gpt4 = openai.completion_model(openai::GPT_4);
94    /// ```
95    fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
96        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
97    }
98}
99
100impl EmbeddingsClient for Client {
101    type EmbeddingModel = EmbeddingModel;
102    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
103        let ndims = match model {
104            TEXT_EMBEDDING_3_LARGE => 3072,
105            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
106            _ => 0,
107        };
108        EmbeddingModel::new(self.clone(), model, ndims)
109    }
110
111    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
112        EmbeddingModel::new(self.clone(), model, ndims)
113    }
114}
115
116impl TranscriptionClient for Client {
117    type TranscriptionModel = TranscriptionModel;
118    /// Create a transcription model with the given name.
119    ///
120    /// # Example
121    /// ```
122    /// use rig::providers::openai::{Client, self};
123    ///
124    /// // Initialize the OpenAI client
125    /// let openai = Client::new("your-open-ai-api-key");
126    ///
127    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
128    /// ```
129    fn transcription_model(&self, model: &str) -> TranscriptionModel {
130        TranscriptionModel::new(self.clone(), model)
131    }
132}
133
134#[cfg(feature = "image")]
135impl ImageGenerationClient for Client {
136    type ImageGenerationModel = ImageGenerationModel;
137    /// Create an image generation model with the given name.
138    ///
139    /// # Example
140    /// ```
141    /// use rig::providers::openai::{Client, self};
142    ///
143    /// // Initialize the OpenAI client
144    /// let openai = Client::new("your-open-ai-api-key");
145    ///
146    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
147    /// ```
148    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
149        ImageGenerationModel::new(self.clone(), model)
150    }
151}
152
153#[cfg(feature = "audio")]
154impl AudioGenerationClient for Client {
155    type AudioGenerationModel = AudioGenerationModel;
156    /// Create an audio generation model with the given name.
157    ///
158    /// # Example
159    /// ```
160    /// use rig::providers::openai::{Client, self};
161    ///
162    /// // Initialize the OpenAI client
163    /// let openai = Client::new("your-open-ai-api-key");
164    ///
165    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
166    /// ```
167    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
168        AudioGenerationModel::new(self.clone(), model)
169    }
170}
171
172#[derive(Debug, Deserialize)]
173pub struct ApiErrorResponse {
174    pub(crate) message: String,
175}
176
177#[derive(Debug, Deserialize)]
178#[serde(untagged)]
179pub(crate) enum ApiResponse<T> {
180    Ok(T),
181    Err(ApiErrorResponse),
182}
183
184#[cfg(test)]
185mod tests {
186    use crate::message::ImageDetail;
187    use crate::providers::openai::{
188        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
189    };
190    use crate::{OneOrMany, message};
191    use serde_path_to_error::deserialize;
192
193    #[test]
194    fn test_deserialize_message() {
195        let assistant_message_json = r#"
196        {
197            "role": "assistant",
198            "content": "\n\nHello there, how may I assist you today?"
199        }
200        "#;
201
202        let assistant_message_json2 = r#"
203        {
204            "role": "assistant",
205            "content": [
206                {
207                    "type": "text",
208                    "text": "\n\nHello there, how may I assist you today?"
209                }
210            ],
211            "tool_calls": null
212        }
213        "#;
214
215        let assistant_message_json3 = r#"
216        {
217            "role": "assistant",
218            "tool_calls": [
219                {
220                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
221                    "type": "function",
222                    "function": {
223                        "name": "subtract",
224                        "arguments": "{\"x\": 2, \"y\": 5}"
225                    }
226                }
227            ],
228            "content": null,
229            "refusal": null
230        }
231        "#;
232
233        let user_message_json = r#"
234        {
235            "role": "user",
236            "content": [
237                {
238                    "type": "text",
239                    "text": "What's in this image?"
240                },
241                {
242                    "type": "image_url",
243                    "image_url": {
244                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
245                    }
246                },
247                {
248                    "type": "audio",
249                    "input_audio": {
250                        "data": "...",
251                        "format": "mp3"
252                    }
253                }
254            ]
255        }
256        "#;
257
258        let assistant_message: Message = {
259            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
260            deserialize(jd).unwrap_or_else(|err| {
261                panic!(
262                    "Deserialization error at {} ({}:{}): {}",
263                    err.path(),
264                    err.inner().line(),
265                    err.inner().column(),
266                    err
267                );
268            })
269        };
270
271        let assistant_message2: Message = {
272            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
273            deserialize(jd).unwrap_or_else(|err| {
274                panic!(
275                    "Deserialization error at {} ({}:{}): {}",
276                    err.path(),
277                    err.inner().line(),
278                    err.inner().column(),
279                    err
280                );
281            })
282        };
283
284        let assistant_message3: Message = {
285            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
286                &mut serde_json::Deserializer::from_str(assistant_message_json3);
287            deserialize(jd).unwrap_or_else(|err| {
288                panic!(
289                    "Deserialization error at {} ({}:{}): {}",
290                    err.path(),
291                    err.inner().line(),
292                    err.inner().column(),
293                    err
294                );
295            })
296        };
297
298        let user_message: Message = {
299            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
300            deserialize(jd).unwrap_or_else(|err| {
301                panic!(
302                    "Deserialization error at {} ({}:{}): {}",
303                    err.path(),
304                    err.inner().line(),
305                    err.inner().column(),
306                    err
307                );
308            })
309        };
310
311        match assistant_message {
312            Message::Assistant { content, .. } => {
313                assert_eq!(
314                    content[0],
315                    AssistantContent::Text {
316                        text: "\n\nHello there, how may I assist you today?".to_string()
317                    }
318                );
319            }
320            _ => panic!("Expected assistant message"),
321        }
322
323        match assistant_message2 {
324            Message::Assistant {
325                content,
326                tool_calls,
327                ..
328            } => {
329                assert_eq!(
330                    content[0],
331                    AssistantContent::Text {
332                        text: "\n\nHello there, how may I assist you today?".to_string()
333                    }
334                );
335
336                assert_eq!(tool_calls, vec![]);
337            }
338            _ => panic!("Expected assistant message"),
339        }
340
341        match assistant_message3 {
342            Message::Assistant {
343                content,
344                tool_calls,
345                refusal,
346                ..
347            } => {
348                assert!(content.is_empty());
349                assert!(refusal.is_none());
350                assert_eq!(
351                    tool_calls[0],
352                    ToolCall {
353                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
354                        r#type: ToolType::Function,
355                        function: Function {
356                            name: "subtract".to_string(),
357                            arguments: serde_json::json!({"x": 2, "y": 5}),
358                        },
359                    }
360                );
361            }
362            _ => panic!("Expected assistant message"),
363        }
364
365        match user_message {
366            Message::User { content, .. } => {
367                let (first, second) = {
368                    let mut iter = content.into_iter();
369                    (iter.next().unwrap(), iter.next().unwrap())
370                };
371                assert_eq!(
372                    first,
373                    UserContent::Text {
374                        text: "What's in this image?".to_string()
375                    }
376                );
377                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
378            }
379            _ => panic!("Expected user message"),
380        }
381    }
382
383    #[test]
384    fn test_message_to_message_conversion() {
385        let user_message = message::Message::User {
386            content: OneOrMany::one(message::UserContent::text("Hello")),
387        };
388
389        let assistant_message = message::Message::Assistant {
390            id: None,
391            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
392        };
393
394        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
395        let converted_assistant_message: Vec<Message> =
396            assistant_message.clone().try_into().unwrap();
397
398        match converted_user_message[0].clone() {
399            Message::User { content, .. } => {
400                assert_eq!(
401                    content.first(),
402                    UserContent::Text {
403                        text: "Hello".to_string()
404                    }
405                );
406            }
407            _ => panic!("Expected user message"),
408        }
409
410        match converted_assistant_message[0].clone() {
411            Message::Assistant { content, .. } => {
412                assert_eq!(
413                    content[0].clone(),
414                    AssistantContent::Text {
415                        text: "Hi there!".to_string()
416                    }
417                );
418            }
419            _ => panic!("Expected assistant message"),
420        }
421
422        let original_user_message: message::Message =
423            converted_user_message[0].clone().try_into().unwrap();
424        let original_assistant_message: message::Message =
425            converted_assistant_message[0].clone().try_into().unwrap();
426
427        assert_eq!(original_user_message, user_message);
428        assert_eq!(original_assistant_message, assistant_message);
429    }
430
431    #[test]
432    fn test_message_from_message_conversion() {
433        let user_message = Message::User {
434            content: OneOrMany::one(UserContent::Text {
435                text: "Hello".to_string(),
436            }),
437            name: None,
438        };
439
440        let assistant_message = Message::Assistant {
441            content: vec![AssistantContent::Text {
442                text: "Hi there!".to_string(),
443            }],
444            refusal: None,
445            audio: None,
446            name: None,
447            tool_calls: vec![],
448        };
449
450        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
451        let converted_assistant_message: message::Message =
452            assistant_message.clone().try_into().unwrap();
453
454        match converted_user_message.clone() {
455            message::Message::User { content } => {
456                assert_eq!(content.first(), message::UserContent::text("Hello"));
457            }
458            _ => panic!("Expected user message"),
459        }
460
461        match converted_assistant_message.clone() {
462            message::Message::Assistant { content, .. } => {
463                assert_eq!(
464                    content.first(),
465                    message::AssistantContent::text("Hi there!")
466                );
467            }
468            _ => panic!("Expected assistant message"),
469        }
470
471        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
472        let original_assistant_message: Vec<Message> =
473            converted_assistant_message.try_into().unwrap();
474
475        assert_eq!(original_user_message[0], user_message);
476        assert_eq!(original_assistant_message[0], assistant_message);
477    }
478}