async_openai_wasm/
client.rs

1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use futures::stream::Filter;
8use futures::{Stream, stream::StreamExt};
9use pin_project::pin_project;
10use reqwest::multipart::Form;
11use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{
15    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
16    Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
17    config::{Config, OpenAIConfig},
18    error::{OpenAIError, WrappedError, map_deserialization_error},
19    file::Files,
20    image::Images,
21    moderation::Moderations,
22    traits::AsyncTryFrom,
23};
24
25#[derive(Debug, Clone)]
26/// Client is a container for config and http_client
27/// used to make API calls.
28pub struct Client<C: Config> {
29    http_client: reqwest::Client,
30    config: C,
31}
32
33impl Client<OpenAIConfig> {
34    /// Client with default [OpenAIConfig]
35    pub fn new() -> Self {
36        Self {
37            http_client: reqwest::Client::new(),
38            config: OpenAIConfig::default(),
39        }
40    }
41}
42
43impl<C: Config> Client<C> {
44    /// Create client with a custom HTTP client, OpenAI config
45    pub fn build(http_client: reqwest::Client, config: C) -> Self {
46        Self {
47            http_client,
48            config,
49        }
50    }
51
52    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
53    pub fn with_config(config: C) -> Self {
54        Self {
55            http_client: reqwest::Client::new(),
56            config,
57        }
58    }
59
60    /// Provide your own [client] to make HTTP requests with.
61    ///
62    /// [client]: reqwest::Client
63    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64        self.http_client = http_client;
65        self
66    }
67
68    // API groups
69
70    /// To call [Models] group related APIs using this client.
71    pub fn models(&self) -> Models<C> {
72        Models::new(self)
73    }
74
75    /// To call [Completions] group related APIs using this client.
76    pub fn completions(&self) -> Completions<C> {
77        Completions::new(self)
78    }
79
80    /// To call [Chat] group related APIs using this client.
81    pub fn chat(&self) -> Chat<C> {
82        Chat::new(self)
83    }
84
85    /// To call [Images] group related APIs using this client.
86    pub fn images(&self) -> Images<C> {
87        Images::new(self)
88    }
89
90    /// To call [Moderations] group related APIs using this client.
91    pub fn moderations(&self) -> Moderations<C> {
92        Moderations::new(self)
93    }
94
95    /// To call [Files] group related APIs using this client.
96    pub fn files(&self) -> Files<C> {
97        Files::new(self)
98    }
99
100    /// To call [Uploads] group related APIs using this client.
101    pub fn uploads(&self) -> Uploads<C> {
102        Uploads::new(self)
103    }
104
105    /// To call [FineTuning] group related APIs using this client.
106    pub fn fine_tuning(&self) -> FineTuning<C> {
107        FineTuning::new(self)
108    }
109
110    /// To call [Embeddings] group related APIs using this client.
111    pub fn embeddings(&self) -> Embeddings<C> {
112        Embeddings::new(self)
113    }
114
115    /// To call [Audio] group related APIs using this client.
116    pub fn audio(&self) -> Audio<C> {
117        Audio::new(self)
118    }
119
120    /// To call [Assistants] group related APIs using this client.
121    pub fn assistants(&self) -> Assistants<C> {
122        Assistants::new(self)
123    }
124
125    /// To call [Threads] group related APIs using this client.
126    pub fn threads(&self) -> Threads<C> {
127        Threads::new(self)
128    }
129
130    /// To call [VectorStores] group related APIs using this client.
131    pub fn vector_stores(&self) -> VectorStores<C> {
132        VectorStores::new(self)
133    }
134
135    /// To call [Batches] group related APIs using this client.
136    pub fn batches(&self) -> Batches<C> {
137        Batches::new(self)
138    }
139
140    /// To call [AuditLogs] group related APIs using this client.
141    pub fn audit_logs(&self) -> AuditLogs<C> {
142        AuditLogs::new(self)
143    }
144
145    /// To call [Invites] group related APIs using this client.
146    pub fn invites(&self) -> Invites<C> {
147        Invites::new(self)
148    }
149
150    /// To call [Users] group related APIs using this client.
151    pub fn users(&self) -> Users<C> {
152        Users::new(self)
153    }
154
155    /// To call [Projects] group related APIs using this client.
156    pub fn projects(&self) -> Projects<C> {
157        Projects::new(self)
158    }
159
160    /// To call [Responses] group related APIs using this client.
161    pub fn responses(&self) -> Responses<C> {
162        Responses::new(self)
163    }
164
165    pub fn config(&self) -> &C {
166        &self.config
167    }
168
169    /// Make a GET request to {path} and deserialize the response body
170    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
171    where
172        O: DeserializeOwned,
173    {
174        let request_maker = async || {
175            Ok(self
176                .http_client
177                .get(self.config.url(path))
178                .query(&self.config.query())
179                .headers(self.config.headers())
180                .build()?)
181        };
182
183        self.execute(request_maker).await
184    }
185
186    /// Make a GET request to {path} with given Query and deserialize the response body
187    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
188    where
189        O: DeserializeOwned,
190        Q: Serialize + ?Sized,
191    {
192        let request_maker = async || {
193            Ok(self
194                .http_client
195                .get(self.config.url(path))
196                .query(&self.config.query())
197                .query(query)
198                .headers(self.config.headers())
199                .build()?)
200        };
201
202        self.execute(request_maker).await
203    }
204
205    /// Make a DELETE request to {path} and deserialize the response body
206    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
207    where
208        O: DeserializeOwned,
209    {
210        let request_maker = async || {
211            Ok(self
212                .http_client
213                .delete(self.config.url(path))
214                .query(&self.config.query())
215                .headers(self.config.headers())
216                .build()?)
217        };
218
219        self.execute(request_maker).await
220    }
221
222    /// Make a GET request to {path} and return the response body
223    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
224        let request_maker = async || {
225            Ok(self
226                .http_client
227                .get(self.config.url(path))
228                .query(&self.config.query())
229                .headers(self.config.headers())
230                .build()?)
231        };
232
233        self.execute_raw(request_maker).await
234    }
235
236    /// Make a POST request to {path} and return the response body
237    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
238    where
239        I: Serialize,
240    {
241        let request_maker = async || {
242            Ok(self
243                .http_client
244                .post(self.config.url(path))
245                .query(&self.config.query())
246                .headers(self.config.headers())
247                .json(&request)
248                .build()?)
249        };
250
251        self.execute_raw(request_maker).await
252    }
253
254    /// Make a POST request to {path} and deserialize the response body
255    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
256    where
257        I: Serialize,
258        O: DeserializeOwned,
259    {
260        let request_maker = async || {
261            Ok(self
262                .http_client
263                .post(self.config.url(path))
264                .query(&self.config.query())
265                .headers(self.config.headers())
266                .json(&request)
267                .build()?)
268        };
269
270        self.execute(request_maker).await
271    }
272
273    /// POST a form at {path} and return the response body
274    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
275    where
276        Form: AsyncTryFrom<F, Error = OpenAIError>,
277        F: Clone,
278    {
279        let request_maker = async || {
280            Ok(self
281                .http_client
282                .post(self.config.url(path))
283                .query(&self.config.query())
284                .headers(self.config.headers())
285                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
286                .build()?)
287        };
288
289        self.execute_raw(request_maker).await
290    }
291
292    /// POST a form at {path} and deserialize the response body
293    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
294    where
295        O: DeserializeOwned,
296        Form: AsyncTryFrom<F, Error = OpenAIError>,
297        F: Clone,
298    {
299        let request_maker = async || {
300            Ok(self
301                .http_client
302                .post(self.config.url(path))
303                .query(&self.config.query())
304                .headers(self.config.headers())
305                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
306                .build()?)
307        };
308
309        self.execute(request_maker).await
310    }
311
312    /// Execute a HTTP request and retry on rate limit
313    ///
314    /// request_maker serves one purpose: to be able to create request again
315    /// to retry API call after getting rate limited. request_maker is async because
316    /// reqwest::multipart::Form is created by async calls to read files for uploads.
317    async fn execute_raw(
318        &self,
319        request_maker: impl AsyncFn() -> Result<reqwest::Request, OpenAIError>,
320    ) -> Result<Bytes, OpenAIError> {
321        let client = self.http_client.clone();
322
323        let request = request_maker().await?;
324        let response = client
325            .execute(request)
326            .await
327            .map_err(OpenAIError::Reqwest)?;
328
329        let status = response.status();
330        let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
331
332        // Deserialize response body from either error object or actual response object
333        if !status.is_success() {
334            let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
335                .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
336
337            if status.as_u16() == 429
338                // API returns 429 also when:
339                // "You exceeded your current quota, please check your plan and billing details."
340                && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
341            {
342                // Rate limited retry...
343                tracing::warn!("Rate limited: {}", wrapped_error.error.message);
344                return Err(OpenAIError::ApiError(wrapped_error.error));
345            } else {
346                return Err(OpenAIError::ApiError(wrapped_error.error));
347            }
348        }
349
350        Ok(bytes)
351    }
352
353    /// Execute a HTTP request and retry on rate limit
354    ///
355    /// request_maker serves one purpose: to be able to create request again
356    /// to retry API call after getting rate limited. request_maker is async because
357    /// reqwest::multipart::Form is created by async calls to read files for uploads.
358    async fn execute<O>(
359        &self,
360        request_maker: impl AsyncFn() -> Result<reqwest::Request, OpenAIError>,
361    ) -> Result<O, OpenAIError>
362    where
363        O: DeserializeOwned,
364    {
365        let bytes = self.execute_raw(request_maker).await?;
366
367        let response: O = serde_json::from_slice(bytes.as_ref())
368            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
369
370        Ok(response)
371    }
372
373    /// Make HTTP POST request to receive SSE
374    pub(crate) async fn post_stream<I, O>(&self, path: &str, request: I) -> OpenAIEventStream<O>
375    where
376        I: Serialize,
377        O: DeserializeOwned + Send + 'static,
378    {
379        let event_source = self
380            .http_client
381            .post(self.config.url(path))
382            .query(&self.config.query())
383            .headers(self.config.headers())
384            .json(&request)
385            .eventsource()
386            .unwrap();
387
388        OpenAIEventStream::new(event_source)
389    }
390
391    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
392        &self,
393        path: &str,
394        request: I,
395        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
396    ) -> OpenAIEventStream<O>
397    where
398        I: Serialize,
399        O: DeserializeOwned + Send + 'static,
400    {
401        let event_source = self
402            .http_client
403            .post(self.config.url(path))
404            .query(&self.config.query())
405            .headers(self.config.headers())
406            .json(&request)
407            .eventsource()
408            .unwrap();
409
410        OpenAIEventStream::with_event_mapping(event_source, event_mapper)
411    }
412
413    /// Make HTTP GET request to receive SSE
414    pub(crate) async fn _get_stream<Q, O>(&self, path: &str, query: &Q) -> OpenAIEventStream<O>
415    where
416        Q: Serialize + ?Sized,
417        O: DeserializeOwned + Send + 'static,
418    {
419        let event_source = self
420            .http_client
421            .get(self.config.url(path))
422            .query(query)
423            .query(&self.config.query())
424            .headers(self.config.headers())
425            .eventsource()
426            .unwrap();
427
428        OpenAIEventStream::new(event_source)
429    }
430}
431
432/// Request which responds with SSE.
433/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
434
435#[pin_project]
436pub struct OpenAIEventStream<O>
437where
438    O: DeserializeOwned + Send + 'static,
439{
440    #[pin]
441    stream: Filter<
442        EventSource,
443        future::Ready<bool>,
444        fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
445    >,
446    event_mapper:
447        Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
448    done: bool,
449    _phantom_data: PhantomData<O>,
450}
451
452impl<O> OpenAIEventStream<O>
453where
454    O: DeserializeOwned + Send + 'static,
455{
456    pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
457    where
458        M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
459    {
460        Self {
461            stream: event_source.filter(|result|
462                // filter out the first event which is always Event::Open
463                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
464            done: false,
465            event_mapper: Some(Box::new(event_mapper)),
466            _phantom_data: PhantomData,
467        }
468    }
469
470    pub(crate) fn new(event_source: EventSource) -> Self {
471        Self {
472            stream: event_source.filter(|result|
473                // filter out the first event which is always Event::Open
474                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
475            done: false,
476            event_mapper: None,
477            _phantom_data: PhantomData,
478        }
479    }
480}
481
482impl<O> Stream for OpenAIEventStream<O>
483where
484    O: DeserializeOwned + Send + 'static,
485{
486    type Item = Result<O, OpenAIError>;
487
488    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
489        let this = self.project();
490        if *this.done {
491            return Poll::Ready(None);
492        }
493        let stream: Pin<&mut _> = this.stream;
494        match stream.poll_next(cx) {
495            Poll::Ready(response) => {
496                match response {
497                    None => Poll::Ready(None), // end of the stream
498                    Some(result) => match result {
499                        Ok(event) => match event {
500                            Event::Open => unreachable!(), // it has been filtered out
501                            Event::Message(message) => {
502                                if let Some(event_mapper) = this.event_mapper.as_ref() {
503                                    if message.data == "[DONE]" {
504                                        *this.done = true;
505                                    }
506                                    let response = event_mapper(message);
507                                    match response {
508                                        Ok(output) => Poll::Ready(Some(Ok(output))),
509                                        Err(_) => Poll::Ready(None),
510                                    }
511                                } else {
512                                    if message.data == "[DONE]" {
513                                        *this.done = true;
514                                        Poll::Ready(None) // end of the stream, defined by OpenAI
515                                    } else {
516                                        // deserialize the data
517                                        match serde_json::from_str::<O>(&message.data) {
518                                            Err(e) => {
519                                                *this.done = true;
520                                                Poll::Ready(Some(Err(map_deserialization_error(
521                                                    e,
522                                                    &message.data.as_bytes(),
523                                                ))))
524                                            }
525                                            Ok(output) => Poll::Ready(Some(Ok(output))),
526                                        }
527                                    }
528                                }
529                            }
530                        },
531                        Err(e) => {
532                            *this.done = true;
533                            Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
534                        }
535                    },
536                }
537            }
538            Poll::Pending => Poll::Pending,
539        }
540    }
541}
542
543// pub(crate) async fn stream_mapped_raw_events<O>(
544//     mut event_source: EventSource,
545//     event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
546// ) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
547//     where
548//         O: DeserializeOwned + std::marker::Send + 'static,
549// {
550//     let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
551//
552//     tokio::spawn(async move {
553//         while let Some(ev) = event_source.next().await {
554//             match ev {
555//                 Err(e) => {
556//                     if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
557//                         // rx dropped
558//                         break;
559//                     }
560//                 }
561//                 Ok(event) => match event {
562//                     Event::Message(message) => {
563//                         let mut done = false;
564//
565//                         if message.data == "[DONE]" {
566//                             done = true;
567//                         }
568//
569//                         let response = event_mapper(message);
570//
571//                         if let Err(_e) = tx.send(response) {
572//                             // rx dropped
573//                             break;
574//                         }
575//
576//                         if done {
577//                             break;
578//                         }
579//                     }
580//                     Event::Open => continue,
581//                 },
582//             }
583//         }
584//
585//         event_source.close();
586//     });
587//
588//     Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
589// }