dynamo_async_openai/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use std::pin::Pin;
12
13use bytes::Bytes;
14use futures::{Stream, stream::StreamExt};
15use reqwest::multipart::Form;
16use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
17use serde::{Serialize, de::DeserializeOwned};
18
19use crate::{
20    Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
21    Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
22    config::{Config, OpenAIConfig},
23    error::{ApiError, OpenAIError, WrappedError, map_deserialization_error},
24    file::Files,
25    image::Images,
26    moderation::Moderations,
27    traits::AsyncTryFrom,
28};
29
30#[derive(Debug, Clone, Default)]
31/// Client is a container for config, backoff and http_client
32/// used to make API calls.
33pub struct Client<C: Config> {
34    http_client: reqwest::Client,
35    config: C,
36    backoff: backoff::ExponentialBackoff,
37}
38
39impl Client<OpenAIConfig> {
40    /// Client with default [OpenAIConfig]
41    pub fn new() -> Self {
42        Self::default()
43    }
44}
45
46impl<C: Config> Client<C> {
47    /// Create client with a custom HTTP client, OpenAI config, and backoff.
48    pub fn build(
49        http_client: reqwest::Client,
50        config: C,
51        backoff: backoff::ExponentialBackoff,
52    ) -> Self {
53        Self {
54            http_client,
55            config,
56            backoff,
57        }
58    }
59
60    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
61    pub fn with_config(config: C) -> Self {
62        Self {
63            http_client: reqwest::Client::new(),
64            config,
65            backoff: Default::default(),
66        }
67    }
68
69    /// Provide your own [client] to make HTTP requests with.
70    ///
71    /// [client]: reqwest::Client
72    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
73        self.http_client = http_client;
74        self
75    }
76
77    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
78    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
79        self.backoff = backoff;
80        self
81    }
82
83    // API groups
84
85    /// To call [Models] group related APIs using this client.
86    pub fn models(&self) -> Models<C> {
87        Models::new(self)
88    }
89
90    /// To call [Completions] group related APIs using this client.
91    pub fn completions(&self) -> Completions<C> {
92        Completions::new(self)
93    }
94
95    /// To call [Chat] group related APIs using this client.
96    pub fn chat(&self) -> Chat<C> {
97        Chat::new(self)
98    }
99
100    /// To call [Images] group related APIs using this client.
101    pub fn images(&self) -> Images<C> {
102        Images::new(self)
103    }
104
105    /// To call [Moderations] group related APIs using this client.
106    pub fn moderations(&self) -> Moderations<C> {
107        Moderations::new(self)
108    }
109
110    /// To call [Files] group related APIs using this client.
111    pub fn files(&self) -> Files<C> {
112        Files::new(self)
113    }
114
115    /// To call [Uploads] group related APIs using this client.
116    pub fn uploads(&self) -> Uploads<C> {
117        Uploads::new(self)
118    }
119
120    /// To call [FineTuning] group related APIs using this client.
121    pub fn fine_tuning(&self) -> FineTuning<C> {
122        FineTuning::new(self)
123    }
124
125    /// To call [Embeddings] group related APIs using this client.
126    pub fn embeddings(&self) -> Embeddings<C> {
127        Embeddings::new(self)
128    }
129
130    /// To call [Audio] group related APIs using this client.
131    pub fn audio(&self) -> Audio<C> {
132        Audio::new(self)
133    }
134
135    /// To call [Assistants] group related APIs using this client.
136    pub fn assistants(&self) -> Assistants<C> {
137        Assistants::new(self)
138    }
139
140    /// To call [Threads] group related APIs using this client.
141    pub fn threads(&self) -> Threads<C> {
142        Threads::new(self)
143    }
144
145    /// To call [VectorStores] group related APIs using this client.
146    pub fn vector_stores(&self) -> VectorStores<C> {
147        VectorStores::new(self)
148    }
149
150    /// To call [Batches] group related APIs using this client.
151    pub fn batches(&self) -> Batches<C> {
152        Batches::new(self)
153    }
154
155    /// To call [AuditLogs] group related APIs using this client.
156    pub fn audit_logs(&self) -> AuditLogs<C> {
157        AuditLogs::new(self)
158    }
159
160    /// To call [Invites] group related APIs using this client.
161    pub fn invites(&self) -> Invites<C> {
162        Invites::new(self)
163    }
164
165    /// To call [Users] group related APIs using this client.
166    pub fn users(&self) -> Users<C> {
167        Users::new(self)
168    }
169
170    /// To call [Projects] group related APIs using this client.
171    pub fn projects(&self) -> Projects<C> {
172        Projects::new(self)
173    }
174
175    /// To call [Responses] group related APIs using this client.
176    pub fn responses(&self) -> Responses<C> {
177        Responses::new(self)
178    }
179
180    pub fn config(&self) -> &C {
181        &self.config
182    }
183
184    /// Make a GET request to {path} and deserialize the response body
185    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
186    where
187        O: DeserializeOwned,
188    {
189        let request_maker = || async {
190            Ok(self
191                .http_client
192                .get(self.config.url(path))
193                .query(&self.config.query())
194                .headers(self.config.headers())
195                .build()?)
196        };
197
198        self.execute(request_maker).await
199    }
200
201    /// Make a GET request to {path} with given Query and deserialize the response body
202    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
203    where
204        O: DeserializeOwned,
205        Q: Serialize + ?Sized,
206    {
207        let request_maker = || async {
208            Ok(self
209                .http_client
210                .get(self.config.url(path))
211                .query(&self.config.query())
212                .query(query)
213                .headers(self.config.headers())
214                .build()?)
215        };
216
217        self.execute(request_maker).await
218    }
219
220    /// Make a DELETE request to {path} and deserialize the response body
221    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
222    where
223        O: DeserializeOwned,
224    {
225        let request_maker = || async {
226            Ok(self
227                .http_client
228                .delete(self.config.url(path))
229                .query(&self.config.query())
230                .headers(self.config.headers())
231                .build()?)
232        };
233
234        self.execute(request_maker).await
235    }
236
237    /// Make a GET request to {path} and return the response body
238    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
239        let request_maker = || async {
240            Ok(self
241                .http_client
242                .get(self.config.url(path))
243                .query(&self.config.query())
244                .headers(self.config.headers())
245                .build()?)
246        };
247
248        self.execute_raw(request_maker).await
249    }
250
251    /// Make a POST request to {path} and return the response body
252    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
253    where
254        I: Serialize,
255    {
256        let request_maker = || async {
257            Ok(self
258                .http_client
259                .post(self.config.url(path))
260                .query(&self.config.query())
261                .headers(self.config.headers())
262                .json(&request)
263                .build()?)
264        };
265
266        self.execute_raw(request_maker).await
267    }
268
269    /// Make a POST request to {path} and deserialize the response body
270    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
271    where
272        I: Serialize,
273        O: DeserializeOwned,
274    {
275        let request_maker = || async {
276            Ok(self
277                .http_client
278                .post(self.config.url(path))
279                .query(&self.config.query())
280                .headers(self.config.headers())
281                .json(&request)
282                .build()?)
283        };
284
285        self.execute(request_maker).await
286    }
287
288    /// POST a form at {path} and return the response body
289    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
290    where
291        Form: AsyncTryFrom<F, Error = OpenAIError>,
292        F: Clone,
293    {
294        let request_maker = || async {
295            Ok(self
296                .http_client
297                .post(self.config.url(path))
298                .query(&self.config.query())
299                .headers(self.config.headers())
300                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
301                .build()?)
302        };
303
304        self.execute_raw(request_maker).await
305    }
306
307    /// POST a form at {path} and deserialize the response body
308    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
309    where
310        O: DeserializeOwned,
311        Form: AsyncTryFrom<F, Error = OpenAIError>,
312        F: Clone,
313    {
314        let request_maker = || async {
315            Ok(self
316                .http_client
317                .post(self.config.url(path))
318                .query(&self.config.query())
319                .headers(self.config.headers())
320                .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
321                .build()?)
322        };
323
324        self.execute(request_maker).await
325    }
326
327    /// Execute a HTTP request and retry on rate limit
328    ///
329    /// request_maker serves one purpose: to be able to create request again
330    /// to retry API call after getting rate limited. request_maker is async because
331    /// reqwest::multipart::Form is created by async calls to read files for uploads.
332    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
333    where
334        M: Fn() -> Fut,
335        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
336    {
337        let client = self.http_client.clone();
338
339        backoff::future::retry(self.backoff.clone(), || async {
340            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
341            let response = client
342                .execute(request)
343                .await
344                .map_err(OpenAIError::Reqwest)
345                .map_err(backoff::Error::Permanent)?;
346
347            let status = response.status();
348            let bytes = response
349                .bytes()
350                .await
351                .map_err(OpenAIError::Reqwest)
352                .map_err(backoff::Error::Permanent)?;
353
354            if status.is_server_error() {
355                // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
356                let message: String = String::from_utf8_lossy(&bytes).into_owned();
357                tracing::warn!("Server error: {status} - {message}");
358                return Err(backoff::Error::Transient {
359                    err: OpenAIError::ApiError(ApiError {
360                        message,
361                        r#type: None,
362                        param: None,
363                        code: None,
364                    }),
365                    retry_after: None,
366                });
367            }
368
369            // Deserialize response body from either error object or actual response object
370            if !status.is_success() {
371                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
372                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
373                    .map_err(backoff::Error::Permanent)?;
374
375                if status.as_u16() == 429
376                    // API returns 429 also when:
377                    // "You exceeded your current quota, please check your plan and billing details."
378                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
379                {
380                    // Rate limited retry...
381                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
382                    return Err(backoff::Error::Transient {
383                        err: OpenAIError::ApiError(wrapped_error.error),
384                        retry_after: None,
385                    });
386                } else {
387                    return Err(backoff::Error::Permanent(OpenAIError::ApiError(
388                        wrapped_error.error,
389                    )));
390                }
391            }
392
393            Ok(bytes)
394        })
395        .await
396    }
397
398    /// Execute a HTTP request and retry on rate limit
399    ///
400    /// request_maker serves one purpose: to be able to create request again
401    /// to retry API call after getting rate limited. request_maker is async because
402    /// reqwest::multipart::Form is created by async calls to read files for uploads.
403    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
404    where
405        O: DeserializeOwned,
406        M: Fn() -> Fut,
407        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
408    {
409        let bytes = self.execute_raw(request_maker).await?;
410
411        let response: O = serde_json::from_slice(bytes.as_ref())
412            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
413
414        Ok(response)
415    }
416
417    /// Make HTTP POST request to receive SSE
418    pub(crate) async fn post_stream<I, O>(
419        &self,
420        path: &str,
421        request: I,
422    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
423    where
424        I: Serialize,
425        O: DeserializeOwned + std::marker::Send + 'static,
426    {
427        let event_source = self
428            .http_client
429            .post(self.config.url(path))
430            .query(&self.config.query())
431            .headers(self.config.headers())
432            .json(&request)
433            .eventsource()
434            .unwrap();
435
436        stream(event_source).await
437    }
438
439    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
440        &self,
441        path: &str,
442        request: I,
443        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
444    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
445    where
446        I: Serialize,
447        O: DeserializeOwned + std::marker::Send + 'static,
448    {
449        let event_source = self
450            .http_client
451            .post(self.config.url(path))
452            .query(&self.config.query())
453            .headers(self.config.headers())
454            .json(&request)
455            .eventsource()
456            .unwrap();
457
458        stream_mapped_raw_events(event_source, event_mapper).await
459    }
460
461    /// Make HTTP GET request to receive SSE
462    pub(crate) async fn _get_stream<Q, O>(
463        &self,
464        path: &str,
465        query: &Q,
466    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
467    where
468        Q: Serialize + ?Sized,
469        O: DeserializeOwned + std::marker::Send + 'static,
470    {
471        let event_source = self
472            .http_client
473            .get(self.config.url(path))
474            .query(query)
475            .query(&self.config.query())
476            .headers(self.config.headers())
477            .eventsource()
478            .unwrap();
479
480        stream(event_source).await
481    }
482}
483
484/// Request which responds with SSE.
485/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
486pub(crate) async fn stream<O>(
487    mut event_source: EventSource,
488) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
489where
490    O: DeserializeOwned + std::marker::Send + 'static,
491{
492    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
493
494    tokio::spawn(async move {
495        while let Some(ev) = event_source.next().await {
496            match ev {
497                Err(e) => {
498                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
499                        // rx dropped
500                        break;
501                    }
502                }
503                Ok(event) => match event {
504                    Event::Message(message) => {
505                        if message.data == "[DONE]" {
506                            break;
507                        }
508
509                        let response = match serde_json::from_str::<O>(&message.data) {
510                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
511                            Ok(output) => Ok(output),
512                        };
513
514                        if let Err(_e) = tx.send(response) {
515                            // rx dropped
516                            break;
517                        }
518                    }
519                    Event::Open => continue,
520                },
521            }
522        }
523
524        event_source.close();
525    });
526
527    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
528}
529
530pub(crate) async fn stream_mapped_raw_events<O>(
531    mut event_source: EventSource,
532    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
533) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
534where
535    O: DeserializeOwned + std::marker::Send + 'static,
536{
537    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
538
539    tokio::spawn(async move {
540        while let Some(ev) = event_source.next().await {
541            match ev {
542                Err(e) => {
543                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
544                        // rx dropped
545                        break;
546                    }
547                }
548                Ok(event) => match event {
549                    Event::Message(message) => {
550                        let mut done = false;
551
552                        if message.data == "[DONE]" {
553                            done = true;
554                        }
555
556                        let response = event_mapper(message);
557
558                        if let Err(_e) = tx.send(response) {
559                            // rx dropped
560                            break;
561                        }
562
563                        if done {
564                            break;
565                        }
566                    }
567                    Event::Open => continue,
568                },
569            }
570        }
571
572        event_source.close();
573    });
574
575    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
576}