rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6use std::fmt::Debug;
7
8#[cfg(feature = "image")]
9use super::image_generation::ImageGenerationModel;
10use super::transcription::TranscriptionModel;
11
12use crate::{
13    client::{
14        CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
15        VerifyError,
16    },
17    extractor::ExtractorBuilder,
18    http_client::{self, HttpClientExt},
19    providers::openai::CompletionModel,
20};
21
22#[cfg(feature = "audio")]
23use crate::client::AudioGenerationClient;
24#[cfg(feature = "image")]
25use crate::client::ImageGenerationClient;
26
27use bytes::Bytes;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30
31// ================================================================
32// Main OpenAI Client
33// ================================================================
34const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
35
36pub struct ClientBuilder<'a, T = reqwest::Client> {
37    api_key: &'a str,
38    base_url: &'a str,
39    http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44    T: Default,
45{
46    pub fn new(api_key: &'a str) -> Self {
47        Self {
48            api_key,
49            base_url: OPENAI_API_BASE_URL,
50            http_client: Default::default(),
51        }
52    }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
57        ClientBuilder {
58            api_key,
59            base_url: OPENAI_API_BASE_URL,
60            http_client,
61        }
62    }
63
64    pub fn base_url(mut self, base_url: &'a str) -> Self {
65        self.base_url = base_url;
66        self
67    }
68
69    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
70        ClientBuilder {
71            api_key: self.api_key,
72            base_url: self.base_url,
73            http_client,
74        }
75    }
76    pub fn build(self) -> Client<T> {
77        Client {
78            base_url: self.base_url.to_string(),
79            api_key: self.api_key.to_string(),
80            http_client: self.http_client,
81        }
82    }
83}
84
85#[derive(Clone)]
86pub struct Client<T = reqwest::Client> {
87    base_url: String,
88    api_key: String,
89    pub(crate) http_client: T,
90}
91
92impl<T> Debug for Client<T>
93where
94    T: Debug,
95{
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.debug_struct("Client")
98            .field("base_url", &self.base_url)
99            .field("http_client", &self.http_client)
100            .field("api_key", &"<REDACTED>")
101            .finish()
102    }
103}
104
105impl Client<reqwest::Client> {
106    /// Create a new OpenAI client builder.
107    ///
108    /// # Example
109    /// ```
110    /// use rig::providers::openai::{ClientBuilder, self};
111    ///
112    /// // Initialize the OpenAI client
113    /// let openai_client = Client::builder("your-open-ai-api-key")
114    ///    .build()
115    /// ```
116    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
117        ClientBuilder::new(api_key)
118    }
119
120    /// Create a new OpenAI client. For more control, use the `builder` method.
121    ///
122    pub fn new(api_key: &str) -> Self {
123        Self::builder(api_key).build()
124    }
125
126    pub fn from_env() -> Self {
127        <Self as ProviderClient>::from_env()
128    }
129}
130
131impl<T> Client<T>
132where
133    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
134{
135    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
136        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
137
138        http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
139    }
140
141    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
142        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
143
144        http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
145    }
146
147    pub(crate) async fn send<U, R>(
148        &self,
149        req: http_client::Request<U>,
150    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
151    where
152        U: Into<Bytes> + Send,
153        R: From<Bytes> + Send + 'static,
154    {
155        self.http_client.send(req).await
156    }
157
158    /// Create an extractor builder with the given completion model.
159    /// Intended for use exclusively with the Chat Completions API.
160    /// Useful for using extractors with Chat Completion compliant APIs.
161    pub fn extractor_completions_api<U>(
162        &self,
163        model: &str,
164    ) -> ExtractorBuilder<CompletionModel<T>, U>
165    where
166        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
167        CompletionModel<T>: crate::completion::CompletionModel,
168    {
169        ExtractorBuilder::new(self.completion_model(model).completions_api())
170    }
171}
172
173impl<T> ProviderClient for Client<T>
174where
175    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
176{
177    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
178    /// Panics if the environment variable is not set.
179    fn from_env() -> Self {
180        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
181        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
182
183        match base_url {
184            Some(url) => ClientBuilder::<T>::new(&api_key).base_url(&url).build(),
185            None => ClientBuilder::<T>::new(&api_key).build(),
186        }
187    }
188
189    fn from_val(input: crate::client::ProviderValue) -> Self {
190        let crate::client::ProviderValue::Simple(api_key) = input else {
191            panic!("Incorrect provider value type")
192        };
193
194        ClientBuilder::<T>::new(&api_key).build()
195    }
196}
197
198impl<T> CompletionClient for Client<T>
199where
200    T: HttpClientExt + std::fmt::Debug + Clone + Default + Send + 'static,
201{
202    type CompletionModel = super::responses_api::ResponsesCompletionModel<T>;
203    /// Create a completion model with the given name.
204    ///
205    /// # Example
206    /// ```
207    /// use rig::providers::openai::{Client, self};
208    ///
209    /// // Initialize the OpenAI client
210    /// let openai = Client::new("your-open-ai-api-key");
211    ///
212    /// let gpt4 = openai.completion_model(openai::GPT_4);
213    /// ```
214    fn completion_model(&self, model: &str) -> Self::CompletionModel {
215        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
216    }
217}
218
219impl<T> EmbeddingsClient for Client<T>
220where
221    T: HttpClientExt + std::fmt::Debug + Clone + Default + Send + 'static,
222{
223    type EmbeddingModel = EmbeddingModel<T>;
224    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
225        let ndims = match model {
226            TEXT_EMBEDDING_3_LARGE => 3072,
227            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
228            _ => 0,
229        };
230        EmbeddingModel::new(self.clone(), model, ndims)
231    }
232
233    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
234        EmbeddingModel::new(self.clone(), model, ndims)
235    }
236}
237
238impl<T> TranscriptionClient for Client<T>
239where
240    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
241{
242    type TranscriptionModel = TranscriptionModel<T>;
243    /// Create a transcription model with the given name.
244    ///
245    /// # Example
246    /// ```
247    /// use rig::providers::openai::{Client, self};
248    ///
249    /// // Initialize the OpenAI client
250    /// let openai = Client::new("your-open-ai-api-key");
251    ///
252    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
253    /// ```
254    fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
255        TranscriptionModel::new(self.clone(), model)
256    }
257}
258
259#[cfg(feature = "image")]
260impl<T> ImageGenerationClient for Client<T>
261where
262    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
263{
264    type ImageGenerationModel = ImageGenerationModel<T>;
265    /// Create an image generation model with the given name.
266    ///
267    /// # Example
268    /// ```
269    /// use rig::providers::openai::{Client, self};
270    ///
271    /// // Initialize the OpenAI client
272    /// let openai = Client::new("your-open-ai-api-key");
273    ///
274    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
275    /// ```
276    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
277        ImageGenerationModel::new(self.clone(), model)
278    }
279}
280
281#[cfg(feature = "audio")]
282impl<T> AudioGenerationClient for Client<T>
283where
284    T: HttpClientExt + Clone + std::fmt::Debug + Send + Default + 'static,
285{
286    type AudioGenerationModel = AudioGenerationModel<T>;
287    /// Create an audio generation model with the given name.
288    ///
289    /// # Example
290    /// ```
291    /// use rig::providers::openai::{Client, self};
292    ///
293    /// // Initialize the OpenAI client
294    /// let openai = Client::new("your-open-ai-api-key");
295    ///
296    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
297    /// ```
298    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
299        AudioGenerationModel::new(self.clone(), model)
300    }
301}
302
303impl<T> VerifyClient for Client<T>
304where
305    T: HttpClientExt + Clone + std::fmt::Debug + Send + Default + 'static,
306{
307    #[cfg_attr(feature = "worker", worker::send)]
308    async fn verify(&self) -> Result<(), VerifyError> {
309        let req = self
310            .get("/models")?
311            .body(http_client::NoBody)
312            .map_err(|e| VerifyError::HttpError(e.into()))?;
313
314        let response = self.send(req).await?;
315
316        match response.status() {
317            reqwest::StatusCode::OK => Ok(()),
318            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
319            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
320                let text = http_client::text(response).await?;
321                Err(VerifyError::ProviderError(text))
322            }
323            _ => {
324                //response.error_for_status()?;
325                Ok(())
326            }
327        }
328    }
329}
330
331#[derive(Debug, Deserialize)]
332pub struct ApiErrorResponse {
333    pub(crate) message: String,
334}
335
336#[derive(Debug, Deserialize)]
337#[serde(untagged)]
338pub(crate) enum ApiResponse<T> {
339    Ok(T),
340    Err(ApiErrorResponse),
341}
342
343#[cfg(test)]
344mod tests {
345    use crate::message::ImageDetail;
346    use crate::providers::openai::{
347        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
348    };
349    use crate::{OneOrMany, message};
350    use serde_path_to_error::deserialize;
351
352    #[test]
353    fn test_deserialize_message() {
354        let assistant_message_json = r#"
355        {
356            "role": "assistant",
357            "content": "\n\nHello there, how may I assist you today?"
358        }
359        "#;
360
361        let assistant_message_json2 = r#"
362        {
363            "role": "assistant",
364            "content": [
365                {
366                    "type": "text",
367                    "text": "\n\nHello there, how may I assist you today?"
368                }
369            ],
370            "tool_calls": null
371        }
372        "#;
373
374        let assistant_message_json3 = r#"
375        {
376            "role": "assistant",
377            "tool_calls": [
378                {
379                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
380                    "type": "function",
381                    "function": {
382                        "name": "subtract",
383                        "arguments": "{\"x\": 2, \"y\": 5}"
384                    }
385                }
386            ],
387            "content": null,
388            "refusal": null
389        }
390        "#;
391
392        let user_message_json = r#"
393        {
394            "role": "user",
395            "content": [
396                {
397                    "type": "text",
398                    "text": "What's in this image?"
399                },
400                {
401                    "type": "image_url",
402                    "image_url": {
403                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
404                    }
405                },
406                {
407                    "type": "audio",
408                    "input_audio": {
409                        "data": "...",
410                        "format": "mp3"
411                    }
412                }
413            ]
414        }
415        "#;
416
417        let assistant_message: Message = {
418            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
419            deserialize(jd).unwrap_or_else(|err| {
420                panic!(
421                    "Deserialization error at {} ({}:{}): {}",
422                    err.path(),
423                    err.inner().line(),
424                    err.inner().column(),
425                    err
426                );
427            })
428        };
429
430        let assistant_message2: Message = {
431            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
432            deserialize(jd).unwrap_or_else(|err| {
433                panic!(
434                    "Deserialization error at {} ({}:{}): {}",
435                    err.path(),
436                    err.inner().line(),
437                    err.inner().column(),
438                    err
439                );
440            })
441        };
442
443        let assistant_message3: Message = {
444            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
445                &mut serde_json::Deserializer::from_str(assistant_message_json3);
446            deserialize(jd).unwrap_or_else(|err| {
447                panic!(
448                    "Deserialization error at {} ({}:{}): {}",
449                    err.path(),
450                    err.inner().line(),
451                    err.inner().column(),
452                    err
453                );
454            })
455        };
456
457        let user_message: Message = {
458            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
459            deserialize(jd).unwrap_or_else(|err| {
460                panic!(
461                    "Deserialization error at {} ({}:{}): {}",
462                    err.path(),
463                    err.inner().line(),
464                    err.inner().column(),
465                    err
466                );
467            })
468        };
469
470        match assistant_message {
471            Message::Assistant { content, .. } => {
472                assert_eq!(
473                    content[0],
474                    AssistantContent::Text {
475                        text: "\n\nHello there, how may I assist you today?".to_string()
476                    }
477                );
478            }
479            _ => panic!("Expected assistant message"),
480        }
481
482        match assistant_message2 {
483            Message::Assistant {
484                content,
485                tool_calls,
486                ..
487            } => {
488                assert_eq!(
489                    content[0],
490                    AssistantContent::Text {
491                        text: "\n\nHello there, how may I assist you today?".to_string()
492                    }
493                );
494
495                assert_eq!(tool_calls, vec![]);
496            }
497            _ => panic!("Expected assistant message"),
498        }
499
500        match assistant_message3 {
501            Message::Assistant {
502                content,
503                tool_calls,
504                refusal,
505                ..
506            } => {
507                assert!(content.is_empty());
508                assert!(refusal.is_none());
509                assert_eq!(
510                    tool_calls[0],
511                    ToolCall {
512                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
513                        r#type: ToolType::Function,
514                        function: Function {
515                            name: "subtract".to_string(),
516                            arguments: serde_json::json!({"x": 2, "y": 5}),
517                        },
518                    }
519                );
520            }
521            _ => panic!("Expected assistant message"),
522        }
523
524        match user_message {
525            Message::User { content, .. } => {
526                let (first, second) = {
527                    let mut iter = content.into_iter();
528                    (iter.next().unwrap(), iter.next().unwrap())
529                };
530                assert_eq!(
531                    first,
532                    UserContent::Text {
533                        text: "What's in this image?".to_string()
534                    }
535                );
536                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
537            }
538            _ => panic!("Expected user message"),
539        }
540    }
541
542    #[test]
543    fn test_message_to_message_conversion() {
544        let user_message = message::Message::User {
545            content: OneOrMany::one(message::UserContent::text("Hello")),
546        };
547
548        let assistant_message = message::Message::Assistant {
549            id: None,
550            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
551        };
552
553        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
554        let converted_assistant_message: Vec<Message> =
555            assistant_message.clone().try_into().unwrap();
556
557        match converted_user_message[0].clone() {
558            Message::User { content, .. } => {
559                assert_eq!(
560                    content.first(),
561                    UserContent::Text {
562                        text: "Hello".to_string()
563                    }
564                );
565            }
566            _ => panic!("Expected user message"),
567        }
568
569        match converted_assistant_message[0].clone() {
570            Message::Assistant { content, .. } => {
571                assert_eq!(
572                    content[0].clone(),
573                    AssistantContent::Text {
574                        text: "Hi there!".to_string()
575                    }
576                );
577            }
578            _ => panic!("Expected assistant message"),
579        }
580
581        let original_user_message: message::Message =
582            converted_user_message[0].clone().try_into().unwrap();
583        let original_assistant_message: message::Message =
584            converted_assistant_message[0].clone().try_into().unwrap();
585
586        assert_eq!(original_user_message, user_message);
587        assert_eq!(original_assistant_message, assistant_message);
588    }
589
590    #[test]
591    fn test_message_from_message_conversion() {
592        let user_message = Message::User {
593            content: OneOrMany::one(UserContent::Text {
594                text: "Hello".to_string(),
595            }),
596            name: None,
597        };
598
599        let assistant_message = Message::Assistant {
600            content: vec![AssistantContent::Text {
601                text: "Hi there!".to_string(),
602            }],
603            refusal: None,
604            audio: None,
605            name: None,
606            tool_calls: vec![],
607        };
608
609        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
610        let converted_assistant_message: message::Message =
611            assistant_message.clone().try_into().unwrap();
612
613        match converted_user_message.clone() {
614            message::Message::User { content } => {
615                assert_eq!(content.first(), message::UserContent::text("Hello"));
616            }
617            _ => panic!("Expected user message"),
618        }
619
620        match converted_assistant_message.clone() {
621            message::Message::Assistant { content, .. } => {
622                assert_eq!(
623                    content.first(),
624                    message::AssistantContent::text("Hi there!")
625                );
626            }
627            _ => panic!("Expected assistant message"),
628        }
629
630        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
631        let original_assistant_message: Vec<Message> =
632            converted_assistant_message.try_into().unwrap();
633
634        assert_eq!(original_user_message[0], user_message);
635        assert_eq!(original_assistant_message[0], assistant_message);
636    }
637}