rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::completion::CompletionModel;
4use super::embedding::{
5    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
6};
7
8#[cfg(feature = "image")]
9use super::image_generation::ImageGenerationModel;
10use super::transcription::TranscriptionModel;
11
12use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient};
13
14#[cfg(feature = "audio")]
15use crate::client::AudioGenerationClient;
16#[cfg(feature = "image")]
17use crate::client::ImageGenerationClient;
18
19use serde::Deserialize;
20
21// ================================================================
22// Main OpenAI Client
23// ================================================================
24const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
25#[derive(Clone, Debug)]
26pub struct Client {
27    base_url: String,
28    http_client: reqwest::Client,
29}
30
31impl Client {
32    /// Create a new OpenAI client with the given API key.
33    pub fn new(api_key: &str) -> Self {
34        Self::from_url(api_key, OPENAI_API_BASE_URL)
35    }
36
37    /// Create a new OpenAI client with the given API key and base API URL.
38    pub fn from_url(api_key: &str, base_url: &str) -> Self {
39        Self {
40            base_url: base_url.to_string(),
41            http_client: reqwest::Client::builder()
42                .default_headers({
43                    let mut headers = reqwest::header::HeaderMap::new();
44                    headers.insert(
45                        "Authorization",
46                        format!("Bearer {api_key}")
47                            .parse()
48                            .expect("Bearer token should parse"),
49                    );
50                    headers
51                })
52                .build()
53                .expect("OpenAI reqwest client should build"),
54        }
55    }
56
57    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
58        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
59        self.http_client.post(url)
60    }
61}
62
63impl ProviderClient for Client {
64    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
65    /// Panics if the environment variable is not set.
66    fn from_env() -> Self {
67        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
68        Self::new(&api_key)
69    }
70}
71
72impl CompletionClient for Client {
73    type CompletionModel = CompletionModel;
74    /// Create a completion model with the given name.
75    ///
76    /// # Example
77    /// ```
78    /// use rig::providers::openai::{Client, self};
79    ///
80    /// // Initialize the OpenAI client
81    /// let openai = Client::new("your-open-ai-api-key");
82    ///
83    /// let gpt4 = openai.completion_model(openai::GPT_4);
84    /// ```
85    fn completion_model(&self, model: &str) -> CompletionModel {
86        CompletionModel::new(self.clone(), model)
87    }
88}
89
90impl EmbeddingsClient for Client {
91    type EmbeddingModel = EmbeddingModel;
92    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
93        let ndims = match model {
94            TEXT_EMBEDDING_3_LARGE => 3072,
95            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
96            _ => 0,
97        };
98        EmbeddingModel::new(self.clone(), model, ndims)
99    }
100
101    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
102        EmbeddingModel::new(self.clone(), model, ndims)
103    }
104}
105
106impl TranscriptionClient for Client {
107    type TranscriptionModel = TranscriptionModel;
108    /// Create a transcription model with the given name.
109    ///
110    /// # Example
111    /// ```
112    /// use rig::providers::openai::{Client, self};
113    ///
114    /// // Initialize the OpenAI client
115    /// let openai = Client::new("your-open-ai-api-key");
116    ///
117    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
118    /// ```
119    fn transcription_model(&self, model: &str) -> TranscriptionModel {
120        TranscriptionModel::new(self.clone(), model)
121    }
122}
123
124#[cfg(feature = "image")]
125impl ImageGenerationClient for Client {
126    type ImageGenerationModel = ImageGenerationModel;
127    /// Create an image generation model with the given name.
128    ///
129    /// # Example
130    /// ```
131    /// use rig::providers::openai::{Client, self};
132    ///
133    /// // Initialize the OpenAI client
134    /// let openai = Client::new("your-open-ai-api-key");
135    ///
136    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
137    /// ```
138    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
139        ImageGenerationModel::new(self.clone(), model)
140    }
141}
142
143#[cfg(feature = "audio")]
144impl AudioGenerationClient for Client {
145    type AudioGenerationModel = AudioGenerationModel;
146    /// Create an audio generation model with the given name.
147    ///
148    /// # Example
149    /// ```
150    /// use rig::providers::openai::{Client, self};
151    ///
152    /// // Initialize the OpenAI client
153    /// let openai = Client::new("your-open-ai-api-key");
154    ///
155    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
156    /// ```
157    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
158        AudioGenerationModel::new(self.clone(), model)
159    }
160}
161
162#[derive(Clone, Debug, Deserialize)]
163pub struct Usage {
164    pub prompt_tokens: usize,
165    pub total_tokens: usize,
166}
167
168impl std::fmt::Display for Usage {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        write!(
171            f,
172            "Prompt tokens: {} Total tokens: {}",
173            self.prompt_tokens, self.total_tokens
174        )
175    }
176}
177
178#[derive(Debug, Deserialize)]
179pub struct ApiErrorResponse {
180    pub(crate) message: String,
181}
182
183#[derive(Debug, Deserialize)]
184#[serde(untagged)]
185pub(crate) enum ApiResponse<T> {
186    Ok(T),
187    Err(ApiErrorResponse),
188}
189#[cfg(test)]
190mod tests {
191    use crate::message::ImageDetail;
192    use crate::providers::openai::{
193        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
194    };
195    use crate::{message, OneOrMany};
196    use serde_path_to_error::deserialize;
197
198    #[test]
199    fn test_deserialize_message() {
200        let assistant_message_json = r#"
201        {
202            "role": "assistant",
203            "content": "\n\nHello there, how may I assist you today?"
204        }
205        "#;
206
207        let assistant_message_json2 = r#"
208        {
209            "role": "assistant",
210            "content": [
211                {
212                    "type": "text",
213                    "text": "\n\nHello there, how may I assist you today?"
214                }
215            ],
216            "tool_calls": null
217        }
218        "#;
219
220        let assistant_message_json3 = r#"
221        {
222            "role": "assistant",
223            "tool_calls": [
224                {
225                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
226                    "type": "function",
227                    "function": {
228                        "name": "subtract",
229                        "arguments": "{\"x\": 2, \"y\": 5}"
230                    }
231                }
232            ],
233            "content": null,
234            "refusal": null
235        }
236        "#;
237
238        let user_message_json = r#"
239        {
240            "role": "user",
241            "content": [
242                {
243                    "type": "text",
244                    "text": "What's in this image?"
245                },
246                {
247                    "type": "image_url",
248                    "image_url": {
249                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
250                    }
251                },
252                {
253                    "type": "audio",
254                    "input_audio": {
255                        "data": "...",
256                        "format": "mp3"
257                    }
258                }
259            ]
260        }
261        "#;
262
263        let assistant_message: Message = {
264            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
265            deserialize(jd).unwrap_or_else(|err| {
266                panic!(
267                    "Deserialization error at {} ({}:{}): {}",
268                    err.path(),
269                    err.inner().line(),
270                    err.inner().column(),
271                    err
272                );
273            })
274        };
275
276        let assistant_message2: Message = {
277            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
278            deserialize(jd).unwrap_or_else(|err| {
279                panic!(
280                    "Deserialization error at {} ({}:{}): {}",
281                    err.path(),
282                    err.inner().line(),
283                    err.inner().column(),
284                    err
285                );
286            })
287        };
288
289        let assistant_message3: Message = {
290            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
291                &mut serde_json::Deserializer::from_str(assistant_message_json3);
292            deserialize(jd).unwrap_or_else(|err| {
293                panic!(
294                    "Deserialization error at {} ({}:{}): {}",
295                    err.path(),
296                    err.inner().line(),
297                    err.inner().column(),
298                    err
299                );
300            })
301        };
302
303        let user_message: Message = {
304            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
305            deserialize(jd).unwrap_or_else(|err| {
306                panic!(
307                    "Deserialization error at {} ({}:{}): {}",
308                    err.path(),
309                    err.inner().line(),
310                    err.inner().column(),
311                    err
312                );
313            })
314        };
315
316        match assistant_message {
317            Message::Assistant { content, .. } => {
318                assert_eq!(
319                    content[0],
320                    AssistantContent::Text {
321                        text: "\n\nHello there, how may I assist you today?".to_string()
322                    }
323                );
324            }
325            _ => panic!("Expected assistant message"),
326        }
327
328        match assistant_message2 {
329            Message::Assistant {
330                content,
331                tool_calls,
332                ..
333            } => {
334                assert_eq!(
335                    content[0],
336                    AssistantContent::Text {
337                        text: "\n\nHello there, how may I assist you today?".to_string()
338                    }
339                );
340
341                assert_eq!(tool_calls, vec![]);
342            }
343            _ => panic!("Expected assistant message"),
344        }
345
346        match assistant_message3 {
347            Message::Assistant {
348                content,
349                tool_calls,
350                refusal,
351                ..
352            } => {
353                assert!(content.is_empty());
354                assert!(refusal.is_none());
355                assert_eq!(
356                    tool_calls[0],
357                    ToolCall {
358                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
359                        r#type: ToolType::Function,
360                        function: Function {
361                            name: "subtract".to_string(),
362                            arguments: serde_json::json!({"x": 2, "y": 5}),
363                        },
364                    }
365                );
366            }
367            _ => panic!("Expected assistant message"),
368        }
369
370        match user_message {
371            Message::User { content, .. } => {
372                let (first, second) = {
373                    let mut iter = content.into_iter();
374                    (iter.next().unwrap(), iter.next().unwrap())
375                };
376                assert_eq!(
377                    first,
378                    UserContent::Text {
379                        text: "What's in this image?".to_string()
380                    }
381                );
382                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
383            }
384            _ => panic!("Expected user message"),
385        }
386    }
387
388    #[test]
389    fn test_message_to_message_conversion() {
390        let user_message = message::Message::User {
391            content: OneOrMany::one(message::UserContent::text("Hello")),
392        };
393
394        let assistant_message = message::Message::Assistant {
395            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
396        };
397
398        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
399        let converted_assistant_message: Vec<Message> =
400            assistant_message.clone().try_into().unwrap();
401
402        match converted_user_message[0].clone() {
403            Message::User { content, .. } => {
404                assert_eq!(
405                    content.first(),
406                    UserContent::Text {
407                        text: "Hello".to_string()
408                    }
409                );
410            }
411            _ => panic!("Expected user message"),
412        }
413
414        match converted_assistant_message[0].clone() {
415            Message::Assistant { content, .. } => {
416                assert_eq!(
417                    content[0].clone(),
418                    AssistantContent::Text {
419                        text: "Hi there!".to_string()
420                    }
421                );
422            }
423            _ => panic!("Expected assistant message"),
424        }
425
426        let original_user_message: message::Message =
427            converted_user_message[0].clone().try_into().unwrap();
428        let original_assistant_message: message::Message =
429            converted_assistant_message[0].clone().try_into().unwrap();
430
431        assert_eq!(original_user_message, user_message);
432        assert_eq!(original_assistant_message, assistant_message);
433    }
434
435    #[test]
436    fn test_message_from_message_conversion() {
437        let user_message = Message::User {
438            content: OneOrMany::one(UserContent::Text {
439                text: "Hello".to_string(),
440            }),
441            name: None,
442        };
443
444        let assistant_message = Message::Assistant {
445            content: vec![AssistantContent::Text {
446                text: "Hi there!".to_string(),
447            }],
448            refusal: None,
449            audio: None,
450            name: None,
451            tool_calls: vec![],
452        };
453
454        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
455        let converted_assistant_message: message::Message =
456            assistant_message.clone().try_into().unwrap();
457
458        match converted_user_message.clone() {
459            message::Message::User { content } => {
460                assert_eq!(content.first(), message::UserContent::text("Hello"));
461            }
462            _ => panic!("Expected user message"),
463        }
464
465        match converted_assistant_message.clone() {
466            message::Message::Assistant { content } => {
467                assert_eq!(
468                    content.first(),
469                    message::AssistantContent::text("Hi there!")
470                );
471            }
472            _ => panic!("Expected assistant message"),
473        }
474
475        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
476        let original_assistant_message: Vec<Message> =
477            converted_assistant_message.try_into().unwrap();
478
479        assert_eq!(original_user_message[0], user_message);
480        assert_eq!(original_assistant_message[0], assistant_message);
481    }
482}