async_openai_wasm/
client.rs

1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use eventsource_stream::EventStreamError;
8use future::Future;
9use futures::stream::Filter;
10use futures::{Stream, stream::StreamExt};
11use pin_project::pin_project;
12use reqwest::header::HeaderMap;
13use reqwest::{Response, multipart::Form};
14use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
15use serde::{Serialize, de::DeserializeOwned};
16
17use crate::error::{ApiError, StreamError};
18use crate::{
19    RequestOptions,
20    config::{Config, OpenAIConfig},
21    error::{OpenAIError, WrappedError, map_deserialization_error},
22    traits::AsyncTryFrom,
23};
24
25#[cfg(feature = "assistant")]
26use crate::Assistants;
27#[cfg(feature = "audio")]
28use crate::Audio;
29#[cfg(feature = "batch")]
30use crate::Batches;
31#[cfg(feature = "chat-completion")]
32use crate::Chat;
33#[cfg(feature = "completions")]
34use crate::Completions;
35#[cfg(feature = "container")]
36use crate::Containers;
37#[cfg(feature = "responses")]
38use crate::Conversations;
39#[cfg(feature = "embedding")]
40use crate::Embeddings;
41#[cfg(feature = "evals")]
42use crate::Evals;
43#[cfg(feature = "finetuning")]
44use crate::FineTuning;
45#[cfg(feature = "model")]
46use crate::Models;
47#[cfg(feature = "realtime")]
48use crate::Realtime;
49#[cfg(feature = "responses")]
50use crate::Responses;
51#[cfg(feature = "assistant")]
52use crate::Threads;
53#[cfg(feature = "upload")]
54use crate::Uploads;
55#[cfg(feature = "vectorstore")]
56use crate::VectorStores;
57#[cfg(feature = "video")]
58use crate::Videos;
59#[cfg(feature = "administration")]
60use crate::admin::Admin;
61#[cfg(feature = "chatkit")]
62use crate::chatkit::Chatkit;
63#[cfg(feature = "file")]
64use crate::file::Files;
65#[cfg(feature = "image")]
66use crate::image::Images;
67#[cfg(feature = "moderation")]
68use crate::moderation::Moderations;
69
70#[derive(Debug, Clone)]
71/// Client is a container for config and http_client
72/// used to make API calls.
73pub struct Client<C: Config> {
74    http_client: reqwest::Client,
75    config: C,
76}
77
78impl Client<OpenAIConfig> {
79    /// Client with default [OpenAIConfig]
80    pub fn new() -> Self {
81        Self {
82            http_client: reqwest::Client::new(),
83            config: OpenAIConfig::default(),
84        }
85    }
86}
87
88impl<C: Config> Client<C> {
89    /// Create client with a custom HTTP client, OpenAI config
90    pub fn build(http_client: reqwest::Client, config: C) -> Self {
91        Self {
92            http_client,
93            config,
94        }
95    }
96
97    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
98    pub fn with_config(config: C) -> Self {
99        Self {
100            http_client: reqwest::Client::new(),
101            config,
102        }
103    }
104
105    /// Provide your own [client] to make HTTP requests with.
106    ///
107    /// [client]: reqwest::Client
108    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
109        self.http_client = http_client;
110        self
111    }
112
113    // API groups
114
115    /// To call [Models] group related APIs using this client.
116    #[cfg(feature = "model")]
117    pub fn models(&self) -> Models<'_, C> {
118        Models::new(self)
119    }
120
121    /// To call [Completions] group related APIs using this client.
122    #[cfg(feature = "completions")]
123    pub fn completions(&self) -> Completions<'_, C> {
124        Completions::new(self)
125    }
126
127    /// To call [Chat] group related APIs using this client.
128    #[cfg(feature = "chat-completion")]
129    pub fn chat(&self) -> Chat<'_, C> {
130        Chat::new(self)
131    }
132
133    /// To call [Images] group related APIs using this client.
134    #[cfg(feature = "image")]
135    pub fn images(&self) -> Images<'_, C> {
136        Images::new(self)
137    }
138
139    /// To call [Moderations] group related APIs using this client.
140    #[cfg(feature = "moderation")]
141    pub fn moderations(&self) -> Moderations<'_, C> {
142        Moderations::new(self)
143    }
144
145    /// To call [Files] group related APIs using this client.
146    #[cfg(feature = "file")]
147    pub fn files(&self) -> Files<'_, C> {
148        Files::new(self)
149    }
150
151    /// To call [Uploads] group related APIs using this client.
152    #[cfg(feature = "upload")]
153    pub fn uploads(&self) -> Uploads<'_, C> {
154        Uploads::new(self)
155    }
156
157    /// To call [FineTuning] group related APIs using this client.
158    #[cfg(feature = "finetuning")]
159    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
160        FineTuning::new(self)
161    }
162
163    /// To call [Embeddings] group related APIs using this client.
164    #[cfg(feature = "embedding")]
165    pub fn embeddings(&self) -> Embeddings<'_, C> {
166        Embeddings::new(self)
167    }
168
169    /// To call [Audio] group related APIs using this client.
170    #[cfg(feature = "audio")]
171    pub fn audio(&self) -> Audio<'_, C> {
172        Audio::new(self)
173    }
174
175    /// To call [Videos] group related APIs using this client.
176    #[cfg(feature = "video")]
177    pub fn videos(&self) -> Videos<'_, C> {
178        Videos::new(self)
179    }
180
181    /// To call [Assistants] group related APIs using this client.
182    #[cfg(feature = "assistant")]
183    pub fn assistants(&self) -> Assistants<'_, C> {
184        Assistants::new(self)
185    }
186
187    /// To call [Threads] group related APIs using this client.
188    #[cfg(feature = "assistant")]
189    pub fn threads(&self) -> Threads<'_, C> {
190        Threads::new(self)
191    }
192
193    /// To call [VectorStores] group related APIs using this client.
194    #[cfg(feature = "vectorstore")]
195    pub fn vector_stores(&self) -> VectorStores<'_, C> {
196        VectorStores::new(self)
197    }
198
199    /// To call [Batches] group related APIs using this client.
200    #[cfg(feature = "batch")]
201    pub fn batches(&self) -> Batches<'_, C> {
202        Batches::new(self)
203    }
204
205    /// To call [Admin] group related APIs using this client.
206    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
207    #[cfg(feature = "administration")]
208    pub fn admin(&self) -> Admin<'_, C> {
209        Admin::new(self)
210    }
211
212    /// To call [Responses] group related APIs using this client.
213    #[cfg(feature = "responses")]
214    pub fn responses(&self) -> Responses<'_, C> {
215        Responses::new(self)
216    }
217
218    /// To call [Conversations] group related APIs using this client.
219    #[cfg(feature = "responses")]
220    pub fn conversations(&self) -> Conversations<'_, C> {
221        Conversations::new(self)
222    }
223
224    /// To call [Containers] group related APIs using this client.
225    #[cfg(feature = "container")]
226    pub fn containers(&self) -> Containers<'_, C> {
227        Containers::new(self)
228    }
229
230    /// To call [Evals] group related APIs using this client.
231    #[cfg(feature = "evals")]
232    pub fn evals(&self) -> Evals<'_, C> {
233        Evals::new(self)
234    }
235
236    #[cfg(feature = "chatkit")]
237    pub fn chatkit(&self) -> Chatkit<'_, C> {
238        Chatkit::new(self)
239    }
240
241    /// To call [Realtime] group related APIs using this client.
242    #[cfg(feature = "realtime")]
243    pub fn realtime(&self) -> Realtime<'_, C> {
244        Realtime::new(self)
245    }
246
247    pub fn config(&self) -> &C {
248        &self.config
249    }
250
251    /// Helper function to build a request builder with common configuration
252    fn build_request_builder(
253        &self,
254        method: reqwest::Method,
255        path: &str,
256        request_options: &RequestOptions,
257    ) -> reqwest::RequestBuilder {
258        let mut request_builder = if let Some(path) = request_options.path() {
259            self.http_client
260                .request(method, self.config.url(path.as_str()))
261        } else {
262            self.http_client.request(method, self.config.url(path))
263        };
264
265        request_builder = request_builder
266            .query(&self.config.query())
267            .headers(self.config.headers());
268
269        if let Some(headers) = request_options.headers() {
270            request_builder = request_builder.headers(headers.clone());
271        }
272
273        if !request_options.query().is_empty() {
274            request_builder = request_builder.query(request_options.query());
275        }
276
277        request_builder
278    }
279
280    /// Make a GET request to {path} and deserialize the response body
281    #[allow(unused)]
282    pub(crate) async fn get<O>(
283        &self,
284        path: &str,
285        request_options: &RequestOptions,
286    ) -> Result<O, OpenAIError>
287    where
288        O: DeserializeOwned,
289    {
290        self.execute(async {
291            Ok(self
292                .build_request_builder(reqwest::Method::GET, path, request_options)
293                .build()?)
294        })
295        .await
296    }
297
298    /// Make a DELETE request to {path} and deserialize the response body
299    #[allow(unused)]
300    pub(crate) async fn delete<O>(
301        &self,
302        path: &str,
303        request_options: &RequestOptions,
304    ) -> Result<O, OpenAIError>
305    where
306        O: DeserializeOwned,
307    {
308        self.execute(async {
309            Ok(self
310                .build_request_builder(reqwest::Method::DELETE, path, request_options)
311                .build()?)
312        })
313        .await
314    }
315
316    /// Make a GET request to {path} and return the response body
317    #[allow(unused)]
318    pub(crate) async fn get_raw(
319        &self,
320        path: &str,
321        request_options: &RequestOptions,
322    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
323        self.execute_raw(async {
324            Ok(self
325                .build_request_builder(reqwest::Method::GET, path, request_options)
326                .build()?)
327        })
328        .await
329    }
330
331    /// Make a POST request to {path} and return the response body
332    #[allow(unused)]
333    pub(crate) async fn post_raw<I>(
334        &self,
335        path: &str,
336        request: I,
337        request_options: &RequestOptions,
338    ) -> Result<(Bytes, HeaderMap), OpenAIError>
339    where
340        I: Serialize,
341    {
342        self.execute_raw(async {
343            Ok(self
344                .build_request_builder(reqwest::Method::POST, path, request_options)
345                .json(&request)
346                .build()?)
347        })
348        .await
349    }
350
351    /// Make a POST request to {path} and deserialize the response body
352    #[allow(unused)]
353    pub(crate) async fn post<I, O>(
354        &self,
355        path: &str,
356        request: I,
357        request_options: &RequestOptions,
358    ) -> Result<O, OpenAIError>
359    where
360        I: Serialize,
361        O: DeserializeOwned,
362    {
363        self.execute(async {
364            Ok(self
365                .build_request_builder(reqwest::Method::POST, path, request_options)
366                .json(&request)
367                .build()?)
368        })
369        .await
370    }
371
372    /// POST a form at {path} and return the response body
373    #[allow(unused)]
374    pub(crate) async fn post_form_raw<F>(
375        &self,
376        path: &str,
377        form: F,
378        request_options: &RequestOptions,
379    ) -> Result<(Bytes, HeaderMap), OpenAIError>
380    where
381        Form: AsyncTryFrom<F, Error = OpenAIError>,
382    {
383        self.execute_raw(async {
384            let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
385            Ok(self
386                .build_request_builder(reqwest::Method::POST, path, request_options)
387                .multipart(form)
388                .build()?)
389        })
390        .await
391    }
392
393    /// POST a form at {path} and deserialize the response body
394    #[allow(unused)]
395    pub(crate) async fn post_form<O, F>(
396        &self,
397        path: &str,
398        form: F,
399        request_options: &RequestOptions,
400    ) -> Result<O, OpenAIError>
401    where
402        O: DeserializeOwned,
403        Form: AsyncTryFrom<F, Error = OpenAIError>,
404    {
405        self.execute(async {
406            let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
407            Ok(self
408                .build_request_builder(reqwest::Method::POST, path, request_options)
409                .multipart(form)
410                .build()?)
411        })
412        .await
413    }
414
415    #[allow(unused)]
416    pub(crate) async fn post_form_stream<O, F>(
417        &self,
418        path: &str,
419        form: F,
420        request_options: &RequestOptions,
421    ) -> Result<OpenAIFormEventStream<O>, OpenAIError>
422    where
423        F: Clone,
424        Form: AsyncTryFrom<F, Error = OpenAIError>,
425        O: DeserializeOwned + Send + 'static,
426    {
427        // Build and execute request manually since multipart::Form is not Clone
428        // and .eventsource() requires cloneability
429        let request_builder = self
430            .build_request_builder(reqwest::Method::POST, path, request_options)
431            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
432
433        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
434
435        // Check for error status
436        if !response.status().is_success() {
437            return Err(read_response(response).await.unwrap_err());
438        }
439
440        // Convert response body to EventSource stream
441        let stream = response
442            .bytes_stream()
443            .map(|result| result.map_err(std::io::Error::other));
444        let event_stream = eventsource_stream::EventStream::new(stream);
445
446        Ok(OpenAIFormEventStream::new(event_stream))
447    }
448
449    /// Execute a HTTP request
450    async fn execute_raw(
451        &self,
452        request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
453    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
454        let client = self.http_client.clone();
455        let request = request_future.await?;
456        let response = client
457            .execute(request)
458            .await
459            .map_err(OpenAIError::Reqwest)?;
460
461        let status = response.status();
462        match read_response(response).await {
463            Ok((bytes, headers)) => Ok((bytes, headers)),
464            Err(e) => match e {
465                OpenAIError::ApiError(api_error) => {
466                    if status.as_u16() == 429
467                        && api_error.r#type != Some("insufficient_quota".to_string())
468                    {
469                        // Rate limited retry...
470                        tracing::warn!("Rate limited: {}", api_error.message);
471                    }
472                    Err(OpenAIError::ApiError(api_error))
473                }
474                _ => Err(e),
475            },
476        }
477    }
478
479    /// Execute a HTTP request
480    async fn execute<O>(
481        &self,
482        request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
483    ) -> Result<O, OpenAIError>
484    where
485        O: DeserializeOwned,
486    {
487        let (bytes, _headers) = self.execute_raw(request_future).await?;
488
489        let response: O = serde_json::from_slice(bytes.as_ref())
490            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
491
492        Ok(response)
493    }
494
495    /// Make HTTP POST request to receive SSE
496    #[allow(unused)]
497    pub(crate) async fn post_stream<I, O>(
498        &self,
499        path: &str,
500        request: I,
501        request_options: &RequestOptions,
502    ) -> OpenAIEventStream<O>
503    where
504        I: Serialize,
505        O: DeserializeOwned + Send + 'static,
506    {
507        let request_builder = self
508            .build_request_builder(reqwest::Method::POST, path, request_options)
509            .json(&request);
510
511        let event_source = request_builder.eventsource().unwrap();
512
513        OpenAIEventStream::new(event_source)
514    }
515
516    #[allow(unused)]
517    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
518        &self,
519        path: &str,
520        request: I,
521        request_options: &RequestOptions,
522        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
523    ) -> OpenAIEventStream<O>
524    where
525        I: Serialize,
526        O: DeserializeOwned + Send + 'static,
527    {
528        let request_builder = self
529            .build_request_builder(reqwest::Method::POST, path, request_options)
530            .json(&request);
531
532        let event_source = request_builder.eventsource().unwrap();
533
534        OpenAIEventStream::with_event_mapping(event_source, event_mapper)
535    }
536
537    /// Make HTTP GET request to receive SSE
538    #[allow(unused)]
539    pub(crate) async fn get_stream<O>(
540        &self,
541        path: &str,
542        request_options: &RequestOptions,
543    ) -> OpenAIEventStream<O>
544    where
545        O: DeserializeOwned + Send + 'static,
546    {
547        let request_builder =
548            self.build_request_builder(reqwest::Method::GET, path, request_options);
549
550        let event_source = request_builder.eventsource().unwrap();
551
552        OpenAIEventStream::new(event_source)
553    }
554}
555
556// TODO: check the implementation of OpenAIEventStream to reflect new changes from #485
557
558/// Request which responds with SSE.
559/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
560
561#[pin_project]
562pub struct OpenAIEventStream<O>
563where
564    O: DeserializeOwned + Send + 'static,
565{
566    #[pin]
567    stream: Filter<
568        EventSource,
569        future::Ready<bool>,
570        fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
571    >,
572    event_mapper:
573        Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
574    done: bool,
575    _phantom_data: PhantomData<O>,
576}
577
578impl<O> OpenAIEventStream<O>
579where
580    O: DeserializeOwned + Send + 'static,
581{
582    pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
583    where
584        M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
585    {
586        Self {
587            stream: event_source.filter(|result|
588                // filter out the first event which is always Event::Open
589                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
590            done: false,
591            event_mapper: Some(Box::new(event_mapper)),
592            _phantom_data: PhantomData,
593        }
594    }
595
596    pub(crate) fn new(event_source: EventSource) -> Self {
597        Self {
598            stream: event_source.filter(|result|
599                // filter out the first event which is always Event::Open
600                future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
601            done: false,
602            event_mapper: None,
603            _phantom_data: PhantomData,
604        }
605    }
606}
607
608impl<O> Stream for OpenAIEventStream<O>
609where
610    O: DeserializeOwned + Send + 'static,
611{
612    type Item = Result<O, OpenAIError>;
613
614    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
615        let this = self.project();
616        if *this.done {
617            return Poll::Ready(None);
618        }
619        let stream: Pin<&mut _> = this.stream;
620        match stream.poll_next(cx) {
621            Poll::Ready(response) => {
622                match response {
623                    None => Poll::Ready(None), // end of the stream
624                    Some(result) => match result {
625                        Ok(event) => match event {
626                            Event::Open => unreachable!(), // it has been filtered out
627                            Event::Message(message) => {
628                                if let Some(event_mapper) = this.event_mapper.as_ref() {
629                                    if message.data == "[DONE]" {
630                                        *this.done = true;
631                                    }
632                                    let response = event_mapper(message);
633                                    match response {
634                                        Ok(output) => Poll::Ready(Some(Ok(output))),
635                                        Err(_) => Poll::Ready(None),
636                                    }
637                                } else {
638                                    if message.data == "[DONE]" {
639                                        *this.done = true;
640                                        Poll::Ready(None) // end of the stream, defined by OpenAI
641                                    } else {
642                                        // deserialize the data
643                                        match serde_json::from_str::<O>(&message.data) {
644                                            Err(e) => {
645                                                *this.done = true;
646                                                Poll::Ready(Some(Err(map_deserialization_error(
647                                                    e,
648                                                    message.data.as_bytes(),
649                                                ))))
650                                            }
651                                            Ok(output) => Poll::Ready(Some(Ok(output))),
652                                        }
653                                    }
654                                }
655                            }
656                        },
657                        Err(e) => {
658                            *this.done = true;
659                            Poll::Ready(Some(Err(OpenAIError::StreamError(Box::new(
660                                StreamError::ReqwestEventSource(e),
661                            )))))
662                        }
663                    },
664                }
665            }
666            Poll::Pending => Poll::Pending,
667        }
668    }
669}
670
671#[pin_project]
672pub struct OpenAIFormEventStream<O>
673where
674    O: DeserializeOwned + Send + 'static,
675{
676    #[pin]
677    event_stream: Box<
678        dyn Stream<Item = Result<eventsource_stream::Event, EventStreamError<std::io::Error>>>
679            + Unpin
680            + 'static,
681    >,
682    done: bool,
683    _phantom_data: PhantomData<O>,
684}
685
686impl<O> OpenAIFormEventStream<O>
687where
688    O: DeserializeOwned + Send + 'static,
689{
690    pub fn new(
691        stream: impl Stream<Item = Result<eventsource_stream::Event, EventStreamError<std::io::Error>>>
692        + Unpin
693        + 'static,
694    ) -> Self {
695        Self {
696            event_stream: Box::new(stream),
697            done: false,
698            _phantom_data: PhantomData,
699        }
700    }
701}
702
703impl<O> Stream for OpenAIFormEventStream<O>
704where
705    O: DeserializeOwned + Send + 'static,
706{
707    type Item = Result<O, OpenAIError>;
708
709    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
710        let this = self.project();
711        if *this.done {
712            return Poll::Ready(None);
713        }
714        let stream: Pin<&mut _> = this.event_stream;
715        match stream.poll_next(cx) {
716            Poll::Pending => Poll::Pending,
717            Poll::Ready(response) => match response {
718                None => Poll::Ready(None),
719                Some(result) => match result {
720                    Err(e) => {
721                        // *this.done = true; // TODO: check this
722                        Poll::Ready(Some(Err(OpenAIError::StreamError(Box::new(
723                            StreamError::EventStream(e.to_string()),
724                        )))))
725                    }
726                    Ok(event) => {
727                        if event.data == "[DONE]" {
728                            *this.done = true;
729                            Poll::Ready(None)
730                        } else {
731                            match serde_json::from_str::<O>(&event.data) {
732                                Err(e) => Poll::Ready(Some(Err(map_deserialization_error(
733                                    e,
734                                    event.data.as_bytes(),
735                                )))),
736                                Ok(output) => Poll::Ready(Some(Ok(output))),
737                            }
738                        }
739                    }
740                },
741            },
742        }
743    }
744}
745
746async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
747    let status = response.status();
748    let headers = response.headers().clone();
749    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
750
751    if status.is_server_error() {
752        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
753        let message: String = String::from_utf8_lossy(&bytes).into_owned();
754        tracing::warn!("Server error: {status} - {message}");
755        return Err(OpenAIError::ApiError(ApiError {
756            message,
757            r#type: None,
758            param: None,
759            code: None,
760        }));
761    }
762
763    // Deserialize response body from either error object or actual response object
764    if !status.is_success() {
765        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
766            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
767
768        return Err(OpenAIError::ApiError(wrapped_error.error));
769    }
770
771    Ok((bytes, headers))
772}