flynn_openai/
client.rs

1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
6use serde::{de::DeserializeOwned, Serialize};
7
8use crate::{
9    config::{Config, OpenAIConfig},
10    error::{map_deserialization_error, OpenAIError, WrappedError},
11    file::Files,
12    image::Images,
13    moderation::Moderations,
14    Assistants, Audio, Batches, Chat, Completions, Embeddings, FineTuning, Models, Threads,
15    VectorStores,
16};
17
18#[derive(Debug, Clone, Default)]
19/// Client is a container for config, backoff and http_client
20/// used to make API calls.
21pub struct Client<C: Config> {
22    http_client: reqwest::Client,
23    config: C,
24    backoff: backoff::ExponentialBackoff,
25}
26
27impl Client<OpenAIConfig> {
28    /// Client with default [OpenAIConfig]
29    pub fn new() -> Self {
30        Self::default()
31    }
32}
33
34impl<C: Config> Client<C> {
35    /// Create client with a custom HTTP client, OpenAI config, and backoff.
36    pub fn build(
37        http_client: reqwest::Client,
38        config: C,
39        backoff: backoff::ExponentialBackoff,
40    ) -> Self {
41        Self {
42            http_client,
43            config,
44            backoff,
45        }
46    }
47
48    /// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
49    pub fn with_config(config: C) -> Self {
50        Self {
51            http_client: reqwest::Client::new(),
52            config,
53            backoff: Default::default(),
54        }
55    }
56
57    /// Provide your own [client] to make HTTP requests with.
58    ///
59    /// [client]: reqwest::Client
60    pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
61        self.http_client = http_client;
62        self
63    }
64
65    /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
66    pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
67        self.backoff = backoff;
68        self
69    }
70
71    // API groups
72
73    /// To call [Models] group related APIs using this client.
74    pub fn models(&self) -> Models<C> {
75        Models::new(self)
76    }
77
78    /// To call [Completions] group related APIs using this client.
79    pub fn completions(&self) -> Completions<C> {
80        Completions::new(self)
81    }
82
83    /// To call [Chat] group related APIs using this client.
84    pub fn chat(&self) -> Chat<C> {
85        Chat::new(self)
86    }
87
88    /// To call [Images] group related APIs using this client.
89    pub fn images(&self) -> Images<C> {
90        Images::new(self)
91    }
92
93    /// To call [Moderations] group related APIs using this client.
94    pub fn moderations(&self) -> Moderations<C> {
95        Moderations::new(self)
96    }
97
98    /// To call [Files] group related APIs using this client.
99    pub fn files(&self) -> Files<C> {
100        Files::new(self)
101    }
102
103    /// To call [FineTuning] group related APIs using this client.
104    pub fn fine_tuning(&self) -> FineTuning<C> {
105        FineTuning::new(self)
106    }
107
108    /// To call [Embeddings] group related APIs using this client.
109    pub fn embeddings(&self) -> Embeddings<C> {
110        Embeddings::new(self)
111    }
112
113    /// To call [Audio] group related APIs using this client.
114    pub fn audio(&self) -> Audio<C> {
115        Audio::new(self)
116    }
117
118    /// To call [Assistants] group related APIs using this client.
119    pub fn assistants(&self) -> Assistants<C> {
120        Assistants::new(self)
121    }
122
123    /// To call [Threads] group related APIs using this client.
124    pub fn threads(&self) -> Threads<C> {
125        Threads::new(self)
126    }
127
128    /// To call [VectorStores] group related APIs using this client.
129    pub fn vector_stores(&self) -> VectorStores<C> {
130        VectorStores::new(self)
131    }
132
133    /// To call [Batches] group related APIs using this client.
134    pub fn batches(&self) -> Batches<C> {
135        Batches::new(self)
136    }
137
138    pub fn config(&self) -> &C {
139        &self.config
140    }
141
142    /// Make a GET request to {path} and deserialize the response body
143    pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
144    where
145        O: DeserializeOwned,
146    {
147        let request_maker = || async {
148            Ok(self
149                .http_client
150                .get(self.config.url(path))
151                .query(&self.config.query())
152                .headers(self.config.headers())
153                .build()?)
154        };
155
156        self.execute(request_maker).await
157    }
158
159    /// Make a GET request to {path} with given Query and deserialize the response body
160    pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
161    where
162        O: DeserializeOwned,
163        Q: Serialize + ?Sized,
164    {
165        let request_maker = || async {
166            Ok(self
167                .http_client
168                .get(self.config.url(path))
169                .query(&self.config.query())
170                .query(query)
171                .headers(self.config.headers())
172                .build()?)
173        };
174
175        self.execute(request_maker).await
176    }
177
178    /// Make a DELETE request to {path} and deserialize the response body
179    pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
180    where
181        O: DeserializeOwned,
182    {
183        let request_maker = || async {
184            Ok(self
185                .http_client
186                .delete(self.config.url(path))
187                .query(&self.config.query())
188                .headers(self.config.headers())
189                .build()?)
190        };
191
192        self.execute(request_maker).await
193    }
194
195    /// Make a GET request to {path} and return the response body
196    pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
197        let request_maker = || async {
198            Ok(self
199                .http_client
200                .get(self.config.url(path))
201                .query(&self.config.query())
202                .headers(self.config.headers())
203                .build()?)
204        };
205
206        self.execute_raw(request_maker).await
207    }
208
209    /// Make a POST request to {path} and return the response body
210    pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
211    where
212        I: Serialize,
213    {
214        let request_maker = || async {
215            Ok(self
216                .http_client
217                .post(self.config.url(path))
218                .query(&self.config.query())
219                .headers(self.config.headers())
220                .json(&request)
221                .build()?)
222        };
223
224        self.execute_raw(request_maker).await
225    }
226
227    /// Make a POST request to {path} and deserialize the response body
228    pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
229    where
230        I: Serialize,
231        O: DeserializeOwned,
232    {
233        let request_maker = || async {
234            Ok(self
235                .http_client
236                .post(self.config.url(path))
237                .query(&self.config.query())
238                .headers(self.config.headers())
239                .json(&request)
240                .build()?)
241        };
242
243        self.execute(request_maker).await
244    }
245
246    /// POST a form at {path} and return the response body
247    pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
248    where
249        reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
250        F: Clone,
251    {
252        let request_maker = || async {
253            Ok(self
254                .http_client
255                .post(self.config.url(path))
256                .query(&self.config.query())
257                .headers(self.config.headers())
258                .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
259                .build()?)
260        };
261
262        self.execute_raw(request_maker).await
263    }
264
265    /// POST a form at {path} and deserialize the response body
266    pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
267    where
268        O: DeserializeOwned,
269        reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
270        F: Clone,
271    {
272        let request_maker = || async {
273            Ok(self
274                .http_client
275                .post(self.config.url(path))
276                .query(&self.config.query())
277                .headers(self.config.headers())
278                .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
279                .build()?)
280        };
281
282        self.execute(request_maker).await
283    }
284
285    /// Execute a HTTP request and retry on rate limit
286    ///
287    /// request_maker serves one purpose: to be able to create request again
288    /// to retry API call after getting rate limited. request_maker is async because
289    /// reqwest::multipart::Form is created by async calls to read files for uploads.
290    async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
291    where
292        M: Fn() -> Fut,
293        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
294    {
295        let client = self.http_client.clone();
296
297        backoff::future::retry(self.backoff.clone(), || async {
298            let request = request_maker().await.map_err(backoff::Error::Permanent)?;
299            let response = client
300                .execute(request)
301                .await
302                .map_err(OpenAIError::Reqwest)
303                .map_err(backoff::Error::Permanent)?;
304
305            let status = response.status();
306            let bytes = response
307                .bytes()
308                .await
309                .map_err(OpenAIError::Reqwest)
310                .map_err(backoff::Error::Permanent)?;
311
312            // Deserialize response body from either error object or actual response object
313            if !status.is_success() {
314                let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
315                    .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
316                    .map_err(backoff::Error::Permanent)?;
317
318                if status.as_u16() == 429
319                    // API returns 429 also when:
320                    // "You exceeded your current quota, please check your plan and billing details."
321                    && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
322                {
323                    // Rate limited retry...
324                    tracing::warn!("Rate limited: {}", wrapped_error.error.message);
325                    return Err(backoff::Error::Transient {
326                        err: OpenAIError::ApiError(wrapped_error.error),
327                        retry_after: None,
328                    });
329                } else {
330                    return Err(backoff::Error::Permanent(OpenAIError::ApiError(
331                        wrapped_error.error,
332                    )));
333                }
334            }
335
336            Ok(bytes)
337        })
338        .await
339    }
340
341    /// Execute a HTTP request and retry on rate limit
342    ///
343    /// request_maker serves one purpose: to be able to create request again
344    /// to retry API call after getting rate limited. request_maker is async because
345    /// reqwest::multipart::Form is created by async calls to read files for uploads.
346    async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
347    where
348        O: DeserializeOwned,
349        M: Fn() -> Fut,
350        Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
351    {
352        let bytes = self.execute_raw(request_maker).await?;
353
354        let response: O = serde_json::from_slice(bytes.as_ref())
355            .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
356
357        Ok(response)
358    }
359
360    /// Make HTTP POST request to receive SSE
361    pub(crate) async fn post_stream<I, O>(
362        &self,
363        path: &str,
364        request: I,
365    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
366    where
367        I: Serialize,
368        O: DeserializeOwned + std::marker::Send + 'static,
369    {
370        let event_source = self
371            .http_client
372            .post(self.config.url(path))
373            .query(&self.config.query())
374            .headers(self.config.headers())
375            .json(&request)
376            .eventsource()
377            .unwrap();
378
379        stream(event_source).await
380    }
381
382    pub(crate) async fn post_stream_mapped_raw_events<I, O>(
383        &self,
384        path: &str,
385        request: I,
386        event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
387    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
388    where
389        I: Serialize,
390        O: DeserializeOwned + std::marker::Send + 'static,
391    {
392        let event_source = self
393            .http_client
394            .post(self.config.url(path))
395            .query(&self.config.query())
396            .headers(self.config.headers())
397            .json(&request)
398            .eventsource()
399            .unwrap();
400
401        stream_mapped_raw_events(event_source, event_mapper).await
402    }
403
404    /// Make HTTP GET request to receive SSE
405    pub(crate) async fn _get_stream<Q, O>(
406        &self,
407        path: &str,
408        query: &Q,
409    ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
410    where
411        Q: Serialize + ?Sized,
412        O: DeserializeOwned + std::marker::Send + 'static,
413    {
414        let event_source = self
415            .http_client
416            .get(self.config.url(path))
417            .query(query)
418            .query(&self.config.query())
419            .headers(self.config.headers())
420            .eventsource()
421            .unwrap();
422
423        stream(event_source).await
424    }
425}
426
427/// Request which responds with SSE.
428/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
429pub(crate) async fn stream<O>(
430    mut event_source: EventSource,
431) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
432where
433    O: DeserializeOwned + std::marker::Send + 'static,
434{
435    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
436
437    tokio::spawn(async move {
438        while let Some(ev) = event_source.next().await {
439            match ev {
440                Err(e) => {
441                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
442                        // rx dropped
443                        break;
444                    }
445                }
446                Ok(event) => match event {
447                    Event::Message(message) => {
448                        if message.data == "[DONE]" {
449                            break;
450                        }
451
452                        let response = match serde_json::from_str::<O>(&message.data) {
453                            Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
454                            Ok(output) => Ok(output),
455                        };
456
457                        if let Err(_e) = tx.send(response) {
458                            // rx dropped
459                            break;
460                        }
461                    }
462                    Event::Open => continue,
463                },
464            }
465        }
466
467        event_source.close();
468    });
469
470    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
471}
472
473pub(crate) async fn stream_mapped_raw_events<O>(
474    mut event_source: EventSource,
475    event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
476) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
477where
478    O: DeserializeOwned + std::marker::Send + 'static,
479{
480    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
481
482    tokio::spawn(async move {
483        while let Some(ev) = event_source.next().await {
484            match ev {
485                Err(e) => {
486                    if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
487                        // rx dropped
488                        break;
489                    }
490                }
491                Ok(event) => match event {
492                    Event::Message(message) => {
493                        let mut done = false;
494
495                        if message.data == "[DONE]" {
496                            done = true;
497                        }
498
499                        let response = event_mapper(message);
500
501                        if let Err(_e) = tx.send(response) {
502                            // rx dropped
503                            break;
504                        }
505
506                        if done {
507                            break;
508                        }
509                    }
510                    Event::Open => continue,
511                },
512            }
513        }
514
515        event_source.close();
516    });
517
518    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
519}