async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    config::{Config, OpenAIConfig},
11    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12    traits::AsyncTryFrom,
13    RequestOptions,
14};
15
16#[cfg(feature = "administration")]
17use crate::admin::Admin;
18#[cfg(feature = "chatkit")]
19use crate::chatkit::Chatkit;
20#[cfg(feature = "file")]
21use crate::file::Files;
22#[cfg(feature = "image")]
23use crate::image::Images;
24#[cfg(feature = "moderation")]
25use crate::moderation::Moderations;
26#[cfg(feature = "assistant")]
27use crate::Assistants;
28#[cfg(feature = "audio")]
29use crate::Audio;
30#[cfg(feature = "batch")]
31use crate::Batches;
32#[cfg(feature = "chat-completion")]
33use crate::Chat;
34#[cfg(feature = "completions")]
35use crate::Completions;
36#[cfg(feature = "container")]
37use crate::Containers;
38#[cfg(feature = "responses")]
39use crate::Conversations;
40#[cfg(feature = "embedding")]
41use crate::Embeddings;
42#[cfg(feature = "evals")]
43use crate::Evals;
44#[cfg(feature = "finetuning")]
45use crate::FineTuning;
46#[cfg(feature = "model")]
47use crate::Models;
48#[cfg(feature = "realtime")]
49use crate::Realtime;
50#[cfg(feature = "responses")]
51use crate::Responses;
52#[cfg(feature = "assistant")]
53use crate::Threads;
54#[cfg(feature = "upload")]
55use crate::Uploads;
56#[cfg(feature = "vectorstore")]
57use crate::VectorStores;
58#[cfg(feature = "video")]
59use crate::Videos;
60
61#[derive(Debug, Clone, Default)]
62/// Client is a container for config, backoff and http_client
63/// used to make API calls.
64pub struct Client<C: Config> {
65    http_client: reqwest::Client,
66    config: C,
67    backoff: backoff::ExponentialBackoff,
68}
69
70impl Client<OpenAIConfig> {
71    /// Client with default [OpenAIConfig]
72    pub fn new() -> Self {
73        Self::default()
74    }
75}
76
77impl<C: Config> Client<C> {
78    /// Create client with a custom HTTP client, OpenAI config, and backoff.
79    pub fn build(
80        http_client: reqwest::Client,
81        config: C,
82        backoff: backoff::ExponentialBackoff,
83    ) -> Self {
84        Self {
85            http_client,
86            config,
87            backoff,
88        }
89    }
90
91    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
92    pub fn with_config(config: C) -> Self {
93        Self {
94            http_client: reqwest::Client::new(),
95            config,
96            backoff: Default::default(),
97        }
98    }
99
100    /// Provide your own [client] to make HTTP requests with.
101    ///
102    /// [client]: reqwest::Client
103    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
104        self.http_client = http_client;
105        self
106    }
107
108    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
109    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
110        self.backoff = backoff;
111        self
112    }
113
114    // API groups
115
116    /// To call [Models] group related APIs using this client.
117    #[cfg(feature = "model")]
118    pub fn models(&self) -> Models<'_, C> {
119        Models::new(self)
120    }
121
122    /// To call [Completions] group related APIs using this client.
123    #[cfg(feature = "completions")]
124    pub fn completions(&self) -> Completions<'_, C> {
125        Completions::new(self)
126    }
127
128    /// To call [Chat] group related APIs using this client.
129    #[cfg(feature = "chat-completion")]
130    pub fn chat(&self) -> Chat<'_, C> {
131        Chat::new(self)
132    }
133
134    /// To call [Images] group related APIs using this client.
135    #[cfg(feature = "image")]
136    pub fn images(&self) -> Images<'_, C> {
137        Images::new(self)
138    }
139
140    /// To call [Moderations] group related APIs using this client.
141    #[cfg(feature = "moderation")]
142    pub fn moderations(&self) -> Moderations<'_, C> {
143        Moderations::new(self)
144    }
145
146    /// To call [Files] group related APIs using this client.
147    #[cfg(feature = "file")]
148    pub fn files(&self) -> Files<'_, C> {
149        Files::new(self)
150    }
151
152    /// To call [Uploads] group related APIs using this client.
153    #[cfg(feature = "upload")]
154    pub fn uploads(&self) -> Uploads<'_, C> {
155        Uploads::new(self)
156    }
157
158    /// To call [FineTuning] group related APIs using this client.
159    #[cfg(feature = "finetuning")]
160    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
161        FineTuning::new(self)
162    }
163
164    /// To call [Embeddings] group related APIs using this client.
165    #[cfg(feature = "embedding")]
166    pub fn embeddings(&self) -> Embeddings<'_, C> {
167        Embeddings::new(self)
168    }
169
170    /// To call [Audio] group related APIs using this client.
171    #[cfg(feature = "audio")]
172    pub fn audio(&self) -> Audio<'_, C> {
173        Audio::new(self)
174    }
175
176    /// To call [Videos] group related APIs using this client.
177    #[cfg(feature = "video")]
178    pub fn videos(&self) -> Videos<'_, C> {
179        Videos::new(self)
180    }
181
182    /// To call [Assistants] group related APIs using this client.
183    #[cfg(feature = "assistant")]
184    pub fn assistants(&self) -> Assistants<'_, C> {
185        Assistants::new(self)
186    }
187
188    /// To call [Threads] group related APIs using this client.
189    #[cfg(feature = "assistant")]
190    pub fn threads(&self) -> Threads<'_, C> {
191        Threads::new(self)
192    }
193
194    /// To call [VectorStores] group related APIs using this client.
195    #[cfg(feature = "vectorstore")]
196    pub fn vector_stores(&self) -> VectorStores<'_, C> {
197        VectorStores::new(self)
198    }
199
200    /// To call [Batches] group related APIs using this client.
201    #[cfg(feature = "batch")]
202    pub fn batches(&self) -> Batches<'_, C> {
203        Batches::new(self)
204    }
205
206    /// To call [Admin] group related APIs using this client.
207    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
208    #[cfg(feature = "administration")]
209    pub fn admin(&self) -> Admin<'_, C> {
210        Admin::new(self)
211    }
212
213    /// To call [Responses] group related APIs using this client.
214    #[cfg(feature = "responses")]
215    pub fn responses(&self) -> Responses<'_, C> {
216        Responses::new(self)
217    }
218
219    /// To call [Conversations] group related APIs using this client.
220    #[cfg(feature = "responses")]
221    pub fn conversations(&self) -> Conversations<'_, C> {
222        Conversations::new(self)
223    }
224
225    /// To call [Containers] group related APIs using this client.
226    #[cfg(feature = "container")]
227    pub fn containers(&self) -> Containers<'_, C> {
228        Containers::new(self)
229    }
230
231    /// To call [Evals] group related APIs using this client.
232    #[cfg(feature = "evals")]
233    pub fn evals(&self) -> Evals<'_, C> {
234        Evals::new(self)
235    }
236
237    #[cfg(feature = "chatkit")]
238    pub fn chatkit(&self) -> Chatkit<'_, C> {
239        Chatkit::new(self)
240    }
241
242    /// To call [Realtime] group related APIs using this client.
243    #[cfg(feature = "realtime")]
244    pub fn realtime(&self) -> Realtime<'_, C> {
245        Realtime::new(self)
246    }
247
248    pub fn config(&self) -> &C {
249        &self.config
250    }
251
252    /// Helper function to build a request builder with common configuration
253    fn build_request_builder(
254        &self,
255        method: reqwest::Method,
256        path: &str,
257        request_options: &RequestOptions,
258    ) -> reqwest::RequestBuilder {
259        let mut request_builder = if let Some(path) = request_options.path() {
260            self.http_client
261                .request(method, self.config.url(path.as_str()))
262        } else {
263            self.http_client.request(method, self.config.url(path))
264        };
265
266        request_builder = request_builder
267            .query(&self.config.query())
268            .headers(self.config.headers());
269
270        if let Some(headers) = request_options.headers() {
271            request_builder = request_builder.headers(headers.clone());
272        }
273
274        if !request_options.query().is_empty() {
275            request_builder = request_builder.query(request_options.query());
276        }
277
278        request_builder
279    }
280
281    /// Make a GET request to {path} and deserialize the response body
282    #[allow(unused)]
283    pub(crate) async fn get<O>(
284        &self,
285        path: &str,
286        request_options: &RequestOptions,
287    ) -> Result<O, OpenAIError>
288    where
289        O: DeserializeOwned,
290    {
291        let request_maker = || async {
292            Ok(self
293                .build_request_builder(reqwest::Method::GET, path, request_options)
294                .build()?)
295        };
296
297        self.execute(request_maker).await
298    }
299
300    /// Make a DELETE request to {path} and deserialize the response body
301    #[allow(unused)]
302    pub(crate) async fn delete<O>(
303        &self,
304        path: &str,
305        request_options: &RequestOptions,
306    ) -> Result<O, OpenAIError>
307    where
308        O: DeserializeOwned,
309    {
310        let request_maker = || async {
311            Ok(self
312                .build_request_builder(reqwest::Method::DELETE, path, request_options)
313                .build()?)
314        };
315
316        self.execute(request_maker).await
317    }
318
319    /// Make a GET request to {path} and return the response body
320    #[allow(unused)]
321    pub(crate) async fn get_raw(
322        &self,
323        path: &str,
324        request_options: &RequestOptions,
325    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
326        let request_maker = || async {
327            Ok(self
328                .build_request_builder(reqwest::Method::GET, path, request_options)
329                .build()?)
330        };
331
332        self.execute_raw(request_maker).await
333    }
334
335    /// Make a POST request to {path} and return the response body
336    #[allow(unused)]
337    pub(crate) async fn post_raw<I>(
338        &self,
339        path: &str,
340        request: I,
341        request_options: &RequestOptions,
342    ) -> Result<(Bytes, HeaderMap), OpenAIError>
343    where
344        I: Serialize,
345    {
346        let request_maker = || async {
347            Ok(self
348                .build_request_builder(reqwest::Method::POST, path, request_options)
349                .json(&request)
350                .build()?)
351        };
352
353        self.execute_raw(request_maker).await
354    }
355
356    /// Make a POST request to {path} and deserialize the response body
357    #[allow(unused)]
358    pub(crate) async fn post<I, O>(
359        &self,
360        path: &str,
361        request: I,
362        request_options: &RequestOptions,
363    ) -> Result<O, OpenAIError>
364    where
365        I: Serialize,
366        O: DeserializeOwned,
367    {
368        let request_maker = || async {
369            Ok(self
370                .build_request_builder(reqwest::Method::POST, path, request_options)
371                .json(&request)
372                .build()?)
373        };
374
375        self.execute(request_maker).await
376    }
377
378    /// POST a form at {path} and return the response body
379    #[allow(unused)]
380    pub(crate) async fn post_form_raw<F>(
381        &self,
382        path: &str,
383        form: F,
384        request_options: &RequestOptions,
385    ) -> Result<(Bytes, HeaderMap), OpenAIError>
386    where
387        Form: AsyncTryFrom<F, Error = OpenAIError>,
388        F: Clone,
389    {
390        let request_maker = || async {
391            Ok(self
392                .build_request_builder(reqwest::Method::POST, path, request_options)
393                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
394                .build()?)
395        };
396
397        self.execute_raw(request_maker).await
398    }
399
400    /// POST a form at {path} and deserialize the response body
401    #[allow(unused)]
402    pub(crate) async fn post_form<O, F>(
403        &self,
404        path: &str,
405        form: F,
406        request_options: &RequestOptions,
407    ) -> Result<O, OpenAIError>
408    where
409        O: DeserializeOwned,
410        Form: AsyncTryFrom<F, Error = OpenAIError>,
411        F: Clone,
412    {
413        let request_maker = || async {
414            Ok(self
415                .build_request_builder(reqwest::Method::POST, path, request_options)
416                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
417                .build()?)
418        };
419
420        self.execute(request_maker).await
421    }
422
423    #[allow(unused)]
424    pub(crate) async fn post_form_stream<O, F>(
425        &self,
426        path: &str,
427        form: F,
428        request_options: &RequestOptions,
429    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
430    where
431        F: Clone,
432        Form: AsyncTryFrom<F, Error = OpenAIError>,
433        O: DeserializeOwned + std::marker::Send + 'static,
434    {
435        // Build and execute request manually since multipart::Form is not Clone
436        // and .eventsource() requires cloneability
437        let request_builder = self
438            .build_request_builder(reqwest::Method::POST, path, request_options)
439            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
440
441        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
442
443        // Check for error status
444        if !response.status().is_success() {
445            return Err(read_response(response).await.unwrap_err());
446        }
447
448        // Convert response body to EventSource stream
449        let stream = response
450            .bytes_stream()
451            .map(|result| result.map_err(std::io::Error::other));
452        let event_stream = eventsource_stream::EventStream::new(stream);
453
454        // Convert EventSource stream to our expected format
455        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
456
457        tokio::spawn(async move {
458            use futures::StreamExt;
459            let mut event_stream = std::pin::pin!(event_stream);
460
461            while let Some(event_result) = event_stream.next().await {
462                match event_result {
463                    Err(e) => {
464                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
465                            StreamError::EventStream(e.to_string()),
466                        )))) {
467                            break;
468                        }
469                    }
470                    Ok(event) => {
471                        // eventsource_stream::Event is a struct with data field
472                        if event.data == "[DONE]" {
473                            break;
474                        }
475
476                        let response = match serde_json::from_str::<O>(&event.data) {
477                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
478                            Ok(output) => Ok(output),
479                        };
480
481                        if let Err(_e) = tx.send(response) {
482                            break;
483                        }
484                    }
485                }
486            }
487        });
488
489        Ok(Box::pin(
490            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
491        ))
492    }
493
494    /// Execute a HTTP request and retry on rate limit
495    ///
496    /// request_maker serves one purpose: to be able to create request again
497    /// to retry API call after getting rate limited. request_maker is async because
498    /// reqwest::multipart::Form is created by async calls to read files for uploads.
499    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
500    where
501        M: Fn() -> Fut,
502        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
503    {
504        let client = self.http_client.clone();
505
506        backoff::future::retry(self.backoff.clone(), || async {
507            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
508            let response = client
509                .execute(request)
510                .await
511                .map_err(OpenAIError::Reqwest)
512                .map_err(backoff::Error::Permanent)?;
513
514            let status = response.status();
515
516            match read_response(response).await {
517                Ok((bytes, headers)) => Ok((bytes, headers)),
518                Err(e) => {
519                    match e {
520                        OpenAIError::ApiError(api_error) => {
521                            if status.is_server_error() {
522                                Err(backoff::Error::Transient {
523                                    err: OpenAIError::ApiError(api_error),
524                                    retry_after: None,
525                                })
526                            } else if status.as_u16() == 429
527                                && api_error.r#type != Some("insufficient_quota".to_string())
528                            {
529                                // Rate limited retry...
530                                tracing::warn!("Rate limited: {}", api_error.message);
531                                Err(backoff::Error::Transient {
532                                    err: OpenAIError::ApiError(api_error),
533                                    retry_after: None,
534                                })
535                            } else {
536                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
537                            }
538                        }
539                        _ => Err(backoff::Error::Permanent(e)),
540                    }
541                }
542            }
543        })
544        .await
545    }
546
547    /// Execute a HTTP request and retry on rate limit
548    ///
549    /// request_maker serves one purpose: to be able to create request again
550    /// to retry API call after getting rate limited. request_maker is async because
551    /// reqwest::multipart::Form is created by async calls to read files for uploads.
552    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
553    where
554        O: DeserializeOwned,
555        M: Fn() -> Fut,
556        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
557    {
558        let (bytes, _headers) = self.execute_raw(request_maker).await?;
559
560        let response: O = serde_json::from_slice(bytes.as_ref())
561            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
562
563        Ok(response)
564    }
565
566    /// Make HTTP POST request to receive SSE
567    #[allow(unused)]
568    pub(crate) async fn post_stream<I, O>(
569        &self,
570        path: &str,
571        request: I,
572        request_options: &RequestOptions,
573    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
574    where
575        I: Serialize,
576        O: DeserializeOwned + std::marker::Send + 'static,
577    {
578        let request_builder = self
579            .build_request_builder(reqwest::Method::POST, path, request_options)
580            .json(&request);
581
582        let event_source = request_builder.eventsource().unwrap();
583
584        stream(event_source).await
585    }
586
587    #[allow(unused)]
588    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
589        &self,
590        path: &str,
591        request: I,
592        request_options: &RequestOptions,
593        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
594    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
595    where
596        I: Serialize,
597        O: DeserializeOwned + std::marker::Send + 'static,
598    {
599        let request_builder = self
600            .build_request_builder(reqwest::Method::POST, path, request_options)
601            .json(&request);
602
603        let event_source = request_builder.eventsource().unwrap();
604
605        stream_mapped_raw_events(event_source, event_mapper).await
606    }
607
608    /// Make HTTP GET request to receive SSE
609    #[allow(unused)]
610    pub(crate) async fn get_stream<O>(
611        &self,
612        path: &str,
613        request_options: &RequestOptions,
614    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
615    where
616        O: DeserializeOwned + std::marker::Send + 'static,
617    {
618        let request_builder =
619            self.build_request_builder(reqwest::Method::GET, path, request_options);
620
621        let event_source = request_builder.eventsource().unwrap();
622
623        stream(event_source).await
624    }
625}
626
627async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
628    let status = response.status();
629    let headers = response.headers().clone();
630    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
631
632    if status.is_server_error() {
633        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
634        let message: String = String::from_utf8_lossy(&bytes).into_owned();
635        tracing::warn!("Server error: {status} - {message}");
636        return Err(OpenAIError::ApiError(ApiError {
637            message,
638            r#type: None,
639            param: None,
640            code: None,
641        }));
642    }
643
644    // Deserialize response body from either error object or actual response object
645    if !status.is_success() {
646        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
647            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
648
649        return Err(OpenAIError::ApiError(wrapped_error.error));
650    }
651
652    Ok((bytes, headers))
653}
654
655async fn map_stream_error(value: EventSourceError) -> OpenAIError {
656    match value {
657        EventSourceError::InvalidStatusCode(status_code, response) => {
658            read_response(response).await.expect_err(&format!(
659                "Unreachable because read_response returns err when status_code {status_code} is invalid"
660            ))
661        }
662        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
663    }
664}
665
666/// Request which responds with SSE.
667/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
668pub(crate) async fn stream<O>(
669    mut event_source: EventSource,
670) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
671where
672    O: DeserializeOwned + std::marker::Send + 'static,
673{
674    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
675
676    tokio::spawn(async move {
677        while let Some(ev) = event_source.next().await {
678            match ev {
679                Err(e) => {
680                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
681                    // https://github.com/64bit/async-openai/issues/456
682                    match &e {
683                        EventSourceError::StreamEnded => {
684                            break;
685                        }
686                        _ => {
687                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
688                                // rx dropped
689                                break;
690                            }
691                        }
692                    }
693                }
694                Ok(event) => match event {
695                    Event::Message(message) => {
696                        if message.data == "[DONE]" {
697                            break;
698                        }
699
700                        let response = match serde_json::from_str::<O>(&message.data) {
701                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
702                            Ok(output) => Ok(output),
703                        };
704
705                        if let Err(_e) = tx.send(response) {
706                            // rx dropped
707                            break;
708                        }
709                    }
710                    Event::Open => continue,
711                },
712            }
713        }
714
715        event_source.close();
716    });
717
718    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
719}
720
721pub(crate) async fn stream_mapped_raw_events<O>(
722    mut event_source: EventSource,
723    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
724) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
725where
726    O: DeserializeOwned + std::marker::Send + 'static,
727{
728    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
729
730    tokio::spawn(async move {
731        while let Some(ev) = event_source.next().await {
732            match ev {
733                Err(e) => {
734                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
735                    // https://github.com/64bit/async-openai/issues/456
736                    match &e {
737                        EventSourceError::StreamEnded => {
738                            break;
739                        }
740                        _ => {
741                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
742                                // rx dropped
743                                break;
744                            }
745                        }
746                    }
747                }
748                Ok(event) => match event {
749                    Event::Message(message) => {
750                        let mut done = false;
751
752                        if message.data == "[DONE]" {
753                            done = true;
754                        }
755
756                        let response = event_mapper(message);
757
758                        if let Err(_e) = tx.send(response) {
759                            // rx dropped
760                            break;
761                        }
762
763                        if done {
764                            break;
765                        }
766                    }
767                    Event::Open => continue,
768                },
769            }
770        }
771
772        event_source.close();
773    });
774
775    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
776}