async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    admin::Admin,
11    chatkit::Chatkit,
12    config::{Config, OpenAIConfig},
13    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
14    file::Files,
15    image::Images,
16    moderation::Moderations,
17    traits::AsyncTryFrom,
18    Assistants, Audio, Batches, Chat, Completions, Containers, Conversations, Embeddings, Evals,
19    FineTuning, Models, RequestOptions, Responses, Threads, Uploads, Usage, VectorStores, Videos,
20};
21
22#[cfg(feature = "realtime")]
23use crate::Realtime;
24
25#[derive(Debug, Clone, Default)]
26/// Client is a container for config, backoff and http_client
27/// used to make API calls.
28pub struct Client<C: Config> {
29    http_client: reqwest::Client,
30    config: C,
31    backoff: backoff::ExponentialBackoff,
32}
33
34impl Client<OpenAIConfig> {
35    /// Client with default [OpenAIConfig]
36    pub fn new() -> Self {
37        Self::default()
38    }
39}
40
41impl<C: Config> Client<C> {
42    /// Create client with a custom HTTP client, OpenAI config, and backoff.
43    pub fn build(
44        http_client: reqwest::Client,
45        config: C,
46        backoff: backoff::ExponentialBackoff,
47    ) -> Self {
48        Self {
49            http_client,
50            config,
51            backoff,
52        }
53    }
54
55    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
56    pub fn with_config(config: C) -> Self {
57        Self {
58            http_client: reqwest::Client::new(),
59            config,
60            backoff: Default::default(),
61        }
62    }
63
64    /// Provide your own [client] to make HTTP requests with.
65    ///
66    /// [client]: reqwest::Client
67    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68        self.http_client = http_client;
69        self
70    }
71
72    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
73    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74        self.backoff = backoff;
75        self
76    }
77
78    // API groups
79
80    /// To call [Models] group related APIs using this client.
81    pub fn models(&self) -> Models<'_, C> {
82        Models::new(self)
83    }
84
85    /// To call [Completions] group related APIs using this client.
86    pub fn completions(&self) -> Completions<'_, C> {
87        Completions::new(self)
88    }
89
90    /// To call [Chat] group related APIs using this client.
91    pub fn chat(&self) -> Chat<'_, C> {
92        Chat::new(self)
93    }
94
95    /// To call [Images] group related APIs using this client.
96    pub fn images(&self) -> Images<'_, C> {
97        Images::new(self)
98    }
99
100    /// To call [Moderations] group related APIs using this client.
101    pub fn moderations(&self) -> Moderations<'_, C> {
102        Moderations::new(self)
103    }
104
105    /// To call [Files] group related APIs using this client.
106    pub fn files(&self) -> Files<'_, C> {
107        Files::new(self)
108    }
109
110    /// To call [Uploads] group related APIs using this client.
111    pub fn uploads(&self) -> Uploads<'_, C> {
112        Uploads::new(self)
113    }
114
115    /// To call [FineTuning] group related APIs using this client.
116    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
117        FineTuning::new(self)
118    }
119
120    /// To call [Embeddings] group related APIs using this client.
121    pub fn embeddings(&self) -> Embeddings<'_, C> {
122        Embeddings::new(self)
123    }
124
125    /// To call [Audio] group related APIs using this client.
126    pub fn audio(&self) -> Audio<'_, C> {
127        Audio::new(self)
128    }
129
130    /// To call [Videos] group related APIs using this client.
131    pub fn videos(&self) -> Videos<'_, C> {
132        Videos::new(self)
133    }
134
135    /// To call [Assistants] group related APIs using this client.
136    pub fn assistants(&self) -> Assistants<'_, C> {
137        Assistants::new(self)
138    }
139
140    /// To call [Threads] group related APIs using this client.
141    pub fn threads(&self) -> Threads<'_, C> {
142        Threads::new(self)
143    }
144
145    /// To call [VectorStores] group related APIs using this client.
146    pub fn vector_stores(&self) -> VectorStores<'_, C> {
147        VectorStores::new(self)
148    }
149
150    /// To call [Batches] group related APIs using this client.
151    pub fn batches(&self) -> Batches<'_, C> {
152        Batches::new(self)
153    }
154
155    /// To call [Admin] group related APIs using this client.
156    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
157    pub fn admin(&self) -> Admin<'_, C> {
158        Admin::new(self)
159    }
160
161    /// To call [Usage] group related APIs using this client.
162    pub fn usage(&self) -> Usage<'_, C> {
163        Usage::new(self)
164    }
165
166    /// To call [Responses] group related APIs using this client.
167    pub fn responses(&self) -> Responses<'_, C> {
168        Responses::new(self)
169    }
170
171    /// To call [Conversations] group related APIs using this client.
172    pub fn conversations(&self) -> Conversations<'_, C> {
173        Conversations::new(self)
174    }
175
176    /// To call [Containers] group related APIs using this client.
177    pub fn containers(&self) -> Containers<'_, C> {
178        Containers::new(self)
179    }
180
181    /// To call [Evals] group related APIs using this client.
182    pub fn evals(&self) -> Evals<'_, C> {
183        Evals::new(self)
184    }
185
186    pub fn chatkit(&self) -> Chatkit<'_, C> {
187        Chatkit::new(self)
188    }
189
190    #[cfg(feature = "realtime")]
191    /// To call [Realtime] group related APIs using this client.
192    pub fn realtime(&self) -> Realtime<'_, C> {
193        Realtime::new(self)
194    }
195
196    pub fn config(&self) -> &C {
197        &self.config
198    }
199
200    /// Helper function to build a request builder with common configuration
201    fn build_request_builder(
202        &self,
203        method: reqwest::Method,
204        path: &str,
205        request_options: &RequestOptions,
206    ) -> reqwest::RequestBuilder {
207        let mut request_builder = if let Some(path) = request_options.path() {
208            self.http_client
209                .request(method, self.config.url(path.as_str()))
210        } else {
211            self.http_client.request(method, self.config.url(path))
212        };
213
214        request_builder = request_builder
215            .query(&self.config.query())
216            .headers(self.config.headers());
217
218        if let Some(headers) = request_options.headers() {
219            request_builder = request_builder.headers(headers.clone());
220        }
221
222        if !request_options.query().is_empty() {
223            request_builder = request_builder.query(request_options.query());
224        }
225
226        request_builder
227    }
228
229    /// Make a GET request to {path} and deserialize the response body
230    pub(crate) async fn get<O>(
231        &self,
232        path: &str,
233        request_options: &RequestOptions,
234    ) -> Result<O, OpenAIError>
235    where
236        O: DeserializeOwned,
237    {
238        let request_maker = || async {
239            Ok(self
240                .build_request_builder(reqwest::Method::GET, path, request_options)
241                .build()?)
242        };
243
244        self.execute(request_maker).await
245    }
246
247    /// Make a DELETE request to {path} and deserialize the response body
248    pub(crate) async fn delete<O>(
249        &self,
250        path: &str,
251        request_options: &RequestOptions,
252    ) -> Result<O, OpenAIError>
253    where
254        O: DeserializeOwned,
255    {
256        let request_maker = || async {
257            Ok(self
258                .build_request_builder(reqwest::Method::DELETE, path, request_options)
259                .build()?)
260        };
261
262        self.execute(request_maker).await
263    }
264
265    /// Make a GET request to {path} and return the response body
266    pub(crate) async fn get_raw(
267        &self,
268        path: &str,
269        request_options: &RequestOptions,
270    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
271        let request_maker = || async {
272            Ok(self
273                .build_request_builder(reqwest::Method::GET, path, request_options)
274                .build()?)
275        };
276
277        self.execute_raw(request_maker).await
278    }
279
280    /// Make a POST request to {path} and return the response body
281    pub(crate) async fn post_raw<I>(
282        &self,
283        path: &str,
284        request: I,
285        request_options: &RequestOptions,
286    ) -> Result<(Bytes, HeaderMap), OpenAIError>
287    where
288        I: Serialize,
289    {
290        let request_maker = || async {
291            Ok(self
292                .build_request_builder(reqwest::Method::POST, path, request_options)
293                .json(&request)
294                .build()?)
295        };
296
297        self.execute_raw(request_maker).await
298    }
299
300    /// Make a POST request to {path} and deserialize the response body
301    pub(crate) async fn post<I, O>(
302        &self,
303        path: &str,
304        request: I,
305        request_options: &RequestOptions,
306    ) -> Result<O, OpenAIError>
307    where
308        I: Serialize,
309        O: DeserializeOwned,
310    {
311        let request_maker = || async {
312            Ok(self
313                .build_request_builder(reqwest::Method::POST, path, request_options)
314                .json(&request)
315                .build()?)
316        };
317
318        self.execute(request_maker).await
319    }
320
321    /// POST a form at {path} and return the response body
322    pub(crate) async fn post_form_raw<F>(
323        &self,
324        path: &str,
325        form: F,
326        request_options: &RequestOptions,
327    ) -> Result<(Bytes, HeaderMap), OpenAIError>
328    where
329        Form: AsyncTryFrom<F, Error = OpenAIError>,
330        F: Clone,
331    {
332        let request_maker = || async {
333            Ok(self
334                .build_request_builder(reqwest::Method::POST, path, request_options)
335                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
336                .build()?)
337        };
338
339        self.execute_raw(request_maker).await
340    }
341
342    /// POST a form at {path} and deserialize the response body
343    pub(crate) async fn post_form<O, F>(
344        &self,
345        path: &str,
346        form: F,
347        request_options: &RequestOptions,
348    ) -> Result<O, OpenAIError>
349    where
350        O: DeserializeOwned,
351        Form: AsyncTryFrom<F, Error = OpenAIError>,
352        F: Clone,
353    {
354        let request_maker = || async {
355            Ok(self
356                .build_request_builder(reqwest::Method::POST, path, request_options)
357                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
358                .build()?)
359        };
360
361        self.execute(request_maker).await
362    }
363
364    pub(crate) async fn post_form_stream<O, F>(
365        &self,
366        path: &str,
367        form: F,
368        request_options: &RequestOptions,
369    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
370    where
371        F: Clone,
372        Form: AsyncTryFrom<F, Error = OpenAIError>,
373        O: DeserializeOwned + std::marker::Send + 'static,
374    {
375        // Build and execute request manually since multipart::Form is not Clone
376        // and .eventsource() requires cloneability
377        let request_builder = self
378            .build_request_builder(reqwest::Method::POST, path, request_options)
379            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
380
381        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
382
383        // Check for error status
384        if !response.status().is_success() {
385            return Err(read_response(response).await.unwrap_err());
386        }
387
388        // Convert response body to EventSource stream
389        let stream = response
390            .bytes_stream()
391            .map(|result| result.map_err(std::io::Error::other));
392        let event_stream = eventsource_stream::EventStream::new(stream);
393
394        // Convert EventSource stream to our expected format
395        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
396
397        tokio::spawn(async move {
398            use futures::StreamExt;
399            let mut event_stream = std::pin::pin!(event_stream);
400
401            while let Some(event_result) = event_stream.next().await {
402                match event_result {
403                    Err(e) => {
404                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
405                            StreamError::EventStream(e.to_string()),
406                        )))) {
407                            break;
408                        }
409                    }
410                    Ok(event) => {
411                        // eventsource_stream::Event is a struct with data field
412                        if event.data == "[DONE]" {
413                            break;
414                        }
415
416                        let response = match serde_json::from_str::<O>(&event.data) {
417                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
418                            Ok(output) => Ok(output),
419                        };
420
421                        if let Err(_e) = tx.send(response) {
422                            break;
423                        }
424                    }
425                }
426            }
427        });
428
429        Ok(Box::pin(
430            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
431        ))
432    }
433
434    /// Execute a HTTP request and retry on rate limit
435    ///
436    /// request_maker serves one purpose: to be able to create request again
437    /// to retry API call after getting rate limited. request_maker is async because
438    /// reqwest::multipart::Form is created by async calls to read files for uploads.
439    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
440    where
441        M: Fn() -> Fut,
442        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
443    {
444        let client = self.http_client.clone();
445
446        backoff::future::retry(self.backoff.clone(), || async {
447            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
448            let response = client
449                .execute(request)
450                .await
451                .map_err(OpenAIError::Reqwest)
452                .map_err(backoff::Error::Permanent)?;
453
454            let status = response.status();
455
456            match read_response(response).await {
457                Ok((bytes, headers)) => Ok((bytes, headers)),
458                Err(e) => {
459                    match e {
460                        OpenAIError::ApiError(api_error) => {
461                            if status.is_server_error() {
462                                Err(backoff::Error::Transient {
463                                    err: OpenAIError::ApiError(api_error),
464                                    retry_after: None,
465                                })
466                            } else if status.as_u16() == 429
467                                && api_error.r#type != Some("insufficient_quota".to_string())
468                            {
469                                // Rate limited retry...
470                                tracing::warn!("Rate limited: {}", api_error.message);
471                                Err(backoff::Error::Transient {
472                                    err: OpenAIError::ApiError(api_error),
473                                    retry_after: None,
474                                })
475                            } else {
476                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
477                            }
478                        }
479                        _ => Err(backoff::Error::Permanent(e)),
480                    }
481                }
482            }
483        })
484        .await
485    }
486
487    /// Execute a HTTP request and retry on rate limit
488    ///
489    /// request_maker serves one purpose: to be able to create request again
490    /// to retry API call after getting rate limited. request_maker is async because
491    /// reqwest::multipart::Form is created by async calls to read files for uploads.
492    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
493    where
494        O: DeserializeOwned,
495        M: Fn() -> Fut,
496        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
497    {
498        let (bytes, _headers) = self.execute_raw(request_maker).await?;
499
500        let response: O = serde_json::from_slice(bytes.as_ref())
501            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
502
503        Ok(response)
504    }
505
506    /// Make HTTP POST request to receive SSE
507    pub(crate) async fn post_stream<I, O>(
508        &self,
509        path: &str,
510        request: I,
511        request_options: &RequestOptions,
512    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
513    where
514        I: Serialize,
515        O: DeserializeOwned + std::marker::Send + 'static,
516    {
517        let request_builder = self
518            .build_request_builder(reqwest::Method::POST, path, request_options)
519            .json(&request);
520
521        let event_source = request_builder.eventsource().unwrap();
522
523        stream(event_source).await
524    }
525
526    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
527        &self,
528        path: &str,
529        request: I,
530        request_options: &RequestOptions,
531        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
532    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
533    where
534        I: Serialize,
535        O: DeserializeOwned + std::marker::Send + 'static,
536    {
537        let request_builder = self
538            .build_request_builder(reqwest::Method::POST, path, request_options)
539            .json(&request);
540
541        let event_source = request_builder.eventsource().unwrap();
542
543        stream_mapped_raw_events(event_source, event_mapper).await
544    }
545
546    /// Make HTTP GET request to receive SSE
547    pub(crate) async fn _get_stream<Q, O>(
548        &self,
549        path: &str,
550        query: &Q,
551    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
552    where
553        Q: Serialize + ?Sized,
554        O: DeserializeOwned + std::marker::Send + 'static,
555    {
556        let event_source = self
557            .http_client
558            .get(self.config.url(path))
559            .query(query)
560            .query(&self.config.query())
561            .headers(self.config.headers())
562            .eventsource()
563            .unwrap();
564
565        stream(event_source).await
566    }
567}
568
569async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
570    let status = response.status();
571    let headers = response.headers().clone();
572    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
573
574    if status.is_server_error() {
575        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
576        let message: String = String::from_utf8_lossy(&bytes).into_owned();
577        tracing::warn!("Server error: {status} - {message}");
578        return Err(OpenAIError::ApiError(ApiError {
579            message,
580            r#type: None,
581            param: None,
582            code: None,
583        }));
584    }
585
586    // Deserialize response body from either error object or actual response object
587    if !status.is_success() {
588        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
589            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
590
591        return Err(OpenAIError::ApiError(wrapped_error.error));
592    }
593
594    Ok((bytes, headers))
595}
596
597async fn map_stream_error(value: EventSourceError) -> OpenAIError {
598    match value {
599        EventSourceError::InvalidStatusCode(status_code, response) => {
600            read_response(response).await.expect_err(&format!(
601                "Unreachable because read_response returns err when status_code {status_code} is invalid"
602            ))
603        }
604        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
605    }
606}
607
608/// Request which responds with SSE.
609/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
610pub(crate) async fn stream<O>(
611    mut event_source: EventSource,
612) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
613where
614    O: DeserializeOwned + std::marker::Send + 'static,
615{
616    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
617
618    tokio::spawn(async move {
619        while let Some(ev) = event_source.next().await {
620            match ev {
621                Err(e) => {
622                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
623                    // https://github.com/64bit/async-openai/issues/456
624                    match &e {
625                        EventSourceError::StreamEnded => {
626                            break;
627                        }
628                        _ => {
629                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
630                                // rx dropped
631                                break;
632                            }
633                        }
634                    }
635                }
636                Ok(event) => match event {
637                    Event::Message(message) => {
638                        if message.data == "[DONE]" {
639                            break;
640                        }
641
642                        let response = match serde_json::from_str::<O>(&message.data) {
643                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
644                            Ok(output) => Ok(output),
645                        };
646
647                        if let Err(_e) = tx.send(response) {
648                            // rx dropped
649                            break;
650                        }
651                    }
652                    Event::Open => continue,
653                },
654            }
655        }
656
657        event_source.close();
658    });
659
660    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
661}
662
663pub(crate) async fn stream_mapped_raw_events<O>(
664    mut event_source: EventSource,
665    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
666) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
667where
668    O: DeserializeOwned + std::marker::Send + 'static,
669{
670    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
671
672    tokio::spawn(async move {
673        while let Some(ev) = event_source.next().await {
674            match ev {
675                Err(e) => {
676                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
677                    // https://github.com/64bit/async-openai/issues/456
678                    match &e {
679                        EventSourceError::StreamEnded => {
680                            break;
681                        }
682                        _ => {
683                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
684                                // rx dropped
685                                break;
686                            }
687                        }
688                    }
689                }
690                Ok(event) => match event {
691                    Event::Message(message) => {
692                        let mut done = false;
693
694                        if message.data == "[DONE]" {
695                            done = true;
696                        }
697
698                        let response = event_mapper(message);
699
700                        if let Err(_e) = tx.send(response) {
701                            // rx dropped
702                            break;
703                        }
704
705                        if done {
706                            break;
707                        }
708                    }
709                    Event::Open => continue,
710                },
711            }
712        }
713
714        event_source.close();
715    });
716
717    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
718}