1use std::pin::Pin;
2
3use bytes::Bytes;
4use futures::{stream::StreamExt, Stream};
5use reqwest::{header::HeaderMap, multipart::Form, Response};
6use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::{
10 config::{Config, OpenAIConfig},
11 error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
12 traits::AsyncTryFrom,
13 RequestOptions,
14};
15
16#[cfg(feature = "administration")]
17use crate::admin::Admin;
18#[cfg(feature = "chatkit")]
19use crate::chatkit::Chatkit;
20#[cfg(feature = "file")]
21use crate::file::Files;
22#[cfg(feature = "image")]
23use crate::image::Images;
24#[cfg(feature = "moderation")]
25use crate::moderation::Moderations;
26#[cfg(feature = "assistant")]
27use crate::Assistants;
28#[cfg(feature = "audio")]
29use crate::Audio;
30#[cfg(feature = "batch")]
31use crate::Batches;
32#[cfg(feature = "chat-completion")]
33use crate::Chat;
34#[cfg(feature = "completions")]
35use crate::Completions;
36#[cfg(feature = "container")]
37use crate::Containers;
38#[cfg(feature = "responses")]
39use crate::Conversations;
40#[cfg(feature = "embedding")]
41use crate::Embeddings;
42#[cfg(feature = "evals")]
43use crate::Evals;
44#[cfg(feature = "finetuning")]
45use crate::FineTuning;
46#[cfg(feature = "model")]
47use crate::Models;
48#[cfg(feature = "realtime")]
49use crate::Realtime;
50#[cfg(feature = "responses")]
51use crate::Responses;
52#[cfg(feature = "assistant")]
53use crate::Threads;
54#[cfg(feature = "upload")]
55use crate::Uploads;
56#[cfg(feature = "vectorstore")]
57use crate::VectorStores;
58#[cfg(feature = "video")]
59use crate::Videos;
60
61#[derive(Debug, Clone, Default)]
62pub struct Client<C: Config> {
65 http_client: reqwest::Client,
66 config: C,
67 backoff: backoff::ExponentialBackoff,
68}
69
70impl Client<OpenAIConfig> {
71 pub fn new() -> Self {
73 Self::default()
74 }
75}
76
77impl<C: Config> Client<C> {
78 pub fn build(
80 http_client: reqwest::Client,
81 config: C,
82 backoff: backoff::ExponentialBackoff,
83 ) -> Self {
84 Self {
85 http_client,
86 config,
87 backoff,
88 }
89 }
90
91 pub fn with_config(config: C) -> Self {
93 Self {
94 http_client: reqwest::Client::new(),
95 config,
96 backoff: Default::default(),
97 }
98 }
99
100 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
104 self.http_client = http_client;
105 self
106 }
107
108 pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
110 self.backoff = backoff;
111 self
112 }
113
114 #[cfg(feature = "model")]
118 pub fn models(&self) -> Models<'_, C> {
119 Models::new(self)
120 }
121
122 #[cfg(feature = "completions")]
124 pub fn completions(&self) -> Completions<'_, C> {
125 Completions::new(self)
126 }
127
128 #[cfg(feature = "chat-completion")]
130 pub fn chat(&self) -> Chat<'_, C> {
131 Chat::new(self)
132 }
133
134 #[cfg(feature = "image")]
136 pub fn images(&self) -> Images<'_, C> {
137 Images::new(self)
138 }
139
140 #[cfg(feature = "moderation")]
142 pub fn moderations(&self) -> Moderations<'_, C> {
143 Moderations::new(self)
144 }
145
146 #[cfg(feature = "file")]
148 pub fn files(&self) -> Files<'_, C> {
149 Files::new(self)
150 }
151
152 #[cfg(feature = "upload")]
154 pub fn uploads(&self) -> Uploads<'_, C> {
155 Uploads::new(self)
156 }
157
158 #[cfg(feature = "finetuning")]
160 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
161 FineTuning::new(self)
162 }
163
164 #[cfg(feature = "embedding")]
166 pub fn embeddings(&self) -> Embeddings<'_, C> {
167 Embeddings::new(self)
168 }
169
170 #[cfg(feature = "audio")]
172 pub fn audio(&self) -> Audio<'_, C> {
173 Audio::new(self)
174 }
175
176 #[cfg(feature = "video")]
178 pub fn videos(&self) -> Videos<'_, C> {
179 Videos::new(self)
180 }
181
182 #[cfg(feature = "assistant")]
184 pub fn assistants(&self) -> Assistants<'_, C> {
185 Assistants::new(self)
186 }
187
188 #[cfg(feature = "assistant")]
190 pub fn threads(&self) -> Threads<'_, C> {
191 Threads::new(self)
192 }
193
194 #[cfg(feature = "vectorstore")]
196 pub fn vector_stores(&self) -> VectorStores<'_, C> {
197 VectorStores::new(self)
198 }
199
200 #[cfg(feature = "batch")]
202 pub fn batches(&self) -> Batches<'_, C> {
203 Batches::new(self)
204 }
205
206 #[cfg(feature = "administration")]
209 pub fn admin(&self) -> Admin<'_, C> {
210 Admin::new(self)
211 }
212
213 #[cfg(feature = "responses")]
215 pub fn responses(&self) -> Responses<'_, C> {
216 Responses::new(self)
217 }
218
219 #[cfg(feature = "responses")]
221 pub fn conversations(&self) -> Conversations<'_, C> {
222 Conversations::new(self)
223 }
224
225 #[cfg(feature = "container")]
227 pub fn containers(&self) -> Containers<'_, C> {
228 Containers::new(self)
229 }
230
231 #[cfg(feature = "evals")]
233 pub fn evals(&self) -> Evals<'_, C> {
234 Evals::new(self)
235 }
236
237 #[cfg(feature = "chatkit")]
238 pub fn chatkit(&self) -> Chatkit<'_, C> {
239 Chatkit::new(self)
240 }
241
242 #[cfg(feature = "realtime")]
244 pub fn realtime(&self) -> Realtime<'_, C> {
245 Realtime::new(self)
246 }
247
248 pub fn config(&self) -> &C {
249 &self.config
250 }
251
252 fn build_request_builder(
254 &self,
255 method: reqwest::Method,
256 path: &str,
257 request_options: &RequestOptions,
258 ) -> reqwest::RequestBuilder {
259 let mut request_builder = if let Some(path) = request_options.path() {
260 self.http_client
261 .request(method, self.config.url(path.as_str()))
262 } else {
263 self.http_client.request(method, self.config.url(path))
264 };
265
266 request_builder = request_builder
267 .query(&self.config.query())
268 .headers(self.config.headers());
269
270 if let Some(headers) = request_options.headers() {
271 request_builder = request_builder.headers(headers.clone());
272 }
273
274 if !request_options.query().is_empty() {
275 request_builder = request_builder.query(request_options.query());
276 }
277
278 request_builder
279 }
280
281 #[allow(unused)]
283 pub(crate) async fn get<O>(
284 &self,
285 path: &str,
286 request_options: &RequestOptions,
287 ) -> Result<O, OpenAIError>
288 where
289 O: DeserializeOwned,
290 {
291 let request_maker = || async {
292 Ok(self
293 .build_request_builder(reqwest::Method::GET, path, request_options)
294 .build()?)
295 };
296
297 self.execute(request_maker).await
298 }
299
300 #[allow(unused)]
302 pub(crate) async fn delete<O>(
303 &self,
304 path: &str,
305 request_options: &RequestOptions,
306 ) -> Result<O, OpenAIError>
307 where
308 O: DeserializeOwned,
309 {
310 let request_maker = || async {
311 Ok(self
312 .build_request_builder(reqwest::Method::DELETE, path, request_options)
313 .build()?)
314 };
315
316 self.execute(request_maker).await
317 }
318
319 #[allow(unused)]
321 pub(crate) async fn get_raw(
322 &self,
323 path: &str,
324 request_options: &RequestOptions,
325 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
326 let request_maker = || async {
327 Ok(self
328 .build_request_builder(reqwest::Method::GET, path, request_options)
329 .build()?)
330 };
331
332 self.execute_raw(request_maker).await
333 }
334
335 #[allow(unused)]
337 pub(crate) async fn post_raw<I>(
338 &self,
339 path: &str,
340 request: I,
341 request_options: &RequestOptions,
342 ) -> Result<(Bytes, HeaderMap), OpenAIError>
343 where
344 I: Serialize,
345 {
346 let request_maker = || async {
347 Ok(self
348 .build_request_builder(reqwest::Method::POST, path, request_options)
349 .json(&request)
350 .build()?)
351 };
352
353 self.execute_raw(request_maker).await
354 }
355
356 #[allow(unused)]
358 pub(crate) async fn post<I, O>(
359 &self,
360 path: &str,
361 request: I,
362 request_options: &RequestOptions,
363 ) -> Result<O, OpenAIError>
364 where
365 I: Serialize,
366 O: DeserializeOwned,
367 {
368 let request_maker = || async {
369 Ok(self
370 .build_request_builder(reqwest::Method::POST, path, request_options)
371 .json(&request)
372 .build()?)
373 };
374
375 self.execute(request_maker).await
376 }
377
378 #[allow(unused)]
380 pub(crate) async fn post_form_raw<F>(
381 &self,
382 path: &str,
383 form: F,
384 request_options: &RequestOptions,
385 ) -> Result<(Bytes, HeaderMap), OpenAIError>
386 where
387 Form: AsyncTryFrom<F, Error = OpenAIError>,
388 F: Clone,
389 {
390 let request_maker = || async {
391 Ok(self
392 .build_request_builder(reqwest::Method::POST, path, request_options)
393 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
394 .build()?)
395 };
396
397 self.execute_raw(request_maker).await
398 }
399
400 #[allow(unused)]
402 pub(crate) async fn post_form<O, F>(
403 &self,
404 path: &str,
405 form: F,
406 request_options: &RequestOptions,
407 ) -> Result<O, OpenAIError>
408 where
409 O: DeserializeOwned,
410 Form: AsyncTryFrom<F, Error = OpenAIError>,
411 F: Clone,
412 {
413 let request_maker = || async {
414 Ok(self
415 .build_request_builder(reqwest::Method::POST, path, request_options)
416 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
417 .build()?)
418 };
419
420 self.execute(request_maker).await
421 }
422
423 #[allow(unused)]
424 pub(crate) async fn post_form_stream<O, F>(
425 &self,
426 path: &str,
427 form: F,
428 request_options: &RequestOptions,
429 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
430 where
431 F: Clone,
432 Form: AsyncTryFrom<F, Error = OpenAIError>,
433 O: DeserializeOwned + std::marker::Send + 'static,
434 {
435 let request_builder = self
438 .build_request_builder(reqwest::Method::POST, path, request_options)
439 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
440
441 let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
442
443 if !response.status().is_success() {
445 return Err(read_response(response).await.unwrap_err());
446 }
447
448 let stream = response
450 .bytes_stream()
451 .map(|result| result.map_err(std::io::Error::other));
452 let event_stream = eventsource_stream::EventStream::new(stream);
453
454 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
456
457 tokio::spawn(async move {
458 use futures::StreamExt;
459 let mut event_stream = std::pin::pin!(event_stream);
460
461 while let Some(event_result) = event_stream.next().await {
462 match event_result {
463 Err(e) => {
464 if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
465 StreamError::EventStream(e.to_string()),
466 )))) {
467 break;
468 }
469 }
470 Ok(event) => {
471 if event.data == "[DONE]" {
473 break;
474 }
475
476 let response = match serde_json::from_str::<O>(&event.data) {
477 Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
478 Ok(output) => Ok(output),
479 };
480
481 if let Err(_e) = tx.send(response) {
482 break;
483 }
484 }
485 }
486 }
487 });
488
489 Ok(Box::pin(
490 tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
491 ))
492 }
493
494 async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<(Bytes, HeaderMap), OpenAIError>
500 where
501 M: Fn() -> Fut,
502 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
503 {
504 let client = self.http_client.clone();
505
506 backoff::future::retry(self.backoff.clone(), || async {
507 let request = request_maker().await.map_err(backoff::Error::Permanent)?;
508 let response = client
509 .execute(request)
510 .await
511 .map_err(OpenAIError::Reqwest)
512 .map_err(backoff::Error::Permanent)?;
513
514 let status = response.status();
515
516 match read_response(response).await {
517 Ok((bytes, headers)) => Ok((bytes, headers)),
518 Err(e) => {
519 match e {
520 OpenAIError::ApiError(api_error) => {
521 if status.is_server_error() {
522 Err(backoff::Error::Transient {
523 err: OpenAIError::ApiError(api_error),
524 retry_after: None,
525 })
526 } else if status.as_u16() == 429
527 && api_error.r#type != Some("insufficient_quota".to_string())
528 {
529 tracing::warn!("Rate limited: {}", api_error.message);
531 Err(backoff::Error::Transient {
532 err: OpenAIError::ApiError(api_error),
533 retry_after: None,
534 })
535 } else {
536 Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
537 }
538 }
539 _ => Err(backoff::Error::Permanent(e)),
540 }
541 }
542 }
543 })
544 .await
545 }
546
547 async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
553 where
554 O: DeserializeOwned,
555 M: Fn() -> Fut,
556 Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
557 {
558 let (bytes, _headers) = self.execute_raw(request_maker).await?;
559
560 let response: O = serde_json::from_slice(bytes.as_ref())
561 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
562
563 Ok(response)
564 }
565
566 #[allow(unused)]
568 pub(crate) async fn post_stream<I, O>(
569 &self,
570 path: &str,
571 request: I,
572 request_options: &RequestOptions,
573 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
574 where
575 I: Serialize,
576 O: DeserializeOwned + std::marker::Send + 'static,
577 {
578 let request_builder = self
579 .build_request_builder(reqwest::Method::POST, path, request_options)
580 .json(&request);
581
582 let event_source = request_builder.eventsource().unwrap();
583
584 stream(event_source).await
585 }
586
587 #[allow(unused)]
588 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
589 &self,
590 path: &str,
591 request: I,
592 request_options: &RequestOptions,
593 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
594 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
595 where
596 I: Serialize,
597 O: DeserializeOwned + std::marker::Send + 'static,
598 {
599 let request_builder = self
600 .build_request_builder(reqwest::Method::POST, path, request_options)
601 .json(&request);
602
603 let event_source = request_builder.eventsource().unwrap();
604
605 stream_mapped_raw_events(event_source, event_mapper).await
606 }
607
608 #[allow(unused)]
610 pub(crate) async fn get_stream<O>(
611 &self,
612 path: &str,
613 request_options: &RequestOptions,
614 ) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
615 where
616 O: DeserializeOwned + std::marker::Send + 'static,
617 {
618 let request_builder =
619 self.build_request_builder(reqwest::Method::GET, path, request_options);
620
621 let event_source = request_builder.eventsource().unwrap();
622
623 stream(event_source).await
624 }
625}
626
627async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
628 let status = response.status();
629 let headers = response.headers().clone();
630 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
631
632 if status.is_server_error() {
633 let message: String = String::from_utf8_lossy(&bytes).into_owned();
635 tracing::warn!("Server error: {status} - {message}");
636 return Err(OpenAIError::ApiError(ApiError {
637 message,
638 r#type: None,
639 param: None,
640 code: None,
641 }));
642 }
643
644 if !status.is_success() {
646 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
647 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
648
649 return Err(OpenAIError::ApiError(wrapped_error.error));
650 }
651
652 Ok((bytes, headers))
653}
654
655async fn map_stream_error(value: EventSourceError) -> OpenAIError {
656 match value {
657 EventSourceError::InvalidStatusCode(status_code, response) => {
658 read_response(response).await.expect_err(&format!(
659 "Unreachable because read_response returns err when status_code {status_code} is invalid"
660 ))
661 }
662 _ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
663 }
664}
665
666pub(crate) async fn stream<O>(
669 mut event_source: EventSource,
670) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
671where
672 O: DeserializeOwned + std::marker::Send + 'static,
673{
674 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
675
676 tokio::spawn(async move {
677 while let Some(ev) = event_source.next().await {
678 match ev {
679 Err(e) => {
680 match &e {
683 EventSourceError::StreamEnded => {
684 break;
685 }
686 _ => {
687 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
688 break;
690 }
691 }
692 }
693 }
694 Ok(event) => match event {
695 Event::Message(message) => {
696 if message.data == "[DONE]" {
697 break;
698 }
699
700 let response = match serde_json::from_str::<O>(&message.data) {
701 Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
702 Ok(output) => Ok(output),
703 };
704
705 if let Err(_e) = tx.send(response) {
706 break;
708 }
709 }
710 Event::Open => continue,
711 },
712 }
713 }
714
715 event_source.close();
716 });
717
718 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
719}
720
721pub(crate) async fn stream_mapped_raw_events<O>(
722 mut event_source: EventSource,
723 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
724) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
725where
726 O: DeserializeOwned + std::marker::Send + 'static,
727{
728 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
729
730 tokio::spawn(async move {
731 while let Some(ev) = event_source.next().await {
732 match ev {
733 Err(e) => {
734 match &e {
737 EventSourceError::StreamEnded => {
738 break;
739 }
740 _ => {
741 if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
742 break;
744 }
745 }
746 }
747 }
748 Ok(event) => match event {
749 Event::Message(message) => {
750 let mut done = false;
751
752 if message.data == "[DONE]" {
753 done = true;
754 }
755
756 let response = event_mapper(message);
757
758 if let Err(_e) = tx.send(response) {
759 break;
761 }
762
763 if done {
764 break;
765 }
766 }
767 Event::Open => continue,
768 },
769 }
770 }
771
772 event_source.close();
773 });
774
775 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
776}