async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    admin::Admin,
11    chatkit::Chatkit,
12    config::{Config, OpenAIConfig},
13    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
14    file::Files,
15    image::Images,
16    moderation::Moderations,
17    traits::AsyncTryFrom,
18    Assistants, Audio, Batches, Chat, Completions, Containers, Conversations, Embeddings, Evals,
19    FineTuning, Models, Responses, Threads, Uploads, Usage, VectorStores, Videos,
20};
21
22#[cfg(feature = "realtime")]
23use crate::Realtime;
24
25#[derive(Debug, Clone, Default)]
26/// Client is a container for config, backoff and http_client
27/// used to make API calls.
28pub struct Client<C: Config> {
29    http_client: reqwest::Client,
30    config: C,
31    backoff: backoff::ExponentialBackoff,
32}
33
34impl Client<OpenAIConfig> {
35    /// Client with default [OpenAIConfig]
36    pub fn new() -> Self {
37        Self::default()
38    }
39}
40
41impl<C: Config> Client<C> {
42    /// Create client with a custom HTTP client, OpenAI config, and backoff.
43    pub fn build(
44        http_client: reqwest::Client,
45        config: C,
46        backoff: backoff::ExponentialBackoff,
47    ) -> Self {
48        Self {
49            http_client,
50            config,
51            backoff,
52        }
53    }
54
55    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
56    pub fn with_config(config: C) -> Self {
57        Self {
58            http_client: reqwest::Client::new(),
59            config,
60            backoff: Default::default(),
61        }
62    }
63
64    /// Provide your own [client] to make HTTP requests with.
65    ///
66    /// [client]: reqwest::Client
67    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68        self.http_client = http_client;
69        self
70    }
71
72    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
73    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74        self.backoff = backoff;
75        self
76    }
77
78    // API groups
79
80    /// To call [Models] group related APIs using this client.
81    pub fn models(&self) -> Models<'_, C> {
82        Models::new(self)
83    }
84
85    /// To call [Completions] group related APIs using this client.
86    pub fn completions(&self) -> Completions<'_, C> {
87        Completions::new(self)
88    }
89
90    /// To call [Chat] group related APIs using this client.
91    pub fn chat(&self) -> Chat<'_, C> {
92        Chat::new(self)
93    }
94
95    /// To call [Images] group related APIs using this client.
96    pub fn images(&self) -> Images<'_, C> {
97        Images::new(self)
98    }
99
100    /// To call [Moderations] group related APIs using this client.
101    pub fn moderations(&self) -> Moderations<'_, C> {
102        Moderations::new(self)
103    }
104
105    /// To call [Files] group related APIs using this client.
106    pub fn files(&self) -> Files<'_, C> {
107        Files::new(self)
108    }
109
110    /// To call [Uploads] group related APIs using this client.
111    pub fn uploads(&self) -> Uploads<'_, C> {
112        Uploads::new(self)
113    }
114
115    /// To call [FineTuning] group related APIs using this client.
116    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
117        FineTuning::new(self)
118    }
119
120    /// To call [Embeddings] group related APIs using this client.
121    pub fn embeddings(&self) -> Embeddings<'_, C> {
122        Embeddings::new(self)
123    }
124
125    /// To call [Audio] group related APIs using this client.
126    pub fn audio(&self) -> Audio<'_, C> {
127        Audio::new(self)
128    }
129
130    /// To call [Videos] group related APIs using this client.
131    pub fn videos(&self) -> Videos<'_, C> {
132        Videos::new(self)
133    }
134
135    /// To call [Assistants] group related APIs using this client.
136    pub fn assistants(&self) -> Assistants<'_, C> {
137        Assistants::new(self)
138    }
139
140    /// To call [Threads] group related APIs using this client.
141    pub fn threads(&self) -> Threads<'_, C> {
142        Threads::new(self)
143    }
144
145    /// To call [VectorStores] group related APIs using this client.
146    pub fn vector_stores(&self) -> VectorStores<'_, C> {
147        VectorStores::new(self)
148    }
149
150    /// To call [Batches] group related APIs using this client.
151    pub fn batches(&self) -> Batches<'_, C> {
152        Batches::new(self)
153    }
154
155    /// To call [Admin] group related APIs using this client.
156    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
157    pub fn admin(&self) -> Admin<'_, C> {
158        Admin::new(self)
159    }
160
161    /// To call [Usage] group related APIs using this client.
162    pub fn usage(&self) -> Usage<'_, C> {
163        Usage::new(self)
164    }
165
166    /// To call [Responses] group related APIs using this client.
167    pub fn responses(&self) -> Responses<'_, C> {
168        Responses::new(self)
169    }
170
171    /// To call [Conversations] group related APIs using this client.
172    pub fn conversations(&self) -> Conversations<'_, C> {
173        Conversations::new(self)
174    }
175
176    /// To call [Containers] group related APIs using this client.
177    pub fn containers(&self) -> Containers<'_, C> {
178        Containers::new(self)
179    }
180
181    /// To call [Evals] group related APIs using this client.
182    pub fn evals(&self) -> Evals<'_, C> {
183        Evals::new(self)
184    }
185
186    pub fn chatkit(&self) -> Chatkit<'_, C> {
187        Chatkit::new(self)
188    }
189
190    #[cfg(feature = "realtime")]
191    /// To call [Realtime] group related APIs using this client.
192    pub fn realtime(&self) -> Realtime<'_, C> {
193        Realtime::new(self)
194    }
195
196    pub fn config(&self) -> &C {
197        &self.config
198    }
199
200    /// Make a GET request to {path} and deserialize the response body
201    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
202    where
203        O: DeserializeOwned,
204    {
205        let request_maker = || async {
206            Ok(self
207                .http_client
208                .get(self.config.url(path))
209                .query(&self.config.query())
210                .headers(self.config.headers())
211                .build()?)
212        };
213
214        self.execute(request_maker).await
215    }
216
217    /// Make a GET request to {path} with given Query and deserialize the response body
218    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
219    where
220        O: DeserializeOwned,
221        Q: Serialize + ?Sized,
222    {
223        let request_maker = || async {
224            Ok(self
225                .http_client
226                .get(self.config.url(path))
227                .query(&self.config.query())
228                .query(query)
229                .headers(self.config.headers())
230                .build()?)
231        };
232
233        self.execute(request_maker).await
234    }
235
236    /// Make a DELETE request to {path} and deserialize the response body
237    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
238    where
239        O: DeserializeOwned,
240    {
241        let request_maker = || async {
242            Ok(self
243                .http_client
244                .delete(self.config.url(path))
245                .query(&self.config.query())
246                .headers(self.config.headers())
247                .build()?)
248        };
249
250        self.execute(request_maker).await
251    }
252
253    /// Make a GET request to {path} and return the response body
254    pub(crate) async fn get_raw(&self, path: &str) -> Result<(Bytes, HeaderMap), OpenAIError> {
255        let request_maker = || async {
256            Ok(self
257                .http_client
258                .get(self.config.url(path))
259                .query(&self.config.query())
260                .headers(self.config.headers())
261                .build()?)
262        };
263
264        self.execute_raw(request_maker).await
265    }
266
267    pub(crate) async fn get_raw_with_query<Q>(
268        &self,
269        path: &str,
270        query: &Q,
271    ) -> Result<(Bytes, HeaderMap), OpenAIError>
272    where
273        Q: Serialize + ?Sized,
274    {
275        let request_maker = || async {
276            Ok(self
277                .http_client
278                .get(self.config.url(path))
279                .query(&self.config.query())
280                .query(query)
281                .headers(self.config.headers())
282                .build()?)
283        };
284
285        self.execute_raw(request_maker).await
286    }
287
288    /// Make a POST request to {path} and return the response body
289    pub(crate) async fn post_raw<I>(
290        &self,
291        path: &str,
292        request: I,
293    ) -> Result<(Bytes, HeaderMap), OpenAIError>
294    where
295        I: Serialize,
296    {
297        let request_maker = || async {
298            Ok(self
299                .http_client
300                .post(self.config.url(path))
301                .query(&self.config.query())
302                .headers(self.config.headers())
303                .json(&request)
304                .build()?)
305        };
306
307        self.execute_raw(request_maker).await
308    }
309
310    /// Make a POST request to {path} and deserialize the response body
311    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
312    where
313        I: Serialize,
314        O: DeserializeOwned,
315    {
316        let request_maker = || async {
317            Ok(self
318                .http_client
319                .post(self.config.url(path))
320                .query(&self.config.query())
321                .headers(self.config.headers())
322                .json(&request)
323                .build()?)
324        };
325
326        self.execute(request_maker).await
327    }
328
329    /// POST a form at {path} and return the response body
330    pub(crate) async fn post_form_raw<F>(
331        &self,
332        path: &str,
333        form: F,
334    ) -> Result<(Bytes, HeaderMap), OpenAIError>
335    where
336        Form: AsyncTryFrom<F, Error = OpenAIError>,
337        F: Clone,
338    {
339        let request_maker = || async {
340            Ok(self
341                .http_client
342                .post(self.config.url(path))
343                .query(&self.config.query())
344                .headers(self.config.headers())
345                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
346                .build()?)
347        };
348
349        self.execute_raw(request_maker).await
350    }
351
352    /// POST a form at {path} and deserialize the response body
353    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
354    where
355        O: DeserializeOwned,
356        Form: AsyncTryFrom<F, Error = OpenAIError>,
357        F: Clone,
358    {
359        let request_maker = || async {
360            Ok(self
361                .http_client
362                .post(self.config.url(path))
363                .query(&self.config.query())
364                .headers(self.config.headers())
365                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
366                .build()?)
367        };
368
369        self.execute(request_maker).await
370    }
371
372    pub(crate) async fn post_form_stream<O, F>(
373        &self,
374        path: &str,
375        form: F,
376    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
377    where
378        F: Clone,
379        Form: AsyncTryFrom<F, Error = OpenAIError>,
380        O: DeserializeOwned + std::marker::Send + 'static,
381    {
382        // Build and execute request manually since multipart::Form is not Clone
383        // and .eventsource() requires cloneability
384        let response = self
385            .http_client
386            .post(self.config.url(path))
387            .query(&self.config.query())
388            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
389            .headers(self.config.headers())
390            .send()
391            .await
392            .map_err(OpenAIError::Reqwest)?;
393
394        // Check for error status
395        if !response.status().is_success() {
396            return Err(read_response(response).await.unwrap_err());
397        }
398
399        // Convert response body to EventSource stream
400        let stream = response
401            .bytes_stream()
402            .map(|result| result.map_err(std::io::Error::other));
403        let event_stream = eventsource_stream::EventStream::new(stream);
404
405        // Convert EventSource stream to our expected format
406        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
407
408        tokio::spawn(async move {
409            use futures::StreamExt;
410            let mut event_stream = std::pin::pin!(event_stream);
411
412            while let Some(event_result) = event_stream.next().await {
413                match event_result {
414                    Err(e) => {
415                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
416                            StreamError::EventStream(e.to_string()),
417                        )))) {
418                            break;
419                        }
420                    }
421                    Ok(event) => {
422                        // eventsource_stream::Event is a struct with data field
423                        if event.data == "[DONE]" {
424                            break;
425                        }
426
427                        let response = match serde_json::from_str::<O>(&event.data) {
428                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
429                            Ok(output) => Ok(output),
430                        };
431
432                        if let Err(_e) = tx.send(response) {
433                            break;
434                        }
435                    }
436                }
437            }
438        });
439
440        Ok(Box::pin(
441            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
442        ))
443    }
444
445    /// Execute a HTTP request and retry on rate limit
446    ///
447    /// request_maker serves one purpose: to be able to create request again
448    /// to retry API call after getting rate limited. request_maker is async because
449    /// reqwest::multipart::Form is created by async calls to read files for uploads.
450    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
451    where
452        M: Fn() -> Fut,
453        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
454    {
455        let client = self.http_client.clone();
456
457        backoff::future::retry(self.backoff.clone(), || async {
458            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
459            let response = client
460                .execute(request)
461                .await
462                .map_err(OpenAIError::Reqwest)
463                .map_err(backoff::Error::Permanent)?;
464
465            let status = response.status();
466
467            match read_response(response).await {
468                Ok((bytes, headers)) => Ok((bytes, headers)),
469                Err(e) => {
470                    match e {
471                        OpenAIError::ApiError(api_error) => {
472                            if status.is_server_error() {
473                                Err(backoff::Error::Transient {
474                                    err: OpenAIError::ApiError(api_error),
475                                    retry_after: None,
476                                })
477                            } else if status.as_u16() == 429
478                                && api_error.r#type != Some("insufficient_quota".to_string())
479                            {
480                                // Rate limited retry...
481                                tracing::warn!("Rate limited: {}", api_error.message);
482                                Err(backoff::Error::Transient {
483                                    err: OpenAIError::ApiError(api_error),
484                                    retry_after: None,
485                                })
486                            } else {
487                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
488                            }
489                        }
490                        _ => Err(backoff::Error::Permanent(e)),
491                    }
492                }
493            }
494        })
495        .await
496    }
497
498    /// Execute a HTTP request and retry on rate limit
499    ///
500    /// request_maker serves one purpose: to be able to create request again
501    /// to retry API call after getting rate limited. request_maker is async because
502    /// reqwest::multipart::Form is created by async calls to read files for uploads.
503    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
504    where
505        O: DeserializeOwned,
506        M: Fn() -> Fut,
507        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
508    {
509        let (bytes, _headers) = self.execute_raw(request_maker).await?;
510
511        let response: O = serde_json::from_slice(bytes.as_ref())
512            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
513
514        Ok(response)
515    }
516
517    /// Make HTTP POST request to receive SSE
518    pub(crate) async fn post_stream<I, O>(
519        &self,
520        path: &str,
521        request: I,
522    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
523    where
524        I: Serialize,
525        O: DeserializeOwned + std::marker::Send + 'static,
526    {
527        let event_source = self
528            .http_client
529            .post(self.config.url(path))
530            .query(&self.config.query())
531            .headers(self.config.headers())
532            .json(&request)
533            .eventsource()
534            .unwrap();
535
536        stream(event_source).await
537    }
538
539    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
540        &self,
541        path: &str,
542        request: I,
543        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
544    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
545    where
546        I: Serialize,
547        O: DeserializeOwned + std::marker::Send + 'static,
548    {
549        let event_source = self
550            .http_client
551            .post(self.config.url(path))
552            .query(&self.config.query())
553            .headers(self.config.headers())
554            .json(&request)
555            .eventsource()
556            .unwrap();
557
558        stream_mapped_raw_events(event_source, event_mapper).await
559    }
560
561    /// Make HTTP GET request to receive SSE
562    pub(crate) async fn _get_stream<Q, O>(
563        &self,
564        path: &str,
565        query: &Q,
566    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
567    where
568        Q: Serialize + ?Sized,
569        O: DeserializeOwned + std::marker::Send + 'static,
570    {
571        let event_source = self
572            .http_client
573            .get(self.config.url(path))
574            .query(query)
575            .query(&self.config.query())
576            .headers(self.config.headers())
577            .eventsource()
578            .unwrap();
579
580        stream(event_source).await
581    }
582}
583
584async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
585    let status = response.status();
586    let headers = response.headers().clone();
587    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
588
589    if status.is_server_error() {
590        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
591        let message: String = String::from_utf8_lossy(&bytes).into_owned();
592        tracing::warn!("Server error: {status} - {message}");
593        return Err(OpenAIError::ApiError(ApiError {
594            message,
595            r#type: None,
596            param: None,
597            code: None,
598        }));
599    }
600
601    // Deserialize response body from either error object or actual response object
602    if !status.is_success() {
603        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
604            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
605
606        return Err(OpenAIError::ApiError(wrapped_error.error));
607    }
608
609    Ok((bytes, headers))
610}
611
612async fn map_stream_error(value: EventSourceError) -> OpenAIError {
613    match value {
614        EventSourceError::InvalidStatusCode(status_code, response) => {
615            read_response(response).await.expect_err(&format!(
616                "Unreachable because read_response returns err when status_code {status_code} is invalid"
617            ))
618        }
619        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
620    }
621}
622
623/// Request which responds with SSE.
624/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
625pub(crate) async fn stream<O>(
626    mut event_source: EventSource,
627) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
628where
629    O: DeserializeOwned + std::marker::Send + 'static,
630{
631    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
632
633    tokio::spawn(async move {
634        while let Some(ev) = event_source.next().await {
635            match ev {
636                Err(e) => {
637                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
638                        // rx dropped
639                        break;
640                    }
641                }
642                Ok(event) => match event {
643                    Event::Message(message) => {
644                        if message.data == "[DONE]" {
645                            break;
646                        }
647
648                        let response = match serde_json::from_str::<O>(&message.data) {
649                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
650                            Ok(output) => Ok(output),
651                        };
652
653                        if let Err(_e) = tx.send(response) {
654                            // rx dropped
655                            break;
656                        }
657                    }
658                    Event::Open => continue,
659                },
660            }
661        }
662
663        event_source.close();
664    });
665
666    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
667}
668
669pub(crate) async fn stream_mapped_raw_events<O>(
670    mut event_source: EventSource,
671    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
672) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
673where
674    O: DeserializeOwned + std::marker::Send + 'static,
675{
676    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
677
678    tokio::spawn(async move {
679        while let Some(ev) = event_source.next().await {
680            match ev {
681                Err(e) => {
682                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
683                        // rx dropped
684                        break;
685                    }
686                }
687                Ok(event) => match event {
688                    Event::Message(message) => {
689                        let mut done = false;
690
691                        if message.data == "[DONE]" {
692                            done = true;
693                        }
694
695                        let response = event_mapper(message);
696
697                        if let Err(_e) = tx.send(response) {
698                            // rx dropped
699                            break;
700                        }
701
702                        if done {
703                            break;
704                        }
705                    }
706                    Event::Open => continue,
707                },
708            }
709        }
710
711        event_source.close();
712    });
713
714    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
715}