async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    traits::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17    Models, Projects, Threads, Uploads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21/// Client is a container for config, backoff and http_client
22/// used to make API calls.
23pub struct Client<C: Config> {
24    http_client: reqwest::Client,
25    config: C,
26    backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30    /// Client with default [OpenAIConfig]
31    pub fn new() -> Self {
32        Self::default()
33    }
34}
35
36impl<C: Config> Client<C> {
37    /// Create client with a custom HTTP client, OpenAI config, and backoff.
38    pub fn build(
39        http_client: reqwest::Client,
40        config: C,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
51    pub fn with_config(config: C) -> Self {
52        Self {
53            http_client: reqwest::Client::new(),
54            config,
55            backoff: Default::default(),
56        }
57    }
58
59    /// Provide your own [client] to make HTTP requests with.
60    ///
61    /// [client]: reqwest::Client
62    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63        self.http_client = http_client;
64        self
65    }
66
67    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
68    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69        self.backoff = backoff;
70        self
71    }
72
73    // API groups
74
75    /// To call [Models] group related APIs using this client.
76    pub fn models(&self) -> Models<C> {
77        Models::new(self)
78    }
79
80    /// To call [Completions] group related APIs using this client.
81    pub fn completions(&self) -> Completions<C> {
82        Completions::new(self)
83    }
84
85    /// To call [Chat] group related APIs using this client.
86    pub fn chat(&self) -> Chat<C> {
87        Chat::new(self)
88    }
89
90    /// To call [Images] group related APIs using this client.
91    pub fn images(&self) -> Images<C> {
92        Images::new(self)
93    }
94
95    /// To call [Moderations] group related APIs using this client.
96    pub fn moderations(&self) -> Moderations<C> {
97        Moderations::new(self)
98    }
99
100    /// To call [Files] group related APIs using this client.
101    pub fn files(&self) -> Files<C> {
102        Files::new(self)
103    }
104
105    /// To call [Uploads] group related APIs using this client.
106    pub fn uploads(&self) -> Uploads<C> {
107        Uploads::new(self)
108    }
109
110    /// To call [FineTuning] group related APIs using this client.
111    pub fn fine_tuning(&self) -> FineTuning<C> {
112        FineTuning::new(self)
113    }
114
115    /// To call [Embeddings] group related APIs using this client.
116    pub fn embeddings(&self) -> Embeddings<C> {
117        Embeddings::new(self)
118    }
119
120    /// To call [Audio] group related APIs using this client.
121    pub fn audio(&self) -> Audio<C> {
122        Audio::new(self)
123    }
124
125    /// To call [Assistants] group related APIs using this client.
126    pub fn assistants(&self) -> Assistants<C> {
127        Assistants::new(self)
128    }
129
130    /// To call [Threads] group related APIs using this client.
131    pub fn threads(&self) -> Threads<C> {
132        Threads::new(self)
133    }
134
135    /// To call [VectorStores] group related APIs using this client.
136    pub fn vector_stores(&self) -> VectorStores<C> {
137        VectorStores::new(self)
138    }
139
140    /// To call [Batches] group related APIs using this client.
141    pub fn batches(&self) -> Batches<C> {
142        Batches::new(self)
143    }
144
145    /// To call [AuditLogs] group related APIs using this client.
146    pub fn audit_logs(&self) -> AuditLogs<C> {
147        AuditLogs::new(self)
148    }
149
150    /// To call [Invites] group related APIs using this client.
151    pub fn invites(&self) -> Invites<C> {
152        Invites::new(self)
153    }
154
155    /// To call [Users] group related APIs using this client.
156    pub fn users(&self) -> Users<C> {
157        Users::new(self)
158    }
159
160    /// To call [Projects] group related APIs using this client.
161    pub fn projects(&self) -> Projects<C> {
162        Projects::new(self)
163    }
164
165    pub fn config(&self) -> &C {
166        &self.config
167    }
168
169    /// Make a GET request to {path} and deserialize the response body
170    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
171    where
172        O: DeserializeOwned,
173    {
174        let request_maker = || async {
175            Ok(self
176                .http_client
177                .get(self.config.url(path))
178                .query(&self.config.query())
179                .headers(self.config.headers())
180                .build()?)
181        };
182
183        self.execute(request_maker).await
184    }
185
186    /// Make a GET request to {path} with given Query and deserialize the response body
187    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
188    where
189        O: DeserializeOwned,
190        Q: Serialize + ?Sized,
191    {
192        let request_maker = || async {
193            Ok(self
194                .http_client
195                .get(self.config.url(path))
196                .query(&self.config.query())
197                .query(query)
198                .headers(self.config.headers())
199                .build()?)
200        };
201
202        self.execute(request_maker).await
203    }
204
205    /// Make a DELETE request to {path} and deserialize the response body
206    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
207    where
208        O: DeserializeOwned,
209    {
210        let request_maker = || async {
211            Ok(self
212                .http_client
213                .delete(self.config.url(path))
214                .query(&self.config.query())
215                .headers(self.config.headers())
216                .build()?)
217        };
218
219        self.execute(request_maker).await
220    }
221
222    /// Make a GET request to {path} and return the response body
223    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
224        let request_maker = || async {
225            Ok(self
226                .http_client
227                .get(self.config.url(path))
228                .query(&self.config.query())
229                .headers(self.config.headers())
230                .build()?)
231        };
232
233        self.execute_raw(request_maker).await
234    }
235
236    /// Make a POST request to {path} and return the response body
237    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
238    where
239        I: Serialize,
240    {
241        let request_maker = || async {
242            Ok(self
243                .http_client
244                .post(self.config.url(path))
245                .query(&self.config.query())
246                .headers(self.config.headers())
247                .json(&request)
248                .build()?)
249        };
250
251        self.execute_raw(request_maker).await
252    }
253
254    /// Make a POST request to {path} and deserialize the response body
255    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
256    where
257        I: Serialize,
258        O: DeserializeOwned,
259    {
260        let request_maker = || async {
261            Ok(self
262                .http_client
263                .post(self.config.url(path))
264                .query(&self.config.query())
265                .headers(self.config.headers())
266                .json(&request)
267                .build()?)
268        };
269
270        self.execute(request_maker).await
271    }
272
273    /// POST a form at {path} and return the response body
274    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
275    where
276        Form: AsyncTryFrom<F, Error = OpenAIError>,
277        F: Clone,
278    {
279        let request_maker = || async {
280            Ok(self
281                .http_client
282                .post(self.config.url(path))
283                .query(&self.config.query())
284                .headers(self.config.headers())
285                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
286                .build()?)
287        };
288
289        self.execute_raw(request_maker).await
290    }
291
292    /// POST a form at {path} and deserialize the response body
293    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
294    where
295        O: DeserializeOwned,
296        Form: AsyncTryFrom<F, Error = OpenAIError>,
297        F: Clone,
298    {
299        let request_maker = || async {
300            Ok(self
301                .http_client
302                .post(self.config.url(path))
303                .query(&self.config.query())
304                .headers(self.config.headers())
305                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
306                .build()?)
307        };
308
309        self.execute(request_maker).await
310    }
311
312    /// Execute a HTTP request and retry on rate limit
313    ///
314    /// request_maker serves one purpose: to be able to create request again
315    /// to retry API call after getting rate limited. request_maker is async because
316    /// reqwest::multipart::Form is created by async calls to read files for uploads.
317    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
318    where
319        M: Fn() -> Fut,
320        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
321    {
322        let client = self.http_client.clone();
323
324        backoff::future::retry(self.backoff.clone(), || async {
325            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
326            let response = client
327                .execute(request)
328                .await
329                .map_err(OpenAIError::Reqwest)
330                .map_err(backoff::Error::Permanent)?;
331
332            let status = response.status();
333            let bytes = response
334                .bytes()
335                .await
336                .map_err(OpenAIError::Reqwest)
337                .map_err(backoff::Error::Permanent)?;
338
339            if status.is_server_error() {
340                // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
341                let message: String = String::from_utf8_lossy(&bytes).into_owned();
342                tracing::warn!("Server error: {status} - {message}");
343                return Err(backoff::Error::Transient {
344                    err: OpenAIError::ApiError(ApiError { message, r#type: None, param: None, code: None }),
345                    retry_after: None,
346                });
347            }
348
349            // Deserialize response body from either error object or actual response object
350            if !status.is_success() {
351                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
352                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
353                    .map_err(backoff::Error::Permanent)?;
354
355                if status.as_u16() == 429
356                    // API returns 429 also when:
357                    // "You exceeded your current quota, please check your plan and billing details."
358                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
359                {
360                    // Rate limited retry...
361                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
362                    return Err(backoff::Error::Transient {
363                        err: OpenAIError::ApiError(wrapped_error.error),
364                        retry_after: None,
365                    });
366                } else {
367                    return Err(backoff::Error::Permanent(OpenAIError::ApiError(
368                        wrapped_error.error,
369                    )));
370                }
371            }
372
373            Ok(bytes)
374        })
375        .await
376    }
377
378    /// Execute a HTTP request and retry on rate limit
379    ///
380    /// request_maker serves one purpose: to be able to create request again
381    /// to retry API call after getting rate limited. request_maker is async because
382    /// reqwest::multipart::Form is created by async calls to read files for uploads.
383    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
384    where
385        O: DeserializeOwned,
386        M: Fn() -> Fut,
387        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
388    {
389        let bytes = self.execute_raw(request_maker).await?;
390
391        let response: O = serde_json::from_slice(bytes.as_ref())
392            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
393
394        Ok(response)
395    }
396
397    /// Make HTTP POST request to receive SSE
398    pub(crate) async fn post_stream<I, O>(
399        &self,
400        path: &str,
401        request: I,
402    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
403    where
404        I: Serialize,
405        O: DeserializeOwned + std::marker::Send + 'static,
406    {
407        let event_source = self
408            .http_client
409            .post(self.config.url(path))
410            .query(&self.config.query())
411            .headers(self.config.headers())
412            .json(&request)
413            .eventsource()
414            .unwrap();
415
416        stream(event_source).await
417    }
418
419    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
420        &self,
421        path: &str,
422        request: I,
423        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
424    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
425    where
426        I: Serialize,
427        O: DeserializeOwned + std::marker::Send + 'static,
428    {
429        let event_source = self
430            .http_client
431            .post(self.config.url(path))
432            .query(&self.config.query())
433            .headers(self.config.headers())
434            .json(&request)
435            .eventsource()
436            .unwrap();
437
438        stream_mapped_raw_events(event_source, event_mapper).await
439    }
440
441    /// Make HTTP GET request to receive SSE
442    pub(crate) async fn _get_stream<Q, O>(
443        &self,
444        path: &str,
445        query: &Q,
446    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
447    where
448        Q: Serialize + ?Sized,
449        O: DeserializeOwned + std::marker::Send + 'static,
450    {
451        let event_source = self
452            .http_client
453            .get(self.config.url(path))
454            .query(query)
455            .query(&self.config.query())
456            .headers(self.config.headers())
457            .eventsource()
458            .unwrap();
459
460        stream(event_source).await
461    }
462}
463
464/// Request which responds with SSE.
465/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
466pub(crate) async fn stream<O>(
467    mut event_source: EventSource,
468) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
469where
470    O: DeserializeOwned + std::marker::Send + 'static,
471{
472    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
473
474    tokio::spawn(async move {
475        while let Some(ev) = event_source.next().await {
476            match ev {
477                Err(e) => {
478                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
479                        // rx dropped
480                        break;
481                    }
482                }
483                Ok(event) => match event {
484                    Event::Message(message) => {
485                        if message.data == "[DONE]" {
486                            break;
487                        }
488
489                        let response = match serde_json::from_str::<O>(&message.data) {
490                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
491                            Ok(output) => Ok(output),
492                        };
493
494                        if let Err(_e) = tx.send(response) {
495                            // rx dropped
496                            break;
497                        }
498                    }
499                    Event::Open => continue,
500                },
501            }
502        }
503
504        event_source.close();
505    });
506
507    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
508}
509
510pub(crate) async fn stream_mapped_raw_events<O>(
511    mut event_source: EventSource,
512    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
513) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
514where
515    O: DeserializeOwned + std::marker::Send + 'static,
516{
517    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
518
519    tokio::spawn(async move {
520        while let Some(ev) = event_source.next().await {
521            match ev {
522                Err(e) => {
523                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
524                        // rx dropped
525                        break;
526                    }
527                }
528                Ok(event) => match event {
529                    Event::Message(message) => {
530                        let mut done = false;
531
532                        if message.data == "[DONE]" {
533                            done = true;
534                        }
535
536                        let response = event_mapper(message);
537
538                        if let Err(_e) = tx.send(response) {
539                            // rx dropped
540                            break;
541                        }
542
543                        if done {
544                            break;
545                        }
546                    }
547                    Event::Open => continue,
548                },
549            }
550        }
551
552        event_source.close();
553    });
554
555    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
556}