1#[cfg(not(target_family = "wasm"))]
2use std::pin::Pin;
3use std::sync::Arc;
4#[cfg(not(target_family = "wasm"))]
5use std::sync::Mutex;
6
7use bytes::Bytes;
8#[cfg(not(target_family = "wasm"))]
9use futures::{stream::StreamExt, Stream};
10use reqwest::{header::HeaderMap, multipart::Form, Response};
11use serde::{de::DeserializeOwned, Serialize};
12
13#[cfg(not(target_family = "wasm"))]
14use crate::error::StreamError;
15#[cfg(feature = "middleware")]
16use crate::executor::TowerExecutor;
17use crate::{
18 config::{Config, OpenAIConfig},
19 error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
20 executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
21 traits::AsyncTryFrom,
22 RequestOptions,
23};
24
25struct RequestParts {
26 request_client: reqwest::Client,
27 method: reqwest::Method,
28 url: String,
29 headers: HeaderMap,
30 query: Vec<(String, String)>,
31}
32
33impl RequestParts {
34 fn build_request_builder(&self) -> reqwest::RequestBuilder {
35 self.request_client
36 .request(self.method.clone(), self.url.clone())
37 .query(&self.query)
38 .headers(self.headers.clone())
39 }
40}
41
42#[cfg(feature = "administration")]
43use crate::admin::Admin;
44#[cfg(feature = "chatkit")]
45use crate::chatkit::Chatkit;
46#[cfg(feature = "file")]
47use crate::file::Files;
48#[cfg(feature = "image")]
49use crate::image::Images;
50#[cfg(feature = "moderation")]
51use crate::moderation::Moderations;
52#[cfg(feature = "assistant")]
53#[allow(deprecated)]
54use crate::Assistants;
55#[cfg(feature = "audio")]
56use crate::Audio;
57#[cfg(feature = "batch")]
58use crate::Batches;
59#[cfg(feature = "chat-completion")]
60use crate::Chat;
61#[cfg(feature = "completions")]
62use crate::Completions;
63#[cfg(feature = "container")]
64use crate::Containers;
65#[cfg(feature = "responses")]
66use crate::Conversations;
67#[cfg(feature = "embedding")]
68use crate::Embeddings;
69#[cfg(feature = "evals")]
70use crate::Evals;
71#[cfg(feature = "finetuning")]
72use crate::FineTuning;
73#[cfg(feature = "model")]
74use crate::Models;
75#[cfg(feature = "realtime")]
76use crate::Realtime;
77#[cfg(feature = "responses")]
78use crate::Responses;
79#[cfg(feature = "skill")]
80use crate::Skills;
81#[cfg(feature = "assistant")]
82#[allow(deprecated)]
83use crate::Threads;
84#[cfg(feature = "upload")]
85use crate::Uploads;
86#[cfg(feature = "vectorstore")]
87use crate::VectorStores;
88#[cfg(feature = "video")]
89use crate::Videos;
90
91#[derive(Clone)]
92pub struct Client<C: Config> {
95 request_client: reqwest::Client,
96 executor: SharedExecutor,
97 config: C,
98}
99
100impl<C> std::fmt::Debug for Client<C>
101where
102 C: Config + std::fmt::Debug,
103{
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("Client")
106 .field("request_client", &self.request_client)
107 .field("config", &self.config)
108 .finish()
109 }
110}
111
112impl<C: Config> Default for Client<C>
113where
114 C: Default,
115{
116 fn default() -> Self {
117 let request_client = reqwest::Client::new();
118 Self {
119 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
120 request_client,
121 config: C::default(),
122 }
123 }
124}
125
126impl Client<OpenAIConfig> {
127 pub fn new() -> Self {
129 Self::default()
130 }
131}
132
133impl<C: Config> Client<C> {
134 pub fn build(http_client: reqwest::Client, config: C) -> Self {
136 Self {
137 executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
138 request_client: http_client,
139 config,
140 }
141 }
142
143 pub fn with_config(config: C) -> Self {
145 let request_client = reqwest::Client::new();
146 Self {
147 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
148 request_client,
149 config,
150 }
151 }
152
153 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
157 self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
158 self.request_client = http_client;
159 self
160 }
161
162 #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
164 pub fn with_http_service<S>(mut self, service: S) -> Self
165 where
166 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
167 S::Future: Send + 'static,
168 S::Error: Into<OpenAIError> + Send + Sync + 'static,
169 {
170 self.executor = Arc::new(TowerExecutor::new(service));
175 self
176 }
177
178 #[cfg(all(feature = "middleware", target_family = "wasm"))]
180 pub fn with_http_service<S>(mut self, service: S) -> Self
181 where
182 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
183 S::Future: 'static,
184 S::Error: Into<OpenAIError> + 'static,
185 {
186 self.executor = Arc::new(TowerExecutor::new(service));
191 self
192 }
193
194 #[cfg(feature = "model")]
198 pub fn models(&self) -> Models<'_, C> {
199 Models::new(self)
200 }
201
202 #[cfg(feature = "completions")]
204 pub fn completions(&self) -> Completions<'_, C> {
205 Completions::new(self)
206 }
207
208 #[cfg(feature = "chat-completion")]
210 pub fn chat(&self) -> Chat<'_, C> {
211 Chat::new(self)
212 }
213
214 #[cfg(feature = "image")]
216 pub fn images(&self) -> Images<'_, C> {
217 Images::new(self)
218 }
219
220 #[cfg(feature = "moderation")]
222 pub fn moderations(&self) -> Moderations<'_, C> {
223 Moderations::new(self)
224 }
225
226 #[cfg(feature = "file")]
228 pub fn files(&self) -> Files<'_, C> {
229 Files::new(self)
230 }
231
232 #[cfg(feature = "upload")]
234 pub fn uploads(&self) -> Uploads<'_, C> {
235 Uploads::new(self)
236 }
237
238 #[cfg(feature = "finetuning")]
240 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
241 FineTuning::new(self)
242 }
243
244 #[cfg(feature = "embedding")]
246 pub fn embeddings(&self) -> Embeddings<'_, C> {
247 Embeddings::new(self)
248 }
249
250 #[cfg(feature = "audio")]
252 pub fn audio(&self) -> Audio<'_, C> {
253 Audio::new(self)
254 }
255
256 #[cfg(feature = "video")]
258 pub fn videos(&self) -> Videos<'_, C> {
259 Videos::new(self)
260 }
261
262 #[cfg(feature = "assistant")]
264 #[deprecated(
265 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
266 )]
267 #[allow(deprecated)]
268 pub fn assistants(&self) -> Assistants<'_, C> {
269 Assistants::new(self)
270 }
271
272 #[cfg(feature = "assistant")]
274 #[deprecated(
275 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
276 )]
277 #[allow(deprecated)]
278 pub fn threads(&self) -> Threads<'_, C> {
279 Threads::new(self)
280 }
281
282 #[cfg(feature = "vectorstore")]
284 pub fn vector_stores(&self) -> VectorStores<'_, C> {
285 VectorStores::new(self)
286 }
287
288 #[cfg(feature = "batch")]
290 pub fn batches(&self) -> Batches<'_, C> {
291 Batches::new(self)
292 }
293
294 #[cfg(feature = "administration")]
297 pub fn admin(&self) -> Admin<'_, C> {
298 Admin::new(self)
299 }
300
301 #[cfg(feature = "responses")]
303 pub fn responses(&self) -> Responses<'_, C> {
304 Responses::new(self)
305 }
306
307 #[cfg(feature = "responses")]
309 pub fn conversations(&self) -> Conversations<'_, C> {
310 Conversations::new(self)
311 }
312
313 #[cfg(feature = "container")]
315 pub fn containers(&self) -> Containers<'_, C> {
316 Containers::new(self)
317 }
318
319 #[cfg(feature = "skill")]
321 pub fn skills(&self) -> Skills<'_, C> {
322 Skills::new(self)
323 }
324
325 #[cfg(feature = "evals")]
327 pub fn evals(&self) -> Evals<'_, C> {
328 Evals::new(self)
329 }
330
331 #[cfg(feature = "chatkit")]
332 pub fn chatkit(&self) -> Chatkit<'_, C> {
333 Chatkit::new(self)
334 }
335
336 #[cfg(feature = "realtime")]
338 pub fn realtime(&self) -> Realtime<'_, C> {
339 Realtime::new(self)
340 }
341
342 pub fn config(&self) -> &C {
343 &self.config
344 }
345
346 fn build_request_parts(
347 &self,
348 method: reqwest::Method,
349 path: &str,
350 request_options: &RequestOptions,
351 ) -> Arc<RequestParts> {
352 let url = if let Some(path) = request_options.path() {
353 self.config.url(path.as_str())
354 } else {
355 self.config.url(path)
356 };
357 let mut headers = self.config.headers();
358 if let Some(request_headers) = request_options.headers() {
359 headers.extend(request_headers.clone());
360 }
361
362 let mut query = self
363 .config
364 .query()
365 .into_iter()
366 .map(|(key, value)| (key.to_string(), value.to_string()))
367 .collect::<Vec<_>>();
368 query.extend_from_slice(request_options.query());
369
370 Arc::new(RequestParts {
371 request_client: self.request_client.clone(),
372 method,
373 url,
374 headers,
375 query,
376 })
377 }
378
379 fn build_request_factory(
380 &self,
381 method: reqwest::Method,
382 path: &str,
383 request_options: &RequestOptions,
384 ) -> HttpRequestFactory {
385 let request_parts = self.build_request_parts(method, path, request_options);
386
387 HttpRequestFactory::new(move || {
388 let request_parts = request_parts.clone();
389
390 async move {
391 let request = request_parts.build_request_builder().build()?;
392 Ok(request)
393 }
394 })
395 }
396
397 fn build_request_factory_with_json<I>(
398 &self,
399 method: reqwest::Method,
400 path: &str,
401 request: I,
402 request_options: &RequestOptions,
403 ) -> Result<HttpRequestFactory, OpenAIError>
404 where
405 I: Serialize,
406 {
407 let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
412 OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
413 })?);
414 let request_parts = self.build_request_parts(method, path, request_options);
415
416 Ok(HttpRequestFactory::new(move || {
417 let request_parts = request_parts.clone();
418 let request = request.clone();
419
420 async move {
421 let request_builder = request_parts
422 .build_request_builder()
423 .header(reqwest::header::CONTENT_TYPE, "application/json")
424 .body(request.clone());
425
426 Ok(request_builder.build()?)
427 }
428 }))
429 }
430
431 #[cfg(not(target_family = "wasm"))]
432 fn build_request_factory_with_form<F>(
433 &self,
434 method: reqwest::Method,
435 path: &str,
436 form: F,
437 request_options: &RequestOptions,
438 ) -> Result<HttpRequestFactory, OpenAIError>
439 where
440 F: Clone + Send + 'static,
441 Form: AsyncTryFrom<F, Error = OpenAIError>,
442 {
443 let form = Arc::new(Mutex::new(form));
447 let request_parts = self.build_request_parts(method, path, request_options);
448
449 Ok(HttpRequestFactory::new(move || {
450 let request_parts = request_parts.clone();
451 let form = form.clone();
452
453 async move {
454 let form = form
455 .lock()
456 .expect("multipart request factory mutex poisoned")
457 .clone();
458 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
459 let request_builder = request_parts.build_request_builder().multipart(form);
460
461 Ok(request_builder.build()?)
462 }
463 }))
464 }
465
466 #[cfg(target_family = "wasm")]
467 fn build_request_factory_with_form<F>(
468 &self,
469 method: reqwest::Method,
470 path: &str,
471 form: F,
472 request_options: &RequestOptions,
473 ) -> Result<HttpRequestFactory, OpenAIError>
474 where
475 F: Clone + 'static,
476 Form: AsyncTryFrom<F, Error = OpenAIError>,
477 {
478 let request_parts = self.build_request_parts(method, path, request_options);
479
480 Ok(HttpRequestFactory::new(move || {
481 let request_parts = request_parts.clone();
482 let form = form.clone();
483
484 async move {
485 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
486 let request_builder = request_parts.build_request_builder().multipart(form);
487
488 Ok(request_builder.build()?)
489 }
490 }))
491 }
492
493 #[allow(unused)]
495 pub(crate) async fn get<O>(
496 &self,
497 path: &str,
498 request_options: &RequestOptions,
499 ) -> Result<O, OpenAIError>
500 where
501 O: DeserializeOwned,
502 {
503 let request_factory =
504 self.build_request_factory(reqwest::Method::GET, path, request_options);
505 self.execute(request_factory).await
506 }
507
508 #[allow(unused)]
510 pub(crate) async fn delete<O>(
511 &self,
512 path: &str,
513 request_options: &RequestOptions,
514 ) -> Result<O, OpenAIError>
515 where
516 O: DeserializeOwned,
517 {
518 let request_factory =
519 self.build_request_factory(reqwest::Method::DELETE, path, request_options);
520 self.execute(request_factory).await
521 }
522
523 #[allow(unused)]
525 pub(crate) async fn get_raw(
526 &self,
527 path: &str,
528 request_options: &RequestOptions,
529 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
530 let request_factory =
531 self.build_request_factory(reqwest::Method::GET, path, request_options);
532 self.execute_raw(request_factory).await
533 }
534
535 #[allow(unused)]
537 pub(crate) async fn post_raw<I>(
538 &self,
539 path: &str,
540 request: I,
541 request_options: &RequestOptions,
542 ) -> Result<(Bytes, HeaderMap), OpenAIError>
543 where
544 I: Serialize,
545 {
546 let request_factory = self.build_request_factory_with_json(
547 reqwest::Method::POST,
548 path,
549 request,
550 request_options,
551 )?;
552 self.execute_raw(request_factory).await
553 }
554
555 #[allow(unused)]
557 pub(crate) async fn post<I, O>(
558 &self,
559 path: &str,
560 request: I,
561 request_options: &RequestOptions,
562 ) -> Result<O, OpenAIError>
563 where
564 I: Serialize,
565 O: DeserializeOwned,
566 {
567 let request_factory = self.build_request_factory_with_json(
568 reqwest::Method::POST,
569 path,
570 request,
571 request_options,
572 )?;
573 self.execute(request_factory).await
574 }
575
576 #[allow(unused)]
578 #[cfg(not(target_family = "wasm"))]
579 pub(crate) async fn post_form_raw<F>(
580 &self,
581 path: &str,
582 form: F,
583 request_options: &RequestOptions,
584 ) -> Result<(Bytes, HeaderMap), OpenAIError>
585 where
586 F: Clone + Send + 'static,
587 Form: AsyncTryFrom<F, Error = OpenAIError>,
588 {
589 let request_factory = self.build_request_factory_with_form(
590 reqwest::Method::POST,
591 path,
592 form,
593 request_options,
594 )?;
595 self.execute_raw(request_factory).await
596 }
597
598 #[allow(unused)]
600 #[cfg(target_family = "wasm")]
601 pub(crate) async fn post_form_raw<F>(
602 &self,
603 path: &str,
604 form: F,
605 request_options: &RequestOptions,
606 ) -> Result<(Bytes, HeaderMap), OpenAIError>
607 where
608 F: Clone + 'static,
609 Form: AsyncTryFrom<F, Error = OpenAIError>,
610 {
611 let request_factory = self.build_request_factory_with_form(
612 reqwest::Method::POST,
613 path,
614 form,
615 request_options,
616 )?;
617 self.execute_raw(request_factory).await
618 }
619
620 #[allow(unused)]
622 #[cfg(not(target_family = "wasm"))]
623 pub(crate) async fn post_form<O, F>(
624 &self,
625 path: &str,
626 form: F,
627 request_options: &RequestOptions,
628 ) -> Result<O, OpenAIError>
629 where
630 O: DeserializeOwned,
631 F: Clone + Send + 'static,
632 Form: AsyncTryFrom<F, Error = OpenAIError>,
633 {
634 let request_factory = self.build_request_factory_with_form(
635 reqwest::Method::POST,
636 path,
637 form,
638 request_options,
639 )?;
640 self.execute(request_factory).await
641 }
642
643 #[allow(unused)]
645 #[cfg(target_family = "wasm")]
646 pub(crate) async fn post_form<O, F>(
647 &self,
648 path: &str,
649 form: F,
650 request_options: &RequestOptions,
651 ) -> Result<O, OpenAIError>
652 where
653 O: DeserializeOwned,
654 F: Clone + 'static,
655 Form: AsyncTryFrom<F, Error = OpenAIError>,
656 {
657 let request_factory = self.build_request_factory_with_form(
658 reqwest::Method::POST,
659 path,
660 form,
661 request_options,
662 )?;
663 self.execute(request_factory).await
664 }
665
666 #[allow(unused)]
667 #[cfg(not(target_family = "wasm"))]
668 pub(crate) async fn post_form_stream<O, F>(
669 &self,
670 path: &str,
671 form: F,
672 request_options: &RequestOptions,
673 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
674 where
675 F: Clone + Send + 'static,
676 Form: AsyncTryFrom<F, Error = OpenAIError>,
677 O: DeserializeOwned + std::marker::Send + 'static,
678 {
679 let request_factory = self.build_request_factory_with_form(
680 reqwest::Method::POST,
681 path,
682 form,
683 request_options,
684 )?;
685
686 self.execute_stream(request_factory).await
687 }
688
689 async fn execute_raw(
690 &self,
691 request_factory: HttpRequestFactory,
692 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
693 let response = self.execute_response(request_factory).await?;
694 read_response(response).await
695 }
696
697 async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
698 where
699 O: DeserializeOwned,
700 {
701 let (bytes, _headers) = self.execute_raw(request_factory).await?;
702
703 let response: O = serde_json::from_slice(bytes.as_ref())
704 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
705
706 Ok(response)
707 }
708
709 async fn execute_response(
710 &self,
711 request_factory: HttpRequestFactory,
712 ) -> Result<Response, OpenAIError> {
713 self.executor.execute(request_factory).await
714 }
715
716 #[cfg(not(target_family = "wasm"))]
717 async fn execute_stream<O>(
718 &self,
719 request_factory: HttpRequestFactory,
720 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
721 where
722 O: DeserializeOwned + std::marker::Send + 'static,
723 {
724 let response = self.execute_response(request_factory).await?;
725 Ok(stream(response).await)
726 }
727
728 #[cfg(not(target_family = "wasm"))]
729 async fn execute_stream_mapped_raw_events<O>(
730 &self,
731 request_factory: HttpRequestFactory,
732 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
733 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
734 where
735 O: DeserializeOwned + std::marker::Send + 'static,
736 {
737 let response = self.execute_response(request_factory).await?;
738 Ok(stream_mapped_raw_events(response, event_mapper).await)
739 }
740
741 #[allow(unused)]
743 #[cfg(not(target_family = "wasm"))]
744 pub(crate) async fn post_stream<I, O>(
745 &self,
746 path: &str,
747 request: I,
748 request_options: &RequestOptions,
749 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
750 where
751 I: Serialize,
752 O: DeserializeOwned + std::marker::Send + 'static,
753 {
754 let request_factory = self.build_request_factory_with_json(
755 reqwest::Method::POST,
756 path,
757 request,
758 request_options,
759 )?;
760 self.execute_stream(request_factory).await
763 }
764
765 #[allow(unused)]
766 #[cfg(not(target_family = "wasm"))]
767 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
768 &self,
769 path: &str,
770 request: I,
771 request_options: &RequestOptions,
772 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
773 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
774 where
775 I: Serialize,
776 O: DeserializeOwned + std::marker::Send + 'static,
777 {
778 let request_factory = self.build_request_factory_with_json(
779 reqwest::Method::POST,
780 path,
781 request,
782 request_options,
783 )?;
784 self.execute_stream_mapped_raw_events(request_factory, event_mapper)
785 .await
786 }
787
788 #[allow(unused)]
790 #[cfg(not(target_family = "wasm"))]
791 pub(crate) async fn get_stream<O>(
792 &self,
793 path: &str,
794 request_options: &RequestOptions,
795 ) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
796 where
797 O: DeserializeOwned + std::marker::Send + 'static,
798 {
799 let request_factory =
800 self.build_request_factory(reqwest::Method::GET, path, request_options);
801 self.execute_stream(request_factory).await
802 }
803}
804
805async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
806 let status = response.status();
807 let headers = response.headers().clone();
808 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
809
810 if status.is_server_error() {
811 let message: String = String::from_utf8_lossy(&bytes).into_owned();
813 tracing::warn!("Server error: {status} - {message}");
814 return Err(OpenAIError::ApiError(ApiError {
815 message,
816 r#type: None,
817 param: None,
818 code: None,
819 }));
820 }
821
822 if !status.is_success() {
824 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
825 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
826
827 return Err(OpenAIError::ApiError(wrapped_error.error));
828 }
829
830 Ok((bytes, headers))
831}
832
833#[cfg(not(target_family = "wasm"))]
836pub(crate) async fn stream<O>(
837 response: Response,
838) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
839where
840 O: DeserializeOwned + std::marker::Send + 'static,
841{
842 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
843
844 tokio::spawn(async move {
845 if !response.status().is_success() {
846 if let Err(e) = read_response(response).await {
847 let _ = tx.send(Err(e));
848 }
849 return;
850 }
851 let byte_stream = response
852 .bytes_stream()
853 .map(|r| r.map_err(std::io::Error::other));
854 let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
855
856 while let Some(ev) = event_stream.next().await {
857 let event = match ev {
858 Ok(e) => e,
859 Err(e) => {
860 let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
861 StreamError::EventStream(e.to_string()),
862 ))));
863 break;
864 }
865 };
866 if event.data == "[DONE]" {
867 break;
868 }
869 if event.event == "keepalive" {
870 continue;
871 }
872
873 let response = serde_json::from_str::<O>(&event.data)
874 .map_err(|e| map_deserialization_error(e, event.data.as_bytes()));
875
876 if tx.send(response).is_err() {
877 break;
878 }
879 }
880 });
881
882 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
883}
884
885#[cfg(not(target_family = "wasm"))]
886pub(crate) async fn stream_mapped_raw_events<O>(
887 response: Response,
888 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
889) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
890where
891 O: DeserializeOwned + std::marker::Send + 'static,
892{
893 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
894
895 tokio::spawn(async move {
896 if !response.status().is_success() {
897 if let Err(e) = read_response(response).await {
898 let _ = tx.send(Err(e));
899 }
900 return;
901 }
902 let byte_stream = response
903 .bytes_stream()
904 .map(|r| r.map_err(std::io::Error::other));
905 let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
906
907 while let Some(ev) = event_stream.next().await {
908 let event = match ev {
909 Ok(e) => e,
910 Err(e) => {
911 let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
912 StreamError::EventStream(e.to_string()),
913 ))));
914 break;
915 }
916 };
917 let done = event.data == "[DONE]";
918
919 if event.event == "keepalive" {
920 continue;
921 }
922
923 let response = event_mapper(event);
924
925 if tx.send(response).is_err() {
926 break;
927 }
928
929 if done {
930 break;
931 }
932 }
933 });
934
935 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
936}
937
938#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
939mod tests {
940 use std::sync::{
941 atomic::{AtomicUsize, Ordering},
942 Arc,
943 };
944
945 use futures::StreamExt;
946 use http::Response as HttpResponse;
947 use serde_json::json;
948 use tower::{service_fn, ServiceBuilder};
949
950 use super::Client;
951 use crate::{
952 config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
953 retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
954 };
955
956 #[tokio::test]
957 async fn unary_requests_dispatch_through_middleware_service() {
958 let request_count = Arc::new(AtomicUsize::new(0));
959 let service = {
960 let request_count = request_count.clone();
961 ServiceBuilder::new()
962 .concurrency_limit(1)
963 .service(service_fn(move |factory: HttpRequestFactory| {
964 let request_count = request_count.clone();
965 async move {
966 let request = factory.build().await?;
967 assert_eq!(request.url().path(), "/models");
968 request_count.fetch_add(1, Ordering::SeqCst);
969 Ok::<reqwest::Response, OpenAIError>(
970 HttpResponse::builder()
971 .status(200)
972 .header("content-type", "application/json")
973 .body(reqwest::Body::from(
974 "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
975 ))
976 .unwrap()
977 .into(),
978 )
979 }
980 }))
981 };
982
983 let client = Client::with_config(
984 OpenAIConfig::new()
985 .with_api_base("http://example.test")
986 .with_api_key("test-key"),
987 )
988 .with_http_service(service);
989
990 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
991
992 assert_eq!(value["object"], "list");
993 assert_eq!(request_count.load(Ordering::SeqCst), 1);
994 }
995
996 #[tokio::test]
997 async fn stream_requests_open_through_middleware_service() {
998 let request_count = Arc::new(AtomicUsize::new(0));
999 let service = {
1000 let request_count = request_count.clone();
1001 ServiceBuilder::new()
1002 .concurrency_limit(1)
1003 .service(service_fn(move |factory: HttpRequestFactory| {
1004 let request_count = request_count.clone();
1005 async move {
1006 let request = factory.build().await?;
1007 assert_eq!(request.url().path(), "/responses");
1008 request_count.fetch_add(1, Ordering::SeqCst);
1009 Ok::<reqwest::Response, OpenAIError>(
1010 HttpResponse::builder()
1011 .status(200)
1012 .header("content-type", "text/event-stream")
1013 .body(reqwest::Body::from(
1014 "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
1015 ))
1016 .unwrap()
1017 .into(),
1018 )
1019 }
1020 }))
1021 };
1022
1023 let client = Client::with_config(
1024 OpenAIConfig::new()
1025 .with_api_base("http://example.test")
1026 .with_api_key("test-key"),
1027 )
1028 .with_http_service(service);
1029
1030 let mut stream = client
1031 .post_stream::<_, serde_json::Value>(
1032 "/responses",
1033 json!({ "stream": true }),
1034 &RequestOptions::new(),
1035 )
1036 .await
1037 .unwrap();
1038
1039 let first = stream.next().await.unwrap().unwrap();
1040
1041 assert_eq!(first, json!({ "ok": true }));
1042 assert_eq!(request_count.load(Ordering::SeqCst), 1);
1043 }
1044
1045 #[tokio::test]
1046 async fn middleware_retry_policy_retries_429_responses() {
1047 let request_count = Arc::new(AtomicUsize::new(0));
1048 let service = {
1049 let request_count = request_count.clone();
1050 ServiceBuilder::new()
1051 .retry(SimpleRetryPolicy::default())
1052 .service(service_fn(move |factory: HttpRequestFactory| {
1053 let request_count = request_count.clone();
1054 async move {
1055 let request = factory.build().await?;
1056 assert_eq!(request.url().path(), "/models");
1057 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1058
1059 let response = if attempt == 0 {
1060 HttpResponse::builder()
1061 .status(429)
1062 .header("content-type", "application/json")
1063 .body(reqwest::Body::from(
1064 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1065 ))
1066 .unwrap()
1067 } else {
1068 HttpResponse::builder()
1069 .status(200)
1070 .header("content-type", "application/json")
1071 .body(reqwest::Body::from(
1072 r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1073 ))
1074 .unwrap()
1075 };
1076
1077 Ok::<reqwest::Response, OpenAIError>(response.into())
1078 }
1079 }))
1080 };
1081
1082 let client = Client::with_config(
1083 OpenAIConfig::new()
1084 .with_api_base("http://example.test")
1085 .with_api_key("test-key"),
1086 )
1087 .with_http_service(service);
1088
1089 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1090
1091 assert_eq!(value["data"][0]["id"], "retry-model");
1092 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1093 }
1094
1095 #[derive(Clone)]
1096 struct RetryableMultipartInput {
1097 conversions: Arc<AtomicUsize>,
1098 }
1099
1100 impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1101 type Error = OpenAIError;
1102
1103 async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1104 value.conversions.fetch_add(1, Ordering::SeqCst);
1105 Ok(reqwest::multipart::Form::new().text("field", "value"))
1106 }
1107 }
1108
1109 #[tokio::test]
1110 async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1111 let request_count = Arc::new(AtomicUsize::new(0));
1112 let conversion_count = Arc::new(AtomicUsize::new(0));
1113
1114 let service = {
1115 let request_count = request_count.clone();
1116 ServiceBuilder::new()
1117 .retry(SimpleRetryPolicy::default())
1118 .service(service_fn(move |factory: HttpRequestFactory| {
1119 let request_count = request_count.clone();
1120 async move {
1121 let request = factory.build().await?;
1122 assert_eq!(request.method(), reqwest::Method::POST);
1123 assert_eq!(request.url().path(), "/files");
1124 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1125
1126 let response = if attempt == 0 {
1127 HttpResponse::builder()
1128 .status(429)
1129 .header("content-type", "application/json")
1130 .body(reqwest::Body::from(
1131 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1132 ))
1133 .unwrap()
1134 } else {
1135 HttpResponse::builder()
1136 .status(200)
1137 .header("content-type", "application/json")
1138 .body(reqwest::Body::from(r#"{"ok":true}"#))
1139 .unwrap()
1140 };
1141
1142 Ok::<reqwest::Response, OpenAIError>(response.into())
1143 }
1144 }))
1145 };
1146
1147 let client = Client::with_config(
1148 OpenAIConfig::new()
1149 .with_api_base("http://example.test")
1150 .with_api_key("test-key"),
1151 )
1152 .with_http_service(service);
1153
1154 let value: serde_json::Value = client
1155 .post_form(
1156 "/files",
1157 RetryableMultipartInput {
1158 conversions: conversion_count.clone(),
1159 },
1160 &RequestOptions::new(),
1161 )
1162 .await
1163 .unwrap();
1164
1165 assert_eq!(value, json!({ "ok": true }));
1166 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1167 assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1168 }
1169}