async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17    Models, Projects, Responses, Threads, Uploads, Users, VectorStores, Videos,
18};
19
20#[derive(Debug, Clone, Default)]
21/// Client is a container for config, backoff and http_client
22/// used to make API calls.
23pub struct Client<C: Config> {
24    http_client: reqwest::Client,
25    config: C,
26    backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30    /// Client with default [OpenAIConfig]
31    pub fn new() -> Self {
32        Self::default()
33    }
34}
35
36impl<C: Config> Client<C> {
37    /// Create client with a custom HTTP client, OpenAI config, and backoff.
38    pub fn build(
39        http_client: reqwest::Client,
40        config: C,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
51    pub fn with_config(config: C) -> Self {
52        Self {
53            http_client: reqwest::Client::new(),
54            config,
55            backoff: Default::default(),
56        }
57    }
58
59    /// Provide your own [client] to make HTTP requests with.
60    ///
61    /// [client]: reqwest::Client
62    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63        self.http_client = http_client;
64        self
65    }
66
67    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
68    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69        self.backoff = backoff;
70        self
71    }
72
73    // API groups
74
75    /// To call [Models] group related APIs using this client.
76    pub fn models(&self) -> Models<C> {
77        Models::new(self)
78    }
79
80    /// To call [Completions] group related APIs using this client.
81    pub fn completions(&self) -> Completions<C> {
82        Completions::new(self)
83    }
84
85    /// To call [Chat] group related APIs using this client.
86    pub fn chat(&self) -> Chat<C> {
87        Chat::new(self)
88    }
89
90    /// To call [Images] group related APIs using this client.
91    pub fn images(&self) -> Images<C> {
92        Images::new(self)
93    }
94
95    /// To call [Moderations] group related APIs using this client.
96    pub fn moderations(&self) -> Moderations<C> {
97        Moderations::new(self)
98    }
99
100    /// To call [Files] group related APIs using this client.
101    pub fn files(&self) -> Files<C> {
102        Files::new(self)
103    }
104
105    /// To call [Uploads] group related APIs using this client.
106    pub fn uploads(&self) -> Uploads<C> {
107        Uploads::new(self)
108    }
109
110    /// To call [FineTuning] group related APIs using this client.
111    pub fn fine_tuning(&self) -> FineTuning<C> {
112        FineTuning::new(self)
113    }
114
115    /// To call [Embeddings] group related APIs using this client.
116    pub fn embeddings(&self) -> Embeddings<C> {
117        Embeddings::new(self)
118    }
119
120    /// To call [Audio] group related APIs using this client.
121    pub fn audio(&self) -> Audio<C> {
122        Audio::new(self)
123    }
124
125    /// To call [Videos] group related APIs using this client.
126    pub fn videos(&self) -> Videos<C> {
127        Videos::new(self)
128    }
129
130    /// To call [Assistants] group related APIs using this client.
131    pub fn assistants(&self) -> Assistants<C> {
132        Assistants::new(self)
133    }
134
135    /// To call [Threads] group related APIs using this client.
136    pub fn threads(&self) -> Threads<C> {
137        Threads::new(self)
138    }
139
140    /// To call [VectorStores] group related APIs using this client.
141    pub fn vector_stores(&self) -> VectorStores<C> {
142        VectorStores::new(self)
143    }
144
145    /// To call [Batches] group related APIs using this client.
146    pub fn batches(&self) -> Batches<C> {
147        Batches::new(self)
148    }
149
150    /// To call [AuditLogs] group related APIs using this client.
151    pub fn audit_logs(&self) -> AuditLogs<C> {
152        AuditLogs::new(self)
153    }
154
155    /// To call [Invites] group related APIs using this client.
156    pub fn invites(&self) -> Invites<C> {
157        Invites::new(self)
158    }
159
160    /// To call [Users] group related APIs using this client.
161    pub fn users(&self) -> Users<C> {
162        Users::new(self)
163    }
164
165    /// To call [Projects] group related APIs using this client.
166    pub fn projects(&self) -> Projects<C> {
167        Projects::new(self)
168    }
169
170    /// To call [Responses] group related APIs using this client.
171    pub fn responses(&self) -> Responses<C> {
172        Responses::new(self)
173    }
174
175    pub fn config(&self) -> &C {
176        &self.config
177    }
178
179    /// Make a GET request to {path} and deserialize the response body
180    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
181    where
182        O: DeserializeOwned,
183    {
184        let request_maker = || async {
185            Ok(self
186                .http_client
187                .get(self.config.url(path))
188                .query(&self.config.query())
189                .headers(self.config.headers())
190                .build()?)
191        };
192
193        self.execute(request_maker).await
194    }
195
196    /// Make a GET request to {path} with given Query and deserialize the response body
197    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
198    where
199        O: DeserializeOwned,
200        Q: Serialize + ?Sized,
201    {
202        let request_maker = || async {
203            Ok(self
204                .http_client
205                .get(self.config.url(path))
206                .query(&self.config.query())
207                .query(query)
208                .headers(self.config.headers())
209                .build()?)
210        };
211
212        self.execute(request_maker).await
213    }
214
215    /// Make a DELETE request to {path} and deserialize the response body
216    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
217    where
218        O: DeserializeOwned,
219    {
220        let request_maker = || async {
221            Ok(self
222                .http_client
223                .delete(self.config.url(path))
224                .query(&self.config.query())
225                .headers(self.config.headers())
226                .build()?)
227        };
228
229        self.execute(request_maker).await
230    }
231
232    /// Make a GET request to {path} and return the response body
233    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
234        let request_maker = || async {
235            Ok(self
236                .http_client
237                .get(self.config.url(path))
238                .query(&self.config.query())
239                .headers(self.config.headers())
240                .build()?)
241        };
242
243        self.execute_raw(request_maker).await
244    }
245
246    pub(crate) async fn get_raw_with_query<Q>(
247        &self,
248        path: &str,
249        query: &Q,
250    ) -> Result<Bytes, OpenAIError>
251    where
252        Q: Serialize + ?Sized,
253    {
254        let request_maker = || async {
255            Ok(self
256                .http_client
257                .get(self.config.url(path))
258                .query(&self.config.query())
259                .query(query)
260                .headers(self.config.headers())
261                .build()?)
262        };
263
264        self.execute_raw(request_maker).await
265    }
266
267    /// Make a POST request to {path} and return the response body
268    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
269    where
270        I: Serialize,
271    {
272        let request_maker = || async {
273            Ok(self
274                .http_client
275                .post(self.config.url(path))
276                .query(&self.config.query())
277                .headers(self.config.headers())
278                .json(&request)
279                .build()?)
280        };
281
282        self.execute_raw(request_maker).await
283    }
284
285    /// Make a POST request to {path} and deserialize the response body
286    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
287    where
288        I: Serialize,
289        O: DeserializeOwned,
290    {
291        let request_maker = || async {
292            Ok(self
293                .http_client
294                .post(self.config.url(path))
295                .query(&self.config.query())
296                .headers(self.config.headers())
297                .json(&request)
298                .build()?)
299        };
300
301        self.execute(request_maker).await
302    }
303
304    /// POST a form at {path} and return the response body
305    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
306    where
307        Form: AsyncTryFrom<F, Error = OpenAIError>,
308        F: Clone,
309    {
310        let request_maker = || async {
311            Ok(self
312                .http_client
313                .post(self.config.url(path))
314                .query(&self.config.query())
315                .headers(self.config.headers())
316                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
317                .build()?)
318        };
319
320        self.execute_raw(request_maker).await
321    }
322
323    /// POST a form at {path} and deserialize the response body
324    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
325    where
326        O: DeserializeOwned,
327        Form: AsyncTryFrom<F, Error = OpenAIError>,
328        F: Clone,
329    {
330        let request_maker = || async {
331            Ok(self
332                .http_client
333                .post(self.config.url(path))
334                .query(&self.config.query())
335                .headers(self.config.headers())
336                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
337                .build()?)
338        };
339
340        self.execute(request_maker).await
341    }
342
343    /// Execute a HTTP request and retry on rate limit
344    ///
345    /// request_maker serves one purpose: to be able to create request again
346    /// to retry API call after getting rate limited. request_maker is async because
347    /// reqwest::multipart::Form is created by async calls to read files for uploads.
348    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
349    where
350        M: Fn() -> Fut,
351        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
352    {
353        let client = self.http_client.clone();
354
355        backoff::future::retry(self.backoff.clone(), || async {
356            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
357            let response = client
358                .execute(request)
359                .await
360                .map_err(OpenAIError::Reqwest)
361                .map_err(backoff::Error::Permanent)?;
362
363            let status = response.status();
364
365            match read_response(response).await {
366                Ok(bytes) => Ok(bytes),
367                Err(e) => {
368                    match e {
369                        OpenAIError::ApiError(api_error) => {
370                            if status.is_server_error() {
371                                Err(backoff::Error::Transient {
372                                    err: OpenAIError::ApiError(api_error),
373                                    retry_after: None,
374                                })
375                            } else if status.as_u16() == 429
376                                && api_error.r#type != Some("insufficient_quota".to_string())
377                            {
378                                // Rate limited retry...
379                                tracing::warn!("Rate limited: {}", api_error.message);
380                                Err(backoff::Error::Transient {
381                                    err: OpenAIError::ApiError(api_error),
382                                    retry_after: None,
383                                })
384                            } else {
385                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
386                            }
387                        }
388                        _ => Err(backoff::Error::Permanent(e)),
389                    }
390                }
391            }
392        })
393        .await
394    }
395
396    /// Execute a HTTP request and retry on rate limit
397    ///
398    /// request_maker serves one purpose: to be able to create request again
399    /// to retry API call after getting rate limited. request_maker is async because
400    /// reqwest::multipart::Form is created by async calls to read files for uploads.
401    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
402    where
403        O: DeserializeOwned,
404        M: Fn() -> Fut,
405        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
406    {
407        let bytes = self.execute_raw(request_maker).await?;
408
409        let response: O = serde_json::from_slice(bytes.as_ref())
410            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
411
412        Ok(response)
413    }
414
415    /// Make HTTP POST request to receive SSE
416    pub(crate) async fn post_stream<I, O>(
417        &self,
418        path: &str,
419        request: I,
420    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
421    where
422        I: Serialize,
423        O: DeserializeOwned + std::marker::Send + 'static,
424    {
425        let event_source = self
426            .http_client
427            .post(self.config.url(path))
428            .query(&self.config.query())
429            .headers(self.config.headers())
430            .json(&request)
431            .eventsource()
432            .unwrap();
433
434        stream(event_source).await
435    }
436
437    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
438        &self,
439        path: &str,
440        request: I,
441        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
442    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
443    where
444        I: Serialize,
445        O: DeserializeOwned + std::marker::Send + 'static,
446    {
447        let event_source = self
448            .http_client
449            .post(self.config.url(path))
450            .query(&self.config.query())
451            .headers(self.config.headers())
452            .json(&request)
453            .eventsource()
454            .unwrap();
455
456        stream_mapped_raw_events(event_source, event_mapper).await
457    }
458
459    /// Make HTTP GET request to receive SSE
460    pub(crate) async fn _get_stream<Q, O>(
461        &self,
462        path: &str,
463        query: &Q,
464    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
465    where
466        Q: Serialize + ?Sized,
467        O: DeserializeOwned + std::marker::Send + 'static,
468    {
469        let event_source = self
470            .http_client
471            .get(self.config.url(path))
472            .query(query)
473            .query(&self.config.query())
474            .headers(self.config.headers())
475            .eventsource()
476            .unwrap();
477
478        stream(event_source).await
479    }
480}
481
482async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
483    let status = response.status();
484    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
485
486    if status.is_server_error() {
487        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
488        let message: String = String::from_utf8_lossy(&bytes).into_owned();
489        tracing::warn!("Server error: {status} - {message}");
490        return Err(OpenAIError::ApiError(ApiError {
491            message,
492            r#type: None,
493            param: None,
494            code: None,
495        }));
496    }
497
498    // Deserialize response body from either error object or actual response object
499    if !status.is_success() {
500        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
501            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
502
503        return Err(OpenAIError::ApiError(wrapped_error.error));
504    }
505
506    Ok(bytes)
507}
508
509async fn map_stream_error(value: EventSourceError) -> OpenAIError {
510    match value {
511        EventSourceError::InvalidStatusCode(status_code, response) => {
512            read_response(response).await.expect_err(&format!(
513                "Unreachable because read_response returns err when status_code {status_code} is invalid"
514            ))
515        }
516        _ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
517    }
518}
519
520/// Request which responds with SSE.
521/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
522pub(crate) async fn stream<O>(
523    mut event_source: EventSource,
524) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
525where
526    O: DeserializeOwned + std::marker::Send + 'static,
527{
528    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
529
530    tokio::spawn(async move {
531        while let Some(ev) = event_source.next().await {
532            match ev {
533                Err(e) => {
534                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
535                        // rx dropped
536                        break;
537                    }
538                }
539                Ok(event) => match event {
540                    Event::Message(message) => {
541                        if message.data == "[DONE]" {
542                            break;
543                        }
544
545                        let response = match serde_json::from_str::<O>(&message.data) {
546                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
547                            Ok(output) => Ok(output),
548                        };
549
550                        if let Err(_e) = tx.send(response) {
551                            // rx dropped
552                            break;
553                        }
554                    }
555                    Event::Open => continue,
556                },
557            }
558        }
559
560        event_source.close();
561    });
562
563    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
564}
565
566pub(crate) async fn stream_mapped_raw_events<O>(
567    mut event_source: EventSource,
568    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
569) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
570where
571    O: DeserializeOwned + std::marker::Send + 'static,
572{
573    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
574
575    tokio::spawn(async move {
576        while let Some(ev) = event_source.next().await {
577            match ev {
578                Err(e) => {
579                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
580                        // rx dropped
581                        break;
582                    }
583                }
584                Ok(event) => match event {
585                    Event::Message(message) => {
586                        let mut done = false;
587
588                        if message.data == "[DONE]" {
589                            done = true;
590                        }
591
592                        let response = event_mapper(message);
593
594                        if let Err(_e) = tx.send(response) {
595                            // rx dropped
596                            break;
597                        }
598
599                        if done {
600                            break;
601                        }
602                    }
603                    Event::Open => continue,
604                },
605            }
606        }
607
608        event_source.close();
609    });
610
611    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
612}