async_openai_compat/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Error, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17    Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21/// Client is a container for config, backoff and http_client
22/// used to make API calls.
23pub struct Client<C: Config> {
24    http_client: reqwest::Client,
25    config: C,
26    backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30    /// Client with default [OpenAIConfig]
31    pub fn new() -> Self {
32        Self::default()
33    }
34}
35
36#[derive(Debug, Deserialize)]
37struct CustomError {
38    error: String
39}
40
41impl<C: Config> Client<C> {
42    /// Create client with a custom HTTP client, OpenAI config, and backoff.
43    pub fn build(
44        http_client: reqwest::Client,
45        config: C,
46        backoff: backoff::ExponentialBackoff,
47    ) -> Self {
48        Self {
49            http_client,
50            config,
51            backoff,
52        }
53    }
54
55    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
56    pub fn with_config(config: C) -> Self {
57        Self {
58            http_client: reqwest::Client::new(),
59            config,
60            backoff: Default::default(),
61        }
62    }
63
64    /// Provide your own [client] to make HTTP requests with.
65    ///
66    /// [client]: reqwest::Client
67    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68        self.http_client = http_client;
69        self
70    }
71
72    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
73    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74        self.backoff = backoff;
75        self
76    }
77
78    // API groups
79
80    /// To call [Models] group related APIs using this client.
81    pub fn models(&self) -> Models<C> {
82        Models::new(self)
83    }
84
85    /// To call [Completions] group related APIs using this client.
86    pub fn completions(&self) -> Completions<C> {
87        Completions::new(self)
88    }
89
90    /// To call [Chat] group related APIs using this client.
91    pub fn chat(&self) -> Chat<C> {
92        Chat::new(self)
93    }
94
95    /// To call [Images] group related APIs using this client.
96    pub fn images(&self) -> Images<C> {
97        Images::new(self)
98    }
99
100    /// To call [Moderations] group related APIs using this client.
101    pub fn moderations(&self) -> Moderations<C> {
102        Moderations::new(self)
103    }
104
105    /// To call [Files] group related APIs using this client.
106    pub fn files(&self) -> Files<C> {
107        Files::new(self)
108    }
109
110    /// To call [Uploads] group related APIs using this client.
111    pub fn uploads(&self) -> Uploads<C> {
112        Uploads::new(self)
113    }
114
115    /// To call [FineTuning] group related APIs using this client.
116    pub fn fine_tuning(&self) -> FineTuning<C> {
117        FineTuning::new(self)
118    }
119
120    /// To call [Embeddings] group related APIs using this client.
121    pub fn embeddings(&self) -> Embeddings<C> {
122        Embeddings::new(self)
123    }
124
125    /// To call [Audio] group related APIs using this client.
126    pub fn audio(&self) -> Audio<C> {
127        Audio::new(self)
128    }
129
130    /// To call [Assistants] group related APIs using this client.
131    pub fn assistants(&self) -> Assistants<C> {
132        Assistants::new(self)
133    }
134
135    /// To call [Threads] group related APIs using this client.
136    pub fn threads(&self) -> Threads<C> {
137        Threads::new(self)
138    }
139
140    /// To call [VectorStores] group related APIs using this client.
141    pub fn vector_stores(&self) -> VectorStores<C> {
142        VectorStores::new(self)
143    }
144
145    /// To call [Batches] group related APIs using this client.
146    pub fn batches(&self) -> Batches<C> {
147        Batches::new(self)
148    }
149
150    /// To call [AuditLogs] group related APIs using this client.
151    pub fn audit_logs(&self) -> AuditLogs<C> {
152        AuditLogs::new(self)
153    }
154
155    /// To call [Invites] group related APIs using this client.
156    pub fn invites(&self) -> Invites<C> {
157        Invites::new(self)
158    }
159
160    /// To call [Users] group related APIs using this client.
161    pub fn users(&self) -> Users<C> {
162        Users::new(self)
163    }
164
165    /// To call [Projects] group related APIs using this client.
166    pub fn projects(&self) -> Projects<C> {
167        Projects::new(self)
168    }
169
170    /// To call [Responses] group related APIs using this client.
171    pub fn responses(&self) -> Responses<C> {
172        Responses::new(self)
173    }
174
175    pub fn config(&self) -> &C {
176        &self.config
177    }
178
179    /// Make a GET request to {path} and deserialize the response body
180    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
181    where
182        O: DeserializeOwned,
183    {
184        let request_maker = || async {
185            Ok(self
186                .http_client
187                .get(self.config.url(path))
188                .query(&self.config.query())
189                .headers(self.config.headers())
190                .build()?)
191        };
192
193        self.execute(request_maker).await
194    }
195
196    /// Make a GET request to {path} with given Query and deserialize the response body
197    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
198    where
199        O: DeserializeOwned,
200        Q: Serialize + ?Sized,
201    {
202        let request_maker = || async {
203            Ok(self
204                .http_client
205                .get(self.config.url(path))
206                .query(&self.config.query())
207                .query(query)
208                .headers(self.config.headers())
209                .build()?)
210        };
211
212        self.execute(request_maker).await
213    }
214
215    /// Make a DELETE request to {path} and deserialize the response body
216    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
217    where
218        O: DeserializeOwned,
219    {
220        let request_maker = || async {
221            Ok(self
222                .http_client
223                .delete(self.config.url(path))
224                .query(&self.config.query())
225                .headers(self.config.headers())
226                .build()?)
227        };
228
229        self.execute(request_maker).await
230    }
231
232    /// Make a GET request to {path} and return the response body
233    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
234        let request_maker = || async {
235            Ok(self
236                .http_client
237                .get(self.config.url(path))
238                .query(&self.config.query())
239                .headers(self.config.headers())
240                .build()?)
241        };
242
243        self.execute_raw(request_maker).await
244    }
245
246    /// Make a POST request to {path} and return the response body
247    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
248    where
249        I: Serialize,
250    {
251        let request_maker = || async {
252            Ok(self
253                .http_client
254                .post(self.config.url(path))
255                .query(&self.config.query())
256                .headers(self.config.headers())
257                .json(&request)
258                .build()?)
259        };
260
261        self.execute_raw(request_maker).await
262    }
263
264    /// Make a POST request to {path} and deserialize the response body
265    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
266    where
267        I: Serialize,
268        O: DeserializeOwned,
269    {
270        let request_maker = || async {
271            Ok(self
272                .http_client
273                .post(self.config.url(path))
274                .query(&self.config.query())
275                .headers(self.config.headers())
276                .json(&request)
277                .build()?)
278        };
279
280        self.execute(request_maker).await
281    }
282
283    /// POST a form at {path} and return the response body
284    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
285    where
286        Form: AsyncTryFrom<F, Error = OpenAIError>,
287        F: Clone,
288    {
289        let request_maker = || async {
290            Ok(self
291                .http_client
292                .post(self.config.url(path))
293                .query(&self.config.query())
294                .headers(self.config.headers())
295                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
296                .build()?)
297        };
298
299        self.execute_raw(request_maker).await
300    }
301
302    /// POST a form at {path} and deserialize the response body
303    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
304    where
305        O: DeserializeOwned,
306        Form: AsyncTryFrom<F, Error = OpenAIError>,
307        F: Clone,
308    {
309        let request_maker = || async {
310            Ok(self
311                .http_client
312                .post(self.config.url(path))
313                .query(&self.config.query())
314                .headers(self.config.headers())
315                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
316                .build()?)
317        };
318
319        self.execute(request_maker).await
320    }
321
322    /// Execute a HTTP request and retry on rate limit
323    ///
324    /// request_maker serves one purpose: to be able to create request again
325    /// to retry API call after getting rate limited. request_maker is async because
326    /// reqwest::multipart::Form is created by async calls to read files for uploads.
327    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
328    where
329        M: Fn() -> Fut,
330        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
331    {
332        let client = self.http_client.clone();
333
334        backoff::future::retry(self.backoff.clone(), || async {
335            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
336            let response = client
337                .execute(request)
338                .await
339                .map_err(OpenAIError::Reqwest)
340                .map_err(backoff::Error::Permanent)?;
341
342            let status = response.status();
343            let bytes = response
344                .bytes()
345                .await
346                .map_err(OpenAIError::Reqwest)
347                .map_err(backoff::Error::Permanent)?;
348
349            if status.is_server_error() {
350                // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
351                let message: String = String::from_utf8_lossy(&bytes).into_owned();
352                tracing::warn!("Server error: {status} - {message}");
353                return Err(backoff::Error::Transient {
354                    err: OpenAIError::ApiError(ApiError {
355                        message,
356                        r#type: None,
357                        param: None,
358                        code: None,
359                    }),
360                    retry_after: None,
361                });
362            }
363
364            // Deserialize response body from either error object or actual response object
365            if !status.is_success() {
366                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
367                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
368                    .map_err(backoff::Error::Permanent)?;
369
370                if status.as_u16() == 429
371                    // API returns 429 also when:
372                    // "You exceeded your current quota, please check your plan and billing details."
373                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
374                {
375                    // Rate limited retry...
376                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
377                    return Err(backoff::Error::Transient {
378                        err: OpenAIError::ApiError(wrapped_error.error),
379                        retry_after: None,
380                    });
381                } else {
382                    return Err(backoff::Error::Permanent(OpenAIError::ApiError(
383                        wrapped_error.error,
384                    )));
385                }
386            }
387
388            Ok(bytes)
389        })
390        .await
391    }
392
393    /// Execute a HTTP request and retry on rate limit
394    ///
395    /// request_maker serves one purpose: to be able to create request again
396    /// to retry API call after getting rate limited. request_maker is async because
397    /// reqwest::multipart::Form is created by async calls to read files for uploads.
398    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
399    where
400        O: DeserializeOwned,
401        M: Fn() -> Fut,
402        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
403    {
404        let bytes = self.execute_raw(request_maker).await?;
405
406        let response: O = serde_json::from_slice(bytes.as_ref())
407            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
408
409        Ok(response)
410    }
411
412    /// Make HTTP POST request to receive SSE
413    pub(crate) async fn post_stream<I, O>(
414        &self,
415        path: &str,
416        request: I,
417    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
418    where
419        I: Serialize,
420        O: DeserializeOwned + std::marker::Send + 'static,
421    {
422        let event_source = self
423            .http_client
424            .post(self.config.url(path))
425            .query(&self.config.query())
426            .headers(self.config.headers())
427            .json(&request)
428            .eventsource()
429            .unwrap();
430
431        stream(event_source).await
432    }
433
434    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
435        &self,
436        path: &str,
437        request: I,
438        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
439    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
440    where
441        I: Serialize,
442        O: DeserializeOwned + std::marker::Send + 'static,
443    {
444        let event_source = self
445            .http_client
446            .post(self.config.url(path))
447            .query(&self.config.query())
448            .headers(self.config.headers())
449            .json(&request)
450            .eventsource()
451            .unwrap();
452
453        stream_mapped_raw_events(event_source, event_mapper).await
454    }
455
456    /// Make HTTP GET request to receive SSE
457    pub(crate) async fn _get_stream<Q, O>(
458        &self,
459        path: &str,
460        query: &Q,
461    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
462    where
463        Q: Serialize + ?Sized,
464        O: DeserializeOwned + std::marker::Send + 'static,
465    {
466        let event_source = self
467            .http_client
468            .get(self.config.url(path))
469            .query(query)
470            .query(&self.config.query())
471            .headers(self.config.headers())
472            .eventsource()
473            .unwrap();
474
475        stream(event_source).await
476    }
477}
478
479async fn handle_eventsource_error(e: Error) -> Result<(), OpenAIError> {
480    let error_text = e.to_string();
481    if let Error::InvalidStatusCode(code, response) = e {
482        if code.as_u16() == 401 {
483            return Err(OpenAIError::ApiError(ApiError {
484                message: "Unauthorized".to_string(),
485                r#type: None,
486                param: None,
487                code: None,
488            }));
489        }
490
491        if code.as_u16() == 429 {
492            return Err(OpenAIError::ApiError(ApiError {
493                message: "Rate limited by provider".to_string(),
494                r#type: None,
495                param: None,
496                code: None,
497            }));
498        }
499
500        if code.as_u16() == 408 {
501            return Err(OpenAIError::ApiError(ApiError {
502                message: "Request to provider timed out".to_string(),
503                r#type: None,
504                param: None,
505                code: None,
506            }));
507        }
508
509        if let Ok(text) = response.text().await {
510            if code.as_u16() == 400 {
511                let custom_error = serde_json::from_str::<CustomError>(&text);
512                if let Ok(error) = custom_error {
513                    return Err(OpenAIError::ApiError(ApiError {
514                        message: error.error,
515                        r#type: None,
516                        param: None,
517                        code: None,
518                    }));
519                }
520            }
521
522            let api_error = serde_json::from_str::<WrappedError>(&text);
523            if let Ok(e) = api_error {
524                return Err(OpenAIError::ApiError(e.error));
525            }
526        }
527    }
528
529    Err(OpenAIError::StreamError(error_text))
530}
531
532/// Request which responds with SSE.
533/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
534pub(crate) async fn stream<O>(
535    mut event_source: EventSource,
536) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
537where
538    O: DeserializeOwned + std::marker::Send + 'static,
539{
540    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
541
542    tokio::spawn(async move {
543        while let Some(ev) = event_source.next().await {
544            match ev {
545                Err(e) => {
546                    if let Err(e) = handle_eventsource_error(e).await {
547                        if let Err(_e) = tx.send(Err(e)) {
548                            // rx dropped
549                            break;
550                        }
551                    }
552                }
553                Ok(event) => match event {
554                    Event::Message(message) => {
555                        if message.data == "[DONE]" {
556                            break;
557                        }
558
559                        let response = match serde_json::from_str::<O>(&message.data) {
560                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
561                            Ok(output) => Ok(output),
562                        };
563
564                        if let Err(_e) = tx.send(response) {
565                            // rx dropped
566                            break;
567                        }
568                    }
569                    Event::Open => continue,
570                },
571            }
572        }
573
574        event_source.close();
575    });
576
577    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
578}
579
580pub(crate) async fn stream_mapped_raw_events<O>(
581    mut event_source: EventSource,
582    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
583) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
584where
585    O: DeserializeOwned + std::marker::Send + 'static,
586{
587    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
588
589    tokio::spawn(async move {
590        while let Some(ev) = event_source.next().await {
591            match ev {
592                Err(e) => {
593                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
594                        // rx dropped
595                        break;
596                    }
597                }
598                Ok(event) => match event {
599                    Event::Message(message) => {
600                        let mut done = false;
601
602                        if message.data == "[DONE]" {
603                            done = true;
604                        }
605
606                        let response = event_mapper(message);
607
608                        if let Err(_e) = tx.send(response) {
609                            // rx dropped
610                            break;
611                        }
612
613                        if done {
614                            break;
615                        }
616                    }
617                    Event::Open => continue,
618                },
619            }
620        }
621
622        event_source.close();
623    });
624
625    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
626}