async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
17    Embeddings, Evals, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
18    VectorStores, Videos,
19};
20
21#[derive(Debug, Clone, Default)]
22/// Client is a container for config, backoff and http_client
23/// used to make API calls.
24pub struct Client<C: Config> {
25    http_client: reqwest::Client,
26    config: C,
27    backoff: backoff::ExponentialBackoff,
28}
29
30impl Client<OpenAIConfig> {
31    /// Client with default [OpenAIConfig]
32    pub fn new() -> Self {
33        Self::default()
34    }
35}
36
37impl<C: Config> Client<C> {
38    /// Create client with a custom HTTP client, OpenAI config, and backoff.
39    pub fn build(
40        http_client: reqwest::Client,
41        config: C,
42        backoff: backoff::ExponentialBackoff,
43    ) -> Self {
44        Self {
45            http_client,
46            config,
47            backoff,
48        }
49    }
50
51    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
52    pub fn with_config(config: C) -> Self {
53        Self {
54            http_client: reqwest::Client::new(),
55            config,
56            backoff: Default::default(),
57        }
58    }
59
60    /// Provide your own [client] to make HTTP requests with.
61    ///
62    /// [client]: reqwest::Client
63    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64        self.http_client = http_client;
65        self
66    }
67
68    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
69    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
70        self.backoff = backoff;
71        self
72    }
73
74    // API groups
75
76    /// To call [Models] group related APIs using this client.
77    pub fn models(&self) -> Models<'_, C> {
78        Models::new(self)
79    }
80
81    /// To call [Completions] group related APIs using this client.
82    pub fn completions(&self) -> Completions<'_, C> {
83        Completions::new(self)
84    }
85
86    /// To call [Chat] group related APIs using this client.
87    pub fn chat(&self) -> Chat<'_, C> {
88        Chat::new(self)
89    }
90
91    /// To call [Images] group related APIs using this client.
92    pub fn images(&self) -> Images<'_, C> {
93        Images::new(self)
94    }
95
96    /// To call [Moderations] group related APIs using this client.
97    pub fn moderations(&self) -> Moderations<'_, C> {
98        Moderations::new(self)
99    }
100
101    /// To call [Files] group related APIs using this client.
102    pub fn files(&self) -> Files<'_, C> {
103        Files::new(self)
104    }
105
106    /// To call [Uploads] group related APIs using this client.
107    pub fn uploads(&self) -> Uploads<'_, C> {
108        Uploads::new(self)
109    }
110
111    /// To call [FineTuning] group related APIs using this client.
112    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
113        FineTuning::new(self)
114    }
115
116    /// To call [Embeddings] group related APIs using this client.
117    pub fn embeddings(&self) -> Embeddings<'_, C> {
118        Embeddings::new(self)
119    }
120
121    /// To call [Audio] group related APIs using this client.
122    pub fn audio(&self) -> Audio<'_, C> {
123        Audio::new(self)
124    }
125
126    /// To call [Videos] group related APIs using this client.
127    pub fn videos(&self) -> Videos<'_, C> {
128        Videos::new(self)
129    }
130
131    /// To call [Assistants] group related APIs using this client.
132    pub fn assistants(&self) -> Assistants<'_, C> {
133        Assistants::new(self)
134    }
135
136    /// To call [Threads] group related APIs using this client.
137    pub fn threads(&self) -> Threads<'_, C> {
138        Threads::new(self)
139    }
140
141    /// To call [VectorStores] group related APIs using this client.
142    pub fn vector_stores(&self) -> VectorStores<'_, C> {
143        VectorStores::new(self)
144    }
145
146    /// To call [Batches] group related APIs using this client.
147    pub fn batches(&self) -> Batches<'_, C> {
148        Batches::new(self)
149    }
150
151    /// To call [AuditLogs] group related APIs using this client.
152    pub fn audit_logs(&self) -> AuditLogs<'_, C> {
153        AuditLogs::new(self)
154    }
155
156    /// To call [Invites] group related APIs using this client.
157    pub fn invites(&self) -> Invites<'_, C> {
158        Invites::new(self)
159    }
160
161    /// To call [Users] group related APIs using this client.
162    pub fn users(&self) -> Users<'_, C> {
163        Users::new(self)
164    }
165
166    /// To call [Projects] group related APIs using this client.
167    pub fn projects(&self) -> Projects<'_, C> {
168        Projects::new(self)
169    }
170
171    /// To call [Responses] group related APIs using this client.
172    pub fn responses(&self) -> Responses<'_, C> {
173        Responses::new(self)
174    }
175
176    /// To call [Conversations] group related APIs using this client.
177    pub fn conversations(&self) -> Conversations<'_, C> {
178        Conversations::new(self)
179    }
180
181    /// To call [Containers] group related APIs using this client.
182    pub fn containers(&self) -> Containers<'_, C> {
183        Containers::new(self)
184    }
185
186    /// To call [Evals] group related APIs using this client.
187    pub fn evals(&self) -> Evals<'_, C> {
188        Evals::new(self)
189    }
190
191    pub fn config(&self) -> &C {
192        &self.config
193    }
194
195    /// Make a GET request to {path} and deserialize the response body
196    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
197    where
198        O: DeserializeOwned,
199    {
200        let request_maker = || async {
201            Ok(self
202                .http_client
203                .get(self.config.url(path))
204                .query(&self.config.query())
205                .headers(self.config.headers())
206                .build()?)
207        };
208
209        self.execute(request_maker).await
210    }
211
212    /// Make a GET request to {path} with given Query and deserialize the response body
213    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
214    where
215        O: DeserializeOwned,
216        Q: Serialize + ?Sized,
217    {
218        let request_maker = || async {
219            Ok(self
220                .http_client
221                .get(self.config.url(path))
222                .query(&self.config.query())
223                .query(query)
224                .headers(self.config.headers())
225                .build()?)
226        };
227
228        self.execute(request_maker).await
229    }
230
231    /// Make a DELETE request to {path} and deserialize the response body
232    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
233    where
234        O: DeserializeOwned,
235    {
236        let request_maker = || async {
237            Ok(self
238                .http_client
239                .delete(self.config.url(path))
240                .query(&self.config.query())
241                .headers(self.config.headers())
242                .build()?)
243        };
244
245        self.execute(request_maker).await
246    }
247
248    /// Make a GET request to {path} and return the response body
249    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
250        let request_maker = || async {
251            Ok(self
252                .http_client
253                .get(self.config.url(path))
254                .query(&self.config.query())
255                .headers(self.config.headers())
256                .build()?)
257        };
258
259        self.execute_raw(request_maker).await
260    }
261
262    pub(crate) async fn get_raw_with_query<Q>(
263        &self,
264        path: &str,
265        query: &Q,
266    ) -> Result<Bytes, OpenAIError>
267    where
268        Q: Serialize + ?Sized,
269    {
270        let request_maker = || async {
271            Ok(self
272                .http_client
273                .get(self.config.url(path))
274                .query(&self.config.query())
275                .query(query)
276                .headers(self.config.headers())
277                .build()?)
278        };
279
280        self.execute_raw(request_maker).await
281    }
282
283    /// Make a POST request to {path} and return the response body
284    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
285    where
286        I: Serialize,
287    {
288        let request_maker = || async {
289            Ok(self
290                .http_client
291                .post(self.config.url(path))
292                .query(&self.config.query())
293                .headers(self.config.headers())
294                .json(&request)
295                .build()?)
296        };
297
298        self.execute_raw(request_maker).await
299    }
300
301    /// Make a POST request to {path} and deserialize the response body
302    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
303    where
304        I: Serialize,
305        O: DeserializeOwned,
306    {
307        let request_maker = || async {
308            Ok(self
309                .http_client
310                .post(self.config.url(path))
311                .query(&self.config.query())
312                .headers(self.config.headers())
313                .json(&request)
314                .build()?)
315        };
316
317        self.execute(request_maker).await
318    }
319
320    /// POST a form at {path} and return the response body
321    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
322    where
323        Form: AsyncTryFrom<F, Error = OpenAIError>,
324        F: Clone,
325    {
326        let request_maker = || async {
327            Ok(self
328                .http_client
329                .post(self.config.url(path))
330                .query(&self.config.query())
331                .headers(self.config.headers())
332                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
333                .build()?)
334        };
335
336        self.execute_raw(request_maker).await
337    }
338
339    /// POST a form at {path} and deserialize the response body
340    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
341    where
342        O: DeserializeOwned,
343        Form: AsyncTryFrom<F, Error = OpenAIError>,
344        F: Clone,
345    {
346        let request_maker = || async {
347            Ok(self
348                .http_client
349                .post(self.config.url(path))
350                .query(&self.config.query())
351                .headers(self.config.headers())
352                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
353                .build()?)
354        };
355
356        self.execute(request_maker).await
357    }
358
359    pub(crate) async fn post_form_stream<O, F>(
360        &self,
361        path: &str,
362        form: F,
363    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
364    where
365        F: Clone,
366        Form: AsyncTryFrom<F, Error = OpenAIError>,
367        O: DeserializeOwned + std::marker::Send + 'static,
368    {
369        // Build and execute request manually since multipart::Form is not Clone
370        // and .eventsource() requires cloneability
371        let response = self
372            .http_client
373            .post(self.config.url(path))
374            .query(&self.config.query())
375            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
376            .headers(self.config.headers())
377            .send()
378            .await
379            .map_err(OpenAIError::Reqwest)?;
380
381        // Check for error status
382        if !response.status().is_success() {
383            return Err(read_response(response).await.unwrap_err());
384        }
385
386        // Convert response body to EventSource stream
387        let stream = response
388            .bytes_stream()
389            .map(|result| result.map_err(std::io::Error::other));
390        let event_stream = eventsource_stream::EventStream::new(stream);
391
392        // Convert EventSource stream to our expected format
393        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
394
395        tokio::spawn(async move {
396            use futures::StreamExt;
397            let mut event_stream = std::pin::pin!(event_stream);
398
399            while let Some(event_result) = event_stream.next().await {
400                match event_result {
401                    Err(e) => {
402                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
403                            StreamError::EventStream(e.to_string()),
404                        )))) {
405                            break;
406                        }
407                    }
408                    Ok(event) => {
409                        // eventsource_stream::Event is a struct with data field
410                        if event.data == "[DONE]" {
411                            break;
412                        }
413
414                        let response = match serde_json::from_str::<O>(&event.data) {
415                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
416                            Ok(output) => Ok(output),
417                        };
418
419                        if let Err(_e) = tx.send(response) {
420                            break;
421                        }
422                    }
423                }
424            }
425        });
426
427        Ok(Box::pin(
428            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
429        ))
430    }
431
432    /// Execute a HTTP request and retry on rate limit
433    ///
434    /// request_maker serves one purpose: to be able to create request again
435    /// to retry API call after getting rate limited. request_maker is async because
436    /// reqwest::multipart::Form is created by async calls to read files for uploads.
437    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
438    where
439        M: Fn() -> Fut,
440        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
441    {
442        let client = self.http_client.clone();
443
444        backoff::future::retry(self.backoff.clone(), || async {
445            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
446            let response = client
447                .execute(request)
448                .await
449                .map_err(OpenAIError::Reqwest)
450                .map_err(backoff::Error::Permanent)?;
451
452            let status = response.status();
453
454            match read_response(response).await {
455                Ok(bytes) => Ok(bytes),
456                Err(e) => {
457                    match e {
458                        OpenAIError::ApiError(api_error) => {
459                            if status.is_server_error() {
460                                Err(backoff::Error::Transient {
461                                    err: OpenAIError::ApiError(api_error),
462                                    retry_after: None,
463                                })
464                            } else if status.as_u16() == 429
465                                && api_error.r#type != Some("insufficient_quota".to_string())
466                            {
467                                // Rate limited retry...
468                                tracing::warn!("Rate limited: {}", api_error.message);
469                                Err(backoff::Error::Transient {
470                                    err: OpenAIError::ApiError(api_error),
471                                    retry_after: None,
472                                })
473                            } else {
474                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
475                            }
476                        }
477                        _ => Err(backoff::Error::Permanent(e)),
478                    }
479                }
480            }
481        })
482        .await
483    }
484
485    /// Execute a HTTP request and retry on rate limit
486    ///
487    /// request_maker serves one purpose: to be able to create request again
488    /// to retry API call after getting rate limited. request_maker is async because
489    /// reqwest::multipart::Form is created by async calls to read files for uploads.
490    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
491    where
492        O: DeserializeOwned,
493        M: Fn() -> Fut,
494        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
495    {
496        let bytes = self.execute_raw(request_maker).await?;
497
498        let response: O = serde_json::from_slice(bytes.as_ref())
499            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
500
501        Ok(response)
502    }
503
504    /// Make HTTP POST request to receive SSE
505    pub(crate) async fn post_stream<I, O>(
506        &self,
507        path: &str,
508        request: I,
509    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
510    where
511        I: Serialize,
512        O: DeserializeOwned + std::marker::Send + 'static,
513    {
514        let event_source = self
515            .http_client
516            .post(self.config.url(path))
517            .query(&self.config.query())
518            .headers(self.config.headers())
519            .json(&request)
520            .eventsource()
521            .unwrap();
522
523        stream(event_source).await
524    }
525
526    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
527        &self,
528        path: &str,
529        request: I,
530        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
531    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
532    where
533        I: Serialize,
534        O: DeserializeOwned + std::marker::Send + 'static,
535    {
536        let event_source = self
537            .http_client
538            .post(self.config.url(path))
539            .query(&self.config.query())
540            .headers(self.config.headers())
541            .json(&request)
542            .eventsource()
543            .unwrap();
544
545        stream_mapped_raw_events(event_source, event_mapper).await
546    }
547
548    /// Make HTTP GET request to receive SSE
549    pub(crate) async fn _get_stream<Q, O>(
550        &self,
551        path: &str,
552        query: &Q,
553    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
554    where
555        Q: Serialize + ?Sized,
556        O: DeserializeOwned + std::marker::Send + 'static,
557    {
558        let event_source = self
559            .http_client
560            .get(self.config.url(path))
561            .query(query)
562            .query(&self.config.query())
563            .headers(self.config.headers())
564            .eventsource()
565            .unwrap();
566
567        stream(event_source).await
568    }
569}
570
571async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
572    let status = response.status();
573    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
574
575    if status.is_server_error() {
576        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
577        let message: String = String::from_utf8_lossy(&bytes).into_owned();
578        tracing::warn!("Server error: {status} - {message}");
579        return Err(OpenAIError::ApiError(ApiError {
580            message,
581            r#type: None,
582            param: None,
583            code: None,
584        }));
585    }
586
587    // Deserialize response body from either error object or actual response object
588    if !status.is_success() {
589        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
590            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
591
592        return Err(OpenAIError::ApiError(wrapped_error.error));
593    }
594
595    Ok(bytes)
596}
597
598async fn map_stream_error(value: EventSourceError) -> OpenAIError {
599    match value {
600        EventSourceError::InvalidStatusCode(status_code, response) => {
601            read_response(response).await.expect_err(&format!(
602                "Unreachable because read_response returns err when status_code {status_code} is invalid"
603            ))
604        }
605        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
606    }
607}
608
609/// Request which responds with SSE.
610/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
611pub(crate) async fn stream<O>(
612    mut event_source: EventSource,
613) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
614where
615    O: DeserializeOwned + std::marker::Send + 'static,
616{
617    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
618
619    tokio::spawn(async move {
620        while let Some(ev) = event_source.next().await {
621            match ev {
622                Err(e) => {
623                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
624                        // rx dropped
625                        break;
626                    }
627                }
628                Ok(event) => match event {
629                    Event::Message(message) => {
630                        if message.data == "[DONE]" {
631                            break;
632                        }
633
634                        let response = match serde_json::from_str::<O>(&message.data) {
635                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
636                            Ok(output) => Ok(output),
637                        };
638
639                        if let Err(_e) = tx.send(response) {
640                            // rx dropped
641                            break;
642                        }
643                    }
644                    Event::Open => continue,
645                },
646            }
647        }
648
649        event_source.close();
650    });
651
652    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
653}
654
655pub(crate) async fn stream_mapped_raw_events<O>(
656    mut event_source: EventSource,
657    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
658) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
659where
660    O: DeserializeOwned + std::marker::Send + 'static,
661{
662    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
663
664    tokio::spawn(async move {
665        while let Some(ev) = event_source.next().await {
666            match ev {
667                Err(e) => {
668                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
669                        // rx dropped
670                        break;
671                    }
672                }
673                Ok(event) => match event {
674                    Event::Message(message) => {
675                        let mut done = false;
676
677                        if message.data == "[DONE]" {
678                            done = true;
679                        }
680
681                        let response = event_mapper(message);
682
683                        if let Err(_e) = tx.send(response) {
684                            // rx dropped
685                            break;
686                        }
687
688                        if done {
689                            break;
690                        }
691                    }
692                    Event::Open => continue,
693                },
694            }
695        }
696
697        event_source.close();
698    });
699
700    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
701}