async_openai_wasm/
client.rs

1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use futures::stream::Filter;
8use futures::{stream::StreamExt, Stream};
9use pin_project::pin_project;
10use reqwest::multipart::Form;
11use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
12use serde::{de::DeserializeOwned, Serialize};
13
14use crate::util::async_convert::AsyncTryFrom;
15use crate::{
16    config::{Config, OpenAIConfig},
17    error::{map_deserialization_error, OpenAIError, WrappedError},
18    file::Files,
19    image::Images,
20    moderation::Moderations,
21    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
22    Models, Projects, Threads, Users, VectorStores,
23};
24
25#[derive(Debug, Clone)]
26/// Client is a container for config and http_client
27/// used to make API calls.
28pub struct Client<C: Config> {
29    http_client: reqwest::Client,
30    config: C,
31}
32
33impl Client<OpenAIConfig> {
34    /// Client with default [OpenAIConfig]
35    pub fn new() -> Self {
36        Self {
37            http_client: reqwest::Client::new(),
38            config: OpenAIConfig::default(),
39        }
40    }
41}
42
43impl<C: Config> Client<C> {
44    /// Create client with a custom HTTP client, OpenAI config
45    pub fn build(http_client: reqwest::Client, config: C) -> Self {
46        Self {
47            http_client,
48            config,
49        }
50    }
51
52    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
53    pub fn with_config(config: C) -> Self {
54        Self {
55            http_client: reqwest::Client::new(),
56            config,
57        }
58    }
59
60    /// Provide your own [client] to make HTTP requests with.
61    ///
62    /// [client]: reqwest::Client
63    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64        self.http_client = http_client;
65        self
66    }
67
68    // API groups
69
70    /// To call [Models] group related APIs using this client.
71    pub fn models(&self) -> Models<C> {
72        Models::new(self)
73    }
74
75    /// To call [Completions] group related APIs using this client.
76    pub fn completions(&self) -> Completions<C> {
77        Completions::new(self)
78    }
79
80    /// To call [Chat] group related APIs using this client.
81    pub fn chat(&self) -> Chat<C> {
82        Chat::new(self)
83    }
84
85    /// To call [Images] group related APIs using this client.
86    pub fn images(&self) -> Images<C> {
87        Images::new(self)
88    }
89
90    /// To call [Moderations] group related APIs using this client.
91    pub fn moderations(&self) -> Moderations<C> {
92        Moderations::new(self)
93    }
94
95    /// To call [Files] group related APIs using this client.
96    pub fn files(&self) -> Files<C> {
97        Files::new(self)
98    }
99
100    /// To call [FineTuning] group related APIs using this client.
101    pub fn fine_tuning(&self) -> FineTuning<C> {
102        FineTuning::new(self)
103    }
104
105    /// To call [Embeddings] group related APIs using this client.
106    pub fn embeddings(&self) -> Embeddings<C> {
107        Embeddings::new(self)
108    }
109
110    /// To call [Audio] group related APIs using this client.
111    pub fn audio(&self) -> Audio<C> {
112        Audio::new(self)
113    }
114
115    /// To call [Assistants] group related APIs using this client.
116    pub fn assistants(&self) -> Assistants<C> {
117        Assistants::new(self)
118    }
119
120    /// To call [Threads] group related APIs using this client.
121    pub fn threads(&self) -> Threads<C> {
122        Threads::new(self)
123    }
124
125    /// To call [VectorStores] group related APIs using this client.
126    pub fn vector_stores(&self) -> VectorStores<C> {
127        VectorStores::new(self)
128    }
129
130    /// To call [Batches] group related APIs using this client.
131    pub fn batches(&self) -> Batches<C> {
132        Batches::new(self)
133    }
134
135    /// To call [AuditLogs] group related APIs using this client.
136    pub fn audit_logs(&self) -> AuditLogs<C> {
137        AuditLogs::new(self)
138    }
139
140    /// To call [Invites] group related APIs using this client.
141    pub fn invites(&self) -> Invites<C> {
142        Invites::new(self)
143    }
144
145    /// To call [Users] group related APIs using this client.
146    pub fn users(&self) -> Users<C> {
147        Users::new(self)
148    }
149
150    /// To call [Projects] group related APIs using this client.
151    pub fn projects(&self) -> Projects<C> {
152        Projects::new(self)
153    }
154
155    pub fn config(&self) -> &C {
156        &self.config
157    }
158
159    /// Make a GET request to {path} and deserialize the response body
160    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
161    where
162        O: DeserializeOwned,
163    {
164        let request_maker = || async {
165            Ok(self
166                .http_client
167                .get(self.config.url(path))
168                .query(&self.config.query())
169                .headers(self.config.headers())
170                .build()?)
171        };
172
173        self.execute(request_maker).await
174    }
175
176    /// Make a GET request to {path} with given Query and deserialize the response body
177    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
178    where
179        O: DeserializeOwned,
180        Q: Serialize + ?Sized,
181    {
182        let request_maker = || async {
183            Ok(self
184                .http_client
185                .get(self.config.url(path))
186                .query(&self.config.query())
187                .query(query)
188                .headers(self.config.headers())
189                .build()?)
190        };
191
192        self.execute(request_maker).await
193    }
194
195    /// Make a DELETE request to {path} and deserialize the response body
196    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
197    where
198        O: DeserializeOwned,
199    {
200        let request_maker = || async {
201            Ok(self
202                .http_client
203                .delete(self.config.url(path))
204                .query(&self.config.query())
205                .headers(self.config.headers())
206                .build()?)
207        };
208
209        self.execute(request_maker).await
210    }
211
212    /// Make a GET request to {path} and return the response body
213    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
214        let request_maker = || async {
215            Ok(self
216                .http_client
217                .get(self.config.url(path))
218                .query(&self.config.query())
219                .headers(self.config.headers())
220                .build()?)
221        };
222
223        self.execute_raw(request_maker).await
224    }
225
226    /// Make a POST request to {path} and return the response body
227    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
228    where
229        I: Serialize,
230    {
231        let request_maker = || async {
232            Ok(self
233                .http_client
234                .post(self.config.url(path))
235                .query(&self.config.query())
236                .headers(self.config.headers())
237                .json(&request)
238                .build()?)
239        };
240
241        self.execute_raw(request_maker).await
242    }
243
244    /// Make a POST request to {path} and deserialize the response body
245    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
246    where
247        I: Serialize,
248        O: DeserializeOwned,
249    {
250        let request_maker = || async {
251            Ok(self
252                .http_client
253                .post(self.config.url(path))
254                .query(&self.config.query())
255                .headers(self.config.headers())
256                .json(&request)
257                .build()?)
258        };
259
260        self.execute(request_maker).await
261    }
262
263    /// POST a form at {path} and return the response body
264    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
265    where
266        Form: AsyncTryFrom<F, Error = OpenAIError>,
267        F: Clone,
268    {
269        let request_maker = || async {
270            Ok(self
271                .http_client
272                .post(self.config.url(path))
273                .query(&self.config.query())
274                .headers(self.config.headers())
275                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
276                .build()?)
277        };
278
279        self.execute_raw(request_maker).await
280    }
281
282    /// POST a form at {path} and deserialize the response body
283    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
284    where
285        O: DeserializeOwned,
286        Form: AsyncTryFrom<F, Error = OpenAIError>,
287        F: Clone,
288    {
289        let request_maker = || async {
290            Ok(self
291                .http_client
292                .post(self.config.url(path))
293                .query(&self.config.query())
294                .headers(self.config.headers())
295                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
296                .build()?)
297        };
298
299        self.execute(request_maker).await
300    }
301
302    /// Execute a HTTP request and retry on rate limit
303    ///
304    /// request_maker serves one purpose: to be able to create request again
305    /// to retry API call after getting rate limited. request_maker is async because
306    /// reqwest::multipart::Form is created by async calls to read files for uploads.
307    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
308    where
309        M: Fn() -> Fut,
310        Fut: future::Future<Output = Result<reqwest::Request, OpenAIError>>,
311    {
312        let client = self.http_client.clone();
313
314        let request = request_maker().await?;
315        let response = client
316            .execute(request)
317            .await
318            .map_err(OpenAIError::Reqwest)?;
319
320        let status = response.status();
321        let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
322
323        // Deserialize response body from either error object or actual response object
324        if !status.is_success() {
325            let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
326                .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
327
328            if status.as_u16() == 429
329                // API returns 429 also when:
330                // "You exceeded your current quota, please check your plan and billing details."
331                && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
332            {
333                // Rate limited retry...
334                tracing::warn!("Rate limited: {}", wrapped_error.error.message);
335                return Err(OpenAIError::ApiError(wrapped_error.error));
336            } else {
337                return Err(OpenAIError::ApiError(wrapped_error.error));
338            }
339        }
340
341        Ok(bytes)
342    }
343
344    /// Execute a HTTP request and retry on rate limit
345    ///
346    /// request_maker serves one purpose: to be able to create request again
347    /// to retry API call after getting rate limited. request_maker is async because
348    /// reqwest::multipart::Form is created by async calls to read files for uploads.
349    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
350    where
351        O: DeserializeOwned,
352        M: Fn() -> Fut,
353        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
354    {
355        let bytes = self.execute_raw(request_maker).await?;
356
357        let response: O = serde_json::from_slice(bytes.as_ref())
358            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
359
360        Ok(response)
361    }
362
363    /// Make HTTP POST request to receive SSE
364    pub(crate) async fn post_stream<I, O>(&self, path: &str, request: I) -> OpenAIEventStream<O>
365    where
366        I: Serialize,
367        O: DeserializeOwned + Send + 'static,
368    {
369        let event_source = self
370            .http_client
371            .post(self.config.url(path))
372            .query(&self.config.query())
373            .headers(self.config.headers())
374            .json(&request)
375            .eventsource()
376            .unwrap();
377
378        OpenAIEventStream::new(event_source)
379    }
380
381    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
382        &self,
383        path: &str,
384        request: I,
385        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
386    ) -> OpenAIEventMappedStream<O>
387    where
388        I: Serialize,
389        O: DeserializeOwned + Send + 'static,
390    {
391        let event_source = self
392            .http_client
393            .post(self.config.url(path))
394            .query(&self.config.query())
395            .headers(self.config.headers())
396            .json(&request)
397            .eventsource()
398            .unwrap();
399
400        OpenAIEventMappedStream::new(event_source, event_mapper)
401    }
402
403    /// Make HTTP GET request to receive SSE
404    pub(crate) async fn _get_stream<Q, O>(&self, path: &str, query: &Q) -> OpenAIEventStream<O>
405    where
406        Q: Serialize + ?Sized,
407        O: DeserializeOwned + Send + 'static,
408    {
409        let event_source = self
410            .http_client
411            .get(self.config.url(path))
412            .query(query)
413            .query(&self.config.query())
414            .headers(self.config.headers())
415            .eventsource()
416            .unwrap();
417
418        OpenAIEventStream::new(event_source)
419    }
420}
421
422/// Request which responds with SSE.
423/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
424#[pin_project]
425pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
426    #[pin]
427    stream: Filter<
428        EventSource,
429        future::Ready<bool>,
430        fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
431    >,
432    done: bool,
433    _phantom_data: PhantomData<O>,
434}
435
436impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
437    pub(crate) fn new(event_source: EventSource) -> Self {
438        Self {
439            stream: event_source.filter(|result|
440                // filter out the first event which is always Event::Open
441                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
442            done: false,
443            _phantom_data: PhantomData,
444        }
445    }
446}
447
448impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
449    type Item = Result<O, OpenAIError>;
450
451    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
452        let this = self.project();
453        if *this.done {
454            return Poll::Ready(None);
455        }
456        let stream: Pin<&mut _> = this.stream;
457        match stream.poll_next(cx) {
458            Poll::Ready(response) => {
459                match response {
460                    None => Poll::Ready(None), // end of the stream
461                    Some(result) => match result {
462                        Ok(event) => match event {
463                            Event::Open => unreachable!(), // it has been filtered out
464                            Event::Message(message) => {
465                                if message.data == "[DONE]" {
466                                    *this.done = true;
467                                    Poll::Ready(None) // end of the stream, defined by OpenAI
468                                } else {
469                                    // deserialize the data
470                                    match serde_json::from_str::<O>(&message.data) {
471                                        Err(e) => {
472                                            *this.done = true;
473                                            Poll::Ready(Some(Err(map_deserialization_error(
474                                                e,
475                                                &message.data.as_bytes(),
476                                            ))))
477                                        }
478                                        Ok(output) => Poll::Ready(Some(Ok(output))),
479                                    }
480                                }
481                            }
482                        },
483                        Err(e) => {
484                            *this.done = true;
485                            Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
486                        }
487                    },
488                }
489            }
490            Poll::Pending => Poll::Pending,
491        }
492    }
493}
494
495#[pin_project]
496pub struct OpenAIEventMappedStream<O>
497where
498    O: Send + 'static,
499{
500    #[pin]
501    stream: Filter<
502        EventSource,
503        future::Ready<bool>,
504        fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
505    >,
506    event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
507    done: bool,
508    _phantom_data: PhantomData<O>,
509}
510
511impl<O> OpenAIEventMappedStream<O>
512where
513    O: Send + 'static,
514{
515    pub(crate) fn new<M>(event_source: EventSource, event_mapper: M) -> Self
516    where
517        M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
518    {
519        Self {
520            stream: event_source.filter(|result|
521                // filter out the first event which is always Event::Open
522                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
523            done: false,
524            event_mapper: Box::new(event_mapper),
525            _phantom_data: PhantomData,
526        }
527    }
528}
529
530impl<O> Stream for OpenAIEventMappedStream<O>
531where
532    O: Send + 'static,
533{
534    type Item = Result<O, OpenAIError>;
535
536    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
537        let this = self.project();
538        if *this.done {
539            return Poll::Ready(None);
540        }
541        let stream: Pin<&mut _> = this.stream;
542        match stream.poll_next(cx) {
543            Poll::Ready(response) => {
544                match response {
545                    None => Poll::Ready(None), // end of the stream
546                    Some(result) => match result {
547                        Ok(event) => match event {
548                            Event::Open => unreachable!(), // it has been filtered out
549                            Event::Message(message) => {
550                                if message.data == "[DONE]" {
551                                    *this.done = true;
552                                }
553                                let response = (this.event_mapper)(message);
554                                match response {
555                                    Ok(output) => Poll::Ready(Some(Ok(output))),
556                                    Err(_) => Poll::Ready(None),
557                                }
558                            }
559                        },
560                        Err(e) => {
561                            *this.done = true;
562                            Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
563                        }
564                    },
565                }
566            }
567            Poll::Pending => Poll::Pending,
568        }
569    }
570}
571
572// pub(crate) async fn stream_mapped_raw_events<O>(
573//     mut event_source: EventSource,
574//     event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
575// ) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
576//     where
577//         O: DeserializeOwned + std::marker::Send + 'static,
578// {
579//     let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
580//
581//     tokio::spawn(async move {
582//         while let Some(ev) = event_source.next().await {
583//             match ev {
584//                 Err(e) => {
585//                     if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
586//                         // rx dropped
587//                         break;
588//                     }
589//                 }
590//                 Ok(event) => match event {
591//                     Event::Message(message) => {
592//                         let mut done = false;
593//
594//                         if message.data == "[DONE]" {
595//                             done = true;
596//                         }
597//
598//                         let response = event_mapper(message);
599//
600//                         if let Err(_e) = tx.send(response) {
601//                             // rx dropped
602//                             break;
603//                         }
604//
605//                         if done {
606//                             break;
607//                         }
608//                     }
609//                     Event::Open => continue,
610//                 },
611//             }
612//         }
613//
614//         event_source.close();
615//     });
616//
617//     Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
618// }