1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::{stream::StreamExt, Stream};
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8#[cfg(not(target_family = "wasm"))]
9use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Serialize};
11
12#[cfg(not(target_family = "wasm"))]
13use crate::error::StreamError;
14use crate::{
15 config::{Config, OpenAIConfig},
16 error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
17 traits::AsyncTryFrom,
18 RequestOptions,
19};
20
21#[cfg(feature = "administration")]
22use crate::admin::Admin;
23#[cfg(feature = "chatkit")]
24use crate::chatkit::Chatkit;
25#[cfg(feature = "file")]
26use crate::file::Files;
27#[cfg(feature = "image")]
28use crate::image::Images;
29#[cfg(feature = "moderation")]
30use crate::moderation::Moderations;
31#[cfg(feature = "assistant")]
32use crate::Assistants;
33#[cfg(feature = "audio")]
34use crate::Audio;
35#[cfg(feature = "batch")]
36use crate::Batches;
37#[cfg(feature = "chat-completion")]
38use crate::Chat;
39#[cfg(feature = "completions")]
40use crate::Completions;
41#[cfg(feature = "container")]
42use crate::Containers;
43#[cfg(feature = "responses")]
44use crate::Conversations;
45#[cfg(feature = "embedding")]
46use crate::Embeddings;
47#[cfg(feature = "evals")]
48use crate::Evals;
49#[cfg(feature = "finetuning")]
50use crate::FineTuning;
51#[cfg(feature = "model")]
52use crate::Models;
53#[cfg(feature = "realtime")]
54use crate::Realtime;
55#[cfg(feature = "responses")]
56use crate::Responses;
57#[cfg(feature = "assistant")]
58use crate::Threads;
59#[cfg(feature = "upload")]
60use crate::Uploads;
61#[cfg(feature = "vectorstore")]
62use crate::VectorStores;
63#[cfg(feature = "video")]
64use crate::Videos;
65
66#[derive(Debug, Clone)]
67pub struct Client<C: Config> {
70 http_client: reqwest::Client,
71 config: C,
72 #[cfg(not(target_family = "wasm"))]
73 backoff: backoff::ExponentialBackoff,
74}
75
76impl<C: Config> Default for Client<C>
77where
78 C: Default,
79{
80 fn default() -> Self {
81 Self {
82 http_client: reqwest::Client::new(),
83 config: C::default(),
84 #[cfg(not(target_family = "wasm"))]
85 backoff: Default::default(),
86 }
87 }
88}
89
90impl Client<OpenAIConfig> {
91 pub fn new() -> Self {
93 Self::default()
94 }
95}
96
97impl<C: Config> Client<C> {
98 #[cfg(not(target_family = "wasm"))]
100 pub fn build(
101 http_client: reqwest::Client,
102 config: C,
103 backoff: backoff::ExponentialBackoff,
104 ) -> Self {
105 Self {
106 http_client,
107 config,
108 backoff,
109 }
110 }
111
112 #[cfg(target_family = "wasm")]
114 pub fn build(http_client: reqwest::Client, config: C) -> Self {
115 Self {
116 http_client,
117 config,
118 }
119 }
120
121 pub fn with_config(config: C) -> Self {
123 Self {
124 http_client: reqwest::Client::new(),
125 config,
126 #[cfg(not(target_family = "wasm"))]
127 backoff: Default::default(),
128 }
129 }
130
131 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
135 self.http_client = http_client;
136 self
137 }
138
139 #[cfg(not(target_family = "wasm"))]
141 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
142 self.backoff = backoff;
143 self
144 }
145
146 #[cfg(feature = "model")]
150 pub fn models(&self) -> Models<'_, C> {
151 Models::new(self)
152 }
153
154 #[cfg(feature = "completions")]
156 pub fn completions(&self) -> Completions<'_, C> {
157 Completions::new(self)
158 }
159
160 #[cfg(feature = "chat-completion")]
162 pub fn chat(&self) -> Chat<'_, C> {
163 Chat::new(self)
164 }
165
166 #[cfg(feature = "image")]
168 pub fn images(&self) -> Images<'_, C> {
169 Images::new(self)
170 }
171
172 #[cfg(feature = "moderation")]
174 pub fn moderations(&self) -> Moderations<'_, C> {
175 Moderations::new(self)
176 }
177
178 #[cfg(feature = "file")]
180 pub fn files(&self) -> Files<'_, C> {
181 Files::new(self)
182 }
183
184 #[cfg(feature = "upload")]
186 pub fn uploads(&self) -> Uploads<'_, C> {
187 Uploads::new(self)
188 }
189
190 #[cfg(feature = "finetuning")]
192 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
193 FineTuning::new(self)
194 }
195
196 #[cfg(feature = "embedding")]
198 pub fn embeddings(&self) -> Embeddings<'_, C> {
199 Embeddings::new(self)
200 }
201
202 #[cfg(feature = "audio")]
204 pub fn audio(&self) -> Audio<'_, C> {
205 Audio::new(self)
206 }
207
208 #[cfg(feature = "video")]
210 pub fn videos(&self) -> Videos<'_, C> {
211 Videos::new(self)
212 }
213
214 #[cfg(feature = "assistant")]
216 pub fn assistants(&self) -> Assistants<'_, C> {
217 Assistants::new(self)
218 }
219
220 #[cfg(feature = "assistant")]
222 pub fn threads(&self) -> Threads<'_, C> {
223 Threads::new(self)
224 }
225
226 #[cfg(feature = "vectorstore")]
228 pub fn vector_stores(&self) -> VectorStores<'_, C> {
229 VectorStores::new(self)
230 }
231
232 #[cfg(feature = "batch")]
234 pub fn batches(&self) -> Batches<'_, C> {
235 Batches::new(self)
236 }
237
238 #[cfg(feature = "administration")]
241 pub fn admin(&self) -> Admin<'_, C> {
242 Admin::new(self)
243 }
244
245 #[cfg(feature = "responses")]
247 pub fn responses(&self) -> Responses<'_, C> {
248 Responses::new(self)
249 }
250
251 #[cfg(feature = "responses")]
253 pub fn conversations(&self) -> Conversations<'_, C> {
254 Conversations::new(self)
255 }
256
257 #[cfg(feature = "container")]
259 pub fn containers(&self) -> Containers<'_, C> {
260 Containers::new(self)
261 }
262
263 #[cfg(feature = "evals")]
265 pub fn evals(&self) -> Evals<'_, C> {
266 Evals::new(self)
267 }
268
269 #[cfg(feature = "chatkit")]
270 pub fn chatkit(&self) -> Chatkit<'_, C> {
271 Chatkit::new(self)
272 }
273
274 #[cfg(feature = "realtime")]
276 pub fn realtime(&self) -> Realtime<'_, C> {
277 Realtime::new(self)
278 }
279
280 pub fn config(&self) -> &C {
281 &self.config
282 }
283
284 fn build_request_builder(
286 &self,
287 method: reqwest::Method,
288 path: &str,
289 request_options: &RequestOptions,
290 ) -> reqwest::RequestBuilder {
291 let mut request_builder = if let Some(path) = request_options.path() {
292 self.http_client
293 .request(method, self.config.url(path.as_str()))
294 } else {
295 self.http_client.request(method, self.config.url(path))
296 };
297
298 request_builder = request_builder
299 .query(&self.config.query())
300 .headers(self.config.headers());
301
302 if let Some(headers) = request_options.headers() {
303 request_builder = request_builder.headers(headers.clone());
304 }
305
306 if !request_options.query().is_empty() {
307 request_builder = request_builder.query(request_options.query());
308 }
309
310 request_builder
311 }
312
313 #[allow(unused)]
315 pub(crate) async fn get<O>(
316 &self,
317 path: &str,
318 request_options: &RequestOptions,
319 ) -> Result<O, OpenAIError>
320 where
321 O: DeserializeOwned,
322 {
323 let request_maker = || async {
324 Ok(self
325 .build_request_builder(reqwest::Method::GET, path, request_options)
326 .build()?)
327 };
328
329 self.execute(request_maker).await
330 }
331
332 #[allow(unused)]
334 pub(crate) async fn delete<O>(
335 &self,
336 path: &str,
337 request_options: &RequestOptions,
338 ) -> Result<O, OpenAIError>
339 where
340 O: DeserializeOwned,
341 {
342 let request_maker = || async {
343 Ok(self
344 .build_request_builder(reqwest::Method::DELETE, path, request_options)
345 .build()?)
346 };
347
348 self.execute(request_maker).await
349 }
350
351 #[allow(unused)]
353 pub(crate) async fn get_raw(
354 &self,
355 path: &str,
356 request_options: &RequestOptions,
357 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
358 let request_maker = || async {
359 Ok(self
360 .build_request_builder(reqwest::Method::GET, path, request_options)
361 .build()?)
362 };
363
364 self.execute_raw(request_maker).await
365 }
366
367 #[allow(unused)]
369 pub(crate) async fn post_raw<I>(
370 &self,
371 path: &str,
372 request: I,
373 request_options: &RequestOptions,
374 ) -> Result<(Bytes, HeaderMap), OpenAIError>
375 where
376 I: Serialize,
377 {
378 let request_maker = || async {
379 Ok(self
380 .build_request_builder(reqwest::Method::POST, path, request_options)
381 .json(&request)
382 .build()?)
383 };
384
385 self.execute_raw(request_maker).await
386 }
387
388 #[allow(unused)]
390 pub(crate) async fn post<I, O>(
391 &self,
392 path: &str,
393 request: I,
394 request_options: &RequestOptions,
395 ) -> Result<O, OpenAIError>
396 where
397 I: Serialize,
398 O: DeserializeOwned,
399 {
400 let request_maker = || async {
401 Ok(self
402 .build_request_builder(reqwest::Method::POST, path, request_options)
403 .json(&request)
404 .build()?)
405 };
406
407 self.execute(request_maker).await
408 }
409
410 #[allow(unused)]
412 pub(crate) async fn post_form_raw<F>(
413 &self,
414 path: &str,
415 form: F,
416 request_options: &RequestOptions,
417 ) -> Result<(Bytes, HeaderMap), OpenAIError>
418 where
419 Form: AsyncTryFrom<F, Error = OpenAIError>,
420 F: Clone,
421 {
422 let request_maker = || async {
423 Ok(self
424 .build_request_builder(reqwest::Method::POST, path, request_options)
425 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
426 .build()?)
427 };
428
429 self.execute_raw(request_maker).await
430 }
431
432 #[allow(unused)]
434 pub(crate) async fn post_form<O, F>(
435 &self,
436 path: &str,
437 form: F,
438 request_options: &RequestOptions,
439 ) -> Result<O, OpenAIError>
440 where
441 O: DeserializeOwned,
442 Form: AsyncTryFrom<F, Error = OpenAIError>,
443 F: Clone,
444 {
445 let request_maker = || async {
446 Ok(self
447 .build_request_builder(reqwest::Method::POST, path, request_options)
448 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
449 .build()?)
450 };
451
452 self.execute(request_maker).await
453 }
454
455 #[allow(unused)]
456 #[cfg(not(target_family = "wasm"))]
457 pub(crate) async fn post_form_stream<O, F>(
458 &self,
459 path: &str,
460 form: F,
461 request_options: &RequestOptions,
462 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
463 where
464 F: Clone,
465 Form: AsyncTryFrom<F, Error = OpenAIError>,
466 O: DeserializeOwned + std::marker::Send + 'static,
467 {
468 let request_builder = self
471 .build_request_builder(reqwest::Method::POST, path, request_options)
472 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
473
474 let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
475
476 if !response.status().is_success() {
478 return Err(read_response(response).await.unwrap_err());
479 }
480
481 let stream = response
483 .bytes_stream()
484 .map(|result| result.map_err(std::io::Error::other));
485 let event_stream = eventsource_stream::EventStream::new(stream);
486
487 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
489
490 tokio::spawn(async move {
491 use futures::StreamExt;
492 let mut event_stream = std::pin::pin!(event_stream);
493
494 while let Some(event_result) = event_stream.next().await {
495 match event_result {
496 Err(e) => {
497 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
498 StreamError::EventStream(e.to_string()),
499 )))) {
500 break;
501 }
502 }
503 Ok(event) => {
504 if event.data == "[DONE]" {
506 break;
507 }
508
509 let response = match serde_json::from_str::<O>(&event.data) {
510 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
511 Ok(output) => Ok(output),
512 };
513
514 if let Err(_e) = tx.send(response) {
515 break;
516 }
517 }
518 }
519 }
520 });
521
522 Ok(Box::pin(
523 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
524 ))
525 }
526
527 #[cfg(not(target_family = "wasm"))]
533 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
534 where
535 M: Fn() -> Fut,
536 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
537 {
538 let client = self.http_client.clone();
539
540 backoff::future::retry(self.backoff.clone(), || async {
541 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
542 let response = client
543 .execute(request)
544 .await
545 .map_err(OpenAIError::Reqwest)
546 .map_err(backoff::Error::Permanent)?;
547
548 let status = response.status();
549
550 match read_response(response).await {
551 Ok((bytes, headers)) => Ok((bytes, headers)),
552 Err(e) => {
553 match e {
554 OpenAIError::ApiError(api_error) => {
555 if status.is_server_error() {
556 Err(backoff::Error::Transient {
557 err: OpenAIError::ApiError(api_error),
558 retry_after: None,
559 })
560 } else if status.as_u16() == 429
561 && api_error.r#type != Some("insufficient_quota".to_string())
562 {
563 tracing::warn!("Rate limited: {}", api_error.message);
565 Err(backoff::Error::Transient {
566 err: OpenAIError::ApiError(api_error),
567 retry_after: None,
568 })
569 } else {
570 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
571 }
572 }
573 _ => Err(backoff::Error::Permanent(e)),
574 }
575 }
576 }
577 })
578 .await
579 }
580
581 #[cfg(target_family = "wasm")]
583 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
584 where
585 M: Fn() -> Fut,
586 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
587 {
588 let request = request_maker().await?;
589 let response = self
590 .http_client
591 .execute(request)
592 .await
593 .map_err(OpenAIError::Reqwest)?;
594
595 read_response(response).await
596 }
597
598 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
604 where
605 O: DeserializeOwned,
606 M: Fn() -> Fut,
607 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
608 {
609 let (bytes, _headers) = self.execute_raw(request_maker).await?;
610
611 let response: O = serde_json::from_slice(bytes.as_ref())
612 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
613
614 Ok(response)
615 }
616
617 #[allow(unused)]
619 #[cfg(not(target_family = "wasm"))]
620 pub(crate) async fn post_stream<I, O>(
621 &self,
622 path: &str,
623 request: I,
624 request_options: &RequestOptions,
625 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
626 where
627 I: Serialize,
628 O: DeserializeOwned + std::marker::Send + 'static,
629 {
630 let request_builder = self
631 .build_request_builder(reqwest::Method::POST, path, request_options)
632 .json(&request);
633
634 let event_source = request_builder.eventsource().unwrap();
635
636 stream(event_source).await
637 }
638
639 #[allow(unused)]
640 #[cfg(not(target_family = "wasm"))]
641 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
642 &self,
643 path: &str,
644 request: I,
645 request_options: &RequestOptions,
646 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
647 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
648 where
649 I: Serialize,
650 O: DeserializeOwned + std::marker::Send + 'static,
651 {
652 let request_builder = self
653 .build_request_builder(reqwest::Method::POST, path, request_options)
654 .json(&request);
655
656 let event_source = request_builder.eventsource().unwrap();
657
658 stream_mapped_raw_events(event_source, event_mapper).await
659 }
660
661 #[allow(unused)]
663 #[cfg(not(target_family = "wasm"))]
664 pub(crate) async fn get_stream<O>(
665 &self,
666 path: &str,
667 request_options: &RequestOptions,
668 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
669 where
670 O: DeserializeOwned + std::marker::Send + 'static,
671 {
672 let request_builder =
673 self.build_request_builder(reqwest::Method::GET, path, request_options);
674
675 let event_source = request_builder.eventsource().unwrap();
676
677 stream(event_source).await
678 }
679}
680
681async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
682 let status = response.status();
683 let headers = response.headers().clone();
684 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
685
686 if status.is_server_error() {
687 let message: String = String::from_utf8_lossy(&bytes).into_owned();
689 tracing::warn!("Server error: {status} - {message}");
690 return Err(OpenAIError::ApiError(ApiError {
691 message,
692 r#type: None,
693 param: None,
694 code: None,
695 }));
696 }
697
698 if !status.is_success() {
700 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
701 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
702
703 return Err(OpenAIError::ApiError(wrapped_error.error));
704 }
705
706 Ok((bytes, headers))
707}
708
709#[cfg(not(target_family = "wasm"))]
710async fn map_stream_error(value: EventSourceError) -> OpenAIError {
711 match value {
712 EventSourceError::InvalidStatusCode(status_code, response) => {
713 read_response(response).await.expect_err(&format!(
714 "Unreachable because read_response returns err when status_code {status_code} is invalid"
715 ))
716 }
717 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
718 }
719}
720
721#[cfg(not(target_family = "wasm"))]
724pub(crate) async fn stream<O>(
725 mut event_source: EventSource,
726) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
727where
728 O: DeserializeOwned + std::marker::Send + 'static,
729{
730 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
731
732 tokio::spawn(async move {
733 while let Some(ev) = event_source.next().await {
734 match ev {
735 Err(e) => {
736 match &e {
739 EventSourceError::StreamEnded => {
740 break;
741 }
742 _ => {
743 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
744 break;
746 }
747 }
748 }
749 }
750 Ok(event) => match event {
751 Event::Message(message) => {
752 if message.data == "[DONE]" {
753 break;
754 }
755
756 let response = match serde_json::from_str::<O>(&message.data) {
757 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
758 Ok(output) => Ok(output),
759 };
760
761 if let Err(_e) = tx.send(response) {
762 break;
764 }
765 }
766 Event::Open => continue,
767 },
768 }
769 }
770
771 event_source.close();
772 });
773
774 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
775}
776
777#[cfg(not(target_family = "wasm"))]
778pub(crate) async fn stream_mapped_raw_events<O>(
779 mut event_source: EventSource,
780 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
781) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
782where
783 O: DeserializeOwned + std::marker::Send + 'static,
784{
785 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
786
787 tokio::spawn(async move {
788 while let Some(ev) = event_source.next().await {
789 match ev {
790 Err(e) => {
791 match &e {
794 EventSourceError::StreamEnded => {
795 break;
796 }
797 _ => {
798 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
799 break;
801 }
802 }
803 }
804 }
805 Ok(event) => match event {
806 Event::Message(message) => {
807 let mut done = false;
808
809 if message.data == "[DONE]" {
810 done = true;
811 }
812
813 let response = event_mapper(message);
814
815 if let Err(_e) = tx.send(response) {
816 break;
818 }
819
820 if done {
821 break;
822 }
823 }
824 Event::Open => continue,
825 },
826 }
827 }
828
829 event_source.close();
830 });
831
832 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
833}