rig/providers/openai/
client.rs

1#[cfg(feature = "audio")]
2use super::audio_generation::AudioGenerationModel;
3use super::embedding::{
4    EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
5};
6use std::fmt::Debug;
7
8#[cfg(feature = "image")]
9use super::image_generation::ImageGenerationModel;
10use super::transcription::TranscriptionModel;
11
12use crate::{
13    client::{
14        CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
15        VerifyError,
16    },
17    extractor::ExtractorBuilder,
18    http_client::{self, HttpClientExt},
19    providers::openai::CompletionModel,
20};
21
22#[cfg(feature = "audio")]
23use crate::client::AudioGenerationClient;
24#[cfg(feature = "image")]
25use crate::client::ImageGenerationClient;
26
27use bytes::Bytes;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30
31// ================================================================
32// Main OpenAI Client
33// ================================================================
34const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
35
36pub struct ClientBuilder<'a, T = reqwest::Client> {
37    api_key: &'a str,
38    base_url: &'a str,
39    http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44    T: Default,
45{
46    pub fn new(api_key: &'a str) -> Self {
47        Self {
48            api_key,
49            base_url: OPENAI_API_BASE_URL,
50            http_client: Default::default(),
51        }
52    }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56    pub fn base_url(mut self, base_url: &'a str) -> Self {
57        self.base_url = base_url;
58        self
59    }
60
61    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
62        ClientBuilder {
63            api_key: self.api_key,
64            base_url: self.base_url,
65            http_client,
66        }
67    }
68    pub fn build(self) -> Client<T> {
69        Client {
70            base_url: self.base_url.to_string(),
71            api_key: self.api_key.to_string(),
72            http_client: self.http_client,
73        }
74    }
75}
76
77#[derive(Clone)]
78pub struct Client<T = reqwest::Client> {
79    base_url: String,
80    api_key: String,
81    http_client: T,
82}
83
84impl<T> Debug for Client<T>
85where
86    T: Debug,
87{
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("Client")
90            .field("base_url", &self.base_url)
91            .field("http_client", &self.http_client)
92            .field("api_key", &"<REDACTED>")
93            .finish()
94    }
95}
96
97impl<T> Client<T>
98where
99    T: Default,
100{
101    /// Create a new OpenAI client builder.
102    ///
103    /// # Example
104    /// ```
105    /// use rig::providers::openai::{ClientBuilder, self};
106    ///
107    /// // Initialize the OpenAI client
108    /// let openai_client = Client::builder("your-open-ai-api-key")
109    ///    .build()
110    /// ```
111    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
112        ClientBuilder::new(api_key)
113    }
114
115    /// Create a new OpenAI client. For more control, use the `builder` method.
116    ///
117    pub fn new(api_key: &str) -> Self {
118        Self::builder(api_key).build()
119    }
120}
121
122impl<T> Client<T>
123where
124    T: HttpClientExt,
125{
126    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
127        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
128
129        http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
130    }
131
132    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
133        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
134
135        http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
136    }
137
138    pub(crate) async fn send<U, R>(
139        &self,
140        req: http_client::Request<U>,
141    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
142    where
143        U: Into<Bytes> + Send,
144        R: From<Bytes> + Send + 'static,
145    {
146        self.http_client.send(req).await
147    }
148}
149
150impl Client<reqwest::Client> {
151    pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder {
152        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
153
154        self.http_client.post(url).bearer_auth(&self.api_key)
155    }
156
157    /// Create an extractor builder with the given completion model.
158    /// Intended for use exclusively with the Chat Completions API.
159    /// Useful for using extractors with Chat Completion compliant APIs.
160    pub fn extractor_completions_api<U>(
161        &self,
162        model: &str,
163    ) -> ExtractorBuilder<CompletionModel<reqwest::Client>, U>
164    where
165        U: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
166    {
167        ExtractorBuilder::new(self.completion_model(model).completions_api())
168    }
169}
170
171impl Client<reqwest::Client> {}
172
173impl ProviderClient for Client<reqwest::Client> {
174    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
175    /// Panics if the environment variable is not set.
176    fn from_env() -> Self {
177        let base_url: Option<String> = std::env::var("OPENAI_BASE_URL").ok();
178        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
179
180        match base_url {
181            Some(url) => Self::builder(&api_key).base_url(&url).build(),
182            None => Self::new(&api_key),
183        }
184    }
185
186    fn from_val(input: crate::client::ProviderValue) -> Self {
187        let crate::client::ProviderValue::Simple(api_key) = input else {
188            panic!("Incorrect provider value type")
189        };
190        Self::new(&api_key)
191    }
192}
193
194impl CompletionClient for Client<reqwest::Client> {
195    type CompletionModel = super::responses_api::ResponsesCompletionModel<reqwest::Client>;
196    /// Create a completion model with the given name.
197    ///
198    /// # Example
199    /// ```
200    /// use rig::providers::openai::{Client, self};
201    ///
202    /// // Initialize the OpenAI client
203    /// let openai = Client::new("your-open-ai-api-key");
204    ///
205    /// let gpt4 = openai.completion_model(openai::GPT_4);
206    /// ```
207    fn completion_model(
208        &self,
209        model: &str,
210    ) -> super::responses_api::ResponsesCompletionModel<reqwest::Client> {
211        super::responses_api::ResponsesCompletionModel::new(self.clone(), model)
212    }
213}
214
215impl EmbeddingsClient for Client<reqwest::Client> {
216    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
217    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
218        let ndims = match model {
219            TEXT_EMBEDDING_3_LARGE => 3072,
220            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
221            _ => 0,
222        };
223        EmbeddingModel::new(self.clone(), model, ndims)
224    }
225
226    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
227        EmbeddingModel::new(self.clone(), model, ndims)
228    }
229}
230
231impl TranscriptionClient for Client<reqwest::Client> {
232    type TranscriptionModel = TranscriptionModel<reqwest::Client>;
233    /// Create a transcription model with the given name.
234    ///
235    /// # Example
236    /// ```
237    /// use rig::providers::openai::{Client, self};
238    ///
239    /// // Initialize the OpenAI client
240    /// let openai = Client::new("your-open-ai-api-key");
241    ///
242    /// let gpt4 = openai.transcription_model(openai::WHISPER_1);
243    /// ```
244    fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
245        TranscriptionModel::new(self.clone(), model)
246    }
247}
248
249#[cfg(feature = "image")]
250impl ImageGenerationClient for Client<reqwest::Client> {
251    type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
252    /// Create an image generation model with the given name.
253    ///
254    /// # Example
255    /// ```
256    /// use rig::providers::openai::{Client, self};
257    ///
258    /// // Initialize the OpenAI client
259    /// let openai = Client::new("your-open-ai-api-key");
260    ///
261    /// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
262    /// ```
263    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
264        ImageGenerationModel::new(self.clone(), model)
265    }
266}
267
268#[cfg(feature = "audio")]
269impl AudioGenerationClient for Client<reqwest::Client> {
270    type AudioGenerationModel = AudioGenerationModel<reqwest::Client>;
271    /// Create an audio generation model with the given name.
272    ///
273    /// # Example
274    /// ```
275    /// use rig::providers::openai::{Client, self};
276    ///
277    /// // Initialize the OpenAI client
278    /// let openai = Client::new("your-open-ai-api-key");
279    ///
280    /// let gpt4 = openai.audio_generation_model(openai::TTS_1);
281    /// ```
282    fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
283        AudioGenerationModel::new(self.clone(), model)
284    }
285}
286
287impl VerifyClient for Client<reqwest::Client> {
288    #[cfg_attr(feature = "worker", worker::send)]
289    async fn verify(&self) -> Result<(), VerifyError> {
290        let req = self
291            .get("/models")?
292            .body(http_client::NoBody)
293            .map_err(|e| VerifyError::HttpError(e.into()))?;
294
295        let response = self.send(req).await?;
296
297        match response.status() {
298            reqwest::StatusCode::OK => Ok(()),
299            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
300            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
301                let text = http_client::text(response).await?;
302                Err(VerifyError::ProviderError(text))
303            }
304            _ => {
305                //response.error_for_status()?;
306                Ok(())
307            }
308        }
309    }
310}
311
312#[derive(Debug, Deserialize)]
313pub struct ApiErrorResponse {
314    pub(crate) message: String,
315}
316
317#[derive(Debug, Deserialize)]
318#[serde(untagged)]
319pub(crate) enum ApiResponse<T> {
320    Ok(T),
321    Err(ApiErrorResponse),
322}
323
324#[cfg(test)]
325mod tests {
326    use crate::message::ImageDetail;
327    use crate::providers::openai::{
328        AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
329    };
330    use crate::{OneOrMany, message};
331    use serde_path_to_error::deserialize;
332
333    #[test]
334    fn test_deserialize_message() {
335        let assistant_message_json = r#"
336        {
337            "role": "assistant",
338            "content": "\n\nHello there, how may I assist you today?"
339        }
340        "#;
341
342        let assistant_message_json2 = r#"
343        {
344            "role": "assistant",
345            "content": [
346                {
347                    "type": "text",
348                    "text": "\n\nHello there, how may I assist you today?"
349                }
350            ],
351            "tool_calls": null
352        }
353        "#;
354
355        let assistant_message_json3 = r#"
356        {
357            "role": "assistant",
358            "tool_calls": [
359                {
360                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
361                    "type": "function",
362                    "function": {
363                        "name": "subtract",
364                        "arguments": "{\"x\": 2, \"y\": 5}"
365                    }
366                }
367            ],
368            "content": null,
369            "refusal": null
370        }
371        "#;
372
373        let user_message_json = r#"
374        {
375            "role": "user",
376            "content": [
377                {
378                    "type": "text",
379                    "text": "What's in this image?"
380                },
381                {
382                    "type": "image_url",
383                    "image_url": {
384                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
385                    }
386                },
387                {
388                    "type": "audio",
389                    "input_audio": {
390                        "data": "...",
391                        "format": "mp3"
392                    }
393                }
394            ]
395        }
396        "#;
397
398        let assistant_message: Message = {
399            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
400            deserialize(jd).unwrap_or_else(|err| {
401                panic!(
402                    "Deserialization error at {} ({}:{}): {}",
403                    err.path(),
404                    err.inner().line(),
405                    err.inner().column(),
406                    err
407                );
408            })
409        };
410
411        let assistant_message2: Message = {
412            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
413            deserialize(jd).unwrap_or_else(|err| {
414                panic!(
415                    "Deserialization error at {} ({}:{}): {}",
416                    err.path(),
417                    err.inner().line(),
418                    err.inner().column(),
419                    err
420                );
421            })
422        };
423
424        let assistant_message3: Message = {
425            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
426                &mut serde_json::Deserializer::from_str(assistant_message_json3);
427            deserialize(jd).unwrap_or_else(|err| {
428                panic!(
429                    "Deserialization error at {} ({}:{}): {}",
430                    err.path(),
431                    err.inner().line(),
432                    err.inner().column(),
433                    err
434                );
435            })
436        };
437
438        let user_message: Message = {
439            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
440            deserialize(jd).unwrap_or_else(|err| {
441                panic!(
442                    "Deserialization error at {} ({}:{}): {}",
443                    err.path(),
444                    err.inner().line(),
445                    err.inner().column(),
446                    err
447                );
448            })
449        };
450
451        match assistant_message {
452            Message::Assistant { content, .. } => {
453                assert_eq!(
454                    content[0],
455                    AssistantContent::Text {
456                        text: "\n\nHello there, how may I assist you today?".to_string()
457                    }
458                );
459            }
460            _ => panic!("Expected assistant message"),
461        }
462
463        match assistant_message2 {
464            Message::Assistant {
465                content,
466                tool_calls,
467                ..
468            } => {
469                assert_eq!(
470                    content[0],
471                    AssistantContent::Text {
472                        text: "\n\nHello there, how may I assist you today?".to_string()
473                    }
474                );
475
476                assert_eq!(tool_calls, vec![]);
477            }
478            _ => panic!("Expected assistant message"),
479        }
480
481        match assistant_message3 {
482            Message::Assistant {
483                content,
484                tool_calls,
485                refusal,
486                ..
487            } => {
488                assert!(content.is_empty());
489                assert!(refusal.is_none());
490                assert_eq!(
491                    tool_calls[0],
492                    ToolCall {
493                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
494                        r#type: ToolType::Function,
495                        function: Function {
496                            name: "subtract".to_string(),
497                            arguments: serde_json::json!({"x": 2, "y": 5}),
498                        },
499                    }
500                );
501            }
502            _ => panic!("Expected assistant message"),
503        }
504
505        match user_message {
506            Message::User { content, .. } => {
507                let (first, second) = {
508                    let mut iter = content.into_iter();
509                    (iter.next().unwrap(), iter.next().unwrap())
510                };
511                assert_eq!(
512                    first,
513                    UserContent::Text {
514                        text: "What's in this image?".to_string()
515                    }
516                );
517                assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
518            }
519            _ => panic!("Expected user message"),
520        }
521    }
522
523    #[test]
524    fn test_message_to_message_conversion() {
525        let user_message = message::Message::User {
526            content: OneOrMany::one(message::UserContent::text("Hello")),
527        };
528
529        let assistant_message = message::Message::Assistant {
530            id: None,
531            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
532        };
533
534        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
535        let converted_assistant_message: Vec<Message> =
536            assistant_message.clone().try_into().unwrap();
537
538        match converted_user_message[0].clone() {
539            Message::User { content, .. } => {
540                assert_eq!(
541                    content.first(),
542                    UserContent::Text {
543                        text: "Hello".to_string()
544                    }
545                );
546            }
547            _ => panic!("Expected user message"),
548        }
549
550        match converted_assistant_message[0].clone() {
551            Message::Assistant { content, .. } => {
552                assert_eq!(
553                    content[0].clone(),
554                    AssistantContent::Text {
555                        text: "Hi there!".to_string()
556                    }
557                );
558            }
559            _ => panic!("Expected assistant message"),
560        }
561
562        let original_user_message: message::Message =
563            converted_user_message[0].clone().try_into().unwrap();
564        let original_assistant_message: message::Message =
565            converted_assistant_message[0].clone().try_into().unwrap();
566
567        assert_eq!(original_user_message, user_message);
568        assert_eq!(original_assistant_message, assistant_message);
569    }
570
571    #[test]
572    fn test_message_from_message_conversion() {
573        let user_message = Message::User {
574            content: OneOrMany::one(UserContent::Text {
575                text: "Hello".to_string(),
576            }),
577            name: None,
578        };
579
580        let assistant_message = Message::Assistant {
581            content: vec![AssistantContent::Text {
582                text: "Hi there!".to_string(),
583            }],
584            refusal: None,
585            audio: None,
586            name: None,
587            tool_calls: vec![],
588        };
589
590        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
591        let converted_assistant_message: message::Message =
592            assistant_message.clone().try_into().unwrap();
593
594        match converted_user_message.clone() {
595            message::Message::User { content } => {
596                assert_eq!(content.first(), message::UserContent::text("Hello"));
597            }
598            _ => panic!("Expected user message"),
599        }
600
601        match converted_assistant_message.clone() {
602            message::Message::Assistant { content, .. } => {
603                assert_eq!(
604                    content.first(),
605                    message::AssistantContent::text("Hi there!")
606                );
607            }
608            _ => panic!("Expected assistant message"),
609        }
610
611        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
612        let original_assistant_message: Vec<Message> =
613            converted_assistant_message.try_into().unwrap();
614
615        assert_eq!(original_user_message[0], user_message);
616        assert_eq!(original_assistant_message[0], assistant_message);
617    }
618}