rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6
7#[cfg(feature = "image")]
8use super::image_generation::ImageGenerationModel;
9use super::transcription::TranscriptionModel;
10
11use crate::{
12    client::{
13        ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient,
14        TranscriptionClient, VerifyClient, VerifyError,
15    },
16    extractor::ExtractorBuilder,
17    providers::openai::CompletionModel,
18};
19
20#[cfg(feature = "audio")]
21use crate::client::AudioGenerationClient;
22#[cfg(feature = "image")]
23use crate::client::ImageGenerationClient;
24
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27
28// ================================================================
29// Main OpenAI Client
30// ================================================================
31const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
32
33pub struct ClientBuilder<'a> {
34    api_key: &'a str,
35    base_url: &'a str,
36    http_client: Option<reqwest::Client>,
37}
38
39impl<'a> ClientBuilder<'a> {
40    pub fn new(api_key: &'a str) -> Self {
41        Self {
42            api_key,
43            base_url: OPENAI_API_BASE_URL,
44            http_client: None,
45        }
46    }
47
48    pub fn base_url(mut self, base_url: &'a str) -> Self {
49        self.base_url = base_url;
50        self
51    }
52
53    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
54        self.http_client = Some(client);
55        self
56    }
57
58    pub fn build(self) -> Result<Client, ClientBuilderError> {
59        let http_client = if let Some(http_client) = self.http_client {
60            http_client
61        } else {
62            reqwest::Client::builder().build()?
63        };
64
65        Ok(Client {
66            base_url: self.base_url.to_string(),
67            api_key: self.api_key.to_string(),
68            http_client,
69        })
70    }
71}
72
73#[derive(Clone)]
74pub struct Client {
75    base_url: String,
76    api_key: String,
77    http_client: reqwest::Client,
78}
79
80impl std::fmt::Debug for Client {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("Client")
83            .field("base_url", &self.base_url)
84            .field("http_client", &self.http_client)
85            .field("api_key", &"<REDACTED>")
86            .finish()
87    }
88}
89
90impl Client {
91    /// Create a new OpenAI client builder.
92    ///
93    /// # Example
94    /// ```
95    /// use rig::providers::openai::{ClientBuilder, self};
96    ///
97    /// // Initialize the OpenAI client
98    /// let openai_client = Client::builder("your-open-ai-api-key")
99    ///    .build()
100    /// ```
101    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
102        ClientBuilder::new(api_key)
103    }
104
105    /// Create a new OpenAI client. For more control, use the `builder` method.
106    ///
107    /// # Panics
108    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
109    pub fn new(api_key: &str) -> Self {
110        Self::builder(api_key)
111            .build()
112            .expect("OpenAI client should build")
113    }
114
115    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
116        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
117        self.http_client.post(url).bearer_auth(&self.api_key)
118    }
119
120    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
121        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
122        self.http_client.get(url).bearer_auth(&self.api_key)
123    }
124
125    /// Create an extractor builder with the given completion model.
126    /// Intended for use exclusively with the Chat Completions API.
127    /// Useful for using extractors with Chat Completion compliant APIs.
128    pub fn extractor_completions_api<T>(&self, model: &str) -> ExtractorBuilder<CompletionModel, T>
129    where
130        T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
131    {
132        ExtractorBuilder::new(self.completion_model(model).completions_api())
133    }
134}
135
136impl ProviderClient for Client {
137    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
138    /// Panics if the environment variable is not set.
139    fn from_env() -> Self {
140        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
141        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
142
143        match base_url {
144            Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(),
145            None => Self::new(&api_key),
146        }
147    }
148
149    fn from_val(input: crate::client::ProviderValue) -> Self {
150        let crate::client::ProviderValue::Simple(api_key) = input else {
151            panic!("Incorrect provider value type")
152        };
153        Self::new(&api_key)
154    }
155}
156
157impl CompletionClient for Client {
158    type CompletionModel = super::responses_api::ResponsesCompletionModel;
159    /// Create a completion model with the given name.
160    ///
161    /// # Example
162    /// ```
163    /// use rig::providers::openai::{Client, self};
164    ///
165    /// // Initialize the OpenAI client
166    /// let openai = Client::new("your-open-ai-api-key");
167    ///
168    /// let gpt4 = openai.completion_model(openai::GPT_4);
169    /// ```
170    fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel {
171        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
172    }
173}
174
175impl EmbeddingsClient for Client {
176    type EmbeddingModel = EmbeddingModel;
177    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
178        let ndims = match model {
179            TEXT_EMBEDDING_3_LARGE => 3072,
180            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
181            _ => 0,
182        };
183        EmbeddingModel::new(self.clone(), model, ndims)
184    }
185
186    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
187        EmbeddingModel::new(self.clone(), model, ndims)
188    }
189}
190
191impl TranscriptionClient for Client {
192    type TranscriptionModel = TranscriptionModel;
193    /// Create a transcription model with the given name.
194    ///
195    /// # Example
196    /// ```
197    /// use rig::providers::openai::{Client, self};
198    ///
199    /// // Initialize the OpenAI client
200    /// let openai = Client::new("your-open-ai-api-key");
201    ///
202    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
203    /// ```
204    fn transcription_model(&self, model: &str) -> TranscriptionModel {
205        TranscriptionModel::new(self.clone(), model)
206    }
207}
208
209#[cfg(feature = "image")]
210impl ImageGenerationClient for Client {
211    type ImageGenerationModel = ImageGenerationModel;
212    /// Create an image generation model with the given name.
213    ///
214    /// # Example
215    /// ```
216    /// use rig::providers::openai::{Client, self};
217    ///
218    /// // Initialize the OpenAI client
219    /// let openai = Client::new("your-open-ai-api-key");
220    ///
221    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
222    /// ```
223    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
224        ImageGenerationModel::new(self.clone(), model)
225    }
226}
227
228#[cfg(feature = "audio")]
229impl AudioGenerationClient for Client {
230    type AudioGenerationModel = AudioGenerationModel;
231    /// Create an audio generation model with the given name.
232    ///
233    /// # Example
234    /// ```
235    /// use rig::providers::openai::{Client, self};
236    ///
237    /// // Initialize the OpenAI client
238    /// let openai = Client::new("your-open-ai-api-key");
239    ///
240    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
241    /// ```
242    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
243        AudioGenerationModel::new(self.clone(), model)
244    }
245}
246
247impl VerifyClient for Client {
248    #[cfg_attr(feature = "worker", worker::send)]
249    async fn verify(&self) -> Result<(), VerifyError> {
250        let response = self.get("/models").send().await?;
251        match response.status() {
252            reqwest::StatusCode::OK => Ok(()),
253            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
254            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
255                Err(VerifyError::ProviderError(response.text().await?))
256            }
257            _ => {
258                response.error_for_status()?;
259                Ok(())
260            }
261        }
262    }
263}
264
265#[derive(Debug, Deserialize)]
266pub struct ApiErrorResponse {
267    pub(crate) message: String,
268}
269
270#[derive(Debug, Deserialize)]
271#[serde(untagged)]
272pub(crate) enum ApiResponse<T> {
273    Ok(T),
274    Err(ApiErrorResponse),
275}
276
277#[cfg(test)]
278mod tests {
279    use crate::message::ImageDetail;
280    use crate::providers::openai::{
281        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
282    };
283    use crate::{OneOrMany, message};
284    use serde_path_to_error::deserialize;
285
286    #[test]
287    fn test_deserialize_message() {
288        let assistant_message_json = r#"
289        {
290            "role": "assistant",
291            "content": "\n\nHello there, how may I assist you today?"
292        }
293        "#;
294
295        let assistant_message_json2 = r#"
296        {
297            "role": "assistant",
298            "content": [
299                {
300                    "type": "text",
301                    "text": "\n\nHello there, how may I assist you today?"
302                }
303            ],
304            "tool_calls": null
305        }
306        "#;
307
308        let assistant_message_json3 = r#"
309        {
310            "role": "assistant",
311            "tool_calls": [
312                {
313                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
314                    "type": "function",
315                    "function": {
316                        "name": "subtract",
317                        "arguments": "{\"x\": 2, \"y\": 5}"
318                    }
319                }
320            ],
321            "content": null,
322            "refusal": null
323        }
324        "#;
325
326        let user_message_json = r#"
327        {
328            "role": "user",
329            "content": [
330                {
331                    "type": "text",
332                    "text": "What's in this image?"
333                },
334                {
335                    "type": "image_url",
336                    "image_url": {
337                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
338                    }
339                },
340                {
341                    "type": "audio",
342                    "input_audio": {
343                        "data": "...",
344                        "format": "mp3"
345                    }
346                }
347            ]
348        }
349        "#;
350
351        let assistant_message: Message = {
352            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
353            deserialize(jd).unwrap_or_else(|err| {
354                panic!(
355                    "Deserialization error at {} ({}:{}): {}",
356                    err.path(),
357                    err.inner().line(),
358                    err.inner().column(),
359                    err
360                );
361            })
362        };
363
364        let assistant_message2: Message = {
365            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
366            deserialize(jd).unwrap_or_else(|err| {
367                panic!(
368                    "Deserialization error at {} ({}:{}): {}",
369                    err.path(),
370                    err.inner().line(),
371                    err.inner().column(),
372                    err
373                );
374            })
375        };
376
377        let assistant_message3: Message = {
378            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
379                &mut serde_json::Deserializer::from_str(assistant_message_json3);
380            deserialize(jd).unwrap_or_else(|err| {
381                panic!(
382                    "Deserialization error at {} ({}:{}): {}",
383                    err.path(),
384                    err.inner().line(),
385                    err.inner().column(),
386                    err
387                );
388            })
389        };
390
391        let user_message: Message = {
392            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
393            deserialize(jd).unwrap_or_else(|err| {
394                panic!(
395                    "Deserialization error at {} ({}:{}): {}",
396                    err.path(),
397                    err.inner().line(),
398                    err.inner().column(),
399                    err
400                );
401            })
402        };
403
404        match assistant_message {
405            Message::Assistant { content, .. } => {
406                assert_eq!(
407                    content[0],
408                    AssistantContent::Text {
409                        text: "\n\nHello there, how may I assist you today?".to_string()
410                    }
411                );
412            }
413            _ => panic!("Expected assistant message"),
414        }
415
416        match assistant_message2 {
417            Message::Assistant {
418                content,
419                tool_calls,
420                ..
421            } => {
422                assert_eq!(
423                    content[0],
424                    AssistantContent::Text {
425                        text: "\n\nHello there, how may I assist you today?".to_string()
426                    }
427                );
428
429                assert_eq!(tool_calls, vec![]);
430            }
431            _ => panic!("Expected assistant message"),
432        }
433
434        match assistant_message3 {
435            Message::Assistant {
436                content,
437                tool_calls,
438                refusal,
439                ..
440            } => {
441                assert!(content.is_empty());
442                assert!(refusal.is_none());
443                assert_eq!(
444                    tool_calls[0],
445                    ToolCall {
446                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
447                        r#type: ToolType::Function,
448                        function: Function {
449                            name: "subtract".to_string(),
450                            arguments: serde_json::json!({"x": 2, "y": 5}),
451                        },
452                    }
453                );
454            }
455            _ => panic!("Expected assistant message"),
456        }
457
458        match user_message {
459            Message::User { content, .. } => {
460                let (first, second) = {
461                    let mut iter = content.into_iter();
462                    (iter.next().unwrap(), iter.next().unwrap())
463                };
464                assert_eq!(
465                    first,
466                    UserContent::Text {
467                        text: "What's in this image?".to_string()
468                    }
469                );
470                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
471            }
472            _ => panic!("Expected user message"),
473        }
474    }
475
476    #[test]
477    fn test_message_to_message_conversion() {
478        let user_message = message::Message::User {
479            content: OneOrMany::one(message::UserContent::text("Hello")),
480        };
481
482        let assistant_message = message::Message::Assistant {
483            id: None,
484            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
485        };
486
487        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
488        let converted_assistant_message: Vec<Message> =
489            assistant_message.clone().try_into().unwrap();
490
491        match converted_user_message[0].clone() {
492            Message::User { content, .. } => {
493                assert_eq!(
494                    content.first(),
495                    UserContent::Text {
496                        text: "Hello".to_string()
497                    }
498                );
499            }
500            _ => panic!("Expected user message"),
501        }
502
503        match converted_assistant_message[0].clone() {
504            Message::Assistant { content, .. } => {
505                assert_eq!(
506                    content[0].clone(),
507                    AssistantContent::Text {
508                        text: "Hi there!".to_string()
509                    }
510                );
511            }
512            _ => panic!("Expected assistant message"),
513        }
514
515        let original_user_message: message::Message =
516            converted_user_message[0].clone().try_into().unwrap();
517        let original_assistant_message: message::Message =
518            converted_assistant_message[0].clone().try_into().unwrap();
519
520        assert_eq!(original_user_message, user_message);
521        assert_eq!(original_assistant_message, assistant_message);
522    }
523
524    #[test]
525    fn test_message_from_message_conversion() {
526        let user_message = Message::User {
527            content: OneOrMany::one(UserContent::Text {
528                text: "Hello".to_string(),
529            }),
530            name: None,
531        };
532
533        let assistant_message = Message::Assistant {
534            content: vec![AssistantContent::Text {
535                text: "Hi there!".to_string(),
536            }],
537            refusal: None,
538            audio: None,
539            name: None,
540            tool_calls: vec![],
541        };
542
543        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
544        let converted_assistant_message: message::Message =
545            assistant_message.clone().try_into().unwrap();
546
547        match converted_user_message.clone() {
548            message::Message::User { content } => {
549                assert_eq!(content.first(), message::UserContent::text("Hello"));
550            }
551            _ => panic!("Expected user message"),
552        }
553
554        match converted_assistant_message.clone() {
555            message::Message::Assistant { content, .. } => {
556                assert_eq!(
557                    content.first(),
558                    message::AssistantContent::text("Hi there!")
559                );
560            }
561            _ => panic!("Expected assistant message"),
562        }
563
564        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
565        let original_assistant_message: Vec<Message> =
566            converted_assistant_message.try_into().unwrap();
567
568        assert_eq!(original_user_message[0], user_message);
569        assert_eq!(original_assistant_message[0], assistant_message);
570    }
571}