1use std::future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use eventsource_stream::EventStreamError;
8use future::Future;
9use futures::stream::Filter;
10use futures::{Stream, stream::StreamExt};
11use pin_project::pin_project;
12use reqwest::header::HeaderMap;
13use reqwest::{Response, multipart::Form};
14use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
15use serde::{Serialize, de::DeserializeOwned};
16
17use crate::error::{ApiError, StreamError};
18use crate::{
19 RequestOptions,
20 config::{Config, OpenAIConfig},
21 error::{OpenAIError, WrappedError, map_deserialization_error},
22 traits::AsyncTryFrom,
23};
24
25#[cfg(feature = "assistant")]
26use crate::Assistants;
27#[cfg(feature = "audio")]
28use crate::Audio;
29#[cfg(feature = "batch")]
30use crate::Batches;
31#[cfg(feature = "chat-completion")]
32use crate::Chat;
33#[cfg(feature = "completions")]
34use crate::Completions;
35#[cfg(feature = "container")]
36use crate::Containers;
37#[cfg(feature = "responses")]
38use crate::Conversations;
39#[cfg(feature = "embedding")]
40use crate::Embeddings;
41#[cfg(feature = "evals")]
42use crate::Evals;
43#[cfg(feature = "finetuning")]
44use crate::FineTuning;
45#[cfg(feature = "model")]
46use crate::Models;
47#[cfg(feature = "realtime")]
48use crate::Realtime;
49#[cfg(feature = "responses")]
50use crate::Responses;
51#[cfg(feature = "assistant")]
52use crate::Threads;
53#[cfg(feature = "upload")]
54use crate::Uploads;
55#[cfg(feature = "vectorstore")]
56use crate::VectorStores;
57#[cfg(feature = "video")]
58use crate::Videos;
59#[cfg(feature = "administration")]
60use crate::admin::Admin;
61#[cfg(feature = "chatkit")]
62use crate::chatkit::Chatkit;
63#[cfg(feature = "file")]
64use crate::file::Files;
65#[cfg(feature = "image")]
66use crate::image::Images;
67#[cfg(feature = "moderation")]
68use crate::moderation::Moderations;
69
70#[derive(Debug, Clone)]
71pub struct Client<C: Config> {
74 http_client: reqwest::Client,
75 config: C,
76}
77
78impl Client<OpenAIConfig> {
79 pub fn new() -> Self {
81 Self {
82 http_client: reqwest::Client::new(),
83 config: OpenAIConfig::default(),
84 }
85 }
86}
87
88impl<C: Config> Client<C> {
89 pub fn build(http_client: reqwest::Client, config: C) -> Self {
91 Self {
92 http_client,
93 config,
94 }
95 }
96
97 pub fn with_config(config: C) -> Self {
99 Self {
100 http_client: reqwest::Client::new(),
101 config,
102 }
103 }
104
105 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
109 self.http_client = http_client;
110 self
111 }
112
113 #[cfg(feature = "model")]
117 pub fn models(&self) -> Models<'_, C> {
118 Models::new(self)
119 }
120
121 #[cfg(feature = "completions")]
123 pub fn completions(&self) -> Completions<'_, C> {
124 Completions::new(self)
125 }
126
127 #[cfg(feature = "chat-completion")]
129 pub fn chat(&self) -> Chat<'_, C> {
130 Chat::new(self)
131 }
132
133 #[cfg(feature = "image")]
135 pub fn images(&self) -> Images<'_, C> {
136 Images::new(self)
137 }
138
139 #[cfg(feature = "moderation")]
141 pub fn moderations(&self) -> Moderations<'_, C> {
142 Moderations::new(self)
143 }
144
145 #[cfg(feature = "file")]
147 pub fn files(&self) -> Files<'_, C> {
148 Files::new(self)
149 }
150
151 #[cfg(feature = "upload")]
153 pub fn uploads(&self) -> Uploads<'_, C> {
154 Uploads::new(self)
155 }
156
157 #[cfg(feature = "finetuning")]
159 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
160 FineTuning::new(self)
161 }
162
163 #[cfg(feature = "embedding")]
165 pub fn embeddings(&self) -> Embeddings<'_, C> {
166 Embeddings::new(self)
167 }
168
169 #[cfg(feature = "audio")]
171 pub fn audio(&self) -> Audio<'_, C> {
172 Audio::new(self)
173 }
174
175 #[cfg(feature = "video")]
177 pub fn videos(&self) -> Videos<'_, C> {
178 Videos::new(self)
179 }
180
181 #[cfg(feature = "assistant")]
183 pub fn assistants(&self) -> Assistants<'_, C> {
184 Assistants::new(self)
185 }
186
187 #[cfg(feature = "assistant")]
189 pub fn threads(&self) -> Threads<'_, C> {
190 Threads::new(self)
191 }
192
193 #[cfg(feature = "vectorstore")]
195 pub fn vector_stores(&self) -> VectorStores<'_, C> {
196 VectorStores::new(self)
197 }
198
199 #[cfg(feature = "batch")]
201 pub fn batches(&self) -> Batches<'_, C> {
202 Batches::new(self)
203 }
204
205 #[cfg(feature = "administration")]
208 pub fn admin(&self) -> Admin<'_, C> {
209 Admin::new(self)
210 }
211
212 #[cfg(feature = "responses")]
214 pub fn responses(&self) -> Responses<'_, C> {
215 Responses::new(self)
216 }
217
218 #[cfg(feature = "responses")]
220 pub fn conversations(&self) -> Conversations<'_, C> {
221 Conversations::new(self)
222 }
223
224 #[cfg(feature = "container")]
226 pub fn containers(&self) -> Containers<'_, C> {
227 Containers::new(self)
228 }
229
230 #[cfg(feature = "evals")]
232 pub fn evals(&self) -> Evals<'_, C> {
233 Evals::new(self)
234 }
235
236 #[cfg(feature = "chatkit")]
237 pub fn chatkit(&self) -> Chatkit<'_, C> {
238 Chatkit::new(self)
239 }
240
241 #[cfg(feature = "realtime")]
243 pub fn realtime(&self) -> Realtime<'_, C> {
244 Realtime::new(self)
245 }
246
247 pub fn config(&self) -> &C {
248 &self.config
249 }
250
251 fn build_request_builder(
253 &self,
254 method: reqwest::Method,
255 path: &str,
256 request_options: &RequestOptions,
257 ) -> reqwest::RequestBuilder {
258 let mut request_builder = if let Some(path) = request_options.path() {
259 self.http_client
260 .request(method, self.config.url(path.as_str()))
261 } else {
262 self.http_client.request(method, self.config.url(path))
263 };
264
265 request_builder = request_builder
266 .query(&self.config.query())
267 .headers(self.config.headers());
268
269 if let Some(headers) = request_options.headers() {
270 request_builder = request_builder.headers(headers.clone());
271 }
272
273 if !request_options.query().is_empty() {
274 request_builder = request_builder.query(request_options.query());
275 }
276
277 request_builder
278 }
279
280 #[allow(unused)]
282 pub(crate) async fn get<O>(
283 &self,
284 path: &str,
285 request_options: &RequestOptions,
286 ) -> Result<O, OpenAIError>
287 where
288 O: DeserializeOwned,
289 {
290 self.execute(async {
291 Ok(self
292 .build_request_builder(reqwest::Method::GET, path, request_options)
293 .build()?)
294 })
295 .await
296 }
297
298 #[allow(unused)]
300 pub(crate) async fn delete<O>(
301 &self,
302 path: &str,
303 request_options: &RequestOptions,
304 ) -> Result<O, OpenAIError>
305 where
306 O: DeserializeOwned,
307 {
308 self.execute(async {
309 Ok(self
310 .build_request_builder(reqwest::Method::DELETE, path, request_options)
311 .build()?)
312 })
313 .await
314 }
315
316 #[allow(unused)]
318 pub(crate) async fn get_raw(
319 &self,
320 path: &str,
321 request_options: &RequestOptions,
322 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
323 self.execute_raw(async {
324 Ok(self
325 .build_request_builder(reqwest::Method::GET, path, request_options)
326 .build()?)
327 })
328 .await
329 }
330
331 #[allow(unused)]
333 pub(crate) async fn post_raw<I>(
334 &self,
335 path: &str,
336 request: I,
337 request_options: &RequestOptions,
338 ) -> Result<(Bytes, HeaderMap), OpenAIError>
339 where
340 I: Serialize,
341 {
342 self.execute_raw(async {
343 Ok(self
344 .build_request_builder(reqwest::Method::POST, path, request_options)
345 .json(&request)
346 .build()?)
347 })
348 .await
349 }
350
351 #[allow(unused)]
353 pub(crate) async fn post<I, O>(
354 &self,
355 path: &str,
356 request: I,
357 request_options: &RequestOptions,
358 ) -> Result<O, OpenAIError>
359 where
360 I: Serialize,
361 O: DeserializeOwned,
362 {
363 self.execute(async {
364 Ok(self
365 .build_request_builder(reqwest::Method::POST, path, request_options)
366 .json(&request)
367 .build()?)
368 })
369 .await
370 }
371
372 #[allow(unused)]
374 pub(crate) async fn post_form_raw<F>(
375 &self,
376 path: &str,
377 form: F,
378 request_options: &RequestOptions,
379 ) -> Result<(Bytes, HeaderMap), OpenAIError>
380 where
381 Form: AsyncTryFrom<F, Error = OpenAIError>,
382 {
383 self.execute_raw(async {
384 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
385 Ok(self
386 .build_request_builder(reqwest::Method::POST, path, request_options)
387 .multipart(form)
388 .build()?)
389 })
390 .await
391 }
392
393 #[allow(unused)]
395 pub(crate) async fn post_form<O, F>(
396 &self,
397 path: &str,
398 form: F,
399 request_options: &RequestOptions,
400 ) -> Result<O, OpenAIError>
401 where
402 O: DeserializeOwned,
403 Form: AsyncTryFrom<F, Error = OpenAIError>,
404 {
405 self.execute(async {
406 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
407 Ok(self
408 .build_request_builder(reqwest::Method::POST, path, request_options)
409 .multipart(form)
410 .build()?)
411 })
412 .await
413 }
414
415 #[allow(unused)]
416 pub(crate) async fn post_form_stream<O, F>(
417 &self,
418 path: &str,
419 form: F,
420 request_options: &RequestOptions,
421 ) -> Result<OpenAIFormEventStream<O>, OpenAIError>
422 where
423 F: Clone,
424 Form: AsyncTryFrom<F, Error = OpenAIError>,
425 O: DeserializeOwned + Send + 'static,
426 {
427 let request_builder = self
430 .build_request_builder(reqwest::Method::POST, path, request_options)
431 .multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
432
433 let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
434
435 if !response.status().is_success() {
437 return Err(read_response(response).await.unwrap_err());
438 }
439
440 let stream = response
442 .bytes_stream()
443 .map(|result| result.map_err(std::io::Error::other));
444 let event_stream = eventsource_stream::EventStream::new(stream);
445
446 Ok(OpenAIFormEventStream::new(event_stream))
447 }
448
449 async fn execute_raw(
451 &self,
452 request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
453 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
454 let client = self.http_client.clone();
455 let request = request_future.await?;
456 let response = client
457 .execute(request)
458 .await
459 .map_err(OpenAIError::Reqwest)?;
460
461 let status = response.status();
462 match read_response(response).await {
463 Ok((bytes, headers)) => Ok((bytes, headers)),
464 Err(e) => match e {
465 OpenAIError::ApiError(api_error) => {
466 if status.as_u16() == 429
467 && api_error.r#type != Some("insufficient_quota".to_string())
468 {
469 tracing::warn!("Rate limited: {}", api_error.message);
471 }
472 Err(OpenAIError::ApiError(api_error))
473 }
474 _ => Err(e),
475 },
476 }
477 }
478
479 async fn execute<O>(
481 &self,
482 request_future: impl Future<Output = Result<reqwest::Request, OpenAIError>>,
483 ) -> Result<O, OpenAIError>
484 where
485 O: DeserializeOwned,
486 {
487 let (bytes, _headers) = self.execute_raw(request_future).await?;
488
489 let response: O = serde_json::from_slice(bytes.as_ref())
490 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
491
492 Ok(response)
493 }
494
495 #[allow(unused)]
497 pub(crate) async fn post_stream<I, O>(
498 &self,
499 path: &str,
500 request: I,
501 request_options: &RequestOptions,
502 ) -> OpenAIEventStream<O>
503 where
504 I: Serialize,
505 O: DeserializeOwned + Send + 'static,
506 {
507 let request_builder = self
508 .build_request_builder(reqwest::Method::POST, path, request_options)
509 .json(&request);
510
511 let event_source = request_builder.eventsource().unwrap();
512
513 OpenAIEventStream::new(event_source)
514 }
515
516 #[allow(unused)]
517 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
518 &self,
519 path: &str,
520 request: I,
521 request_options: &RequestOptions,
522 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
523 ) -> OpenAIEventStream<O>
524 where
525 I: Serialize,
526 O: DeserializeOwned + Send + 'static,
527 {
528 let request_builder = self
529 .build_request_builder(reqwest::Method::POST, path, request_options)
530 .json(&request);
531
532 let event_source = request_builder.eventsource().unwrap();
533
534 OpenAIEventStream::with_event_mapping(event_source, event_mapper)
535 }
536
537 #[allow(unused)]
539 pub(crate) async fn get_stream<O>(
540 &self,
541 path: &str,
542 request_options: &RequestOptions,
543 ) -> OpenAIEventStream<O>
544 where
545 O: DeserializeOwned + Send + 'static,
546 {
547 let request_builder =
548 self.build_request_builder(reqwest::Method::GET, path, request_options);
549
550 let event_source = request_builder.eventsource().unwrap();
551
552 OpenAIEventStream::new(event_source)
553 }
554}
555
556#[pin_project]
562pub struct OpenAIEventStream<O>
563where
564 O: DeserializeOwned + Send + 'static,
565{
566 #[pin]
567 stream: Filter<
568 EventSource,
569 future::Ready<bool>,
570 fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>,
571 >,
572 event_mapper:
573 Option<Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>>,
574 done: bool,
575 _phantom_data: PhantomData<O>,
576}
577
578impl<O> OpenAIEventStream<O>
579where
580 O: DeserializeOwned + Send + 'static,
581{
582 pub(crate) fn with_event_mapping<M>(event_source: EventSource, event_mapper: M) -> Self
583 where
584 M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
585 {
586 Self {
587 stream: event_source.filter(|result|
588 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
590 done: false,
591 event_mapper: Some(Box::new(event_mapper)),
592 _phantom_data: PhantomData,
593 }
594 }
595
596 pub(crate) fn new(event_source: EventSource) -> Self {
597 Self {
598 stream: event_source.filter(|result|
599 future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))),
601 done: false,
602 event_mapper: None,
603 _phantom_data: PhantomData,
604 }
605 }
606}
607
608impl<O> Stream for OpenAIEventStream<O>
609where
610 O: DeserializeOwned + Send + 'static,
611{
612 type Item = Result<O, OpenAIError>;
613
614 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
615 let this = self.project();
616 if *this.done {
617 return Poll::Ready(None);
618 }
619 let stream: Pin<&mut _> = this.stream;
620 match stream.poll_next(cx) {
621 Poll::Ready(response) => {
622 match response {
623 None => Poll::Ready(None), Some(result) => match result {
625 Ok(event) => match event {
626 Event::Open => unreachable!(), Event::Message(message) => {
628 if let Some(event_mapper) = this.event_mapper.as_ref() {
629 if message.data == "[DONE]" {
630 *this.done = true;
631 }
632 let response = event_mapper(message);
633 match response {
634 Ok(output) => Poll::Ready(Some(Ok(output))),
635 Err(_) => Poll::Ready(None),
636 }
637 } else {
638 if message.data == "[DONE]" {
639 *this.done = true;
640 Poll::Ready(None) } else {
642 match serde_json::from_str::<O>(&message.data) {
644 Err(e) => {
645 *this.done = true;
646 Poll::Ready(Some(Err(map_deserialization_error(
647 e,
648 message.data.as_bytes(),
649 ))))
650 }
651 Ok(output) => Poll::Ready(Some(Ok(output))),
652 }
653 }
654 }
655 }
656 },
657 Err(e) => {
658 *this.done = true;
659 Poll::Ready(Some(Err(OpenAIError::StreamError(Box::new(
660 StreamError::ReqwestEventSource(e),
661 )))))
662 }
663 },
664 }
665 }
666 Poll::Pending => Poll::Pending,
667 }
668 }
669}
670
671#[pin_project]
672pub struct OpenAIFormEventStream<O>
673where
674 O: DeserializeOwned + Send + 'static,
675{
676 #[pin]
677 event_stream: Box<
678 dyn Stream<Item = Result<eventsource_stream::Event, EventStreamError<std::io::Error>>>
679 + Unpin
680 + 'static,
681 >,
682 done: bool,
683 _phantom_data: PhantomData<O>,
684}
685
686impl<O> OpenAIFormEventStream<O>
687where
688 O: DeserializeOwned + Send + 'static,
689{
690 pub fn new(
691 stream: impl Stream<Item = Result<eventsource_stream::Event, EventStreamError<std::io::Error>>>
692 + Unpin
693 + 'static,
694 ) -> Self {
695 Self {
696 event_stream: Box::new(stream),
697 done: false,
698 _phantom_data: PhantomData,
699 }
700 }
701}
702
703impl<O> Stream for OpenAIFormEventStream<O>
704where
705 O: DeserializeOwned + Send + 'static,
706{
707 type Item = Result<O, OpenAIError>;
708
709 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
710 let this = self.project();
711 if *this.done {
712 return Poll::Ready(None);
713 }
714 let stream: Pin<&mut _> = this.event_stream;
715 match stream.poll_next(cx) {
716 Poll::Pending => Poll::Pending,
717 Poll::Ready(response) => match response {
718 None => Poll::Ready(None),
719 Some(result) => match result {
720 Err(e) => {
721 Poll::Ready(Some(Err(OpenAIError::StreamError(Box::new(
723 StreamError::EventStream(e.to_string()),
724 )))))
725 }
726 Ok(event) => {
727 if event.data == "[DONE]" {
728 *this.done = true;
729 Poll::Ready(None)
730 } else {
731 match serde_json::from_str::<O>(&event.data) {
732 Err(e) => Poll::Ready(Some(Err(map_deserialization_error(
733 e,
734 event.data.as_bytes(),
735 )))),
736 Ok(output) => Poll::Ready(Some(Ok(output))),
737 }
738 }
739 }
740 },
741 },
742 }
743 }
744}
745
746async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
747 let status = response.status();
748 let headers = response.headers().clone();
749 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
750
751 if status.is_server_error() {
752 let message: String = String::from_utf8_lossy(&bytes).into_owned();
754 tracing::warn!("Server error: {status} - {message}");
755 return Err(OpenAIError::ApiError(ApiError {
756 message,
757 r#type: None,
758 param: None,
759 code: None,
760 }));
761 }
762
763 if !status.is_success() {
765 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
766 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
767
768 return Err(OpenAIError::ApiError(wrapped_error.error));
769 }
770
771 Ok((bytes, headers))
772}