1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14 config::{Config, OpenAIConfig},
15 error::{map_deserialization_error, ApiError, ApiErrorResponse, OpenAIError, WrappedError},
16 executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17 traits::AsyncTryFrom,
18 RequestOptions,
19};
20
21struct RequestParts {
22 request_client: reqwest::Client,
23 method: reqwest::Method,
24 url: String,
25 headers: HeaderMap,
26 query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30 fn build_request_builder(&self) -> reqwest::RequestBuilder {
31 self.request_client
32 .request(self.method.clone(), self.url.clone())
33 .query(&self.query)
34 .headers(self.headers.clone())
35 }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88pub struct Client<C: Config> {
91 request_client: reqwest::Client,
92 executor: SharedExecutor,
93 config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98 C: Config + std::fmt::Debug,
99{
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("Client")
102 .field("request_client", &self.request_client)
103 .field("config", &self.config)
104 .finish()
105 }
106}
107
108impl<C: Config> Default for Client<C>
109where
110 C: Default,
111{
112 fn default() -> Self {
113 let request_client = reqwest::Client::new();
114 Self {
115 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116 request_client,
117 config: C::default(),
118 }
119 }
120}
121
122impl Client<OpenAIConfig> {
123 pub fn new() -> Self {
125 Self::default()
126 }
127}
128
129impl<C: Config> Client<C> {
130 pub fn build(http_client: reqwest::Client, config: C) -> Self {
132 Self {
133 executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134 request_client: http_client,
135 config,
136 }
137 }
138
139 pub fn with_config(config: C) -> Self {
141 let request_client = reqwest::Client::new();
142 Self {
143 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144 request_client,
145 config,
146 }
147 }
148
149 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153 self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154 self.request_client = http_client;
155 self
156 }
157
158 #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160 pub fn with_http_service<S>(mut self, service: S) -> Self
161 where
162 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163 S::Future: Send + 'static,
164 S::Error: Into<OpenAIError> + Send + Sync + 'static,
165 {
166 self.executor = Arc::new(TowerExecutor::new(service));
171 self
172 }
173
174 #[cfg(all(feature = "middleware", target_family = "wasm"))]
176 pub fn with_http_service<S>(mut self, service: S) -> Self
177 where
178 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179 S::Future: 'static,
180 S::Error: Into<OpenAIError> + 'static,
181 {
182 self.executor = Arc::new(TowerExecutor::new(service));
187 self
188 }
189
190 #[cfg(feature = "model")]
194 pub fn models(&self) -> Models<'_, C> {
195 Models::new(self)
196 }
197
198 #[cfg(feature = "completions")]
200 pub fn completions(&self) -> Completions<'_, C> {
201 Completions::new(self)
202 }
203
204 #[cfg(feature = "chat-completion")]
206 pub fn chat(&self) -> Chat<'_, C> {
207 Chat::new(self)
208 }
209
210 #[cfg(feature = "image")]
212 pub fn images(&self) -> Images<'_, C> {
213 Images::new(self)
214 }
215
216 #[cfg(feature = "moderation")]
218 pub fn moderations(&self) -> Moderations<'_, C> {
219 Moderations::new(self)
220 }
221
222 #[cfg(feature = "file")]
224 pub fn files(&self) -> Files<'_, C> {
225 Files::new(self)
226 }
227
228 #[cfg(feature = "upload")]
230 pub fn uploads(&self) -> Uploads<'_, C> {
231 Uploads::new(self)
232 }
233
234 #[cfg(feature = "finetuning")]
236 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237 FineTuning::new(self)
238 }
239
240 #[cfg(feature = "embedding")]
242 pub fn embeddings(&self) -> Embeddings<'_, C> {
243 Embeddings::new(self)
244 }
245
246 #[cfg(feature = "audio")]
248 pub fn audio(&self) -> Audio<'_, C> {
249 Audio::new(self)
250 }
251
252 #[cfg(feature = "video")]
254 pub fn videos(&self) -> Videos<'_, C> {
255 Videos::new(self)
256 }
257
258 #[cfg(feature = "assistant")]
260 #[deprecated(
261 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262 )]
263 #[allow(deprecated)]
264 pub fn assistants(&self) -> Assistants<'_, C> {
265 Assistants::new(self)
266 }
267
268 #[cfg(feature = "assistant")]
270 #[deprecated(
271 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272 )]
273 #[allow(deprecated)]
274 pub fn threads(&self) -> Threads<'_, C> {
275 Threads::new(self)
276 }
277
278 #[cfg(feature = "vectorstore")]
280 pub fn vector_stores(&self) -> VectorStores<'_, C> {
281 VectorStores::new(self)
282 }
283
284 #[cfg(feature = "batch")]
286 pub fn batches(&self) -> Batches<'_, C> {
287 Batches::new(self)
288 }
289
290 #[cfg(feature = "administration")]
293 pub fn admin(&self) -> Admin<'_, C> {
294 Admin::new(self)
295 }
296
297 #[cfg(feature = "responses")]
299 pub fn responses(&self) -> Responses<'_, C> {
300 Responses::new(self)
301 }
302
303 #[cfg(feature = "responses")]
305 pub fn conversations(&self) -> Conversations<'_, C> {
306 Conversations::new(self)
307 }
308
309 #[cfg(feature = "container")]
311 pub fn containers(&self) -> Containers<'_, C> {
312 Containers::new(self)
313 }
314
315 #[cfg(feature = "skill")]
317 pub fn skills(&self) -> Skills<'_, C> {
318 Skills::new(self)
319 }
320
321 #[cfg(feature = "evals")]
323 pub fn evals(&self) -> Evals<'_, C> {
324 Evals::new(self)
325 }
326
327 #[cfg(feature = "chatkit")]
328 pub fn chatkit(&self) -> Chatkit<'_, C> {
329 Chatkit::new(self)
330 }
331
332 #[cfg(feature = "realtime")]
334 pub fn realtime(&self) -> Realtime<'_, C> {
335 Realtime::new(self)
336 }
337
338 pub fn config(&self) -> &C {
339 &self.config
340 }
341
342 fn build_request_parts(
343 &self,
344 method: reqwest::Method,
345 path: &str,
346 request_options: &RequestOptions,
347 ) -> Arc<RequestParts> {
348 let url = if let Some(path) = request_options.path() {
349 self.config.url(path.as_str())
350 } else {
351 self.config.url(path)
352 };
353 let mut headers = self.config.headers();
354 if let Some(request_headers) = request_options.headers() {
355 headers.extend(request_headers.clone());
356 }
357
358 let mut query = self
359 .config
360 .query()
361 .into_iter()
362 .map(|(key, value)| (key.to_string(), value.to_string()))
363 .collect::<Vec<_>>();
364 query.extend_from_slice(request_options.query());
365
366 Arc::new(RequestParts {
367 request_client: self.request_client.clone(),
368 method,
369 url,
370 headers,
371 query,
372 })
373 }
374
375 fn build_request_factory(
376 &self,
377 method: reqwest::Method,
378 path: &str,
379 request_options: &RequestOptions,
380 ) -> HttpRequestFactory {
381 let request_parts = self.build_request_parts(method, path, request_options);
382
383 HttpRequestFactory::new(move || {
384 let request_parts = request_parts.clone();
385
386 async move {
387 let request = request_parts.build_request_builder().build()?;
388 Ok(request)
389 }
390 })
391 }
392
393 fn build_request_factory_with_json<I>(
394 &self,
395 method: reqwest::Method,
396 path: &str,
397 request: I,
398 request_options: &RequestOptions,
399 ) -> Result<HttpRequestFactory, OpenAIError>
400 where
401 I: Serialize,
402 {
403 let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
406 OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
407 })?);
408 let request_parts = self.build_request_parts(method, path, request_options);
409
410 Ok(HttpRequestFactory::new(move || {
411 let request_parts = request_parts.clone();
412 let request = request.clone();
413
414 async move {
415 let request_builder = request_parts
416 .build_request_builder()
417 .header(reqwest::header::CONTENT_TYPE, "application/json")
418 .body(request.clone());
419
420 Ok(request_builder.build()?)
421 }
422 }))
423 }
424
425 fn build_request_factory_with_form<F>(
426 &self,
427 method: reqwest::Method,
428 path: &str,
429 form: F,
430 request_options: &RequestOptions,
431 ) -> Result<HttpRequestFactory, OpenAIError>
432 where
433 F: Clone + crate::traits::MaybeSend + 'static,
434 Form: AsyncTryFrom<F, Error = OpenAIError>,
435 {
436 #[cfg(not(target_family = "wasm"))]
440 let form = Arc::new(Mutex::new(form));
441 let request_parts = self.build_request_parts(method, path, request_options);
442
443 Ok(HttpRequestFactory::new(move || {
444 let request_parts = request_parts.clone();
445 let form = form.clone();
446
447 async move {
448 #[cfg(not(target_family = "wasm"))]
449 let form = form
450 .lock()
451 .expect("multipart request factory mutex poisoned")
452 .clone();
453 #[cfg(target_family = "wasm")]
454 let form = form.clone();
455 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
456 let request_builder = request_parts.build_request_builder().multipart(form);
457
458 Ok(request_builder.build()?)
459 }
460 }))
461 }
462
463 #[allow(unused)]
465 pub(crate) async fn get<O>(
466 &self,
467 path: &str,
468 request_options: &RequestOptions,
469 ) -> Result<O, OpenAIError>
470 where
471 O: DeserializeOwned,
472 {
473 let request_factory =
474 self.build_request_factory(reqwest::Method::GET, path, request_options);
475 self.execute(request_factory).await
476 }
477
478 #[allow(unused)]
480 pub(crate) async fn delete<O>(
481 &self,
482 path: &str,
483 request_options: &RequestOptions,
484 ) -> Result<O, OpenAIError>
485 where
486 O: DeserializeOwned,
487 {
488 let request_factory =
489 self.build_request_factory(reqwest::Method::DELETE, path, request_options);
490 self.execute(request_factory).await
491 }
492
493 #[allow(unused)]
495 pub(crate) async fn get_raw(
496 &self,
497 path: &str,
498 request_options: &RequestOptions,
499 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
500 let request_factory =
501 self.build_request_factory(reqwest::Method::GET, path, request_options);
502 self.execute_raw(request_factory).await
503 }
504
505 #[allow(unused)]
507 pub(crate) async fn post_raw<I>(
508 &self,
509 path: &str,
510 request: I,
511 request_options: &RequestOptions,
512 ) -> Result<(Bytes, HeaderMap), OpenAIError>
513 where
514 I: Serialize,
515 {
516 let request_factory = self.build_request_factory_with_json(
517 reqwest::Method::POST,
518 path,
519 request,
520 request_options,
521 )?;
522 self.execute_raw(request_factory).await
523 }
524
525 #[allow(unused)]
527 pub(crate) async fn post<I, O>(
528 &self,
529 path: &str,
530 request: I,
531 request_options: &RequestOptions,
532 ) -> Result<O, OpenAIError>
533 where
534 I: Serialize,
535 O: DeserializeOwned,
536 {
537 let request_factory = self.build_request_factory_with_json(
538 reqwest::Method::POST,
539 path,
540 request,
541 request_options,
542 )?;
543 self.execute(request_factory).await
544 }
545
546 #[allow(unused)]
548 pub(crate) async fn post_form_raw<F>(
549 &self,
550 path: &str,
551 form: F,
552 request_options: &RequestOptions,
553 ) -> Result<(Bytes, HeaderMap), OpenAIError>
554 where
555 F: Clone + crate::traits::MaybeSend + 'static,
556 Form: AsyncTryFrom<F, Error = OpenAIError>,
557 {
558 let request_factory = self.build_request_factory_with_form(
559 reqwest::Method::POST,
560 path,
561 form,
562 request_options,
563 )?;
564 self.execute_raw(request_factory).await
565 }
566
567 #[allow(unused)]
569 pub(crate) async fn post_form<O, F>(
570 &self,
571 path: &str,
572 form: F,
573 request_options: &RequestOptions,
574 ) -> Result<O, OpenAIError>
575 where
576 O: DeserializeOwned,
577 F: Clone + crate::traits::MaybeSend + 'static,
578 Form: AsyncTryFrom<F, Error = OpenAIError>,
579 {
580 let request_factory = self.build_request_factory_with_form(
581 reqwest::Method::POST,
582 path,
583 form,
584 request_options,
585 )?;
586 self.execute(request_factory).await
587 }
588
589 #[allow(unused)]
590 pub(crate) async fn post_form_stream<O, F>(
591 &self,
592 path: &str,
593 form: F,
594 request_options: &RequestOptions,
595 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
596 where
597 F: Clone + crate::traits::MaybeSend + 'static,
598 Form: AsyncTryFrom<F, Error = OpenAIError>,
599 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
600 {
601 let request_factory = self.build_request_factory_with_form(
602 reqwest::Method::POST,
603 path,
604 form,
605 request_options,
606 )?;
607
608 self.execute_stream(request_factory).await
609 }
610
611 async fn execute_raw(
612 &self,
613 request_factory: HttpRequestFactory,
614 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
615 let response = self.execute_response(request_factory).await?;
616 read_response(response).await
617 }
618
619 async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
620 where
621 O: DeserializeOwned,
622 {
623 let (bytes, _headers) = self.execute_raw(request_factory).await?;
624
625 let response: O = serde_json::from_slice(bytes.as_ref())
626 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
627
628 Ok(response)
629 }
630
631 async fn execute_response(
632 &self,
633 request_factory: HttpRequestFactory,
634 ) -> Result<Response, OpenAIError> {
635 self.executor.execute(request_factory).await
636 }
637
638 async fn execute_stream<O>(
639 &self,
640 request_factory: HttpRequestFactory,
641 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
642 where
643 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
644 {
645 let response = self.execute_response(request_factory).await?;
646 Ok(stream(response).await)
647 }
648
649 async fn execute_stream_mapped_raw_events<O>(
650 &self,
651 request_factory: HttpRequestFactory,
652 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
653 + crate::traits::MaybeSend
654 + 'static,
655 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
656 where
657 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
658 {
659 let response = self.execute_response(request_factory).await?;
660 Ok(stream_mapped_raw_events(response, event_mapper).await)
661 }
662
663 #[allow(unused)]
665 pub(crate) async fn post_stream<I, O>(
666 &self,
667 path: &str,
668 request: I,
669 request_options: &RequestOptions,
670 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
671 where
672 I: Serialize,
673 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
674 {
675 let request_factory = self.build_request_factory_with_json(
676 reqwest::Method::POST,
677 path,
678 request,
679 request_options,
680 )?;
681 self.execute_stream(request_factory).await
684 }
685
686 #[allow(unused)]
687 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
688 &self,
689 path: &str,
690 request: I,
691 request_options: &RequestOptions,
692 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
693 + crate::traits::MaybeSend
694 + 'static,
695 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
696 where
697 I: Serialize,
698 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
699 {
700 let request_factory = self.build_request_factory_with_json(
701 reqwest::Method::POST,
702 path,
703 request,
704 request_options,
705 )?;
706 self.execute_stream_mapped_raw_events(request_factory, event_mapper)
707 .await
708 }
709
710 #[allow(unused)]
712 pub(crate) async fn get_stream<O>(
713 &self,
714 path: &str,
715 request_options: &RequestOptions,
716 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
717 where
718 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
719 {
720 let request_factory =
721 self.build_request_factory(reqwest::Method::GET, path, request_options);
722 self.execute_stream(request_factory).await
723 }
724}
725
726async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
727 let status = response.status();
728 let headers = response.headers().clone();
729 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
730
731 if status.is_server_error() {
732 let message: String = String::from_utf8_lossy(&bytes).into_owned();
734 tracing::warn!("Server error: {status} - {message}");
735 return Err(OpenAIError::ApiError(ApiErrorResponse {
736 status_code: status,
737 api_error: ApiError {
738 message,
739 r#type: None,
740 param: None,
741 code: None,
742 },
743 }));
744 }
745
746 if !status.is_success() {
748 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
749 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
750
751 return Err(OpenAIError::ApiError(ApiErrorResponse {
752 status_code: status,
753 api_error: wrapped_error.error,
754 }));
755 }
756
757 Ok((bytes, headers))
758}
759
760pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
763where
764 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
765{
766 stream_mapped_raw_events(response, |event| {
767 serde_json::from_str::<O>(&event.data)
768 .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
769 })
770 .await
771}
772
773#[cfg(target_family = "wasm")]
774pub(crate) async fn stream_mapped_raw_events<O>(
775 response: Response,
776 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
777) -> crate::types::stream::StreamResponse<O>
778where
779 O: DeserializeOwned + 'static,
780{
781 if !response.status().is_success() {
782 return Box::pin(futures::stream::once(async move {
783 match read_response(response).await {
784 Ok(_) => Err(OpenAIError::InvalidArgument(
785 "stream request failed without an error body".into(),
786 )),
787 Err(error) => Err(error),
788 }
789 }));
790 }
791
792 let byte_stream = response
793 .bytes_stream()
794 .map(|result| result.map_err(std::io::Error::other));
795 let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
796
797 Box::pin(futures::stream::unfold(
798 (event_stream, event_mapper),
799 |(mut event_stream, event_mapper)| async move {
800 loop {
801 let event = match event_stream.next().await {
802 Some(Ok(event)) => event,
803 Some(Err(error)) => {
804 return Some((
805 Err(OpenAIError::StreamError(Box::new(
806 StreamError::EventStream(error.to_string()),
807 ))),
808 (event_stream, event_mapper),
809 ));
810 }
811 None => return None,
812 };
813
814 if event.data == "[DONE]" {
815 return None;
816 }
817
818 if event.event == "keepalive" {
819 continue;
820 }
821
822 let response = event_mapper(event);
823 return Some((response, (event_stream, event_mapper)));
824 }
825 },
826 ))
827}
828
829#[cfg(not(target_family = "wasm"))]
830pub(crate) async fn stream_mapped_raw_events<O>(
831 response: Response,
832 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
833) -> crate::types::stream::StreamResponse<O>
834where
835 O: DeserializeOwned + std::marker::Send + 'static,
836{
837 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
838
839 tokio::spawn(async move {
840 if !response.status().is_success() {
841 if let Err(e) = read_response(response).await {
842 let _ = tx.send(Err(e));
843 }
844 return;
845 }
846 let byte_stream = response
847 .bytes_stream()
848 .map(|r| r.map_err(std::io::Error::other));
849 let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
850
851 while let Some(ev) = event_stream.next().await {
852 let event = match ev {
853 Ok(e) => e,
854 Err(e) => {
855 let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
856 StreamError::EventStream(e.to_string()),
857 ))));
858 break;
859 }
860 };
861 if event.data == "[DONE]" {
862 break;
863 }
864
865 if event.event == "keepalive" {
866 continue;
867 }
868
869 let response = event_mapper(event);
870
871 if tx.send(response).is_err() {
872 break;
873 }
874 }
875 });
876
877 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
878}
879
880#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
881mod tests {
882 use std::sync::{
883 atomic::{AtomicUsize, Ordering},
884 Arc,
885 };
886
887 use futures::StreamExt;
888 use http::Response as HttpResponse;
889 use serde_json::json;
890 use tower::{service_fn, ServiceBuilder};
891
892 use super::Client;
893 use crate::{
894 config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
895 retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
896 };
897
898 #[tokio::test]
899 async fn unary_requests_dispatch_through_middleware_service() {
900 let request_count = Arc::new(AtomicUsize::new(0));
901 let service = {
902 let request_count = request_count.clone();
903 ServiceBuilder::new()
904 .concurrency_limit(1)
905 .service(service_fn(move |factory: HttpRequestFactory| {
906 let request_count = request_count.clone();
907 async move {
908 let request = factory.build().await?;
909 assert_eq!(request.url().path(), "/models");
910 request_count.fetch_add(1, Ordering::SeqCst);
911 Ok::<reqwest::Response, OpenAIError>(
912 HttpResponse::builder()
913 .status(200)
914 .header("content-type", "application/json")
915 .body(reqwest::Body::from(
916 "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
917 ))
918 .unwrap()
919 .into(),
920 )
921 }
922 }))
923 };
924
925 let client = Client::with_config(
926 OpenAIConfig::new()
927 .with_api_base("http://example.test")
928 .with_api_key("test-key"),
929 )
930 .with_http_service(service);
931
932 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
933
934 assert_eq!(value["object"], "list");
935 assert_eq!(request_count.load(Ordering::SeqCst), 1);
936 }
937
938 #[tokio::test]
939 async fn stream_requests_open_through_middleware_service() {
940 let request_count = Arc::new(AtomicUsize::new(0));
941 let service = {
942 let request_count = request_count.clone();
943 ServiceBuilder::new()
944 .concurrency_limit(1)
945 .service(service_fn(move |factory: HttpRequestFactory| {
946 let request_count = request_count.clone();
947 async move {
948 let request = factory.build().await?;
949 assert_eq!(request.url().path(), "/responses");
950 request_count.fetch_add(1, Ordering::SeqCst);
951 Ok::<reqwest::Response, OpenAIError>(
952 HttpResponse::builder()
953 .status(200)
954 .header("content-type", "text/event-stream")
955 .body(reqwest::Body::from(
956 "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
957 ))
958 .unwrap()
959 .into(),
960 )
961 }
962 }))
963 };
964
965 let client = Client::with_config(
966 OpenAIConfig::new()
967 .with_api_base("http://example.test")
968 .with_api_key("test-key"),
969 )
970 .with_http_service(service);
971
972 let mut stream = client
973 .post_stream::<_, serde_json::Value>(
974 "/responses",
975 json!({ "stream": true }),
976 &RequestOptions::new(),
977 )
978 .await
979 .unwrap();
980
981 let first = stream.next().await.unwrap().unwrap();
982
983 assert_eq!(first, json!({ "ok": true }));
984 assert_eq!(request_count.load(Ordering::SeqCst), 1);
985 }
986
987 #[tokio::test]
988 async fn middleware_retry_policy_retries_429_responses() {
989 let request_count = Arc::new(AtomicUsize::new(0));
990 let service = {
991 let request_count = request_count.clone();
992 ServiceBuilder::new()
993 .retry(SimpleRetryPolicy::default())
994 .service(service_fn(move |factory: HttpRequestFactory| {
995 let request_count = request_count.clone();
996 async move {
997 let request = factory.build().await?;
998 assert_eq!(request.url().path(), "/models");
999 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1000
1001 let response = if attempt == 0 {
1002 HttpResponse::builder()
1003 .status(429)
1004 .header("content-type", "application/json")
1005 .body(reqwest::Body::from(
1006 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1007 ))
1008 .unwrap()
1009 } else {
1010 HttpResponse::builder()
1011 .status(200)
1012 .header("content-type", "application/json")
1013 .body(reqwest::Body::from(
1014 r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1015 ))
1016 .unwrap()
1017 };
1018
1019 Ok::<reqwest::Response, OpenAIError>(response.into())
1020 }
1021 }))
1022 };
1023
1024 let client = Client::with_config(
1025 OpenAIConfig::new()
1026 .with_api_base("http://example.test")
1027 .with_api_key("test-key"),
1028 )
1029 .with_http_service(service);
1030
1031 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1032
1033 assert_eq!(value["data"][0]["id"], "retry-model");
1034 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1035 }
1036
1037 #[derive(Clone)]
1038 struct RetryableMultipartInput {
1039 conversions: Arc<AtomicUsize>,
1040 }
1041
1042 impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1043 type Error = OpenAIError;
1044
1045 async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1046 value.conversions.fetch_add(1, Ordering::SeqCst);
1047 Ok(reqwest::multipart::Form::new().text("field", "value"))
1048 }
1049 }
1050
1051 #[tokio::test]
1052 async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1053 let request_count = Arc::new(AtomicUsize::new(0));
1054 let conversion_count = Arc::new(AtomicUsize::new(0));
1055
1056 let service = {
1057 let request_count = request_count.clone();
1058 ServiceBuilder::new()
1059 .retry(SimpleRetryPolicy::default())
1060 .service(service_fn(move |factory: HttpRequestFactory| {
1061 let request_count = request_count.clone();
1062 async move {
1063 let request = factory.build().await?;
1064 assert_eq!(request.method(), reqwest::Method::POST);
1065 assert_eq!(request.url().path(), "/files");
1066 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1067
1068 let response = if attempt == 0 {
1069 HttpResponse::builder()
1070 .status(429)
1071 .header("content-type", "application/json")
1072 .body(reqwest::Body::from(
1073 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1074 ))
1075 .unwrap()
1076 } else {
1077 HttpResponse::builder()
1078 .status(200)
1079 .header("content-type", "application/json")
1080 .body(reqwest::Body::from(r#"{"ok":true}"#))
1081 .unwrap()
1082 };
1083
1084 Ok::<reqwest::Response, OpenAIError>(response.into())
1085 }
1086 }))
1087 };
1088
1089 let client = Client::with_config(
1090 OpenAIConfig::new()
1091 .with_api_base("http://example.test")
1092 .with_api_key("test-key"),
1093 )
1094 .with_http_service(service);
1095
1096 let value: serde_json::Value = client
1097 .post_form(
1098 "/files",
1099 RetryableMultipartInput {
1100 conversions: conversion_count.clone(),
1101 },
1102 &RequestOptions::new(),
1103 )
1104 .await
1105 .unwrap();
1106
1107 assert_eq!(value, json!({ "ok": true }));
1108 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1109 assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1110 }
1111}