async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    chatkit::Chatkit,
11    config::{Config, OpenAIConfig},
12    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
13    file::Files,
14    image::Images,
15    moderation::Moderations,
16    traits::AsyncTryFrom,
17    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
18    Embeddings, Evals, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
19    VectorStores, Videos,
20};
21
22#[derive(Debug, Clone, Default)]
23/// Client is a container for config, backoff and http_client
24/// used to make API calls.
25pub struct Client<C: Config> {
26    http_client: reqwest::Client,
27    config: C,
28    backoff: backoff::ExponentialBackoff,
29}
30
31impl Client<OpenAIConfig> {
32    /// Client with default [OpenAIConfig]
33    pub fn new() -> Self {
34        Self::default()
35    }
36}
37
38impl<C: Config> Client<C> {
39    /// Create client with a custom HTTP client, OpenAI config, and backoff.
40    pub fn build(
41        http_client: reqwest::Client,
42        config: C,
43        backoff: backoff::ExponentialBackoff,
44    ) -> Self {
45        Self {
46            http_client,
47            config,
48            backoff,
49        }
50    }
51
52    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
53    pub fn with_config(config: C) -> Self {
54        Self {
55            http_client: reqwest::Client::new(),
56            config,
57            backoff: Default::default(),
58        }
59    }
60
61    /// Provide your own [client] to make HTTP requests with.
62    ///
63    /// [client]: reqwest::Client
64    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
65        self.http_client = http_client;
66        self
67    }
68
69    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
70    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
71        self.backoff = backoff;
72        self
73    }
74
75    // API groups
76
77    /// To call [Models] group related APIs using this client.
78    pub fn models(&self) -> Models<'_, C> {
79        Models::new(self)
80    }
81
82    /// To call [Completions] group related APIs using this client.
83    pub fn completions(&self) -> Completions<'_, C> {
84        Completions::new(self)
85    }
86
87    /// To call [Chat] group related APIs using this client.
88    pub fn chat(&self) -> Chat<'_, C> {
89        Chat::new(self)
90    }
91
92    /// To call [Images] group related APIs using this client.
93    pub fn images(&self) -> Images<'_, C> {
94        Images::new(self)
95    }
96
97    /// To call [Moderations] group related APIs using this client.
98    pub fn moderations(&self) -> Moderations<'_, C> {
99        Moderations::new(self)
100    }
101
102    /// To call [Files] group related APIs using this client.
103    pub fn files(&self) -> Files<'_, C> {
104        Files::new(self)
105    }
106
107    /// To call [Uploads] group related APIs using this client.
108    pub fn uploads(&self) -> Uploads<'_, C> {
109        Uploads::new(self)
110    }
111
112    /// To call [FineTuning] group related APIs using this client.
113    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
114        FineTuning::new(self)
115    }
116
117    /// To call [Embeddings] group related APIs using this client.
118    pub fn embeddings(&self) -> Embeddings<'_, C> {
119        Embeddings::new(self)
120    }
121
122    /// To call [Audio] group related APIs using this client.
123    pub fn audio(&self) -> Audio<'_, C> {
124        Audio::new(self)
125    }
126
127    /// To call [Videos] group related APIs using this client.
128    pub fn videos(&self) -> Videos<'_, C> {
129        Videos::new(self)
130    }
131
132    /// To call [Assistants] group related APIs using this client.
133    pub fn assistants(&self) -> Assistants<'_, C> {
134        Assistants::new(self)
135    }
136
137    /// To call [Threads] group related APIs using this client.
138    pub fn threads(&self) -> Threads<'_, C> {
139        Threads::new(self)
140    }
141
142    /// To call [VectorStores] group related APIs using this client.
143    pub fn vector_stores(&self) -> VectorStores<'_, C> {
144        VectorStores::new(self)
145    }
146
147    /// To call [Batches] group related APIs using this client.
148    pub fn batches(&self) -> Batches<'_, C> {
149        Batches::new(self)
150    }
151
152    /// To call [AuditLogs] group related APIs using this client.
153    pub fn audit_logs(&self) -> AuditLogs<'_, C> {
154        AuditLogs::new(self)
155    }
156
157    /// To call [Invites] group related APIs using this client.
158    pub fn invites(&self) -> Invites<'_, C> {
159        Invites::new(self)
160    }
161
162    /// To call [Users] group related APIs using this client.
163    pub fn users(&self) -> Users<'_, C> {
164        Users::new(self)
165    }
166
167    /// To call [Projects] group related APIs using this client.
168    pub fn projects(&self) -> Projects<'_, C> {
169        Projects::new(self)
170    }
171
172    /// To call [Responses] group related APIs using this client.
173    pub fn responses(&self) -> Responses<'_, C> {
174        Responses::new(self)
175    }
176
177    /// To call [Conversations] group related APIs using this client.
178    pub fn conversations(&self) -> Conversations<'_, C> {
179        Conversations::new(self)
180    }
181
182    /// To call [Containers] group related APIs using this client.
183    pub fn containers(&self) -> Containers<'_, C> {
184        Containers::new(self)
185    }
186
187    /// To call [Evals] group related APIs using this client.
188    pub fn evals(&self) -> Evals<'_, C> {
189        Evals::new(self)
190    }
191
192    pub fn chatkit(&self) -> Chatkit<'_, C> {
193        Chatkit::new(self)
194    }
195
196    pub fn config(&self) -> &C {
197        &self.config
198    }
199
200    /// Make a GET request to {path} and deserialize the response body
201    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
202    where
203        O: DeserializeOwned,
204    {
205        let request_maker = || async {
206            Ok(self
207                .http_client
208                .get(self.config.url(path))
209                .query(&self.config.query())
210                .headers(self.config.headers())
211                .build()?)
212        };
213
214        self.execute(request_maker).await
215    }
216
217    /// Make a GET request to {path} with given Query and deserialize the response body
218    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
219    where
220        O: DeserializeOwned,
221        Q: Serialize + ?Sized,
222    {
223        let request_maker = || async {
224            Ok(self
225                .http_client
226                .get(self.config.url(path))
227                .query(&self.config.query())
228                .query(query)
229                .headers(self.config.headers())
230                .build()?)
231        };
232
233        self.execute(request_maker).await
234    }
235
236    /// Make a DELETE request to {path} and deserialize the response body
237    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
238    where
239        O: DeserializeOwned,
240    {
241        let request_maker = || async {
242            Ok(self
243                .http_client
244                .delete(self.config.url(path))
245                .query(&self.config.query())
246                .headers(self.config.headers())
247                .build()?)
248        };
249
250        self.execute(request_maker).await
251    }
252
253    /// Make a GET request to {path} and return the response body
254    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
255        let request_maker = || async {
256            Ok(self
257                .http_client
258                .get(self.config.url(path))
259                .query(&self.config.query())
260                .headers(self.config.headers())
261                .build()?)
262        };
263
264        self.execute_raw(request_maker).await
265    }
266
267    pub(crate) async fn get_raw_with_query<Q>(
268        &self,
269        path: &str,
270        query: &Q,
271    ) -> Result<Bytes, OpenAIError>
272    where
273        Q: Serialize + ?Sized,
274    {
275        let request_maker = || async {
276            Ok(self
277                .http_client
278                .get(self.config.url(path))
279                .query(&self.config.query())
280                .query(query)
281                .headers(self.config.headers())
282                .build()?)
283        };
284
285        self.execute_raw(request_maker).await
286    }
287
288    /// Make a POST request to {path} and return the response body
289    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
290    where
291        I: Serialize,
292    {
293        let request_maker = || async {
294            Ok(self
295                .http_client
296                .post(self.config.url(path))
297                .query(&self.config.query())
298                .headers(self.config.headers())
299                .json(&request)
300                .build()?)
301        };
302
303        self.execute_raw(request_maker).await
304    }
305
306    /// Make a POST request to {path} and deserialize the response body
307    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
308    where
309        I: Serialize,
310        O: DeserializeOwned,
311    {
312        let request_maker = || async {
313            Ok(self
314                .http_client
315                .post(self.config.url(path))
316                .query(&self.config.query())
317                .headers(self.config.headers())
318                .json(&request)
319                .build()?)
320        };
321
322        self.execute(request_maker).await
323    }
324
325    /// POST a form at {path} and return the response body
326    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
327    where
328        Form: AsyncTryFrom<F, Error = OpenAIError>,
329        F: Clone,
330    {
331        let request_maker = || async {
332            Ok(self
333                .http_client
334                .post(self.config.url(path))
335                .query(&self.config.query())
336                .headers(self.config.headers())
337                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
338                .build()?)
339        };
340
341        self.execute_raw(request_maker).await
342    }
343
344    /// POST a form at {path} and deserialize the response body
345    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
346    where
347        O: DeserializeOwned,
348        Form: AsyncTryFrom<F, Error = OpenAIError>,
349        F: Clone,
350    {
351        let request_maker = || async {
352            Ok(self
353                .http_client
354                .post(self.config.url(path))
355                .query(&self.config.query())
356                .headers(self.config.headers())
357                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
358                .build()?)
359        };
360
361        self.execute(request_maker).await
362    }
363
364    pub(crate) async fn post_form_stream<O, F>(
365        &self,
366        path: &str,
367        form: F,
368    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
369    where
370        F: Clone,
371        Form: AsyncTryFrom<F, Error = OpenAIError>,
372        O: DeserializeOwned + std::marker::Send + 'static,
373    {
374        // Build and execute request manually since multipart::Form is not Clone
375        // and .eventsource() requires cloneability
376        let response = self
377            .http_client
378            .post(self.config.url(path))
379            .query(&self.config.query())
380            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
381            .headers(self.config.headers())
382            .send()
383            .await
384            .map_err(OpenAIError::Reqwest)?;
385
386        // Check for error status
387        if !response.status().is_success() {
388            return Err(read_response(response).await.unwrap_err());
389        }
390
391        // Convert response body to EventSource stream
392        let stream = response
393            .bytes_stream()
394            .map(|result| result.map_err(std::io::Error::other));
395        let event_stream = eventsource_stream::EventStream::new(stream);
396
397        // Convert EventSource stream to our expected format
398        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
399
400        tokio::spawn(async move {
401            use futures::StreamExt;
402            let mut event_stream = std::pin::pin!(event_stream);
403
404            while let Some(event_result) = event_stream.next().await {
405                match event_result {
406                    Err(e) => {
407                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
408                            StreamError::EventStream(e.to_string()),
409                        )))) {
410                            break;
411                        }
412                    }
413                    Ok(event) => {
414                        // eventsource_stream::Event is a struct with data field
415                        if event.data == "[DONE]" {
416                            break;
417                        }
418
419                        let response = match serde_json::from_str::<O>(&event.data) {
420                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
421                            Ok(output) => Ok(output),
422                        };
423
424                        if let Err(_e) = tx.send(response) {
425                            break;
426                        }
427                    }
428                }
429            }
430        });
431
432        Ok(Box::pin(
433            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
434        ))
435    }
436
437    /// Execute a HTTP request and retry on rate limit
438    ///
439    /// request_maker serves one purpose: to be able to create request again
440    /// to retry API call after getting rate limited. request_maker is async because
441    /// reqwest::multipart::Form is created by async calls to read files for uploads.
442    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
443    where
444        M: Fn() -> Fut,
445        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
446    {
447        let client = self.http_client.clone();
448
449        backoff::future::retry(self.backoff.clone(), || async {
450            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
451            let response = client
452                .execute(request)
453                .await
454                .map_err(OpenAIError::Reqwest)
455                .map_err(backoff::Error::Permanent)?;
456
457            let status = response.status();
458
459            match read_response(response).await {
460                Ok(bytes) => Ok(bytes),
461                Err(e) => {
462                    match e {
463                        OpenAIError::ApiError(api_error) => {
464                            if status.is_server_error() {
465                                Err(backoff::Error::Transient {
466                                    err: OpenAIError::ApiError(api_error),
467                                    retry_after: None,
468                                })
469                            } else if status.as_u16() == 429
470                                && api_error.r#type != Some("insufficient_quota".to_string())
471                            {
472                                // Rate limited retry...
473                                tracing::warn!("Rate limited: {}", api_error.message);
474                                Err(backoff::Error::Transient {
475                                    err: OpenAIError::ApiError(api_error),
476                                    retry_after: None,
477                                })
478                            } else {
479                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
480                            }
481                        }
482                        _ => Err(backoff::Error::Permanent(e)),
483                    }
484                }
485            }
486        })
487        .await
488    }
489
490    /// Execute a HTTP request and retry on rate limit
491    ///
492    /// request_maker serves one purpose: to be able to create request again
493    /// to retry API call after getting rate limited. request_maker is async because
494    /// reqwest::multipart::Form is created by async calls to read files for uploads.
495    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
496    where
497        O: DeserializeOwned,
498        M: Fn() -> Fut,
499        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
500    {
501        let bytes = self.execute_raw(request_maker).await?;
502
503        let response: O = serde_json::from_slice(bytes.as_ref())
504            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
505
506        Ok(response)
507    }
508
509    /// Make HTTP POST request to receive SSE
510    pub(crate) async fn post_stream<I, O>(
511        &self,
512        path: &str,
513        request: I,
514    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
515    where
516        I: Serialize,
517        O: DeserializeOwned + std::marker::Send + 'static,
518    {
519        let event_source = self
520            .http_client
521            .post(self.config.url(path))
522            .query(&self.config.query())
523            .headers(self.config.headers())
524            .json(&request)
525            .eventsource()
526            .unwrap();
527
528        stream(event_source).await
529    }
530
531    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
532        &self,
533        path: &str,
534        request: I,
535        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
536    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
537    where
538        I: Serialize,
539        O: DeserializeOwned + std::marker::Send + 'static,
540    {
541        let event_source = self
542            .http_client
543            .post(self.config.url(path))
544            .query(&self.config.query())
545            .headers(self.config.headers())
546            .json(&request)
547            .eventsource()
548            .unwrap();
549
550        stream_mapped_raw_events(event_source, event_mapper).await
551    }
552
553    /// Make HTTP GET request to receive SSE
554    pub(crate) async fn _get_stream<Q, O>(
555        &self,
556        path: &str,
557        query: &Q,
558    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
559    where
560        Q: Serialize + ?Sized,
561        O: DeserializeOwned + std::marker::Send + 'static,
562    {
563        let event_source = self
564            .http_client
565            .get(self.config.url(path))
566            .query(query)
567            .query(&self.config.query())
568            .headers(self.config.headers())
569            .eventsource()
570            .unwrap();
571
572        stream(event_source).await
573    }
574}
575
576async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
577    let status = response.status();
578    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
579
580    if status.is_server_error() {
581        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
582        let message: String = String::from_utf8_lossy(&bytes).into_owned();
583        tracing::warn!("Server error: {status} - {message}");
584        return Err(OpenAIError::ApiError(ApiError {
585            message,
586            r#type: None,
587            param: None,
588            code: None,
589        }));
590    }
591
592    // Deserialize response body from either error object or actual response object
593    if !status.is_success() {
594        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
595            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
596
597        return Err(OpenAIError::ApiError(wrapped_error.error));
598    }
599
600    Ok(bytes)
601}
602
603async fn map_stream_error(value: EventSourceError) -> OpenAIError {
604    match value {
605        EventSourceError::InvalidStatusCode(status_code, response) => {
606            read_response(response).await.expect_err(&format!(
607                "Unreachable because read_response returns err when status_code {status_code} is invalid"
608            ))
609        }
610        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
611    }
612}
613
614/// Request which responds with SSE.
615/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
616pub(crate) async fn stream<O>(
617    mut event_source: EventSource,
618) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
619where
620    O: DeserializeOwned + std::marker::Send + 'static,
621{
622    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
623
624    tokio::spawn(async move {
625        while let Some(ev) = event_source.next().await {
626            match ev {
627                Err(e) => {
628                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
629                        // rx dropped
630                        break;
631                    }
632                }
633                Ok(event) => match event {
634                    Event::Message(message) => {
635                        if message.data == "[DONE]" {
636                            break;
637                        }
638
639                        let response = match serde_json::from_str::<O>(&message.data) {
640                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
641                            Ok(output) => Ok(output),
642                        };
643
644                        if let Err(_e) = tx.send(response) {
645                            // rx dropped
646                            break;
647                        }
648                    }
649                    Event::Open => continue,
650                },
651            }
652        }
653
654        event_source.close();
655    });
656
657    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
658}
659
660pub(crate) async fn stream_mapped_raw_events<O>(
661    mut event_source: EventSource,
662    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
663) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
664where
665    O: DeserializeOwned + std::marker::Send + 'static,
666{
667    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
668
669    tokio::spawn(async move {
670        while let Some(ev) = event_source.next().await {
671            match ev {
672                Err(e) => {
673                    if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
674                        // rx dropped
675                        break;
676                    }
677                }
678                Ok(event) => match event {
679                    Event::Message(message) => {
680                        let mut done = false;
681
682                        if message.data == "[DONE]" {
683                            done = true;
684                        }
685
686                        let response = event_mapper(message);
687
688                        if let Err(_e) = tx.send(response) {
689                            // rx dropped
690                            break;
691                        }
692
693                        if done {
694                            break;
695                        }
696                    }
697                    Event::Open => continue,
698                },
699            }
700        }
701
702        event_source.close();
703    });
704
705    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
706}