async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
17    Embeddings, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
18    VectorStores, Videos,
19};
20
21#[derive(Debug, Clone, Default)]
22/// Client is a container for config, backoff and http_client
23/// used to make API calls.
24pub struct Client<C: Config> {
25    http_client: reqwest::Client,
26    config: C,
27    backoff: backoff::ExponentialBackoff,
28}
29
30impl Client<OpenAIConfig> {
31    /// Client with default [OpenAIConfig]
32    pub fn new() -> Self {
33        Self::default()
34    }
35}
36
37impl<C: Config> Client<C> {
38    /// Create client with a custom HTTP client, OpenAI config, and backoff.
39    pub fn build(
40        http_client: reqwest::Client,
41        config: C,
42        backoff: backoff::ExponentialBackoff,
43    ) -> Self {
44        Self {
45            http_client,
46            config,
47            backoff,
48        }
49    }
50
51    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
52    pub fn with_config(config: C) -> Self {
53        Self {
54            http_client: reqwest::Client::new(),
55            config,
56            backoff: Default::default(),
57        }
58    }
59
60    /// Provide your own [client] to make HTTP requests with.
61    ///
62    /// [client]: reqwest::Client
63    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64        self.http_client = http_client;
65        self
66    }
67
68    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
69    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
70        self.backoff = backoff;
71        self
72    }
73
74    // API groups
75
76    /// To call [Models] group related APIs using this client.
77    pub fn models(&self) -> Models<'_, C> {
78        Models::new(self)
79    }
80
81    /// To call [Completions] group related APIs using this client.
82    pub fn completions(&self) -> Completions<'_, C> {
83        Completions::new(self)
84    }
85
86    /// To call [Chat] group related APIs using this client.
87    pub fn chat(&self) -> Chat<'_, C> {
88        Chat::new(self)
89    }
90
91    /// To call [Images] group related APIs using this client.
92    pub fn images(&self) -> Images<'_, C> {
93        Images::new(self)
94    }
95
96    /// To call [Moderations] group related APIs using this client.
97    pub fn moderations(&self) -> Moderations<'_, C> {
98        Moderations::new(self)
99    }
100
101    /// To call [Files] group related APIs using this client.
102    pub fn files(&self) -> Files<'_, C> {
103        Files::new(self)
104    }
105
106    /// To call [Uploads] group related APIs using this client.
107    pub fn uploads(&self) -> Uploads<'_, C> {
108        Uploads::new(self)
109    }
110
111    /// To call [FineTuning] group related APIs using this client.
112    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
113        FineTuning::new(self)
114    }
115
116    /// To call [Embeddings] group related APIs using this client.
117    pub fn embeddings(&self) -> Embeddings<'_, C> {
118        Embeddings::new(self)
119    }
120
121    /// To call [Audio] group related APIs using this client.
122    pub fn audio(&self) -> Audio<'_, C> {
123        Audio::new(self)
124    }
125
126    /// To call [Videos] group related APIs using this client.
127    pub fn videos(&self) -> Videos<'_, C> {
128        Videos::new(self)
129    }
130
131    /// To call [Assistants] group related APIs using this client.
132    pub fn assistants(&self) -> Assistants<'_, C> {
133        Assistants::new(self)
134    }
135
136    /// To call [Threads] group related APIs using this client.
137    pub fn threads(&self) -> Threads<'_, C> {
138        Threads::new(self)
139    }
140
141    /// To call [VectorStores] group related APIs using this client.
142    pub fn vector_stores(&self) -> VectorStores<'_, C> {
143        VectorStores::new(self)
144    }
145
146    /// To call [Batches] group related APIs using this client.
147    pub fn batches(&self) -> Batches<'_, C> {
148        Batches::new(self)
149    }
150
151    /// To call [AuditLogs] group related APIs using this client.
152    pub fn audit_logs(&self) -> AuditLogs<'_, C> {
153        AuditLogs::new(self)
154    }
155
156    /// To call [Invites] group related APIs using this client.
157    pub fn invites(&self) -> Invites<'_, C> {
158        Invites::new(self)
159    }
160
161    /// To call [Users] group related APIs using this client.
162    pub fn users(&self) -> Users<'_, C> {
163        Users::new(self)
164    }
165
166    /// To call [Projects] group related APIs using this client.
167    pub fn projects(&self) -> Projects<'_, C> {
168        Projects::new(self)
169    }
170
171    /// To call [Responses] group related APIs using this client.
172    pub fn responses(&self) -> Responses<'_, C> {
173        Responses::new(self)
174    }
175
176    /// To call [Conversations] group related APIs using this client.
177    pub fn conversations(&self) -> Conversations<'_, C> {
178        Conversations::new(self)
179    }
180
181    /// To call [Containers] group related APIs using this client.
182    pub fn containers(&self) -> Containers<'_, C> {
183        Containers::new(self)
184    }
185
186    pub fn config(&self) -> &C {
187        &self.config
188    }
189
190    /// Make a GET request to {path} and deserialize the response body
191    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
192    where
193        O: DeserializeOwned,
194    {
195        let request_maker = || async {
196            Ok(self
197                .http_client
198                .get(self.config.url(path))
199                .query(&self.config.query())
200                .headers(self.config.headers())
201                .build()?)
202        };
203
204        self.execute(request_maker).await
205    }
206
207    /// Make a GET request to {path} with given Query and deserialize the response body
208    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
209    where
210        O: DeserializeOwned,
211        Q: Serialize + ?Sized,
212    {
213        let request_maker = || async {
214            Ok(self
215                .http_client
216                .get(self.config.url(path))
217                .query(&self.config.query())
218                .query(query)
219                .headers(self.config.headers())
220                .build()?)
221        };
222
223        self.execute(request_maker).await
224    }
225
226    /// Make a DELETE request to {path} and deserialize the response body
227    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
228    where
229        O: DeserializeOwned,
230    {
231        let request_maker = || async {
232            Ok(self
233                .http_client
234                .delete(self.config.url(path))
235                .query(&self.config.query())
236                .headers(self.config.headers())
237                .build()?)
238        };
239
240        self.execute(request_maker).await
241    }
242
243    /// Make a GET request to {path} and return the response body
244    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
245        let request_maker = || async {
246            Ok(self
247                .http_client
248                .get(self.config.url(path))
249                .query(&self.config.query())
250                .headers(self.config.headers())
251                .build()?)
252        };
253
254        self.execute_raw(request_maker).await
255    }
256
257    pub(crate) async fn get_raw_with_query<Q>(
258        &self,
259        path: &str,
260        query: &Q,
261    ) -> Result<Bytes, OpenAIError>
262    where
263        Q: Serialize + ?Sized,
264    {
265        let request_maker = || async {
266            Ok(self
267                .http_client
268                .get(self.config.url(path))
269                .query(&self.config.query())
270                .query(query)
271                .headers(self.config.headers())
272                .build()?)
273        };
274
275        self.execute_raw(request_maker).await
276    }
277
278    /// Make a POST request to {path} and return the response body
279    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
280    where
281        I: Serialize,
282    {
283        let request_maker = || async {
284            Ok(self
285                .http_client
286                .post(self.config.url(path))
287                .query(&self.config.query())
288                .headers(self.config.headers())
289                .json(&request)
290                .build()?)
291        };
292
293        self.execute_raw(request_maker).await
294    }
295
296    /// Make a POST request to {path} and deserialize the response body
297    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
298    where
299        I: Serialize,
300        O: DeserializeOwned,
301    {
302        let request_maker = || async {
303            Ok(self
304                .http_client
305                .post(self.config.url(path))
306                .query(&self.config.query())
307                .headers(self.config.headers())
308                .json(&request)
309                .build()?)
310        };
311
312        self.execute(request_maker).await
313    }
314
315    /// POST a form at {path} and return the response body
316    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
317    where
318        Form: AsyncTryFrom<F, Error = OpenAIError>,
319        F: Clone,
320    {
321        let request_maker = || async {
322            Ok(self
323                .http_client
324                .post(self.config.url(path))
325                .query(&self.config.query())
326                .headers(self.config.headers())
327                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
328                .build()?)
329        };
330
331        self.execute_raw(request_maker).await
332    }
333
334    /// POST a form at {path} and deserialize the response body
335    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
336    where
337        O: DeserializeOwned,
338        Form: AsyncTryFrom<F, Error = OpenAIError>,
339        F: Clone,
340    {
341        let request_maker = || async {
342            Ok(self
343                .http_client
344                .post(self.config.url(path))
345                .query(&self.config.query())
346                .headers(self.config.headers())
347                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
348                .build()?)
349        };
350
351        self.execute(request_maker).await
352    }
353
354    pub(crate) async fn post_form_stream<O, F>(
355        &self,
356        path: &str,
357        form: F,
358    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
359    where
360        F: Clone,
361        Form: AsyncTryFrom<F, Error = OpenAIError>,
362        O: DeserializeOwned + std::marker::Send + 'static,
363    {
364        // Build and execute request manually since multipart::Form is not Clone
365        // and .eventsource() requires cloneability
366        let response = self
367            .http_client
368            .post(self.config.url(path))
369            .query(&self.config.query())
370            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
371            .headers(self.config.headers())
372            .send()
373            .await
374            .map_err(OpenAIError::Reqwest)?;
375
376        // Check for error status
377        if !response.status().is_success() {
378            return Err(read_response(response).await.unwrap_err());
379        }
380
381        // Convert response body to EventSource stream
382        let stream = response
383            .bytes_stream()
384            .map(|result| result.map_err(std::io::Error::other));
385        let event_stream = eventsource_stream::EventStream::new(stream);
386
387        // Convert EventSource stream to our expected format
388        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
389
390        tokio::spawn(async move {
391            use futures::StreamExt;
392            let mut event_stream = std::pin::pin!(event_stream);
393
394            while let Some(event_result) = event_stream.next().await {
395                match event_result {
396                    Err(e) => {
397                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
398                            StreamError::EventStream(e.to_string()),
399                        )))) {
400                            break;
401                        }
402                    }
403                    Ok(event) => {
404                        // eventsource_stream::Event is a struct with data field
405                        if event.data == "[DONE]" {
406                            break;
407                        }
408
409                        let response = match serde_json::from_str::<O>(&event.data) {
410                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
411                            Ok(output) => Ok(output),
412                        };
413
414                        if let Err(_e) = tx.send(response) {
415                            break;
416                        }
417                    }
418                }
419            }
420        });
421
422        Ok(Box::pin(
423            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
424        ))
425    }
426
427    /// Execute a HTTP request and retry on rate limit
428    ///
429    /// request_maker serves one purpose: to be able to create request again
430    /// to retry API call after getting rate limited. request_maker is async because
431    /// reqwest::multipart::Form is created by async calls to read files for uploads.
432    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
433    where
434        M: Fn() -> Fut,
435        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
436    {
437        let client = self.http_client.clone();
438
439        backoff::future::retry(self.backoff.clone(), || async {
440            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
441            let response = client
442                .execute(request)
443                .await
444                .map_err(OpenAIError::Reqwest)
445                .map_err(backoff::Error::Permanent)?;
446
447            let status = response.status();
448
449            match read_response(response).await {
450                Ok(bytes) => Ok(bytes),
451                Err(e) => {
452                    match e {
453                        OpenAIError::ApiError(api_error) => {
454                            if status.is_server_error() {
455                                Err(backoff::Error::Transient {
456                                    err: OpenAIError::ApiError(api_error),
457                                    retry_after: None,
458                                })
459                            } else if status.as_u16() == 429
460                                && api_error.r#type != Some("insufficient_quota".to_string())
461                            {
462                                // Rate limited retry...
463                                tracing::warn!("Rate limited: {}", api_error.message);
464                                Err(backoff::Error::Transient {
465                                    err: OpenAIError::ApiError(api_error),
466                                    retry_after: None,
467                                })
468                            } else {
469                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
470                            }
471                        }
472                        _ => Err(backoff::Error::Permanent(e)),
473                    }
474                }
475            }
476        })
477        .await
478    }
479
480    /// Execute a HTTP request and retry on rate limit
481    ///
482    /// request_maker serves one purpose: to be able to create request again
483    /// to retry API call after getting rate limited. request_maker is async because
484    /// reqwest::multipart::Form is created by async calls to read files for uploads.
485    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
486    where
487        O: DeserializeOwned,
488        M: Fn() -> Fut,
489        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
490    {
491        let bytes = self.execute_raw(request_maker).await?;
492
493        let response: O = serde_json::from_slice(bytes.as_ref())
494            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
495
496        Ok(response)
497    }
498
499    /// Make HTTP POST request to receive SSE
500    pub(crate) async fn post_stream<I, O>(
501        &self,
502        path: &str,
503        request: I,
504    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
505    where
506        I: Serialize,
507        O: DeserializeOwned + std::marker::Send + 'static,
508    {
509        let event_source = self
510            .http_client
511            .post(self.config.url(path))
512            .query(&self.config.query())
513            .headers(self.config.headers())
514            .json(&request)
515            .eventsource()
516            .unwrap();
517
518        stream(event_source).await
519    }
520
521    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
522        &self,
523        path: &str,
524        request: I,
525        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
526    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
527    where
528        I: Serialize,
529        O: DeserializeOwned + std::marker::Send + 'static,
530    {
531        let event_source = self
532            .http_client
533            .post(self.config.url(path))
534            .query(&self.config.query())
535            .headers(self.config.headers())
536            .json(&request)
537            .eventsource()
538            .unwrap();
539
540        stream_mapped_raw_events(event_source, event_mapper).await
541    }
542
543    /// Make HTTP GET request to receive SSE
544    pub(crate) async fn _get_stream<Q, O>(
545        &self,
546        path: &str,
547        query: &Q,
548    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
549    where
550        Q: Serialize + ?Sized,
551        O: DeserializeOwned + std::marker::Send + 'static,
552    {
553        let event_source = self
554            .http_client
555            .get(self.config.url(path))
556            .query(query)
557            .query(&self.config.query())
558            .headers(self.config.headers())
559            .eventsource()
560            .unwrap();
561
562        stream(event_source).await
563    }
564}
565
566async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
567    let status = response.status();
568    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
569
570    if status.is_server_error() {
571        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
572        let message: String = String::from_utf8_lossy(&bytes).into_owned();
573        tracing::warn!("Server error: {status} - {message}");
574        return Err(OpenAIError::ApiError(ApiError {
575            message,
576            r#type: None,
577            param: None,
578            code: None,
579        }));
580    }
581
582    // Deserialize response body from either error object or actual response object
583    if !status.is_success() {
584        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
585            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
586
587        return Err(OpenAIError::ApiError(wrapped_error.error));
588    }
589
590    Ok(bytes)
591}
592
593async fn map_stream_error(value: EventSourceError) -> OpenAIError {
594    match value {
595        EventSourceError::InvalidStatusCode(status_code, response) => {
596            read_response(response).await.expect_err(&format!(
597                "Unreachable because read_response returns err when status_code {status_code} is invalid"
598            ))
599        }
600        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
601    }
602}
603
604/// Request which responds with SSE.
605/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
606pub(crate) async fn stream<O>(
607    mut event_source: EventSource,
608) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
609where
610    O: DeserializeOwned + std::marker::Send + 'static,
611{
612    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
613
614    tokio::spawn(async move {
615        while let Some(ev) = event_source.next().await {
616            match ev {
617                Err(e) => {
618                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
619                        // rx dropped
620                        break;
621                    }
622                }
623                Ok(event) => match event {
624                    Event::Message(message) => {
625                        if message.data == "[DONE]" {
626                            break;
627                        }
628
629                        let response = match serde_json::from_str::<O>(&message.data) {
630                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
631                            Ok(output) => Ok(output),
632                        };
633
634                        if let Err(_e) = tx.send(response) {
635                            // rx dropped
636                            break;
637                        }
638                    }
639                    Event::Open => continue,
640                },
641            }
642        }
643
644        event_source.close();
645    });
646
647    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
648}
649
650pub(crate) async fn stream_mapped_raw_events<O>(
651    mut event_source: EventSource,
652    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
653) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
654where
655    O: DeserializeOwned + std::marker::Send + 'static,
656{
657    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
658
659    tokio::spawn(async move {
660        while let Some(ev) = event_source.next().await {
661            match ev {
662                Err(e) => {
663                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
664                        // rx dropped
665                        break;
666                    }
667                }
668                Ok(event) => match event {
669                    Event::Message(message) => {
670                        let mut done = false;
671
672                        if message.data == "[DONE]" {
673                            done = true;
674                        }
675
676                        let response = event_mapper(message);
677
678                        if let Err(_e) = tx.send(response) {
679                            // rx dropped
680                            break;
681                        }
682
683                        if done {
684                            break;
685                        }
686                    }
687                    Event::Open => continue,
688                },
689            }
690        }
691
692        event_source.close();
693    });
694
695    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
696}