async_openai_wasm/
client.rs

1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use future::Future;
8use futures::stream::Filter;
9use futures::{Stream, stream::StreamExt};
10use pin_project::pin_project;
11use reqwest::{Response, multipart::Form};
12use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
13use serde::{Serialize, de::DeserializeOwned};
14
15use crate::error::{ApiError, StreamError};
16use crate::{
17    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
18    Models, Projects, Responses, Threads, Uploads, Users, VectorStores, Videos,
19    config::{Config, OpenAIConfig},
20    error::{OpenAIError, WrappedError, map_deserialization_error},
21    file::Files,
22    image::Images,
23    moderation::Moderations,
24    traits::AsyncTryFrom,
25};
26
27#[derive(Debug, Clone)]
28/// Client is a container for config and http_client
29/// used to make API calls.
30pub struct Client<C: Config> {
31    http_client: reqwest::Client,
32    config: C,
33}
34
35impl Client<OpenAIConfig> {
36    /// Client with default [OpenAIConfig]
37    pub fn new() -> Self {
38        Self {
39            http_client: reqwest::Client::new(),
40            config: OpenAIConfig::default(),
41        }
42    }
43}
44
45impl<C: Config> Client<C> {
46    /// Create client with a custom HTTP client, OpenAI config
47    pub fn build(http_client: reqwest::Client, config: C) -> Self {
48        Self {
49            http_client,
50            config,
51        }
52    }
53
54    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
55    pub fn with_config(config: C) -> Self {
56        Self {
57            http_client: reqwest::Client::new(),
58            config,
59        }
60    }
61
62    /// Provide your own [client] to make HTTP requests with.
63    ///
64    /// [client]: reqwest::Client
65    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
66        self.http_client = http_client;
67        self
68    }
69
70    // API groups
71
72    /// To call [Models] group related APIs using this client.
73    pub fn models(&self) -> Models<'_, C> {
74        Models::new(self)
75    }
76
77    /// To call [Completions] group related APIs using this client.
78    pub fn completions(&self) -> Completions<'_, C> {
79        Completions::new(self)
80    }
81
82    /// To call [Chat] group related APIs using this client.
83    pub fn chat(&self) -> Chat<'_, C> {
84        Chat::new(self)
85    }
86
87    /// To call [Images] group related APIs using this client.
88    pub fn images(&self) -> Images<'_, C> {
89        Images::new(self)
90    }
91
92    /// To call [Moderations] group related APIs using this client.
93    pub fn moderations(&self) -> Moderations<'_, C> {
94        Moderations::new(self)
95    }
96
97    /// To call [Files] group related APIs using this client.
98    pub fn files(&self) -> Files<'_, C> {
99        Files::new(self)
100    }
101
102    /// To call [Uploads] group related APIs using this client.
103    pub fn uploads(&self) -> Uploads<'_, C> {
104        Uploads::new(self)
105    }
106
107    /// To call [FineTuning] group related APIs using this client.
108    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
109        FineTuning::new(self)
110    }
111
112    /// To call [Embeddings] group related APIs using this client.
113    pub fn embeddings(&self) -> Embeddings<'_, C> {
114        Embeddings::new(self)
115    }
116
117    /// To call [Audio] group related APIs using this client.
118    pub fn audio(&self) -> Audio<'_, C> {
119        Audio::new(self)
120    }
121
122    /// To call [Videos] group related APIs using this client.
123    pub fn videos(&self) -> Videos<'_, C> {
124        Videos::new(self)
125    }
126
127    /// To call [Assistants] group related APIs using this client.
128    pub fn assistants(&self) -> Assistants<'_, C> {
129        Assistants::new(self)
130    }
131
132    /// To call [Threads] group related APIs using this client.
133    pub fn threads(&self) -> Threads<'_, C> {
134        Threads::new(self)
135    }
136
137    /// To call [VectorStores] group related APIs using this client.
138    pub fn vector_stores(&self) -> VectorStores<'_, C> {
139        VectorStores::new(self)
140    }
141
142    /// To call [Batches] group related APIs using this client.
143    pub fn batches(&self) -> Batches<'_, C> {
144        Batches::new(self)
145    }
146
147    /// To call [AuditLogs] group related APIs using this client.
148    pub fn audit_logs(&self) -> AuditLogs<'_, C> {
149        AuditLogs::new(self)
150    }
151
152    /// To call [Invites] group related APIs using this client.
153    pub fn invites(&self) -> Invites<'_, C> {
154        Invites::new(self)
155    }
156
157    /// To call [Users] group related APIs using this client.
158    pub fn users(&self) -> Users<'_, C> {
159        Users::new(self)
160    }
161
162    /// To call [Projects] group related APIs using this client.
163    pub fn projects(&self) -> Projects<'_, C> {
164        Projects::new(self)
165    }
166
167    /// To call [Responses] group related APIs using this client.
168    pub fn responses(&self) -> Responses<'_, C> {
169        Responses::new(self)
170    }
171
172    pub fn config(&self) -> &C {
173        &self.config
174    }
175
176    /// Make a GET request to {path} and deserialize the response body
177    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
178    where
179        O: DeserializeOwned,
180    {
181        self.execute(async {
182            Ok(self
183                .http_client
184                .get(self.config.url(path))
185                .query(&self.config.query())
186                .headers(self.config.headers())
187                .build()?)
188        })
189        .await
190    }
191
192    /// Make a GET request to {path} with given Query and deserialize the response body
193    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
194    where
195        O: DeserializeOwned,
196        Q: Serialize + ?Sized,
197    {
198        self.execute(async {
199            Ok(self
200                .http_client
201                .get(self.config.url(path))
202                .query(&self.config.query())
203                .query(query)
204                .headers(self.config.headers())
205                .build()?)
206        })
207        .await
208    }
209
210    /// Make a DELETE request to {path} and deserialize the response body
211    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
212    where
213        O: DeserializeOwned,
214    {
215        self.execute(async {
216            Ok(self
217                .http_client
218                .delete(self.config.url(path))
219                .query(&self.config.query())
220                .headers(self.config.headers())
221                .build()?)
222        })
223        .await
224    }
225
226    /// Make a GET request to {path} and return the response body
227    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
228        self.execute_raw(async {
229            Ok(self
230                .http_client
231                .get(self.config.url(path))
232                .query(&self.config.query())
233                .headers(self.config.headers())
234                .build()?)
235        })
236        .await
237    }
238
239    pub(crate) async fn get_raw_with_query<Q>(
240        &self,
241        path: &str,
242        query: &Q,
243    ) -> Result<Bytes, OpenAIError>
244    where
245        Q: Serialize + ?Sized,
246    {
247        self.execute_raw(async {
248            Ok(self
249                .http_client
250                .get(self.config.url(path))
251                .query(&self.config.query())
252                .query(query)
253                .headers(self.config.headers())
254                .build()?)
255        })
256        .await
257    }
258
259    /// Make a POST request to {path} and return the response body
260    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
261    where
262        I: Serialize,
263    {
264        self.execute_raw(async {
265            Ok(self
266                .http_client
267                .post(self.config.url(path))
268                .query(&self.config.query())
269                .headers(self.config.headers())
270                .json(&request)
271                .build()?)
272        })
273        .await
274    }
275
276    /// Make a POST request to {path} and deserialize the response body
277    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
278    where
279        I: Serialize,
280        O: DeserializeOwned,
281    {
282        self.execute(async {
283            Ok(self
284                .http_client
285                .post(self.config.url(path))
286                .query(&self.config.query())
287                .headers(self.config.headers())
288                .json(&request)
289                .build()?)
290        })
291        .await
292    }
293
294    /// POST a form at {path} and return the response body
295    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
296    where
297        Form: AsyncTryFrom<F, Error = OpenAIError>,
298    {
299        self.execute_raw(async {
300            let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
301            Ok(self
302                .http_client
303                .post(self.config.url(path))
304                .query(&self.config.query())
305                .headers(self.config.headers())
306                .multipart(form)
307                .build()?)
308        })
309        .await
310    }
311
312    /// POST a form at {path} and deserialize the response body
313    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
314    where
315        O: DeserializeOwned,
316        Form: AsyncTryFrom<F, Error = OpenAIError>,
317    {
318        self.execute(async {
319            let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
320            Ok(self
321                .http_client
322                .post(self.config.url(path))
323                .query(&self.config.query())
324                .headers(self.config.headers())
325                .multipart(form)
326                .build()?)
327        })
328        .await
329    }
330
331    /// Execute a HTTP request
332    async fn execute_raw(
333        &self,
334        request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
335    ) -> Result<Bytes, OpenAIError> {
336        let client = self.http_client.clone();
337        let request = request_future.await?;
338        let response = client
339            .execute(request)
340            .await
341            .map_err(OpenAIError::Reqwest)?;
342
343        let status = response.status();
344        match read_response(response).await {
345            Ok(bytes) => Ok(bytes),
346            Err(e) => match e {
347                OpenAIError::ApiError(api_error) => {
348                    if status.as_u16() == 429
349                        && api_error.r#type != Some("insufficient_quota".to_string())
350                    {
351                        // Rate limited retry...
352                        tracing::warn!("Rate limited: {}", api_error.message);
353                    }
354                    Err(OpenAIError::ApiError(api_error))
355                }
356                _ => Err(e),
357            },
358        }
359    }
360
361    /// Execute a HTTP request
362    async fn execute<O>(
363        &self,
364        request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
365    ) -> Result<O, OpenAIError>
366    where
367        O: DeserializeOwned,
368    {
369        let bytes = self.execute_raw(request_future).await?;
370
371        let response: O = serde_json::from_slice(bytes.as_ref())
372            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
373
374        Ok(response)
375    }
376
377    /// Make HTTP POST request to receive SSE
378    pub(crate) async fn post_stream<I, O>(&self, path: &str, request: I) -> OpenAIEventStream<O>
379    where
380        I: Serialize,
381        O: DeserializeOwned + Send + 'static,
382    {
383        let event_source = self
384            .http_client
385            .post(self.config.url(path))
386            .query(&self.config.query())
387            .headers(self.config.headers())
388            .json(&request)
389            .eventsource()
390            .unwrap();
391
392        OpenAIEventStream::new(event_source)
393    }
394
395    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
396        &self,
397        path: &str,
398        request: I,
399        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
400    ) -> OpenAIEventStream<O>
401    where
402        I: Serialize,
403        O: DeserializeOwned + Send + 'static,
404    {
405        let event_source = self
406            .http_client
407            .post(self.config.url(path))
408            .query(&self.config.query())
409            .headers(self.config.headers())
410            .json(&request)
411            .eventsource()
412            .unwrap();
413
414        OpenAIEventStream::with_event_mapping(event_source, event_mapper)
415    }
416
417    /// Make HTTP GET request to receive SSE
418    pub(crate) async fn _get_stream<Q, O>(&self, path: &str, query: &Q) -> OpenAIEventStream<O>
419    where
420        Q: Serialize + ?Sized,
421        O: DeserializeOwned + Send + 'static,
422    {
423        let event_source = self
424            .http_client
425            .get(self.config.url(path))
426            .query(query)
427            .query(&self.config.query())
428            .headers(self.config.headers())
429            .eventsource()
430            .unwrap();
431
432        OpenAIEventStream::new(event_source)
433    }
434}
435
436/// Request which responds with SSE.
437/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
438
439#[pin_project]
440pub struct OpenAIEventStream<O>
441where
442    O: DeserializeOwned + Send + 'static,
443{
444    #[pin]
445    stream: Filter<
446        EventSource,
447        future::Ready<bool>,
448        fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
449    >,
450    event_mapper:
451        Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
452    done: bool,
453    _phantom_data: PhantomData<O>,
454}
455
456impl<O> OpenAIEventStream<O>
457where
458    O: DeserializeOwned + Send + 'static,
459{
460    pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
461    where
462        M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
463    {
464        Self {
465            stream: event_source.filter(|result|
466                // filter out the first event which is always Event::Open
467                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
468            done: false,
469            event_mapper: Some(Box::new(event_mapper)),
470            _phantom_data: PhantomData,
471        }
472    }
473
474    pub(crate) fn new(event_source: EventSource) -> Self {
475        Self {
476            stream: event_source.filter(|result|
477                // filter out the first event which is always Event::Open
478                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
479            done: false,
480            event_mapper: None,
481            _phantom_data: PhantomData,
482        }
483    }
484}
485
486impl<O> Stream for OpenAIEventStream<O>
487where
488    O: DeserializeOwned + Send + 'static,
489{
490    type Item = Result<O, OpenAIError>;
491
492    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
493        let this = self.project();
494        if *this.done {
495            return Poll::Ready(None);
496        }
497        let stream: Pin<&mut _> = this.stream;
498        match stream.poll_next(cx) {
499            Poll::Ready(response) => {
500                match response {
501                    None => Poll::Ready(None), // end of the stream
502                    Some(result) => match result {
503                        Ok(event) => match event {
504                            Event::Open => unreachable!(), // it has been filtered out
505                            Event::Message(message) => {
506                                if let Some(event_mapper) = this.event_mapper.as_ref() {
507                                    if message.data == "[DONE]" {
508                                        *this.done = true;
509                                    }
510                                    let response = event_mapper(message);
511                                    match response {
512                                        Ok(output) => Poll::Ready(Some(Ok(output))),
513                                        Err(_) => Poll::Ready(None),
514                                    }
515                                } else {
516                                    if message.data == "[DONE]" {
517                                        *this.done = true;
518                                        Poll::Ready(None) // end of the stream, defined by OpenAI
519                                    } else {
520                                        // deserialize the data
521                                        match serde_json::from_str::<O>(&message.data) {
522                                            Err(e) => {
523                                                *this.done = true;
524                                                Poll::Ready(Some(Err(map_deserialization_error(
525                                                    e,
526                                                    message.data.as_bytes(),
527                                                ))))
528                                            }
529                                            Ok(output) => Poll::Ready(Some(Ok(output))),
530                                        }
531                                    }
532                                }
533                            }
534                        },
535                        Err(e) => {
536                            *this.done = true;
537                            Poll::Ready(Some(Err(OpenAIError::StreamError(
538                                StreamError::ReqwestEventSource(e),
539                            ))))
540                        }
541                    },
542                }
543            }
544            Poll::Pending => Poll::Pending,
545        }
546    }
547}
548
549async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
550    let status = response.status();
551    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
552
553    if status.is_server_error() {
554        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
555        let message: String = String::from_utf8_lossy(&bytes).into_owned();
556        tracing::warn!("Server error: {status} - {message}");
557        return Err(OpenAIError::ApiError(ApiError {
558            message,
559            r#type: None,
560            param: None,
561            code: None,
562        }));
563    }
564
565    // Deserialize response body from either error object or actual response object
566    if !status.is_success() {
567        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
568            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
569
570        return Err(OpenAIError::ApiError(wrapped_error.error));
571    }
572
573    Ok(bytes)
574}