async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17    Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21/// Client is a container for config, backoff and http_client
22/// used to make API calls.
23pub struct Client<C: Config> {
24    http_client: reqwest::Client,
25    config: C,
26    backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30    /// Client with default [OpenAIConfig]
31    pub fn new() -> Self {
32        Self::default()
33    }
34}
35
36impl<C: Config> Client<C> {
37    /// Create client with a custom HTTP client, OpenAI config, and backoff.
38    pub fn build(
39        http_client: reqwest::Client,
40        config: C,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
51    pub fn with_config(config: C) -> Self {
52        Self {
53            http_client: reqwest::Client::new(),
54            config,
55            backoff: Default::default(),
56        }
57    }
58
59    /// Provide your own [client] to make HTTP requests with.
60    ///
61    /// [client]: reqwest::Client
62    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63        self.http_client = http_client;
64        self
65    }
66
67    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
68    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69        self.backoff = backoff;
70        self
71    }
72
73    // API groups
74
75    /// To call [Models] group related APIs using this client.
76    pub fn models(&self) -> Models<C> {
77        Models::new(self)
78    }
79
80    /// To call [Completions] group related APIs using this client.
81    pub fn completions(&self) -> Completions<C> {
82        Completions::new(self)
83    }
84
85    /// To call [Chat] group related APIs using this client.
86    pub fn chat(&self) -> Chat<C> {
87        Chat::new(self)
88    }
89
90    /// To call [Images] group related APIs using this client.
91    pub fn images(&self) -> Images<C> {
92        Images::new(self)
93    }
94
95    /// To call [Moderations] group related APIs using this client.
96    pub fn moderations(&self) -> Moderations<C> {
97        Moderations::new(self)
98    }
99
100    /// To call [Files] group related APIs using this client.
101    pub fn files(&self) -> Files<C> {
102        Files::new(self)
103    }
104
105    /// To call [Uploads] group related APIs using this client.
106    pub fn uploads(&self) -> Uploads<C> {
107        Uploads::new(self)
108    }
109
110    /// To call [FineTuning] group related APIs using this client.
111    pub fn fine_tuning(&self) -> FineTuning<C> {
112        FineTuning::new(self)
113    }
114
115    /// To call [Embeddings] group related APIs using this client.
116    pub fn embeddings(&self) -> Embeddings<C> {
117        Embeddings::new(self)
118    }
119
120    /// To call [Audio] group related APIs using this client.
121    pub fn audio(&self) -> Audio<C> {
122        Audio::new(self)
123    }
124
125    /// To call [Assistants] group related APIs using this client.
126    pub fn assistants(&self) -> Assistants<C> {
127        Assistants::new(self)
128    }
129
130    /// To call [Threads] group related APIs using this client.
131    pub fn threads(&self) -> Threads<C> {
132        Threads::new(self)
133    }
134
135    /// To call [VectorStores] group related APIs using this client.
136    pub fn vector_stores(&self) -> VectorStores<C> {
137        VectorStores::new(self)
138    }
139
140    /// To call [Batches] group related APIs using this client.
141    pub fn batches(&self) -> Batches<C> {
142        Batches::new(self)
143    }
144
145    /// To call [AuditLogs] group related APIs using this client.
146    pub fn audit_logs(&self) -> AuditLogs<C> {
147        AuditLogs::new(self)
148    }
149
150    /// To call [Invites] group related APIs using this client.
151    pub fn invites(&self) -> Invites<C> {
152        Invites::new(self)
153    }
154
155    /// To call [Users] group related APIs using this client.
156    pub fn users(&self) -> Users<C> {
157        Users::new(self)
158    }
159
160    /// To call [Projects] group related APIs using this client.
161    pub fn projects(&self) -> Projects<C> {
162        Projects::new(self)
163    }
164
165    /// To call [Responses] group related APIs using this client.
166    pub fn responses(&self) -> Responses<C> {
167        Responses::new(self)
168    }
169
170    pub fn config(&self) -> &C {
171        &self.config
172    }
173
174    /// Make a GET request to {path} and deserialize the response body
175    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
176    where
177        O: DeserializeOwned,
178    {
179        let request_maker = || async {
180            Ok(self
181                .http_client
182                .get(self.config.url(path))
183                .query(&self.config.query())
184                .headers(self.config.headers())
185                .build()?)
186        };
187
188        self.execute(request_maker).await
189    }
190
191    /// Make a GET request to {path} with given Query and deserialize the response body
192    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
193    where
194        O: DeserializeOwned,
195        Q: Serialize + ?Sized,
196    {
197        let request_maker = || async {
198            Ok(self
199                .http_client
200                .get(self.config.url(path))
201                .query(&self.config.query())
202                .query(query)
203                .headers(self.config.headers())
204                .build()?)
205        };
206
207        self.execute(request_maker).await
208    }
209
210    /// Make a DELETE request to {path} and deserialize the response body
211    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
212    where
213        O: DeserializeOwned,
214    {
215        let request_maker = || async {
216            Ok(self
217                .http_client
218                .delete(self.config.url(path))
219                .query(&self.config.query())
220                .headers(self.config.headers())
221                .build()?)
222        };
223
224        self.execute(request_maker).await
225    }
226
227    /// Make a GET request to {path} and return the response body
228    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
229        let request_maker = || async {
230            Ok(self
231                .http_client
232                .get(self.config.url(path))
233                .query(&self.config.query())
234                .headers(self.config.headers())
235                .build()?)
236        };
237
238        self.execute_raw(request_maker).await
239    }
240
241    /// Make a POST request to {path} and return the response body
242    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
243    where
244        I: Serialize,
245    {
246        let request_maker = || async {
247            Ok(self
248                .http_client
249                .post(self.config.url(path))
250                .query(&self.config.query())
251                .headers(self.config.headers())
252                .json(&request)
253                .build()?)
254        };
255
256        self.execute_raw(request_maker).await
257    }
258
259    /// Make a POST request to {path} and deserialize the response body
260    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
261    where
262        I: Serialize,
263        O: DeserializeOwned,
264    {
265        let request_maker = || async {
266            Ok(self
267                .http_client
268                .post(self.config.url(path))
269                .query(&self.config.query())
270                .headers(self.config.headers())
271                .json(&request)
272                .build()?)
273        };
274
275        self.execute(request_maker).await
276    }
277
278    /// POST a form at {path} and return the response body
279    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
280    where
281        Form: AsyncTryFrom<F, Error = OpenAIError>,
282        F: Clone,
283    {
284        let request_maker = || async {
285            Ok(self
286                .http_client
287                .post(self.config.url(path))
288                .query(&self.config.query())
289                .headers(self.config.headers())
290                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
291                .build()?)
292        };
293
294        self.execute_raw(request_maker).await
295    }
296
297    /// POST a form at {path} and deserialize the response body
298    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
299    where
300        O: DeserializeOwned,
301        Form: AsyncTryFrom<F, Error = OpenAIError>,
302        F: Clone,
303    {
304        let request_maker = || async {
305            Ok(self
306                .http_client
307                .post(self.config.url(path))
308                .query(&self.config.query())
309                .headers(self.config.headers())
310                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
311                .build()?)
312        };
313
314        self.execute(request_maker).await
315    }
316
317    /// Execute a HTTP request and retry on rate limit
318    ///
319    /// request_maker serves one purpose: to be able to create request again
320    /// to retry API call after getting rate limited. request_maker is async because
321    /// reqwest::multipart::Form is created by async calls to read files for uploads.
322    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
323    where
324        M: Fn() -> Fut,
325        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
326    {
327        let client = self.http_client.clone();
328
329        backoff::future::retry(self.backoff.clone(), || async {
330            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
331            let response = client
332                .execute(request)
333                .await
334                .map_err(OpenAIError::Reqwest)
335                .map_err(backoff::Error::Permanent)?;
336
337            let status = response.status();
338
339            match read_response(response).await {
340                Ok(bytes) => Ok(bytes),
341                Err(e) => {
342                    match e {
343                        OpenAIError::ApiError(api_error) => {
344                            if status.is_server_error() {
345                                Err(backoff::Error::Transient {
346                                    err: OpenAIError::ApiError(api_error),
347                                    retry_after: None,
348                                })
349                            } else if status.as_u16() == 429
350                                && api_error.r#type != Some("insufficient_quota".to_string())
351                            {
352                                // Rate limited retry...
353                                tracing::warn!("Rate limited: {}", api_error.message);
354                                Err(backoff::Error::Transient {
355                                    err: OpenAIError::ApiError(api_error),
356                                    retry_after: None,
357                                })
358                            } else {
359                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
360                            }
361                        }
362                        _ => Err(backoff::Error::Permanent(e)),
363                    }
364                }
365            }
366        })
367        .await
368    }
369
370    /// Execute a HTTP request and retry on rate limit
371    ///
372    /// request_maker serves one purpose: to be able to create request again
373    /// to retry API call after getting rate limited. request_maker is async because
374    /// reqwest::multipart::Form is created by async calls to read files for uploads.
375    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
376    where
377        O: DeserializeOwned,
378        M: Fn() -> Fut,
379        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
380    {
381        let bytes = self.execute_raw(request_maker).await?;
382
383        let response: O = serde_json::from_slice(bytes.as_ref())
384            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
385
386        Ok(response)
387    }
388
389    /// Make HTTP POST request to receive SSE
390    pub(crate) async fn post_stream<I, O>(
391        &self,
392        path: &str,
393        request: I,
394    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
395    where
396        I: Serialize,
397        O: DeserializeOwned + std::marker::Send + 'static,
398    {
399        let event_source = self
400            .http_client
401            .post(self.config.url(path))
402            .query(&self.config.query())
403            .headers(self.config.headers())
404            .json(&request)
405            .eventsource()
406            .unwrap();
407
408        stream(event_source).await
409    }
410
411    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
412        &self,
413        path: &str,
414        request: I,
415        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
416    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
417    where
418        I: Serialize,
419        O: DeserializeOwned + std::marker::Send + 'static,
420    {
421        let event_source = self
422            .http_client
423            .post(self.config.url(path))
424            .query(&self.config.query())
425            .headers(self.config.headers())
426            .json(&request)
427            .eventsource()
428            .unwrap();
429
430        stream_mapped_raw_events(event_source, event_mapper).await
431    }
432
433    /// Make HTTP GET request to receive SSE
434    pub(crate) async fn _get_stream<Q, O>(
435        &self,
436        path: &str,
437        query: &Q,
438    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
439    where
440        Q: Serialize + ?Sized,
441        O: DeserializeOwned + std::marker::Send + 'static,
442    {
443        let event_source = self
444            .http_client
445            .get(self.config.url(path))
446            .query(query)
447            .query(&self.config.query())
448            .headers(self.config.headers())
449            .eventsource()
450            .unwrap();
451
452        stream(event_source).await
453    }
454}
455
456async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
457    let status = response.status();
458    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
459
460    if status.is_server_error() {
461        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
462        let message: String = String::from_utf8_lossy(&bytes).into_owned();
463        tracing::warn!("Server error: {status} - {message}");
464        return Err(OpenAIError::ApiError(ApiError {
465            message,
466            r#type: None,
467            param: None,
468            code: None,
469        }));
470    }
471
472    // Deserialize response body from either error object or actual response object
473    if !status.is_success() {
474        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
475            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
476
477        return Err(OpenAIError::ApiError(wrapped_error.error));
478    }
479
480    Ok(bytes)
481}
482
483async fn map_stream_error(value: EventSourceError) -> OpenAIError {
484    match value {
485        EventSourceError::InvalidStatusCode(status_code, response) => {
486            read_response(response).await.expect_err(&format!(
487                "Unreachable because read_response returns err when status_code {status_code} is invalid"
488            ))
489        }
490        _ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
491    }
492}
493
494/// Request which responds with SSE.
495/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
496pub(crate) async fn stream<O>(
497    mut event_source: EventSource,
498) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
499where
500    O: DeserializeOwned + std::marker::Send + 'static,
501{
502    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
503
504    tokio::spawn(async move {
505        while let Some(ev) = event_source.next().await {
506            match ev {
507                Err(e) => {
508                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
509                        // rx dropped
510                        break;
511                    }
512                }
513                Ok(event) => match event {
514                    Event::Message(message) => {
515                        if message.data == "[DONE]" {
516                            break;
517                        }
518
519                        let response = match serde_json::from_str::<O>(&message.data) {
520                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
521                            Ok(output) => Ok(output),
522                        };
523
524                        if let Err(_e) = tx.send(response) {
525                            // rx dropped
526                            break;
527                        }
528                    }
529                    Event::Open => continue,
530                },
531            }
532        }
533
534        event_source.close();
535    });
536
537    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
538}
539
540pub(crate) async fn stream_mapped_raw_events<O>(
541    mut event_source: EventSource,
542    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
543) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
544where
545    O: DeserializeOwned + std::marker::Send + 'static,
546{
547    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
548
549    tokio::spawn(async move {
550        while let Some(ev) = event_source.next().await {
551            match ev {
552                Err(e) => {
553                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
554                        // rx dropped
555                        break;
556                    }
557                }
558                Ok(event) => match event {
559                    Event::Message(message) => {
560                        let mut done = false;
561
562                        if message.data == "[DONE]" {
563                            done = true;
564                        }
565
566                        let response = event_mapper(message);
567
568                        if let Err(_e) = tx.send(response) {
569                            // rx dropped
570                            break;
571                        }
572
573                        if done {
574                            break;
575                        }
576                    }
577                    Event::Open => continue,
578                },
579            }
580        }
581
582        event_source.close();
583    });
584
585    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
586}