1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12 file::Files,
13 image::Images,
14 moderation::Moderations,
15 traits::AsyncTryFrom,
16 Assistants, Audio, AuditLogs, Batches, Chat, Completions, Containers, Conversations,
17 Embeddings, Evals, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users,
18 VectorStores, Videos,
19};
20
21#[derive(Debug, Clone, Default)]
22pub struct Client<C: Config> {
25 http_client: reqwest::Client,
26 config: C,
27 backoff: backoff::ExponentialBackoff,
28}
29
30impl Client<OpenAIConfig> {
31 pub fn new() -> Self {
33 Self::default()
34 }
35}
36
37impl<C: Config> Client<C> {
38 pub fn build(
40 http_client: reqwest::Client,
41 config: C,
42 backoff: backoff::ExponentialBackoff,
43 ) -> Self {
44 Self {
45 http_client,
46 config,
47 backoff,
48 }
49 }
50
51 pub fn with_config(config: C) -> Self {
53 Self {
54 http_client: reqwest::Client::new(),
55 config,
56 backoff: Default::default(),
57 }
58 }
59
60 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
64 self.http_client = http_client;
65 self
66 }
67
68 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
70 self.backoff = backoff;
71 self
72 }
73
74 pub fn models(&self) -> Models<'_, C> {
78 Models::new(self)
79 }
80
81 pub fn completions(&self) -> Completions<'_, C> {
83 Completions::new(self)
84 }
85
86 pub fn chat(&self) -> Chat<'_, C> {
88 Chat::new(self)
89 }
90
91 pub fn images(&self) -> Images<'_, C> {
93 Images::new(self)
94 }
95
96 pub fn moderations(&self) -> Moderations<'_, C> {
98 Moderations::new(self)
99 }
100
101 pub fn files(&self) -> Files<'_, C> {
103 Files::new(self)
104 }
105
106 pub fn uploads(&self) -> Uploads<'_, C> {
108 Uploads::new(self)
109 }
110
111 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
113 FineTuning::new(self)
114 }
115
116 pub fn embeddings(&self) -> Embeddings<'_, C> {
118 Embeddings::new(self)
119 }
120
121 pub fn audio(&self) -> Audio<'_, C> {
123 Audio::new(self)
124 }
125
126 pub fn videos(&self) -> Videos<'_, C> {
128 Videos::new(self)
129 }
130
131 pub fn assistants(&self) -> Assistants<'_, C> {
133 Assistants::new(self)
134 }
135
136 pub fn threads(&self) -> Threads<'_, C> {
138 Threads::new(self)
139 }
140
141 pub fn vector_stores(&self) -> VectorStores<'_, C> {
143 VectorStores::new(self)
144 }
145
146 pub fn batches(&self) -> Batches<'_, C> {
148 Batches::new(self)
149 }
150
151 pub fn audit_logs(&self) -> AuditLogs<'_, C> {
153 AuditLogs::new(self)
154 }
155
156 pub fn invites(&self) -> Invites<'_, C> {
158 Invites::new(self)
159 }
160
161 pub fn users(&self) -> Users<'_, C> {
163 Users::new(self)
164 }
165
166 pub fn projects(&self) -> Projects<'_, C> {
168 Projects::new(self)
169 }
170
171 pub fn responses(&self) -> Responses<'_, C> {
173 Responses::new(self)
174 }
175
176 pub fn conversations(&self) -> Conversations<'_, C> {
178 Conversations::new(self)
179 }
180
181 pub fn containers(&self) -> Containers<'_, C> {
183 Containers::new(self)
184 }
185
186 pub fn evals(&self) -> Evals<'_, C> {
188 Evals::new(self)
189 }
190
191 pub fn config(&self) -> &C {
192 &self.config
193 }
194
195 pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
197 where
198 O: DeserializeOwned,
199 {
200 let request_maker = || async {
201 Ok(self
202 .http_client
203 .get(self.config.url(path))
204 .query(&self.config.query())
205 .headers(self.config.headers())
206 .build()?)
207 };
208
209 self.execute(request_maker).await
210 }
211
212 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
214 where
215 O: DeserializeOwned,
216 Q: Serialize + ?Sized,
217 {
218 let request_maker = || async {
219 Ok(self
220 .http_client
221 .get(self.config.url(path))
222 .query(&self.config.query())
223 .query(query)
224 .headers(self.config.headers())
225 .build()?)
226 };
227
228 self.execute(request_maker).await
229 }
230
231 pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
233 where
234 O: DeserializeOwned,
235 {
236 let request_maker = || async {
237 Ok(self
238 .http_client
239 .delete(self.config.url(path))
240 .query(&self.config.query())
241 .headers(self.config.headers())
242 .build()?)
243 };
244
245 self.execute(request_maker).await
246 }
247
248 pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
250 let request_maker = || async {
251 Ok(self
252 .http_client
253 .get(self.config.url(path))
254 .query(&self.config.query())
255 .headers(self.config.headers())
256 .build()?)
257 };
258
259 self.execute_raw(request_maker).await
260 }
261
262 pub(crate) async fn get_raw_with_query<Q>(
263 &self,
264 path: &str,
265 query: &Q,
266 ) -> Result<Bytes, OpenAIError>
267 where
268 Q: Serialize + ?Sized,
269 {
270 let request_maker = || async {
271 Ok(self
272 .http_client
273 .get(self.config.url(path))
274 .query(&self.config.query())
275 .query(query)
276 .headers(self.config.headers())
277 .build()?)
278 };
279
280 self.execute_raw(request_maker).await
281 }
282
283 pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
285 where
286 I: Serialize,
287 {
288 let request_maker = || async {
289 Ok(self
290 .http_client
291 .post(self.config.url(path))
292 .query(&self.config.query())
293 .headers(self.config.headers())
294 .json(&request)
295 .build()?)
296 };
297
298 self.execute_raw(request_maker).await
299 }
300
301 pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
303 where
304 I: Serialize,
305 O: DeserializeOwned,
306 {
307 let request_maker = || async {
308 Ok(self
309 .http_client
310 .post(self.config.url(path))
311 .query(&self.config.query())
312 .headers(self.config.headers())
313 .json(&request)
314 .build()?)
315 };
316
317 self.execute(request_maker).await
318 }
319
320 pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
322 where
323 Form: AsyncTryFrom<F, Error = OpenAIError>,
324 F: Clone,
325 {
326 let request_maker = || async {
327 Ok(self
328 .http_client
329 .post(self.config.url(path))
330 .query(&self.config.query())
331 .headers(self.config.headers())
332 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
333 .build()?)
334 };
335
336 self.execute_raw(request_maker).await
337 }
338
339 pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
341 where
342 O: DeserializeOwned,
343 Form: AsyncTryFrom<F, Error = OpenAIError>,
344 F: Clone,
345 {
346 let request_maker = || async {
347 Ok(self
348 .http_client
349 .post(self.config.url(path))
350 .query(&self.config.query())
351 .headers(self.config.headers())
352 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
353 .build()?)
354 };
355
356 self.execute(request_maker).await
357 }
358
359 pub(crate) async fn post_form_stream<O, F>(
360 &self,
361 path: &str,
362 form: F,
363 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
364 where
365 F: Clone,
366 Form: AsyncTryFrom<F, Error = OpenAIError>,
367 O: DeserializeOwned + std::marker::Send + 'static,
368 {
369 let response = self
372 .http_client
373 .post(self.config.url(path))
374 .query(&self.config.query())
375 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
376 .headers(self.config.headers())
377 .send()
378 .await
379 .map_err(OpenAIError::Reqwest)?;
380
381 if !response.status().is_success() {
383 return Err(read_response(response).await.unwrap_err());
384 }
385
386 let stream = response
388 .bytes_stream()
389 .map(|result| result.map_err(std::io::Error::other));
390 let event_stream = eventsource_stream::EventStream::new(stream);
391
392 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
394
395 tokio::spawn(async move {
396 use futures::StreamExt;
397 let mut event_stream = std::pin::pin!(event_stream);
398
399 while let Some(event_result) = event_stream.next().await {
400 match event_result {
401 Err(e) => {
402 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
403 StreamError::EventStream(e.to_string()),
404 )))) {
405 break;
406 }
407 }
408 Ok(event) => {
409 if event.data == "[DONE]" {
411 break;
412 }
413
414 let response = match serde_json::from_str::<O>(&event.data) {
415 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
416 Ok(output) => Ok(output),
417 };
418
419 if let Err(_e) = tx.send(response) {
420 break;
421 }
422 }
423 }
424 }
425 });
426
427 Ok(Box::pin(
428 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
429 ))
430 }
431
432 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
438 where
439 M: Fn() -> Fut,
440 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
441 {
442 let client = self.http_client.clone();
443
444 backoff::future::retry(self.backoff.clone(), || async {
445 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
446 let response = client
447 .execute(request)
448 .await
449 .map_err(OpenAIError::Reqwest)
450 .map_err(backoff::Error::Permanent)?;
451
452 let status = response.status();
453
454 match read_response(response).await {
455 Ok(bytes) => Ok(bytes),
456 Err(e) => {
457 match e {
458 OpenAIError::ApiError(api_error) => {
459 if status.is_server_error() {
460 Err(backoff::Error::Transient {
461 err: OpenAIError::ApiError(api_error),
462 retry_after: None,
463 })
464 } else if status.as_u16() == 429
465 && api_error.r#type != Some("insufficient_quota".to_string())
466 {
467 tracing::warn!("Rate limited: {}", api_error.message);
469 Err(backoff::Error::Transient {
470 err: OpenAIError::ApiError(api_error),
471 retry_after: None,
472 })
473 } else {
474 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
475 }
476 }
477 _ => Err(backoff::Error::Permanent(e)),
478 }
479 }
480 }
481 })
482 .await
483 }
484
485 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
491 where
492 O: DeserializeOwned,
493 M: Fn() -> Fut,
494 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
495 {
496 let bytes = self.execute_raw(request_maker).await?;
497
498 let response: O = serde_json::from_slice(bytes.as_ref())
499 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
500
501 Ok(response)
502 }
503
504 pub(crate) async fn post_stream<I, O>(
506 &self,
507 path: &str,
508 request: I,
509 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
510 where
511 I: Serialize,
512 O: DeserializeOwned + std::marker::Send + 'static,
513 {
514 let event_source = self
515 .http_client
516 .post(self.config.url(path))
517 .query(&self.config.query())
518 .headers(self.config.headers())
519 .json(&request)
520 .eventsource()
521 .unwrap();
522
523 stream(event_source).await
524 }
525
526 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
527 &self,
528 path: &str,
529 request: I,
530 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
531 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
532 where
533 I: Serialize,
534 O: DeserializeOwned + std::marker::Send + 'static,
535 {
536 let event_source = self
537 .http_client
538 .post(self.config.url(path))
539 .query(&self.config.query())
540 .headers(self.config.headers())
541 .json(&request)
542 .eventsource()
543 .unwrap();
544
545 stream_mapped_raw_events(event_source, event_mapper).await
546 }
547
548 pub(crate) async fn _get_stream<Q, O>(
550 &self,
551 path: &str,
552 query: &Q,
553 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
554 where
555 Q: Serialize + ?Sized,
556 O: DeserializeOwned + std::marker::Send + 'static,
557 {
558 let event_source = self
559 .http_client
560 .get(self.config.url(path))
561 .query(query)
562 .query(&self.config.query())
563 .headers(self.config.headers())
564 .eventsource()
565 .unwrap();
566
567 stream(event_source).await
568 }
569}
570
571async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
572 let status = response.status();
573 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
574
575 if status.is_server_error() {
576 let message: String = String::from_utf8_lossy(&bytes).into_owned();
578 tracing::warn!("Server error: {status} - {message}");
579 return Err(OpenAIError::ApiError(ApiError {
580 message,
581 r#type: None,
582 param: None,
583 code: None,
584 }));
585 }
586
587 if !status.is_success() {
589 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
590 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
591
592 return Err(OpenAIError::ApiError(wrapped_error.error));
593 }
594
595 Ok(bytes)
596}
597
598async fn map_stream_error(value: EventSourceError) -> OpenAIError {
599 match value {
600 EventSourceError::InvalidStatusCode(status_code, response) => {
601 read_response(response).await.expect_err(&format!(
602 "Unreachable because read_response returns err when status_code {status_code} is invalid"
603 ))
604 }
605 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
606 }
607}
608
609pub(crate) async fn stream<O>(
612 mut event_source: EventSource,
613) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
614where
615 O: DeserializeOwned + std::marker::Send + 'static,
616{
617 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
618
619 tokio::spawn(async move {
620 while let Some(ev) = event_source.next().await {
621 match ev {
622 Err(e) => {
623 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
624 break;
626 }
627 }
628 Ok(event) => match event {
629 Event::Message(message) => {
630 if message.data == "[DONE]" {
631 break;
632 }
633
634 let response = match serde_json::from_str::<O>(&message.data) {
635 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
636 Ok(output) => Ok(output),
637 };
638
639 if let Err(_e) = tx.send(response) {
640 break;
642 }
643 }
644 Event::Open => continue,
645 },
646 }
647 }
648
649 event_source.close();
650 });
651
652 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
653}
654
655pub(crate) async fn stream_mapped_raw_events<O>(
656 mut event_source: EventSource,
657 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
658) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
659where
660 O: DeserializeOwned + std::marker::Send + 'static,
661{
662 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
663
664 tokio::spawn(async move {
665 while let Some(ev) = event_source.next().await {
666 match ev {
667 Err(e) => {
668 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
669 break;
671 }
672 }
673 Ok(event) => match event {
674 Event::Message(message) => {
675 let mut done = false;
676
677 if message.data == "[DONE]" {
678 done = true;
679 }
680
681 let response = event_mapper(message);
682
683 if let Err(_e) = tx.send(response) {
684 break;
686 }
687
688 if done {
689 break;
690 }
691 }
692 Event::Open => continue,
693 },
694 }
695 }
696
697 event_source.close();
698 });
699
700 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
701}