async_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10    admin::Admin,
11    chatkit::Chatkit,
12    config::{Config, OpenAIConfig},
13    error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
14    file::Files,
15    image::Images,
16    moderation::Moderations,
17    traits::AsyncTryFrom,
18    Assistants, Audio, Batches, Chat, Completions, Containers, Conversations, Embeddings, Evals,
19    FineTuning, Models, RequestOptions, Responses, Threads, Uploads, VectorStores, Videos,
20};
21
22#[cfg(feature = "realtime")]
23use crate::Realtime;
24
25#[derive(Debug, Clone, Default)]
26/// Client is a container for config, backoff and http_client
27/// used to make API calls.
28pub struct Client<C: Config> {
29    http_client: reqwest::Client,
30    config: C,
31    backoff: backoff::ExponentialBackoff,
32}
33
34impl Client<OpenAIConfig> {
35    /// Client with default [OpenAIConfig]
36    pub fn new() -> Self {
37        Self::default()
38    }
39}
40
41impl<C: Config> Client<C> {
42    /// Create client with a custom HTTP client, OpenAI config, and backoff.
43    pub fn build(
44        http_client: reqwest::Client,
45        config: C,
46        backoff: backoff::ExponentialBackoff,
47    ) -> Self {
48        Self {
49            http_client,
50            config,
51            backoff,
52        }
53    }
54
55    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
56    pub fn with_config(config: C) -> Self {
57        Self {
58            http_client: reqwest::Client::new(),
59            config,
60            backoff: Default::default(),
61        }
62    }
63
64    /// Provide your own [client] to make HTTP requests with.
65    ///
66    /// [client]: reqwest::Client
67    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
68        self.http_client = http_client;
69        self
70    }
71
72    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
73    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
74        self.backoff = backoff;
75        self
76    }
77
78    // API groups
79
80    /// To call [Models] group related APIs using this client.
81    pub fn models(&self) -> Models<'_, C> {
82        Models::new(self)
83    }
84
85    /// To call [Completions] group related APIs using this client.
86    pub fn completions(&self) -> Completions<'_, C> {
87        Completions::new(self)
88    }
89
90    /// To call [Chat] group related APIs using this client.
91    pub fn chat(&self) -> Chat<'_, C> {
92        Chat::new(self)
93    }
94
95    /// To call [Images] group related APIs using this client.
96    pub fn images(&self) -> Images<'_, C> {
97        Images::new(self)
98    }
99
100    /// To call [Moderations] group related APIs using this client.
101    pub fn moderations(&self) -> Moderations<'_, C> {
102        Moderations::new(self)
103    }
104
105    /// To call [Files] group related APIs using this client.
106    pub fn files(&self) -> Files<'_, C> {
107        Files::new(self)
108    }
109
110    /// To call [Uploads] group related APIs using this client.
111    pub fn uploads(&self) -> Uploads<'_, C> {
112        Uploads::new(self)
113    }
114
115    /// To call [FineTuning] group related APIs using this client.
116    pub fn fine_tuning(&self) -> FineTuning<'_, C> {
117        FineTuning::new(self)
118    }
119
120    /// To call [Embeddings] group related APIs using this client.
121    pub fn embeddings(&self) -> Embeddings<'_, C> {
122        Embeddings::new(self)
123    }
124
125    /// To call [Audio] group related APIs using this client.
126    pub fn audio(&self) -> Audio<'_, C> {
127        Audio::new(self)
128    }
129
130    /// To call [Videos] group related APIs using this client.
131    pub fn videos(&self) -> Videos<'_, C> {
132        Videos::new(self)
133    }
134
135    /// To call [Assistants] group related APIs using this client.
136    pub fn assistants(&self) -> Assistants<'_, C> {
137        Assistants::new(self)
138    }
139
140    /// To call [Threads] group related APIs using this client.
141    pub fn threads(&self) -> Threads<'_, C> {
142        Threads::new(self)
143    }
144
145    /// To call [VectorStores] group related APIs using this client.
146    pub fn vector_stores(&self) -> VectorStores<'_, C> {
147        VectorStores::new(self)
148    }
149
150    /// To call [Batches] group related APIs using this client.
151    pub fn batches(&self) -> Batches<'_, C> {
152        Batches::new(self)
153    }
154
155    /// To call [Admin] group related APIs using this client.
156    /// This groups together admin API keys, invites, users, projects, audit logs, and certificates.
157    pub fn admin(&self) -> Admin<'_, C> {
158        Admin::new(self)
159    }
160
161    /// To call [Responses] group related APIs using this client.
162    pub fn responses(&self) -> Responses<'_, C> {
163        Responses::new(self)
164    }
165
166    /// To call [Conversations] group related APIs using this client.
167    pub fn conversations(&self) -> Conversations<'_, C> {
168        Conversations::new(self)
169    }
170
171    /// To call [Containers] group related APIs using this client.
172    pub fn containers(&self) -> Containers<'_, C> {
173        Containers::new(self)
174    }
175
176    /// To call [Evals] group related APIs using this client.
177    pub fn evals(&self) -> Evals<'_, C> {
178        Evals::new(self)
179    }
180
181    pub fn chatkit(&self) -> Chatkit<'_, C> {
182        Chatkit::new(self)
183    }
184
185    #[cfg(feature = "realtime")]
186    /// To call [Realtime] group related APIs using this client.
187    pub fn realtime(&self) -> Realtime<'_, C> {
188        Realtime::new(self)
189    }
190
191    pub fn config(&self) -> &C {
192        &self.config
193    }
194
195    /// Helper function to build a request builder with common configuration
196    fn build_request_builder(
197        &self,
198        method: reqwest::Method,
199        path: &str,
200        request_options: &RequestOptions,
201    ) -> reqwest::RequestBuilder {
202        let mut request_builder = if let Some(path) = request_options.path() {
203            self.http_client
204                .request(method, self.config.url(path.as_str()))
205        } else {
206            self.http_client.request(method, self.config.url(path))
207        };
208
209        request_builder = request_builder
210            .query(&self.config.query())
211            .headers(self.config.headers());
212
213        if let Some(headers) = request_options.headers() {
214            request_builder = request_builder.headers(headers.clone());
215        }
216
217        if !request_options.query().is_empty() {
218            request_builder = request_builder.query(request_options.query());
219        }
220
221        request_builder
222    }
223
224    /// Make a GET request to {path} and deserialize the response body
225    pub(crate) async fn get<O>(
226        &self,
227        path: &str,
228        request_options: &RequestOptions,
229    ) -> Result<O, OpenAIError>
230    where
231        O: DeserializeOwned,
232    {
233        let request_maker = || async {
234            Ok(self
235                .build_request_builder(reqwest::Method::GET, path, request_options)
236                .build()?)
237        };
238
239        self.execute(request_maker).await
240    }
241
242    /// Make a DELETE request to {path} and deserialize the response body
243    pub(crate) async fn delete<O>(
244        &self,
245        path: &str,
246        request_options: &RequestOptions,
247    ) -> Result<O, OpenAIError>
248    where
249        O: DeserializeOwned,
250    {
251        let request_maker = || async {
252            Ok(self
253                .build_request_builder(reqwest::Method::DELETE, path, request_options)
254                .build()?)
255        };
256
257        self.execute(request_maker).await
258    }
259
260    /// Make a GET request to {path} and return the response body
261    pub(crate) async fn get_raw(
262        &self,
263        path: &str,
264        request_options: &RequestOptions,
265    ) -> Result<(Bytes, HeaderMap), OpenAIError> {
266        let request_maker = || async {
267            Ok(self
268                .build_request_builder(reqwest::Method::GET, path, request_options)
269                .build()?)
270        };
271
272        self.execute_raw(request_maker).await
273    }
274
275    /// Make a POST request to {path} and return the response body
276    pub(crate) async fn post_raw<I>(
277        &self,
278        path: &str,
279        request: I,
280        request_options: &RequestOptions,
281    ) -> Result<(Bytes, HeaderMap), OpenAIError>
282    where
283        I: Serialize,
284    {
285        let request_maker = || async {
286            Ok(self
287                .build_request_builder(reqwest::Method::POST, path, request_options)
288                .json(&request)
289                .build()?)
290        };
291
292        self.execute_raw(request_maker).await
293    }
294
295    /// Make a POST request to {path} and deserialize the response body
296    pub(crate) async fn post<I, O>(
297        &self,
298        path: &str,
299        request: I,
300        request_options: &RequestOptions,
301    ) -> Result<O, OpenAIError>
302    where
303        I: Serialize,
304        O: DeserializeOwned,
305    {
306        let request_maker = || async {
307            Ok(self
308                .build_request_builder(reqwest::Method::POST, path, request_options)
309                .json(&request)
310                .build()?)
311        };
312
313        self.execute(request_maker).await
314    }
315
316    /// POST a form at {path} and return the response body
317    pub(crate) async fn post_form_raw<F>(
318        &self,
319        path: &str,
320        form: F,
321        request_options: &RequestOptions,
322    ) -> Result<(Bytes, HeaderMap), OpenAIError>
323    where
324        Form: AsyncTryFrom<F, Error = OpenAIError>,
325        F: Clone,
326    {
327        let request_maker = || async {
328            Ok(self
329                .build_request_builder(reqwest::Method::POST, path, request_options)
330                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
331                .build()?)
332        };
333
334        self.execute_raw(request_maker).await
335    }
336
337    /// POST a form at {path} and deserialize the response body
338    pub(crate) async fn post_form<O, F>(
339        &self,
340        path: &str,
341        form: F,
342        request_options: &RequestOptions,
343    ) -> Result<O, OpenAIError>
344    where
345        O: DeserializeOwned,
346        Form: AsyncTryFrom<F, Error = OpenAIError>,
347        F: Clone,
348    {
349        let request_maker = || async {
350            Ok(self
351                .build_request_builder(reqwest::Method::POST, path, request_options)
352                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
353                .build()?)
354        };
355
356        self.execute(request_maker).await
357    }
358
359    pub(crate) async fn post_form_stream<O, F>(
360        &self,
361        path: &str,
362        form: F,
363        request_options: &RequestOptions,
364    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
365    where
366        F: Clone,
367        Form: AsyncTryFrom<F, Error = OpenAIError>,
368        O: DeserializeOwned + std::marker::Send + 'static,
369    {
370        // Build and execute request manually since multipart::Form is not Clone
371        // and .eventsource() requires cloneability
372        let request_builder = self
373            .build_request_builder(reqwest::Method::POST, path, request_options)
374            .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
375
376        let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
377
378        // Check for error status
379        if !response.status().is_success() {
380            return Err(read_response(response).await.unwrap_err());
381        }
382
383        // Convert response body to EventSource stream
384        let stream = response
385            .bytes_stream()
386            .map(|result| result.map_err(std::io::Error::other));
387        let event_stream = eventsource_stream::EventStream::new(stream);
388
389        // Convert EventSource stream to our expected format
390        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
391
392        tokio::spawn(async move {
393            use futures::StreamExt;
394            let mut event_stream = std::pin::pin!(event_stream);
395
396            while let Some(event_result) = event_stream.next().await {
397                match event_result {
398                    Err(e) => {
399                        if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
400                            StreamError::EventStream(e.to_string()),
401                        )))) {
402                            break;
403                        }
404                    }
405                    Ok(event) => {
406                        // eventsource_stream::Event is a struct with data field
407                        if event.data == "[DONE]" {
408                            break;
409                        }
410
411                        let response = match serde_json::from_str::<O>(&event.data) {
412                            Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
413                            Ok(output) => Ok(output),
414                        };
415
416                        if let Err(_e) = tx.send(response) {
417                            break;
418                        }
419                    }
420                }
421            }
422        });
423
424        Ok(Box::pin(
425            tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
426        ))
427    }
428
429    /// Execute a HTTP request and retry on rate limit
430    ///
431    /// request_maker serves one purpose: to be able to create request again
432    /// to retry API call after getting rate limited. request_maker is async because
433    /// reqwest::multipart::Form is created by async calls to read files for uploads.
434    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
435    where
436        M: Fn() -> Fut,
437        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
438    {
439        let client = self.http_client.clone();
440
441        backoff::future::retry(self.backoff.clone(), || async {
442            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
443            let response = client
444                .execute(request)
445                .await
446                .map_err(OpenAIError::Reqwest)
447                .map_err(backoff::Error::Permanent)?;
448
449            let status = response.status();
450
451            match read_response(response).await {
452                Ok((bytes, headers)) => Ok((bytes, headers)),
453                Err(e) => {
454                    match e {
455                        OpenAIError::ApiError(api_error) => {
456                            if status.is_server_error() {
457                                Err(backoff::Error::Transient {
458                                    err: OpenAIError::ApiError(api_error),
459                                    retry_after: None,
460                                })
461                            } else if status.as_u16() == 429
462                                && api_error.r#type != Some("insufficient_quota".to_string())
463                            {
464                                // Rate limited retry...
465                                tracing::warn!("Rate limited: {}", api_error.message);
466                                Err(backoff::Error::Transient {
467                                    err: OpenAIError::ApiError(api_error),
468                                    retry_after: None,
469                                })
470                            } else {
471                                Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
472                            }
473                        }
474                        _ => Err(backoff::Error::Permanent(e)),
475                    }
476                }
477            }
478        })
479        .await
480    }
481
482    /// Execute a HTTP request and retry on rate limit
483    ///
484    /// request_maker serves one purpose: to be able to create request again
485    /// to retry API call after getting rate limited. request_maker is async because
486    /// reqwest::multipart::Form is created by async calls to read files for uploads.
487    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
488    where
489        O: DeserializeOwned,
490        M: Fn() -> Fut,
491        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
492    {
493        let (bytes, _headers) = self.execute_raw(request_maker).await?;
494
495        let response: O = serde_json::from_slice(bytes.as_ref())
496            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
497
498        Ok(response)
499    }
500
501    /// Make HTTP POST request to receive SSE
502    pub(crate) async fn post_stream<I, O>(
503        &self,
504        path: &str,
505        request: I,
506        request_options: &RequestOptions,
507    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
508    where
509        I: Serialize,
510        O: DeserializeOwned + std::marker::Send + 'static,
511    {
512        let request_builder = self
513            .build_request_builder(reqwest::Method::POST, path, request_options)
514            .json(&request);
515
516        let event_source = request_builder.eventsource().unwrap();
517
518        stream(event_source).await
519    }
520
521    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
522        &self,
523        path: &str,
524        request: I,
525        request_options: &RequestOptions,
526        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
527    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
528    where
529        I: Serialize,
530        O: DeserializeOwned + std::marker::Send + 'static,
531    {
532        let request_builder = self
533            .build_request_builder(reqwest::Method::POST, path, request_options)
534            .json(&request);
535
536        let event_source = request_builder.eventsource().unwrap();
537
538        stream_mapped_raw_events(event_source, event_mapper).await
539    }
540
541    /// Make HTTP GET request to receive SSE
542    pub(crate) async fn get_stream<O>(
543        &self,
544        path: &str,
545        request_options: &RequestOptions,
546    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
547    where
548        O: DeserializeOwned + std::marker::Send + 'static,
549    {
550        let request_builder =
551            self.build_request_builder(reqwest::Method::GET, path, request_options);
552
553        let event_source = request_builder.eventsource().unwrap();
554
555        stream(event_source).await
556    }
557}
558
559async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
560    let status = response.status();
561    let headers = response.headers().clone();
562    let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
563
564    if status.is_server_error() {
565        // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
566        let message: String = String::from_utf8_lossy(&bytes).into_owned();
567        tracing::warn!("Server error: {status} - {message}");
568        return Err(OpenAIError::ApiError(ApiError {
569            message,
570            r#type: None,
571            param: None,
572            code: None,
573        }));
574    }
575
576    // Deserialize response body from either error object or actual response object
577    if !status.is_success() {
578        let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
579            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
580
581        return Err(OpenAIError::ApiError(wrapped_error.error));
582    }
583
584    Ok((bytes, headers))
585}
586
587async fn map_stream_error(value: EventSourceError) -> OpenAIError {
588    match value {
589        EventSourceError::InvalidStatusCode(status_code, response) => {
590            read_response(response).await.expect_err(&format!(
591                "Unreachable because read_response returns err when status_code {status_code} is invalid"
592            ))
593        }
594        _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
595    }
596}
597
598/// Request which responds with SSE.
599/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
600pub(crate) async fn stream<O>(
601    mut event_source: EventSource,
602) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
603where
604    O: DeserializeOwned + std::marker::Send + 'static,
605{
606    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
607
608    tokio::spawn(async move {
609        while let Some(ev) = event_source.next().await {
610            match ev {
611                Err(e) => {
612                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
613                    // https://github.com/64bit/async-openai/issues/456
614                    match &e {
615                        EventSourceError::StreamEnded => {
616                            break;
617                        }
618                        _ => {
619                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
620                                // rx dropped
621                                break;
622                            }
623                        }
624                    }
625                }
626                Ok(event) => match event {
627                    Event::Message(message) => {
628                        if message.data == "[DONE]" {
629                            break;
630                        }
631
632                        let response = match serde_json::from_str::<O>(&message.data) {
633                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
634                            Ok(output) => Ok(output),
635                        };
636
637                        if let Err(_e) = tx.send(response) {
638                            // rx dropped
639                            break;
640                        }
641                    }
642                    Event::Open => continue,
643                },
644            }
645        }
646
647        event_source.close();
648    });
649
650    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
651}
652
653pub(crate) async fn stream_mapped_raw_events<O>(
654    mut event_source: EventSource,
655    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
656) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
657where
658    O: DeserializeOwned + std::marker::Send + 'static,
659{
660    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
661
662    tokio::spawn(async move {
663        while let Some(ev) = event_source.next().await {
664            match ev {
665                Err(e) => {
666                    // Handle StreamEnded gracefully - it's a normal end of stream, not an error
667                    // https://github.com/64bit/async-openai/issues/456
668                    match &e {
669                        EventSourceError::StreamEnded => {
670                            break;
671                        }
672                        _ => {
673                            if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
674                                // rx dropped
675                                break;
676                            }
677                        }
678                    }
679                }
680                Ok(event) => match event {
681                    Event::Message(message) => {
682                        let mut done = false;
683
684                        if message.data == "[DONE]" {
685                            done = true;
686                        }
687
688                        let response = event_mapper(message);
689
690                        if let Err(_e) = tx.send(response) {
691                            // rx dropped
692                            break;
693                        }
694
695                        if done {
696                            break;
697                        }
698                    }
699                    Event::Open => continue,
700                },
701            }
702        }
703
704        event_source.close();
705    });
706
707    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
708}