async_openai_compat/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Error, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, OpenAIError, WrappedError},
12    file::Files,
13    image::Images,
14    moderation::Moderations,
15    util::AsyncTryFrom,
16    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
17    Models, Projects, Threads, Users, VectorStores,
18};
19
20#[derive(Debug, Clone, Default)]
21/// Client is a container for config, backoff and http_client
22/// used to make API calls.
23pub struct Client<C: Config> {
24    http_client: reqwest::Client,
25    config: C,
26    backoff: backoff::ExponentialBackoff,
27}
28
29impl Client<OpenAIConfig> {
30    /// Client with default [OpenAIConfig]
31    pub fn new() -> Self {
32        Self::default()
33    }
34}
35
36impl<C: Config> Client<C> {
37    /// Create client with a custom HTTP client, OpenAI config, and backoff.
38    pub fn build(
39        http_client: reqwest::Client,
40        config: C,
41        backoff: backoff::ExponentialBackoff,
42    ) -> Self {
43        Self {
44            http_client,
45            config,
46            backoff,
47        }
48    }
49
50    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
51    pub fn with_config(config: C) -> Self {
52        Self {
53            http_client: reqwest::Client::new(),
54            config,
55            backoff: Default::default(),
56        }
57    }
58
59    /// Provide your own [client] to make HTTP requests with.
60    ///
61    /// [client]: reqwest::Client
62    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
63        self.http_client = http_client;
64        self
65    }
66
67    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
68    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
69        self.backoff = backoff;
70        self
71    }
72
73    // API groups
74
75    /// To call [Models] group related APIs using this client.
76    pub fn models(&self) -> Models<C> {
77        Models::new(self)
78    }
79
80    /// To call [Completions] group related APIs using this client.
81    pub fn completions(&self) -> Completions<C> {
82        Completions::new(self)
83    }
84
85    /// To call [Chat] group related APIs using this client.
86    pub fn chat(&self) -> Chat<C> {
87        Chat::new(self)
88    }
89
90    /// To call [Images] group related APIs using this client.
91    pub fn images(&self) -> Images<C> {
92        Images::new(self)
93    }
94
95    /// To call [Moderations] group related APIs using this client.
96    pub fn moderations(&self) -> Moderations<C> {
97        Moderations::new(self)
98    }
99
100    /// To call [Files] group related APIs using this client.
101    pub fn files(&self) -> Files<C> {
102        Files::new(self)
103    }
104
105    /// To call [FineTuning] group related APIs using this client.
106    pub fn fine_tuning(&self) -> FineTuning<C> {
107        FineTuning::new(self)
108    }
109
110    /// To call [Embeddings] group related APIs using this client.
111    pub fn embeddings(&self) -> Embeddings<C> {
112        Embeddings::new(self)
113    }
114
115    /// To call [Audio] group related APIs using this client.
116    pub fn audio(&self) -> Audio<C> {
117        Audio::new(self)
118    }
119
120    /// To call [Assistants] group related APIs using this client.
121    pub fn assistants(&self) -> Assistants<C> {
122        Assistants::new(self)
123    }
124
125    /// To call [Threads] group related APIs using this client.
126    pub fn threads(&self) -> Threads<C> {
127        Threads::new(self)
128    }
129
130    /// To call [VectorStores] group related APIs using this client.
131    pub fn vector_stores(&self) -> VectorStores<C> {
132        VectorStores::new(self)
133    }
134
135    /// To call [Batches] group related APIs using this client.
136    pub fn batches(&self) -> Batches<C> {
137        Batches::new(self)
138    }
139
140    /// To call [AuditLogs] group related APIs using this client.
141    pub fn audit_logs(&self) -> AuditLogs<C> {
142        AuditLogs::new(self)
143    }
144
145    /// To call [Invites] group related APIs using this client.
146    pub fn invites(&self) -> Invites<C> {
147        Invites::new(self)
148    }
149
150    /// To call [Users] group related APIs using this client.
151    pub fn users(&self) -> Users<C> {
152        Users::new(self)
153    }
154
155    /// To call [Projects] group related APIs using this client.
156    pub fn projects(&self) -> Projects<C> {
157        Projects::new(self)
158    }
159
160    pub fn config(&self) -> &C {
161        &self.config
162    }
163
164    /// Make a GET request to {path} and deserialize the response body
165    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
166    where
167        O: DeserializeOwned,
168    {
169        let request_maker = || async {
170            Ok(self
171                .http_client
172                .get(self.config.url(path))
173                .query(&self.config.query())
174                .headers(self.config.headers())
175                .build()?)
176        };
177
178        self.execute(request_maker).await
179    }
180
181    /// Make a GET request to {path} with given Query and deserialize the response body
182    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
183    where
184        O: DeserializeOwned,
185        Q: Serialize + ?Sized,
186    {
187        let request_maker = || async {
188            Ok(self
189                .http_client
190                .get(self.config.url(path))
191                .query(&self.config.query())
192                .query(query)
193                .headers(self.config.headers())
194                .build()?)
195        };
196
197        self.execute(request_maker).await
198    }
199
200    /// Make a DELETE request to {path} and deserialize the response body
201    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
202    where
203        O: DeserializeOwned,
204    {
205        let request_maker = || async {
206            Ok(self
207                .http_client
208                .delete(self.config.url(path))
209                .query(&self.config.query())
210                .headers(self.config.headers())
211                .build()?)
212        };
213
214        self.execute(request_maker).await
215    }
216
217    /// Make a GET request to {path} and return the response body
218    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
219        let request_maker = || async {
220            Ok(self
221                .http_client
222                .get(self.config.url(path))
223                .query(&self.config.query())
224                .headers(self.config.headers())
225                .build()?)
226        };
227
228        self.execute_raw(request_maker).await
229    }
230
231    /// Make a POST request to {path} and return the response body
232    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
233    where
234        I: Serialize,
235    {
236        let request_maker = || async {
237            Ok(self
238                .http_client
239                .post(self.config.url(path))
240                .query(&self.config.query())
241                .headers(self.config.headers())
242                .json(&request)
243                .build()?)
244        };
245
246        self.execute_raw(request_maker).await
247    }
248
249    /// Make a POST request to {path} and deserialize the response body
250    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
251    where
252        I: Serialize,
253        O: DeserializeOwned,
254    {
255        let request_maker = || async {
256            Ok(self
257                .http_client
258                .post(self.config.url(path))
259                .query(&self.config.query())
260                .headers(self.config.headers())
261                .json(&request)
262                .build()?)
263        };
264
265        self.execute(request_maker).await
266    }
267
268    /// POST a form at {path} and return the response body
269    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
270    where
271        Form: AsyncTryFrom<F, Error = OpenAIError>,
272        F: Clone,
273    {
274        let request_maker = || async {
275            Ok(self
276                .http_client
277                .post(self.config.url(path))
278                .query(&self.config.query())
279                .headers(self.config.headers())
280                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
281                .build()?)
282        };
283
284        self.execute_raw(request_maker).await
285    }
286
287    /// POST a form at {path} and deserialize the response body
288    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
289    where
290        O: DeserializeOwned,
291        Form: AsyncTryFrom<F, Error = OpenAIError>,
292        F: Clone,
293    {
294        let request_maker = || async {
295            Ok(self
296                .http_client
297                .post(self.config.url(path))
298                .query(&self.config.query())
299                .headers(self.config.headers())
300                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
301                .build()?)
302        };
303
304        self.execute(request_maker).await
305    }
306
307    /// Execute a HTTP request and retry on rate limit
308    ///
309    /// request_maker serves one purpose: to be able to create request again
310    /// to retry API call after getting rate limited. request_maker is async because
311    /// reqwest::multipart::Form is created by async calls to read files for uploads.
312    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
313    where
314        M: Fn() -> Fut,
315        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
316    {
317        let client = self.http_client.clone();
318
319        backoff::future::retry(self.backoff.clone(), || async {
320            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
321            let response = client
322                .execute(request)
323                .await
324                .map_err(OpenAIError::Reqwest)
325                .map_err(backoff::Error::Permanent)?;
326
327            let status = response.status();
328            let bytes = response
329                .bytes()
330                .await
331                .map_err(OpenAIError::Reqwest)
332                .map_err(backoff::Error::Permanent)?;
333
334            // Deserialize response body from either error object or actual response object
335            if !status.is_success() {
336                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
337                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
338                    .map_err(backoff::Error::Permanent)?;
339
340                if status.as_u16() == 429
341                    // API returns 429 also when:
342                    // "You exceeded your current quota, please check your plan and billing details."
343                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
344                {
345                    // Rate limited retry...
346                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
347                    return Err(backoff::Error::Transient {
348                        err: OpenAIError::ApiError(wrapped_error.error),
349                        retry_after: None,
350                    });
351                } else {
352                    return Err(backoff::Error::Permanent(OpenAIError::ApiError(
353                        wrapped_error.error,
354                    )));
355                }
356            }
357
358            Ok(bytes)
359        })
360        .await
361    }
362
363    /// Execute a HTTP request and retry on rate limit
364    ///
365    /// request_maker serves one purpose: to be able to create request again
366    /// to retry API call after getting rate limited. request_maker is async because
367    /// reqwest::multipart::Form is created by async calls to read files for uploads.
368    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
369    where
370        O: DeserializeOwned,
371        M: Fn() -> Fut,
372        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
373    {
374        let bytes = self.execute_raw(request_maker).await?;
375
376        let response: O = serde_json::from_slice(bytes.as_ref())
377            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
378
379        Ok(response)
380    }
381
382    /// Make HTTP POST request to receive SSE
383    pub(crate) async fn post_stream<I, O>(
384        &self,
385        path: &str,
386        request: I,
387    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
388    where
389        I: Serialize,
390        O: DeserializeOwned + std::marker::Send + 'static,
391    {
392        let event_source = self
393            .http_client
394            .post(self.config.url(path))
395            .query(&self.config.query())
396            .headers(self.config.headers())
397            .json(&request)
398            .eventsource()
399            .unwrap();
400
401        stream(event_source).await
402    }
403
404    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
405        &self,
406        path: &str,
407        request: I,
408        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
409    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
410    where
411        I: Serialize,
412        O: DeserializeOwned + std::marker::Send + 'static,
413    {
414        let event_source = self
415            .http_client
416            .post(self.config.url(path))
417            .query(&self.config.query())
418            .headers(self.config.headers())
419            .json(&request)
420            .eventsource()
421            .unwrap();
422
423        stream_mapped_raw_events(event_source, event_mapper).await
424    }
425
426    /// Make HTTP GET request to receive SSE
427    pub(crate) async fn _get_stream<Q, O>(
428        &self,
429        path: &str,
430        query: &Q,
431    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
432    where
433        Q: Serialize + ?Sized,
434        O: DeserializeOwned + std::marker::Send + 'static,
435    {
436        let event_source = self
437            .http_client
438            .get(self.config.url(path))
439            .query(query)
440            .query(&self.config.query())
441            .headers(self.config.headers())
442            .eventsource()
443            .unwrap();
444
445        stream(event_source).await
446    }
447}
448
449async fn handle_eventsource_error(e: Error) -> Result<(), OpenAIError> {
450    let error_text = e.to_string();
451    if let Error::InvalidStatusCode(_code, response) = e {
452        if let Ok(text) = response.text().await {
453            let api_error = serde_json::from_str::<WrappedError>(&text);
454            if let Ok(e) = api_error {
455                return Err(OpenAIError::ApiError(e.error));
456            }
457        }
458    }
459
460    Err(OpenAIError::StreamError(error_text))
461}
462
463/// Request which responds with SSE.
464/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
465pub(crate) async fn stream<O>(
466    mut event_source: EventSource,
467) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
468where
469    O: DeserializeOwned + std::marker::Send + 'static,
470{
471    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
472
473    tokio::spawn(async move {
474        while let Some(ev) = event_source.next().await {
475            match ev {
476                Err(e) => {
477                    if let Err(e) = handle_eventsource_error(e).await {
478                        if let Err(_e) = tx.send(Err(e)) {
479                            // rx dropped
480                            break;
481                        }
482                    }
483                }
484                Ok(event) => match event {
485                    Event::Message(message) => {
486                        if message.data == "[DONE]" {
487                            break;
488                        }
489
490                        let response = match serde_json::from_str::<O>(&message.data) {
491                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
492                            Ok(output) => Ok(output),
493                        };
494
495                        if let Err(_e) = tx.send(response) {
496                            // rx dropped
497                            break;
498                        }
499                    }
500                    Event::Open => continue,
501                },
502            }
503        }
504
505        event_source.close();
506    });
507
508    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
509}
510
511pub(crate) async fn stream_mapped_raw_events<O>(
512    mut event_source: EventSource,
513    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
514) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
515where
516    O: DeserializeOwned + std::marker::Send + 'static,
517{
518    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
519
520    tokio::spawn(async move {
521        while let Some(ev) = event_source.next().await {
522            match ev {
523                Err(e) => {
524                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
525                        // rx dropped
526                        break;
527                    }
528                }
529                Ok(event) => match event {
530                    Event::Message(message) => {
531                        let mut done = false;
532
533                        if message.data == "[DONE]" {
534                            done = true;
535                        }
536
537                        let response = event_mapper(message);
538
539                        if let Err(_e) = tx.send(response) {
540                            // rx dropped
541                            break;
542                        }
543
544                        if done {
545                            break;
546                        }
547                    }
548                    Event::Open => continue,
549                },
550            }
551        }
552
553        event_source.close();
554    });
555
556    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
557}