1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
6use serde::{de::DeserializeOwned, Serialize};
7
8use crate::{
9 config::{Config, OpenAIConfig},
10 error::{map_deserialization_error, OpenAIError, WrappedError},
11 file::Files,
12 image::Images,
13 moderation::Moderations,
14 Assistants, Audio, Batches, Chat, Completions, Embeddings, FineTuning, Models, Threads,
15 VectorStores,
16};
17
18#[derive(Debug, Clone, Default)]
19pub struct Client<C: Config> {
22 http_client: reqwest::Client,
23 config: C,
24 backoff: backoff::ExponentialBackoff,
25}
26
27impl Client<OpenAIConfig> {
28 pub fn new() -> Self {
30 Self::default()
31 }
32}
33
34impl<C: Config> Client<C> {
35 pub fn build(
37 http_client: reqwest::Client,
38 config: C,
39 backoff: backoff::ExponentialBackoff,
40 ) -> Self {
41 Self {
42 http_client,
43 config,
44 backoff,
45 }
46 }
47
48 pub fn with_config(config: C) -> Self {
50 Self {
51 http_client: reqwest::Client::new(),
52 config,
53 backoff: Default::default(),
54 }
55 }
56
57 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
61 self.http_client = http_client;
62 self
63 }
64
65 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
67 self.backoff = backoff;
68 self
69 }
70
71 pub fn models(&self) -> Models<C> {
75 Models::new(self)
76 }
77
78 pub fn completions(&self) -> Completions<C> {
80 Completions::new(self)
81 }
82
83 pub fn chat(&self) -> Chat<C> {
85 Chat::new(self)
86 }
87
88 pub fn images(&self) -> Images<C> {
90 Images::new(self)
91 }
92
93 pub fn moderations(&self) -> Moderations<C> {
95 Moderations::new(self)
96 }
97
98 pub fn files(&self) -> Files<C> {
100 Files::new(self)
101 }
102
103 pub fn fine_tuning(&self) -> FineTuning<C> {
105 FineTuning::new(self)
106 }
107
108 pub fn embeddings(&self) -> Embeddings<C> {
110 Embeddings::new(self)
111 }
112
113 pub fn audio(&self) -> Audio<C> {
115 Audio::new(self)
116 }
117
118 pub fn assistants(&self) -> Assistants<C> {
120 Assistants::new(self)
121 }
122
123 pub fn threads(&self) -> Threads<C> {
125 Threads::new(self)
126 }
127
128 pub fn vector_stores(&self) -> VectorStores<C> {
130 VectorStores::new(self)
131 }
132
133 pub fn batches(&self) -> Batches<C> {
135 Batches::new(self)
136 }
137
138 pub fn config(&self) -> &C {
139 &self.config
140 }
141
142 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
144 where
145 O: DeserializeOwned,
146 {
147 let request_maker = || async {
148 Ok(self
149 .http_client
150 .get(self.config.url(path))
151 .query(&self.config.query())
152 .headers(self.config.headers())
153 .build()?)
154 };
155
156 self.execute(request_maker).await
157 }
158
159 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
161 where
162 O: DeserializeOwned,
163 Q: Serialize + ?Sized,
164 {
165 let request_maker = || async {
166 Ok(self
167 .http_client
168 .get(self.config.url(path))
169 .query(&self.config.query())
170 .query(query)
171 .headers(self.config.headers())
172 .build()?)
173 };
174
175 self.execute(request_maker).await
176 }
177
178 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
180 where
181 O: DeserializeOwned,
182 {
183 let request_maker = || async {
184 Ok(self
185 .http_client
186 .delete(self.config.url(path))
187 .query(&self.config.query())
188 .headers(self.config.headers())
189 .build()?)
190 };
191
192 self.execute(request_maker).await
193 }
194
195 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
197 let request_maker = || async {
198 Ok(self
199 .http_client
200 .get(self.config.url(path))
201 .query(&self.config.query())
202 .headers(self.config.headers())
203 .build()?)
204 };
205
206 self.execute_raw(request_maker).await
207 }
208
209 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
211 where
212 I: Serialize,
213 {
214 let request_maker = || async {
215 Ok(self
216 .http_client
217 .post(self.config.url(path))
218 .query(&self.config.query())
219 .headers(self.config.headers())
220 .json(&request)
221 .build()?)
222 };
223
224 self.execute_raw(request_maker).await
225 }
226
227 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
229 where
230 I: Serialize,
231 O: DeserializeOwned,
232 {
233 let request_maker = || async {
234 Ok(self
235 .http_client
236 .post(self.config.url(path))
237 .query(&self.config.query())
238 .headers(self.config.headers())
239 .json(&request)
240 .build()?)
241 };
242
243 self.execute(request_maker).await
244 }
245
246 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
248 where
249 reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
250 F: Clone,
251 {
252 let request_maker = || async {
253 Ok(self
254 .http_client
255 .post(self.config.url(path))
256 .query(&self.config.query())
257 .headers(self.config.headers())
258 .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
259 .build()?)
260 };
261
262 self.execute_raw(request_maker).await
263 }
264
265 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
267 where
268 O: DeserializeOwned,
269 reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
270 F: Clone,
271 {
272 let request_maker = || async {
273 Ok(self
274 .http_client
275 .post(self.config.url(path))
276 .query(&self.config.query())
277 .headers(self.config.headers())
278 .multipart(async_convert::TryFrom::try_from(form.clone()).await?)
279 .build()?)
280 };
281
282 self.execute(request_maker).await
283 }
284
285 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
291 where
292 M: Fn() -> Fut,
293 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
294 {
295 let client = self.http_client.clone();
296
297 backoff::future::retry(self.backoff.clone(), || async {
298 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
299 let response = client
300 .execute(request)
301 .await
302 .map_err(OpenAIError::Reqwest)
303 .map_err(backoff::Error::Permanent)?;
304
305 let status = response.status();
306 let bytes = response
307 .bytes()
308 .await
309 .map_err(OpenAIError::Reqwest)
310 .map_err(backoff::Error::Permanent)?;
311
312 if !status.is_success() {
314 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
315 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))
316 .map_err(backoff::Error::Permanent)?;
317
318 if status.as_u16() == 429
319 && wrapped_error.error.r#type != Some("insufficient_quota".to_string())
322 {
323 tracing::warn!("Rate limited: {}", wrapped_error.error.message);
325 return Err(backoff::Error::Transient {
326 err: OpenAIError::ApiError(wrapped_error.error),
327 retry_after: None,
328 });
329 } else {
330 return Err(backoff::Error::Permanent(OpenAIError::ApiError(
331 wrapped_error.error,
332 )));
333 }
334 }
335
336 Ok(bytes)
337 })
338 .await
339 }
340
341 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
347 where
348 O: DeserializeOwned,
349 M: Fn() -> Fut,
350 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
351 {
352 let bytes = self.execute_raw(request_maker).await?;
353
354 let response: O = serde_json::from_slice(bytes.as_ref())
355 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
356
357 Ok(response)
358 }
359
360 pub(crate) async fn post_stream<I, O>(
362 &self,
363 path: &str,
364 request: I,
365 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
366 where
367 I: Serialize,
368 O: DeserializeOwned + std::marker::Send + 'static,
369 {
370 let event_source = self
371 .http_client
372 .post(self.config.url(path))
373 .query(&self.config.query())
374 .headers(self.config.headers())
375 .json(&request)
376 .eventsource()
377 .unwrap();
378
379 stream(event_source).await
380 }
381
382 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
383 &self,
384 path: &str,
385 request: I,
386 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
387 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
388 where
389 I: Serialize,
390 O: DeserializeOwned + std::marker::Send + 'static,
391 {
392 let event_source = self
393 .http_client
394 .post(self.config.url(path))
395 .query(&self.config.query())
396 .headers(self.config.headers())
397 .json(&request)
398 .eventsource()
399 .unwrap();
400
401 stream_mapped_raw_events(event_source, event_mapper).await
402 }
403
404 pub(crate) async fn _get_stream<Q, O>(
406 &self,
407 path: &str,
408 query: &Q,
409 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
410 where
411 Q: Serialize + ?Sized,
412 O: DeserializeOwned + std::marker::Send + 'static,
413 {
414 let event_source = self
415 .http_client
416 .get(self.config.url(path))
417 .query(query)
418 .query(&self.config.query())
419 .headers(self.config.headers())
420 .eventsource()
421 .unwrap();
422
423 stream(event_source).await
424 }
425}
426
427pub(crate) async fn stream<O>(
430 mut event_source: EventSource,
431) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
432where
433 O: DeserializeOwned + std::marker::Send + 'static,
434{
435 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
436
437 tokio::spawn(async move {
438 while let Some(ev) = event_source.next().await {
439 match ev {
440 Err(e) => {
441 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
442 break;
444 }
445 }
446 Ok(event) => match event {
447 Event::Message(message) => {
448 if message.data == "[DONE]" {
449 break;
450 }
451
452 let response = match serde_json::from_str::<O>(&message.data) {
453 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
454 Ok(output) => Ok(output),
455 };
456
457 if let Err(_e) = tx.send(response) {
458 break;
460 }
461 }
462 Event::Open => continue,
463 },
464 }
465 }
466
467 event_source.close();
468 });
469
470 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
471}
472
473pub(crate) async fn stream_mapped_raw_events<O>(
474 mut event_source: EventSource,
475 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
476) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
477where
478 O: DeserializeOwned + std::marker::Send + 'static,
479{
480 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
481
482 tokio::spawn(async move {
483 while let Some(ev) = event_source.next().await {
484 match ev {
485 Err(e) => {
486 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
487 break;
489 }
490 }
491 Ok(event) => match event {
492 Event::Message(message) => {
493 let mut done = false;
494
495 if message.data == "[DONE]" {
496 done = true;
497 }
498
499 let response = event_mapper(message);
500
501 if let Err(_e) = tx.send(response) {
502 break;
504 }
505
506 if done {
507 break;
508 }
509 }
510 Event::Open => continue,
511 },
512 }
513 }
514
515 event_source.close();
516 });
517
518 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
519}