async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
17    Embeddings, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
18    VectorStores, Videos,
19};
20
21#[derive(Debug, Clone, Default)]
22/// Client is a container for config, backoff and http_client
23/// used to make API calls.
24pub struct Client<C: Config> {
25    http_client: reqwest::Client,
26    config: C,
27    backoff: backoff::ExponentialBackoff,
28}
29
30impl Client<OpenAIConfig> {
31    /// Client with default [OpenAIConfig]
32    pub fn new() -> Self {
33        Self::default()
34    }
35}
36
37impl<C: Config> Client<C> {
38    /// Create client with a custom HTTP client, OpenAI config, and backoff.
39    pub fn build(
40        http_client: reqwest::Client,
41        config: C,
42        backoff: backoff::ExponentialBackoff,
43    ) -> Self {
44        Self {
45            http_client,
46            config,
47            backoff,
48        }
49    }
50
51    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
52    pub fn with_config(config: C) -> Self {
53        Self {
54            http_client: reqwest::Client::new(),
55            config,
56            backoff: Default::default(),
57        }
58    }
59
60    /// Provide your own [client] to make HTTP requests with.
61    ///
62    /// [client]: reqwest::Client
63    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64        self.http_client = http_client;
65        self
66    }
67
68    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
69    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
70        self.backoff = backoff;
71        self
72    }
73
74    // API groups
75
76    /// To call [Models] group related APIs using this client.
77    pub fn models(&self) -> Models<'_, C> {
78        Models::new(self)
79    }
80
81    /// To call [Completions] group related APIs using this client.
82    pub fn completions(&self) -> Completions<'_, C> {
83        Completions::new(self)
84    }
85
86    /// To call [Chat] group related APIs using this client.
87    pub fn chat(&self) -> Chat<'_, C> {
88        Chat::new(self)
89    }
90
91    /// To call [Images] group related APIs using this client.
92    pub fn images(&self) -> Images<'_, C> {
93        Images::new(self)
94    }
95
96    /// To call [Moderations] group related APIs using this client.
97    pub fn moderations(&self) -> Moderations<'_, C> {
98        Moderations::new(self)
99    }
100
101    /// To call [Files] group related APIs using this client.
102    pub fn files(&self) -> Files<'_, C> {
103        Files::new(self)
104    }
105
106    /// To call [Uploads] group related APIs using this client.
107    pub fn uploads(&self) -> Uploads<'_, C> {
108        Uploads::new(self)
109    }
110
111    /// To call [FineTuning] group related APIs using this client.
112    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
113        FineTuning::new(self)
114    }
115
116    /// To call [Embeddings] group related APIs using this client.
117    pub fn embeddings(&self) -> Embeddings<'_, C> {
118        Embeddings::new(self)
119    }
120
121    /// To call [Audio] group related APIs using this client.
122    pub fn audio(&self) -> Audio<'_, C> {
123        Audio::new(self)
124    }
125
126    /// To call [Videos] group related APIs using this client.
127    pub fn videos(&self) -> Videos<'_, C> {
128        Videos::new(self)
129    }
130
131    /// To call [Assistants] group related APIs using this client.
132    pub fn assistants(&self) -> Assistants<'_, C> {
133        Assistants::new(self)
134    }
135
136    /// To call [Threads] group related APIs using this client.
137    pub fn threads(&self) -> Threads<'_, C> {
138        Threads::new(self)
139    }
140
141    /// To call [VectorStores] group related APIs using this client.
142    pub fn vector_stores(&self) -> VectorStores<'_, C> {
143        VectorStores::new(self)
144    }
145
146    /// To call [Batches] group related APIs using this client.
147    pub fn batches(&self) -> Batches<'_, C> {
148        Batches::new(self)
149    }
150
151    /// To call [AuditLogs] group related APIs using this client.
152    pub fn audit_logs(&self) -> AuditLogs<'_, C> {
153        AuditLogs::new(self)
154    }
155
156    /// To call [Invites] group related APIs using this client.
157    pub fn invites(&self) -> Invites<'_, C> {
158        Invites::new(self)
159    }
160
161    /// To call [Users] group related APIs using this client.
162    pub fn users(&self) -> Users<'_, C> {
163        Users::new(self)
164    }
165
166    /// To call [Projects] group related APIs using this client.
167    pub fn projects(&self) -> Projects<'_, C> {
168        Projects::new(self)
169    }
170
171    /// To call [Responses] group related APIs using this client.
172    pub fn responses(&self) -> Responses<'_, C> {
173        Responses::new(self)
174    }
175
176    /// To call [Conversations] group related APIs using this client.
177    pub fn conversations(&self) -> Conversations<'_, C> {
178        Conversations::new(self)
179    }
180
181    /// To call [Containers] group related APIs using this client.
182    pub fn containers(&self) -> Containers<'_, C> {
183        Containers::new(self)
184    }
185
186    pub fn config(&self) -> &C {
187        &self.config
188    }
189
190    /// Make a GET request to {path} and deserialize the response body
191    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
192    where
193        O: DeserializeOwned,
194    {
195        let request_maker = || async {
196            Ok(self
197                .http_client
198                .get(self.config.url(path))
199                .query(&self.config.query())
200                .headers(self.config.headers())
201                .build()?)
202        };
203
204        self.execute(request_maker).await
205    }
206
207    /// Make a GET request to {path} with given Query and deserialize the response body
208    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
209    where
210        O: DeserializeOwned,
211        Q: Serialize + ?Sized,
212    {
213        let request_maker = || async {
214            Ok(self
215                .http_client
216                .get(self.config.url(path))
217                .query(&self.config.query())
218                .query(query)
219                .headers(self.config.headers())
220                .build()?)
221        };
222
223        self.execute(request_maker).await
224    }
225
226    /// Make a DELETE request to {path} and deserialize the response body
227    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
228    where
229        O: DeserializeOwned,
230    {
231        let request_maker = || async {
232            Ok(self
233                .http_client
234                .delete(self.config.url(path))
235                .query(&self.config.query())
236                .headers(self.config.headers())
237                .build()?)
238        };
239
240        self.execute(request_maker).await
241    }
242
243    /// Make a GET request to {path} and return the response body
244    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
245        let request_maker = || async {
246            Ok(self
247                .http_client
248                .get(self.config.url(path))
249                .query(&self.config.query())
250                .headers(self.config.headers())
251                .build()?)
252        };
253
254        self.execute_raw(request_maker).await
255    }
256
257    pub(crate) async fn get_raw_with_query<Q>(
258        &self,
259        path: &str,
260        query: &Q,
261    ) -> Result<Bytes, OpenAIError>
262    where
263        Q: Serialize + ?Sized,
264    {
265        let request_maker = || async {
266            Ok(self
267                .http_client
268                .get(self.config.url(path))
269                .query(&self.config.query())
270                .query(query)
271                .headers(self.config.headers())
272                .build()?)
273        };
274
275        self.execute_raw(request_maker).await
276    }
277
278    /// Make a POST request to {path} and return the response body
279    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
280    where
281        I: Serialize,
282    {
283        let request_maker = || async {
284            Ok(self
285                .http_client
286                .post(self.config.url(path))
287                .query(&self.config.query())
288                .headers(self.config.headers())
289                .json(&request)
290                .build()?)
291        };
292
293        self.execute_raw(request_maker).await
294    }
295
296    /// Make a POST request to {path} and deserialize the response body
297    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
298    where
299        I: Serialize,
300        O: DeserializeOwned,
301    {
302        let request_maker = || async {
303            Ok(self
304                .http_client
305                .post(self.config.url(path))
306                .query(&self.config.query())
307                .headers(self.config.headers())
308                .json(&request)
309                .build()?)
310        };
311
312        self.execute(request_maker).await
313    }
314
315    /// POST a form at {path} and return the response body
316    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
317    where
318        Form: AsyncTryFrom<F, Error = OpenAIError>,
319        F: Clone,
320    {
321        let request_maker = || async {
322            Ok(self
323                .http_client
324                .post(self.config.url(path))
325                .query(&self.config.query())
326                .headers(self.config.headers())
327                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
328                .build()?)
329        };
330
331        self.execute_raw(request_maker).await
332    }
333
334    /// POST a form at {path} and deserialize the response body
335    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
336    where
337        O: DeserializeOwned,
338        Form: AsyncTryFrom<F, Error = OpenAIError>,
339        F: Clone,
340    {
341        let request_maker = || async {
342            Ok(self
343                .http_client
344                .post(self.config.url(path))
345                .query(&self.config.query())
346                .headers(self.config.headers())
347                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
348                .build()?)
349        };
350
351        self.execute(request_maker).await
352    }
353
354    /// Execute a HTTP request and retry on rate limit
355    ///
356    /// request_maker serves one purpose: to be able to create request again
357    /// to retry API call after getting rate limited. request_maker is async because
358    /// reqwest::multipart::Form is created by async calls to read files for uploads.
359    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
360    where
361        M: Fn() -> Fut,
362        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
363    {
364        let client = self.http_client.clone();
365
366        backoff::future::retry(self.backoff.clone(), || async {
367            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
368            let response = client
369                .execute(request)
370                .await
371                .map_err(OpenAIError::Reqwest)
372                .map_err(backoff::Error::Permanent)?;
373
374            let status = response.status();
375
376            match read_response(response).await {
377                Ok(bytes) => Ok(bytes),
378                Err(e) => {
379                    match e {
380                        OpenAIError::ApiError(api_error) => {
381                            if status.is_server_error() {
382                                Err(backoff::Error::Transient {
383                                    err: OpenAIError::ApiError(api_error),
384                                    retry_after: None,
385                                })
386                            } else if status.as_u16() == 429
387                                && api_error.r#type != Some("insufficient_quota".to_string())
388                            {
389                                // Rate limited retry...
390                                tracing::warn!("Rate limited: {}", api_error.message);
391                                Err(backoff::Error::Transient {
392                                    err: OpenAIError::ApiError(api_error),
393                                    retry_after: None,
394                                })
395                            } else {
396                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
397                            }
398                        }
399                        _ => Err(backoff::Error::Permanent(e)),
400                    }
401                }
402            }
403        })
404        .await
405    }
406
407    /// Execute a HTTP request and retry on rate limit
408    ///
409    /// request_maker serves one purpose: to be able to create request again
410    /// to retry API call after getting rate limited. request_maker is async because
411    /// reqwest::multipart::Form is created by async calls to read files for uploads.
412    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
413    where
414        O: DeserializeOwned,
415        M: Fn() -> Fut,
416        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
417    {
418        let bytes = self.execute_raw(request_maker).await?;
419
420        let response: O = serde_json::from_slice(bytes.as_ref())
421            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
422
423        Ok(response)
424    }
425
426    /// Make HTTP POST request to receive SSE
427    pub(crate) async fn post_stream<I, O>(
428        &self,
429        path: &str,
430        request: I,
431    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
432    where
433        I: Serialize,
434        O: DeserializeOwned + std::marker::Send + 'static,
435    {
436        let event_source = self
437            .http_client
438            .post(self.config.url(path))
439            .query(&self.config.query())
440            .headers(self.config.headers())
441            .json(&request)
442            .eventsource()
443            .unwrap();
444
445        stream(event_source).await
446    }
447
448    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
449        &self,
450        path: &str,
451        request: I,
452        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
453    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
454    where
455        I: Serialize,
456        O: DeserializeOwned + std::marker::Send + 'static,
457    {
458        let event_source = self
459            .http_client
460            .post(self.config.url(path))
461            .query(&self.config.query())
462            .headers(self.config.headers())
463            .json(&request)
464            .eventsource()
465            .unwrap();
466
467        stream_mapped_raw_events(event_source, event_mapper).await
468    }
469
470    /// Make HTTP GET request to receive SSE
471    pub(crate) async fn _get_stream<Q, O>(
472        &self,
473        path: &str,
474        query: &Q,
475    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
476    where
477        Q: Serialize + ?Sized,
478        O: DeserializeOwned + std::marker::Send + 'static,
479    {
480        let event_source = self
481            .http_client
482            .get(self.config.url(path))
483            .query(query)
484            .query(&self.config.query())
485            .headers(self.config.headers())
486            .eventsource()
487            .unwrap();
488
489        stream(event_source).await
490    }
491}
492
493async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
494    let status = response.status();
495    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
496
497    if status.is_server_error() {
498        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
499        let message: String = String::from_utf8_lossy(&bytes).into_owned();
500        tracing::warn!("Server error: {status} - {message}");
501        return Err(OpenAIError::ApiError(ApiError {
502            message,
503            r#type: None,
504            param: None,
505            code: None,
506        }));
507    }
508
509    // Deserialize response body from either error object or actual response object
510    if !status.is_success() {
511        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
512            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
513
514        return Err(OpenAIError::ApiError(wrapped_error.error));
515    }
516
517    Ok(bytes)
518}
519
520async fn map_stream_error(value: EventSourceError) -> OpenAIError {
521    match value {
522        EventSourceError::InvalidStatusCode(status_code, response) => {
523            read_response(response).await.expect_err(&format!(
524                "Unreachable because read_response returns err when status_code {status_code} is invalid"
525            ))
526        }
527        _ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
528    }
529}
530
531/// Request which responds with SSE.
532/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
533pub(crate) async fn stream<O>(
534    mut event_source: EventSource,
535) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
536where
537    O: DeserializeOwned + std::marker::Send + 'static,
538{
539    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
540
541    tokio::spawn(async move {
542        while let Some(ev) = event_source.next().await {
543            match ev {
544                Err(e) => {
545                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
546                        // rx dropped
547                        break;
548                    }
549                }
550                Ok(event) => match event {
551                    Event::Message(message) => {
552                        if message.data == "[DONE]" {
553                            break;
554                        }
555
556                        let response = match serde_json::from_str::<O>(&message.data) {
557                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
558                            Ok(output) => Ok(output),
559                        };
560
561                        if let Err(_e) = tx.send(response) {
562                            // rx dropped
563                            break;
564                        }
565                    }
566                    Event::Open => continue,
567                },
568            }
569        }
570
571        event_source.close();
572    });
573
574    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
575}
576
577pub(crate) async fn stream_mapped_raw_events<O>(
578    mut event_source: EventSource,
579    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
580) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
581where
582    O: DeserializeOwned + std::marker::Send + 'static,
583{
584    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
585
586    tokio::spawn(async move {
587        while let Some(ev) = event_source.next().await {
588            match ev {
589                Err(e) => {
590                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
591                        // rx dropped
592                        break;
593                    }
594                }
595                Ok(event) => match event {
596                    Event::Message(message) => {
597                        let mut done = false;
598
599                        if message.data == "[DONE]" {
600                            done = true;
601                        }
602
603                        let response = event_mapper(message);
604
605                        if let Err(_e) = tx.send(response) {
606                            // rx dropped
607                            break;
608                        }
609
610                        if done {
611                            break;
612                        }
613                    }
614                    Event::Open => continue,
615                },
616            }
617        }
618
619        event_source.close();
620    });
621
622    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
623}