async_openai_alt/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
6use serde::{de::DeserializeOwned, Serialize};
7
8use crate::{
9    config::{Config, OpenAIConfig},
10    error::{map_deserialization_error, OpenAIError, WrappedError},
11    file::Files,
12    image::Images,
13    moderation::Moderations,
14    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
15    Models, Projects, Threads, Users, VectorStores,
16};
17
18#[derive(Debug, Clone, Default)]
19/// Client is a container for config, backoff and http_client
20/// used to make API calls.
21pub struct Client<C: Config> {
22    http_client: reqwest::Client,
23    config: C,
24    backoff: backoff::ExponentialBackoff,
25}
26
27impl Client<OpenAIConfig> {
28    /// Client with default [OpenAIConfig]
29    pub fn new() -> Self {
30        Self::default()
31    }
32}
33
34impl<C: Config> Client<C> {
35    /// Create client with a custom HTTP client, OpenAI config, and backoff.
36    pub fn build(
37        http_client: reqwest::Client,
38        config: C,
39        backoff: backoff::ExponentialBackoff,
40    ) -> Self {
41        Self {
42            http_client,
43            config,
44            backoff,
45        }
46    }
47
48    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
49    pub fn with_config(config: C) -> Self {
50        Self {
51            http_client: reqwest::Client::new(),
52            config,
53            backoff: Default::default(),
54        }
55    }
56
57    /// Provide your own [client] to make HTTP requests with.
58    ///
59    /// [client]: reqwest::Client
60    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
61        self.http_client = http_client;
62        self
63    }
64
65    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
66    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
67        self.backoff = backoff;
68        self
69    }
70
71    // API groups
72
73    /// To call [Models] group related APIs using this client.
74    pub fn models(&self) -> Models<C> {
75        Models::new(self)
76    }
77
78    /// To call [Completions] group related APIs using this client.
79    pub fn completions(&self) -> Completions<C> {
80        Completions::new(self)
81    }
82
83    /// To call [Chat] group related APIs using this client.
84    pub fn chat(&self) -> Chat<C> {
85        Chat::new(self)
86    }
87
88    /// To call [Images] group related APIs using this client.
89    pub fn images(&self) -> Images<C> {
90        Images::new(self)
91    }
92
93    /// To call [Moderations] group related APIs using this client.
94    pub fn moderations(&self) -> Moderations<C> {
95        Moderations::new(self)
96    }
97
98    /// To call [Files] group related APIs using this client.
99    pub fn files(&self) -> Files<C> {
100        Files::new(self)
101    }
102
103    /// To call [FineTuning] group related APIs using this client.
104    pub fn fine_tuning(&self) -> FineTuning<C> {
105        FineTuning::new(self)
106    }
107
108    /// To call [Embeddings] group related APIs using this client.
109    pub fn embeddings(&self) -> Embeddings<C> {
110        Embeddings::new(self)
111    }
112
113    /// To call [Audio] group related APIs using this client.
114    pub fn audio(&self) -> Audio<C> {
115        Audio::new(self)
116    }
117
118    /// To call [Assistants] group related APIs using this client.
119    pub fn assistants(&self) -> Assistants<C> {
120        Assistants::new(self)
121    }
122
123    /// To call [Threads] group related APIs using this client.
124    pub fn threads(&self) -> Threads<C> {
125        Threads::new(self)
126    }
127
128    /// To call [VectorStores] group related APIs using this client.
129    pub fn vector_stores(&self) -> VectorStores<C> {
130        VectorStores::new(self)
131    }
132
133    /// To call [Batches] group related APIs using this client.
134    pub fn batches(&self) -> Batches<C> {
135        Batches::new(self)
136    }
137
138    /// To call [AuditLogs] group related APIs using this client.
139    pub fn audit_logs(&self) -> AuditLogs<C> {
140        AuditLogs::new(self)
141    }
142
143    /// To call [Invites] group related APIs using this client.
144    pub fn invites(&self) -> Invites<C> {
145        Invites::new(self)
146    }
147
148    /// To call [Users] group related APIs using this client.
149    pub fn users(&self) -> Users<C> {
150        Users::new(self)
151    }
152
153    /// To call [Projects] group related APIs using this client.
154    pub fn projects(&self) -> Projects<C> {
155        Projects::new(self)
156    }
157
158    pub fn config(&self) -> &C {
159        &self.config
160    }
161
162    /// Make a GET request to {path} and deserialize the response body
163    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
164    where
165        O: DeserializeOwned,
166    {
167        let request_maker = || async {
168            Ok(self
169                .http_client
170                .get(self.config.url(path))
171                .query(&self.config.query())
172                .headers(self.config.headers())
173                .build()?)
174        };
175
176        self.execute(request_maker).await
177    }
178
179    /// Make a GET request to {path} with given Query and deserialize the response body
180    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
181    where
182        O: DeserializeOwned,
183        Q: Serialize + ?Sized,
184    {
185        let request_maker = || async {
186            Ok(self
187                .http_client
188                .get(self.config.url(path))
189                .query(&self.config.query())
190                .query(query)
191                .headers(self.config.headers())
192                .build()?)
193        };
194
195        self.execute(request_maker).await
196    }
197
198    /// Make a DELETE request to {path} and deserialize the response body
199    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
200    where
201        O: DeserializeOwned,
202    {
203        let request_maker = || async {
204            Ok(self
205                .http_client
206                .delete(self.config.url(path))
207                .query(&self.config.query())
208                .headers(self.config.headers())
209                .build()?)
210        };
211
212        self.execute(request_maker).await
213    }
214
215    /// Make a GET request to {path} and return the response body
216    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
217        let request_maker = || async {
218            Ok(self
219                .http_client
220                .get(self.config.url(path))
221                .query(&self.config.query())
222                .headers(self.config.headers())
223                .build()?)
224        };
225
226        self.execute_raw(request_maker).await
227    }
228
229    /// Make a POST request to {path} and return the response body
230    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
231    where
232        I: Serialize,
233    {
234        let request_maker = || async {
235            Ok(self
236                .http_client
237                .post(self.config.url(path))
238                .query(&self.config.query())
239                .headers(self.config.headers())
240                .json(&request)
241                .build()?)
242        };
243
244        self.execute_raw(request_maker).await
245    }
246
247    /// Make a POST request to {path} and deserialize the response body
248    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
249    where
250        I: Serialize,
251        O: DeserializeOwned,
252    {
253        let request_maker = || async {
254            Ok(self
255                .http_client
256                .post(self.config.url(path))
257                .query(&self.config.query())
258                .headers(self.config.headers())
259                .json(&request)
260                .build()?)
261        };
262
263        self.execute(request_maker).await
264    }
265
266    /// POST a form at {path} and return the response body
267    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
268    where
269        reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
270        F: Clone,
271    {
272        let request_maker = || async {
273            Ok(self
274                .http_client
275                .post(self.config.url(path))
276                .query(&self.config.query())
277                .headers(self.config.headers())
278                .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
279                .build()?)
280        };
281
282        self.execute_raw(request_maker).await
283    }
284
285    /// POST a form at {path} and deserialize the response body
286    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
287    where
288        O: DeserializeOwned,
289        reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
290        F: Clone,
291    {
292        let request_maker = || async {
293            Ok(self
294                .http_client
295                .post(self.config.url(path))
296                .query(&self.config.query())
297                .headers(self.config.headers())
298                .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
299                .build()?)
300        };
301
302        self.execute(request_maker).await
303    }
304
305    /// Execute a HTTP request and retry on rate limit
306    ///
307    /// request_maker serves one purpose: to be able to create request again
308    /// to retry API call after getting rate limited. request_maker is async because
309    /// reqwest::multipart::Form is created by async calls to read files for uploads.
310    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
311    where
312        M: Fn() -> Fut,
313        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
314    {
315        let client = self.http_client.clone();
316
317        backoff::future::retry(self.backoff.clone(), || async {
318            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
319            let response = client
320                .execute(request)
321                .await
322                .map_err(OpenAIError::Reqwest)
323                .map_err(backoff::Error::Permanent)?;
324
325            let status = response.status();
326            let bytes = response
327                .bytes()
328                .await
329                .map_err(OpenAIError::Reqwest)
330                .map_err(backoff::Error::Permanent)?;
331
332            // Deserialize response body from either error object or actual response object
333            if !status.is_success() {
334                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
335                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
336                    .map_err(backoff::Error::Permanent)?;
337
338                if status.as_u16() == 429
339                    // API returns 429 also when:
340                    // "You exceeded your current quota, please check your plan and billing details."
341                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
342                {
343                    // Rate limited retry...
344                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
345                    return Err(backoff::Error::Transient {
346                        err: OpenAIError::ApiError(wrapped_error.error),
347                        retry_after: None,
348                    });
349                } else {
350                    return Err(backoff::Error::Permanent(OpenAIError::ApiError(
351                        wrapped_error.error,
352                    )));
353                }
354            }
355
356            Ok(bytes)
357        })
358        .await
359    }
360
361    /// Execute a HTTP request and retry on rate limit
362    ///
363    /// request_maker serves one purpose: to be able to create request again
364    /// to retry API call after getting rate limited. request_maker is async because
365    /// reqwest::multipart::Form is created by async calls to read files for uploads.
366    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
367    where
368        O: DeserializeOwned,
369        M: Fn() -> Fut,
370        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
371    {
372        let bytes = self.execute_raw(request_maker).await?;
373
374        let response: O = serde_json::from_slice(bytes.as_ref())
375            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
376
377        Ok(response)
378    }
379
380    /// Make HTTP POST request to receive SSE
381    pub(crate) async fn post_stream<I, O>(
382        &self,
383        path: &str,
384        request: I,
385    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
386    where
387        I: Serialize,
388        O: DeserializeOwned + std::marker::Send + 'static,
389    {
390        let event_source = self
391            .http_client
392            .post(self.config.url(path))
393            .query(&self.config.query())
394            .headers(self.config.headers())
395            .json(&request)
396            .eventsource()
397            .unwrap();
398
399        stream(event_source).await
400    }
401
402    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
403        &self,
404        path: &str,
405        request: I,
406        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
407    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
408    where
409        I: Serialize,
410        O: DeserializeOwned + std::marker::Send + 'static,
411    {
412        let event_source = self
413            .http_client
414            .post(self.config.url(path))
415            .query(&self.config.query())
416            .headers(self.config.headers())
417            .json(&request)
418            .eventsource()
419            .unwrap();
420
421        stream_mapped_raw_events(event_source, event_mapper).await
422    }
423
424    /// Make HTTP GET request to receive SSE
425    pub(crate) async fn _get_stream<Q, O>(
426        &self,
427        path: &str,
428        query: &Q,
429    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
430    where
431        Q: Serialize + ?Sized,
432        O: DeserializeOwned + std::marker::Send + 'static,
433    {
434        let event_source = self
435            .http_client
436            .get(self.config.url(path))
437            .query(query)
438            .query(&self.config.query())
439            .headers(self.config.headers())
440            .eventsource()
441            .unwrap();
442
443        stream(event_source).await
444    }
445}
446
447/// Request which responds with SSE.
448/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
449pub(crate) async fn stream<O>(
450    mut event_source: EventSource,
451) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
452where
453    O: DeserializeOwned + std::marker::Send + 'static,
454{
455    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
456
457    tokio::spawn(async move {
458        while let Some(ev) = event_source.next().await {
459            match ev {
460                Err(e) => {
461                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
462                        // rx dropped
463                        break;
464                    }
465                }
466                Ok(event) => match event {
467                    Event::Message(message) => {
468                        if message.data == "[DONE]" {
469                            break;
470                        }
471
472                        let response = match serde_json::from_str::<O>(&message.data) {
473                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
474                            Ok(output) => Ok(output),
475                        };
476
477                        if let Err(_e) = tx.send(response) {
478                            // rx dropped
479                            break;
480                        }
481                    }
482                    Event::Open => continue,
483                },
484            }
485        }
486
487        event_source.close();
488    });
489
490    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
491}
492
493pub(crate) async fn stream_mapped_raw_events<O>(
494    mut event_source: EventSource,
495    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
496) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
497where
498    O: DeserializeOwned + std::marker::Send + 'static,
499{
500    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
501
502    tokio::spawn(async move {
503        while let Some(ev) = event_source.next().await {
504            match ev {
505                Err(e) => {
506                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
507                        // rx dropped
508                        break;
509                    }
510                }
511                Ok(event) => match event {
512                    Event::Message(message) => {
513                        let mut done = false;
514
515                        if message.data == "[DONE]" {
516                            done = true;
517                        }
518
519                        let response = event_mapper(message);
520
521                        if let Err(_e) = tx.send(response) {
522                            // rx dropped
523                            break;
524                        }
525
526                        if done {
527                            break;
528                        }
529                    }
530                    Event::Open => continue,
531                },
532            }
533        }
534
535        event_source.close();
536    });
537
538    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
539}