1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::{Stream, stream::StreamExt};
7use reqwest::Response;
8use reqwest::header::HeaderMap;
9use reqwest::multipart::Form;
10#[cfg(not(target_family = "wasm"))]
11use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14
15#[cfg(feature = "assistant")]
16use crate::Assistants;
17#[cfg(feature = "audio")]
18use crate::Audio;
19#[cfg(feature = "batch")]
20use crate::Batches;
21#[cfg(feature = "chat-completion")]
22use crate::Chat;
23#[cfg(feature = "completions")]
24use crate::Completions;
25#[cfg(feature = "container")]
26use crate::Containers;
27#[cfg(feature = "responses")]
28use crate::Conversations;
29#[cfg(feature = "embedding")]
30use crate::Embeddings;
31#[cfg(feature = "evals")]
32use crate::Evals;
33#[cfg(feature = "finetuning")]
34use crate::FineTuning;
35#[cfg(feature = "model")]
36use crate::Models;
37#[cfg(feature = "realtime")]
38use crate::Realtime;
39use crate::RequestOptions;
40#[cfg(feature = "responses")]
41use crate::Responses;
42#[cfg(feature = "assistant")]
43use crate::Threads;
44#[cfg(feature = "upload")]
45use crate::Uploads;
46#[cfg(feature = "vectorstore")]
47use crate::VectorStores;
48#[cfg(feature = "video")]
49use crate::Videos;
50#[cfg(feature = "administration")]
51use crate::admin::Admin;
52#[cfg(feature = "chatkit")]
53use crate::chatkit::Chatkit;
54use crate::config::{Config, OpenAIConfig};
55#[cfg(not(target_family = "wasm"))]
56use crate::error::StreamError;
57use crate::error::{ApiError, OpenAIError, WrappedError, map_deserialization_error};
58#[cfg(feature = "file")]
59use crate::file::Files;
60#[cfg(feature = "image")]
61use crate::image::Images;
62#[cfg(feature = "moderation")]
63use crate::moderation::Moderations;
64use crate::traits::AsyncTryFrom;
65
66#[derive(Debug, Clone, Default)]
67pub struct Client<C: Config> {
70 http_client: reqwest::Client,
71 config: C,
72 #[cfg(not(target_family = "wasm"))]
73 backoff: backoff::ExponentialBackoff,
74}
75
76impl Client<OpenAIConfig> {
77 pub fn new() -> Self {
79 Self::default()
80 }
81}
82
83impl<C: Config> Client<C> {
84 #[cfg(not(target_family = "wasm"))]
86 pub fn build(
87 http_client: reqwest::Client,
88 config: C,
89 backoff: backoff::ExponentialBackoff,
90 ) -> Self {
91 Self {
92 http_client,
93 config,
94 backoff,
95 }
96 }
97
98 #[cfg(target_family = "wasm")]
100 pub fn build(http_client: reqwest::Client, config: C) -> Self {
101 Self {
102 http_client,
103 config,
104 }
105 }
106
107 pub fn with_config(config: C) -> Self {
109 Self {
110 http_client: reqwest::Client::new(),
111 config,
112 #[cfg(not(target_family = "wasm"))]
113 backoff: Default::default(),
114 }
115 }
116
117 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
121 self.http_client = http_client;
122 self
123 }
124
125 #[cfg(not(target_family = "wasm"))]
127 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
128 self.backoff = backoff;
129 self
130 }
131
132 #[cfg(feature = "model")]
136 pub fn models(&self) -> Models<'_, C> {
137 Models::new(self)
138 }
139
140 #[cfg(feature = "completions")]
142 pub fn completions(&self) -> Completions<'_, C> {
143 Completions::new(self)
144 }
145
146 #[cfg(feature = "chat-completion")]
148 pub fn chat(&self) -> Chat<'_, C> {
149 Chat::new(self)
150 }
151
152 #[cfg(feature = "image")]
154 pub fn images(&self) -> Images<'_, C> {
155 Images::new(self)
156 }
157
158 #[cfg(feature = "moderation")]
160 pub fn moderations(&self) -> Moderations<'_, C> {
161 Moderations::new(self)
162 }
163
164 #[cfg(feature = "file")]
166 pub fn files(&self) -> Files<'_, C> {
167 Files::new(self)
168 }
169
170 #[cfg(feature = "upload")]
172 pub fn uploads(&self) -> Uploads<'_, C> {
173 Uploads::new(self)
174 }
175
176 #[cfg(feature = "finetuning")]
178 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
179 FineTuning::new(self)
180 }
181
182 #[cfg(feature = "embedding")]
184 pub fn embeddings(&self) -> Embeddings<'_, C> {
185 Embeddings::new(self)
186 }
187
188 #[cfg(feature = "audio")]
190 pub fn audio(&self) -> Audio<'_, C> {
191 Audio::new(self)
192 }
193
194 #[cfg(feature = "video")]
196 pub fn videos(&self) -> Videos<'_, C> {
197 Videos::new(self)
198 }
199
200 #[cfg(feature = "assistant")]
202 pub fn assistants(&self) -> Assistants<'_, C> {
203 Assistants::new(self)
204 }
205
206 #[cfg(feature = "assistant")]
208 pub fn threads(&self) -> Threads<'_, C> {
209 Threads::new(self)
210 }
211
212 #[cfg(feature = "vectorstore")]
214 pub fn vector_stores(&self) -> VectorStores<'_, C> {
215 VectorStores::new(self)
216 }
217
218 #[cfg(feature = "batch")]
220 pub fn batches(&self) -> Batches<'_, C> {
221 Batches::new(self)
222 }
223 #[cfg(feature = "administration")]
226 pub fn admin(&self) -> Admin<'_, C> {
227 Admin::new(self)
228 }
229
230 #[cfg(feature = "responses")]
232 pub fn responses(&self) -> Responses<'_, C> {
233 Responses::new(self)
234 }
235
236 #[cfg(feature = "responses")]
238 pub fn conversations(&self) -> Conversations<'_, C> {
239 Conversations::new(self)
240 }
241
242 #[cfg(feature = "container")]
244 pub fn containers(&self) -> Containers<'_, C> {
245 Containers::new(self)
246 }
247
248 #[cfg(feature = "evals")]
250 pub fn evals(&self) -> Evals<'_, C> {
251 Evals::new(self)
252 }
253
254 #[cfg(feature = "chatkit")]
255 pub fn chatkit(&self) -> Chatkit<'_, C> {
256 Chatkit::new(self)
257 }
258
259 #[cfg(feature = "realtime")]
261 pub fn realtime(&self) -> Realtime<'_, C> {
262 Realtime::new(self)
263 }
264
265 pub fn config(&self) -> &C {
266 &self.config
267 }
268
269 fn build_request_builder(
271 &self,
272 method: reqwest::Method,
273 path: &str,
274 request_options: &RequestOptions,
275 ) -> reqwest::RequestBuilder {
276 let mut request_builder = if let Some(path) = request_options.path() {
277 self.http_client
278 .request(method, self.config.url(path.as_str()))
279 } else {
280 self.http_client.request(method, self.config.url(path))
281 };
282
283 request_builder = request_builder
284 .query(&self.config.query())
285 .headers(self.config.headers());
286
287 if let Some(headers) = request_options.headers() {
288 request_builder = request_builder.headers(headers.clone());
289 }
290
291 if !request_options.query().is_empty() {
292 request_builder = request_builder.query(request_options.query());
293 }
294
295 request_builder
296 }
297
298 #[allow(unused)]
300 pub(crate) async fn get<O>(
301 &self,
302 path: &str,
303 request_options: &RequestOptions,
304 ) -> Result<O, OpenAIError>
305 where
306 O: DeserializeOwned,
307 {
308 let request_maker = || async {
309 Ok(self
310 .build_request_builder(reqwest::Method::GET, path, request_options)
311 .build()?)
312 };
313
314 self.execute(request_maker).await
315 }
316
317 #[allow(dead_code)]
319 pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
320 where
321 O: DeserializeOwned,
322 Q: Serialize + ?Sized,
323 {
324 let request_maker = || async {
325 Ok(self
326 .http_client
327 .get(self.config.url(path))
328 .query(&self.config.query())
329 .query(query)
330 .headers(self.config.headers())
331 .build()?)
332 };
333
334 self.execute(request_maker).await
335 }
336
337 #[allow(unused)]
339 pub(crate) async fn delete<O>(
340 &self,
341 path: &str,
342 request_options: &RequestOptions,
343 ) -> Result<O, OpenAIError>
344 where
345 O: DeserializeOwned,
346 {
347 let request_maker = || async {
348 Ok(self
349 .build_request_builder(reqwest::Method::DELETE, path, request_options)
350 .build()?)
351 };
352
353 self.execute(request_maker).await
354 }
355
356 #[allow(unused)]
358 pub(crate) async fn get_raw(
359 &self,
360 path: &str,
361 request_options: &RequestOptions,
362 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
363 let request_maker = || async {
364 Ok(self
365 .build_request_builder(reqwest::Method::GET, path, request_options)
366 .build()?)
367 };
368
369 self.execute_raw(request_maker).await
370 }
371
372 #[allow(unused)]
374 pub(crate) async fn post_raw<I>(
375 &self,
376 path: &str,
377 request: I,
378 request_options: &RequestOptions,
379 ) -> Result<(Bytes, HeaderMap), OpenAIError>
380 where
381 I: Serialize,
382 {
383 let request_maker = || async {
384 Ok(self
385 .build_request_builder(reqwest::Method::POST, path, request_options)
386 .json(&request)
387 .build()?)
388 };
389
390 self.execute_raw(request_maker).await
391 }
392
393 #[allow(unused)]
395 pub(crate) async fn post<I, O>(
396 &self,
397 path: &str,
398 request: I,
399 request_options: &RequestOptions,
400 ) -> Result<O, OpenAIError>
401 where
402 I: Serialize,
403 O: DeserializeOwned,
404 {
405 let request_maker = || async {
406 Ok(self
407 .build_request_builder(reqwest::Method::POST, path, request_options)
408 .json(&request)
409 .build()?)
410 };
411
412 self.execute(request_maker).await
413 }
414
415 #[allow(unused)]
417 pub(crate) async fn post_form_raw<F>(
418 &self,
419 path: &str,
420 form: F,
421 request_options: &RequestOptions,
422 ) -> Result<(Bytes, HeaderMap), OpenAIError>
423 where
424 Form: AsyncTryFrom<F, Error = OpenAIError>,
425 F: Clone,
426 {
427 let request_maker = || async {
428 Ok(self
429 .build_request_builder(reqwest::Method::POST, path, request_options)
430 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
431 .build()?)
432 };
433
434 self.execute_raw(request_maker).await
435 }
436
437 #[allow(unused)]
439 pub(crate) async fn post_form<O, F>(
440 &self,
441 path: &str,
442 form: F,
443 request_options: &RequestOptions,
444 ) -> Result<O, OpenAIError>
445 where
446 O: DeserializeOwned,
447 Form: AsyncTryFrom<F, Error = OpenAIError>,
448 F: Clone,
449 {
450 let request_maker = || async {
451 Ok(self
452 .build_request_builder(reqwest::Method::POST, path, request_options)
453 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
454 .build()?)
455 };
456
457 self.execute(request_maker).await
458 }
459
460 #[allow(unused)]
461 #[cfg(not(target_family = "wasm"))]
462 pub(crate) async fn post_form_stream<O, F>(
463 &self,
464 path: &str,
465 form: F,
466 request_options: &RequestOptions,
467 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
468 where
469 F: Clone,
470 Form: AsyncTryFrom<F, Error = OpenAIError>,
471 O: DeserializeOwned + std::marker::Send + 'static,
472 {
473 let request_builder = self
476 .build_request_builder(reqwest::Method::POST, path, request_options)
477 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
478
479 let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
480
481 if !response.status().is_success() {
483 return Err(read_response(response).await.unwrap_err());
484 }
485
486 let stream = response
488 .bytes_stream()
489 .map(|result| result.map_err(std::io::Error::other));
490 let event_stream = eventsource_stream::EventStream::new(stream);
491
492 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
494
495 tokio::spawn(async move {
496 use futures::StreamExt;
497 let mut event_stream = std::pin::pin!(event_stream);
498
499 while let Some(event_result) = event_stream.next().await {
500 match event_result {
501 Err(e) => {
502 if let Err(_e) = tx.send(Err(OpenAIError::Stream(
503 StreamError::EventStream(e.to_string()).to_string(),
504 ))) {
505 break;
506 }
507 }
508 Ok(event) => {
509 if event.data == "[DONE]" {
511 break;
512 }
513
514 let response = match serde_json::from_str::<O>(&event.data) {
515 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
516 Ok(output) => Ok(output),
517 };
518
519 if let Err(_e) = tx.send(response) {
520 break;
521 }
522 }
523 }
524 }
525 });
526
527 Ok(Box::pin(
528 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
529 ))
530 }
531
532 #[cfg(not(target_family = "wasm"))]
538 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
539 where
540 M: Fn() -> Fut,
541 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
542 {
543 let client = self.http_client.clone();
544
545 backoff::future::retry(self.backoff.clone(), || async {
546 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
547 let response = client
548 .execute(request)
549 .await
550 .map_err(OpenAIError::Reqwest)
551 .map_err(backoff::Error::Permanent)?;
552
553 let status = response.status();
554
555 match read_response(response).await {
556 Ok((bytes, headers)) => Ok((bytes, headers)),
557 Err(e) => {
558 match e {
559 OpenAIError::ApiError(api_error) => {
560 if status.is_server_error() {
561 Err(backoff::Error::Transient {
562 err: OpenAIError::ApiError(api_error),
563 retry_after: None,
564 })
565 } else if status.as_u16() == 429
566 && api_error.kind != Some("insufficient_quota".to_string())
567 {
568 tracing::warn!("Rate limited: {}", api_error.message);
570 Err(backoff::Error::Transient {
571 err: OpenAIError::ApiError(api_error),
572 retry_after: None,
573 })
574 } else {
575 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
576 }
577 }
578 _ => Err(backoff::Error::Permanent(e)),
579 }
580 }
581 }
582 })
583 .await
584 }
585
586 #[cfg(target_family = "wasm")]
588 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
589 where
590 M: Fn() -> Fut,
591 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
592 {
593 let request = request_maker().await?;
594 let response = self
595 .http_client
596 .execute(request)
597 .await
598 .map_err(OpenAIError::Reqwest)?;
599
600 read_response(response).await
601 }
602
603 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
609 where
610 O: DeserializeOwned,
611 M: Fn() -> Fut,
612 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
613 {
614 let (bytes, _headers) = self.execute_raw(request_maker).await?;
615
616 let response: O = serde_json::from_slice(bytes.as_ref())
617 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
618
619 Ok(response)
620 }
621
622 #[allow(unused)]
624 #[cfg(not(target_family = "wasm"))]
625 pub(crate) async fn post_stream<I, O>(
626 &self,
627 path: &str,
628 request: I,
629 request_options: &RequestOptions,
630 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
631 where
632 I: Serialize,
633 O: DeserializeOwned + std::marker::Send + 'static,
634 {
635 let request_builder = self
636 .build_request_builder(reqwest::Method::POST, path, request_options)
637 .json(&request);
638
639 let event_source = request_builder.eventsource().unwrap();
640
641 stream(event_source).await
642 }
643
644 #[allow(unused)]
645 #[cfg(not(target_family = "wasm"))]
646 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
647 &self,
648 path: &str,
649 request: I,
650 request_options: &RequestOptions,
651 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
652 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
653 where
654 I: Serialize,
655 O: DeserializeOwned + std::marker::Send + 'static,
656 {
657 let request_builder = self
658 .build_request_builder(reqwest::Method::POST, path, request_options)
659 .json(&request);
660
661 let event_source = request_builder.eventsource().unwrap();
662
663 stream_mapped_raw_events(event_source, event_mapper).await
664 }
665
666 #[allow(unused)]
668 #[cfg(not(target_family = "wasm"))]
669 pub(crate) async fn get_stream<O>(
670 &self,
671 path: &str,
672 request_options: &RequestOptions,
673 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
674 where
675 O: DeserializeOwned + std::marker::Send + 'static,
676 {
677 let request_builder =
678 self.build_request_builder(reqwest::Method::GET, path, request_options);
679
680 let event_source = request_builder.eventsource().unwrap();
681
682 stream(event_source).await
683 }
684}
685
686async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
687 let status = response.status();
688 let headers = response.headers().clone();
689 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
690
691 if status.is_server_error() {
692 let message: String = String::from_utf8_lossy(&bytes).into_owned();
695 tracing::warn!("Server error: {status} - {message}");
696 return Err(OpenAIError::ApiError(ApiError {
697 message,
698 kind: None,
699 param: None,
700 code: None,
701 }));
702 }
703
704 if !status.is_success() {
706 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
707 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
708
709 return Err(OpenAIError::ApiError(wrapped_error.error));
710 }
711
712 Ok((bytes, headers))
713}
714
715#[cfg(not(target_family = "wasm"))]
716async fn map_stream_error(value: EventSourceError) -> OpenAIError {
717 match value {
718 EventSourceError::InvalidStatusCode(status_code, response) => {
719 read_response(response).await.expect_err(&format!(
720 "Unreachable because read_response returns err when status_code {status_code} is invalid"
721 ))
722 }
723 _ => OpenAIError::Stream(StreamError::ReqwestEventSource(value).to_string()),
724 }
725}
726
727#[cfg(not(target_family = "wasm"))]
730pub(crate) async fn stream<O>(
731 mut event_source: EventSource,
732) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
733where
734 O: DeserializeOwned + std::marker::Send + 'static,
735{
736 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
737
738 tokio::spawn(async move {
739 while let Some(ev) = event_source.next().await {
740 match ev {
741 Err(e) => {
742 match &e {
745 EventSourceError::StreamEnded => {
746 break;
747 }
748 _ => {
749 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
750 break;
752 }
753 }
754 }
755 }
756 Ok(event) => match event {
757 Event::Message(message) => {
758 if message.data == "[DONE]" {
759 break;
760 }
761
762 let response = match serde_json::from_str::<O>(&message.data) {
763 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
764 Ok(output) => Ok(output),
765 };
766
767 if let Err(_e) = tx.send(response) {
768 break;
770 }
771 }
772 Event::Open => continue,
773 },
774 }
775 }
776
777 event_source.close();
778 });
779
780 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
781}
782
783#[cfg(not(target_family = "wasm"))]
784pub(crate) async fn stream_mapped_raw_events<O>(
785 mut event_source: EventSource,
786 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
787) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
788where
789 O: DeserializeOwned + std::marker::Send + 'static,
790{
791 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
792
793 tokio::spawn(async move {
794 while let Some(ev) = event_source.next().await {
795 match ev {
796 Err(e) => {
797 match &e {
800 EventSourceError::StreamEnded => {
801 break;
802 }
803 _ => {
804 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
805 break;
807 }
808 }
809 }
810 }
811 Ok(event) => match event {
812 Event::Message(message) => {
813 let mut done = false;
814
815 if message.data == "[DONE]" {
816 done = true;
817 }
818
819 let response = event_mapper(message);
820
821 if let Err(_e) = tx.send(response) {
822 break;
824 }
825
826 if done {
827 break;
828 }
829 }
830 Event::Open => continue,
831 },
832 }
833 }
834
835 event_source.close();
836 });
837
838 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
839}