1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14 config::{Config, OpenAIConfig},
15 error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
16 executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17 traits::AsyncTryFrom,
18 RequestOptions,
19};
20
21struct RequestParts {
22 request_client: reqwest::Client,
23 method: reqwest::Method,
24 url: String,
25 headers: HeaderMap,
26 query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30 fn build_request_builder(&self) -> reqwest::RequestBuilder {
31 self.request_client
32 .request(self.method.clone(), self.url.clone())
33 .query(&self.query)
34 .headers(self.headers.clone())
35 }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88pub struct Client<C: Config> {
91 request_client: reqwest::Client,
92 executor: SharedExecutor,
93 config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98 C: Config + std::fmt::Debug,
99{
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("Client")
102 .field("request_client", &self.request_client)
103 .field("config", &self.config)
104 .finish()
105 }
106}
107
108impl<C: Config> Default for Client<C>
109where
110 C: Default,
111{
112 fn default() -> Self {
113 let request_client = reqwest::Client::new();
114 Self {
115 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116 request_client,
117 config: C::default(),
118 }
119 }
120}
121
122impl Client<OpenAIConfig> {
123 pub fn new() -> Self {
125 Self::default()
126 }
127}
128
129impl<C: Config> Client<C> {
130 pub fn build(http_client: reqwest::Client, config: C) -> Self {
132 Self {
133 executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134 request_client: http_client,
135 config,
136 }
137 }
138
139 pub fn with_config(config: C) -> Self {
141 let request_client = reqwest::Client::new();
142 Self {
143 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144 request_client,
145 config,
146 }
147 }
148
149 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153 self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154 self.request_client = http_client;
155 self
156 }
157
158 #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160 pub fn with_http_service<S>(mut self, service: S) -> Self
161 where
162 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163 S::Future: Send + 'static,
164 S::Error: Into<OpenAIError> + Send + Sync + 'static,
165 {
166 self.executor = Arc::new(TowerExecutor::new(service));
171 self
172 }
173
174 #[cfg(all(feature = "middleware", target_family = "wasm"))]
176 pub fn with_http_service<S>(mut self, service: S) -> Self
177 where
178 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179 S::Future: 'static,
180 S::Error: Into<OpenAIError> + 'static,
181 {
182 self.executor = Arc::new(TowerExecutor::new(service));
187 self
188 }
189
190 #[cfg(feature = "model")]
194 pub fn models(&self) -> Models<'_, C> {
195 Models::new(self)
196 }
197
198 #[cfg(feature = "completions")]
200 pub fn completions(&self) -> Completions<'_, C> {
201 Completions::new(self)
202 }
203
204 #[cfg(feature = "chat-completion")]
206 pub fn chat(&self) -> Chat<'_, C> {
207 Chat::new(self)
208 }
209
210 #[cfg(feature = "image")]
212 pub fn images(&self) -> Images<'_, C> {
213 Images::new(self)
214 }
215
216 #[cfg(feature = "moderation")]
218 pub fn moderations(&self) -> Moderations<'_, C> {
219 Moderations::new(self)
220 }
221
222 #[cfg(feature = "file")]
224 pub fn files(&self) -> Files<'_, C> {
225 Files::new(self)
226 }
227
228 #[cfg(feature = "upload")]
230 pub fn uploads(&self) -> Uploads<'_, C> {
231 Uploads::new(self)
232 }
233
234 #[cfg(feature = "finetuning")]
236 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237 FineTuning::new(self)
238 }
239
240 #[cfg(feature = "embedding")]
242 pub fn embeddings(&self) -> Embeddings<'_, C> {
243 Embeddings::new(self)
244 }
245
246 #[cfg(feature = "audio")]
248 pub fn audio(&self) -> Audio<'_, C> {
249 Audio::new(self)
250 }
251
252 #[cfg(feature = "video")]
254 pub fn videos(&self) -> Videos<'_, C> {
255 Videos::new(self)
256 }
257
258 #[cfg(feature = "assistant")]
260 #[deprecated(
261 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262 )]
263 #[allow(deprecated)]
264 pub fn assistants(&self) -> Assistants<'_, C> {
265 Assistants::new(self)
266 }
267
268 #[cfg(feature = "assistant")]
270 #[deprecated(
271 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272 )]
273 #[allow(deprecated)]
274 pub fn threads(&self) -> Threads<'_, C> {
275 Threads::new(self)
276 }
277
278 #[cfg(feature = "vectorstore")]
280 pub fn vector_stores(&self) -> VectorStores<'_, C> {
281 VectorStores::new(self)
282 }
283
284 #[cfg(feature = "batch")]
286 pub fn batches(&self) -> Batches<'_, C> {
287 Batches::new(self)
288 }
289
290 #[cfg(feature = "administration")]
293 pub fn admin(&self) -> Admin<'_, C> {
294 Admin::new(self)
295 }
296
297 #[cfg(feature = "responses")]
299 pub fn responses(&self) -> Responses<'_, C> {
300 Responses::new(self)
301 }
302
303 #[cfg(feature = "responses")]
305 pub fn conversations(&self) -> Conversations<'_, C> {
306 Conversations::new(self)
307 }
308
309 #[cfg(feature = "container")]
311 pub fn containers(&self) -> Containers<'_, C> {
312 Containers::new(self)
313 }
314
315 #[cfg(feature = "skill")]
317 pub fn skills(&self) -> Skills<'_, C> {
318 Skills::new(self)
319 }
320
321 #[cfg(feature = "evals")]
323 pub fn evals(&self) -> Evals<'_, C> {
324 Evals::new(self)
325 }
326
327 #[cfg(feature = "chatkit")]
328 pub fn chatkit(&self) -> Chatkit<'_, C> {
329 Chatkit::new(self)
330 }
331
332 #[cfg(feature = "realtime")]
334 pub fn realtime(&self) -> Realtime<'_, C> {
335 Realtime::new(self)
336 }
337
338 pub fn config(&self) -> &C {
339 &self.config
340 }
341
342 fn build_request_parts(
343 &self,
344 method: reqwest::Method,
345 path: &str,
346 request_options: &RequestOptions,
347 ) -> Arc<RequestParts> {
348 let url = if let Some(path) = request_options.path() {
349 self.config.url(path.as_str())
350 } else {
351 self.config.url(path)
352 };
353 let mut headers = self.config.headers();
354 if let Some(request_headers) = request_options.headers() {
355 headers.extend(request_headers.clone());
356 }
357
358 let mut query = self
359 .config
360 .query()
361 .into_iter()
362 .map(|(key, value)| (key.to_string(), value.to_string()))
363 .collect::<Vec<_>>();
364 query.extend_from_slice(request_options.query());
365
366 Arc::new(RequestParts {
367 request_client: self.request_client.clone(),
368 method,
369 url,
370 headers,
371 query,
372 })
373 }
374
375 fn build_request_factory(
376 &self,
377 method: reqwest::Method,
378 path: &str,
379 request_options: &RequestOptions,
380 ) -> HttpRequestFactory {
381 let request_parts = self.build_request_parts(method, path, request_options);
382
383 HttpRequestFactory::new(move || {
384 let request_parts = request_parts.clone();
385
386 async move {
387 let request = request_parts.build_request_builder().build()?;
388 Ok(request)
389 }
390 })
391 }
392
393 fn build_request_factory_with_json<I>(
394 &self,
395 method: reqwest::Method,
396 path: &str,
397 request: I,
398 request_options: &RequestOptions,
399 ) -> Result<HttpRequestFactory, OpenAIError>
400 where
401 I: Serialize,
402 {
403 let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
406 OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
407 })?);
408 let request_parts = self.build_request_parts(method, path, request_options);
409
410 Ok(HttpRequestFactory::new(move || {
411 let request_parts = request_parts.clone();
412 let request = request.clone();
413
414 async move {
415 let request_builder = request_parts
416 .build_request_builder()
417 .header(reqwest::header::CONTENT_TYPE, "application/json")
418 .body(request.clone());
419
420 Ok(request_builder.build()?)
421 }
422 }))
423 }
424
425 fn build_request_factory_with_form<F>(
426 &self,
427 method: reqwest::Method,
428 path: &str,
429 form: F,
430 request_options: &RequestOptions,
431 ) -> Result<HttpRequestFactory, OpenAIError>
432 where
433 F: Clone + crate::traits::MaybeSend + 'static,
434 Form: AsyncTryFrom<F, Error = OpenAIError>,
435 {
436 #[cfg(not(target_family = "wasm"))]
440 let form = Arc::new(Mutex::new(form));
441 let request_parts = self.build_request_parts(method, path, request_options);
442
443 Ok(HttpRequestFactory::new(move || {
444 let request_parts = request_parts.clone();
445 let form = form.clone();
446
447 async move {
448 #[cfg(not(target_family = "wasm"))]
449 let form = form
450 .lock()
451 .expect("multipart request factory mutex poisoned")
452 .clone();
453 #[cfg(target_family = "wasm")]
454 let form = form.clone();
455 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
456 let request_builder = request_parts.build_request_builder().multipart(form);
457
458 Ok(request_builder.build()?)
459 }
460 }))
461 }
462
463 #[allow(unused)]
465 pub(crate) async fn get<O>(
466 &self,
467 path: &str,
468 request_options: &RequestOptions,
469 ) -> Result<O, OpenAIError>
470 where
471 O: DeserializeOwned,
472 {
473 let request_factory =
474 self.build_request_factory(reqwest::Method::GET, path, request_options);
475 self.execute(request_factory).await
476 }
477
478 #[allow(unused)]
480 pub(crate) async fn delete<O>(
481 &self,
482 path: &str,
483 request_options: &RequestOptions,
484 ) -> Result<O, OpenAIError>
485 where
486 O: DeserializeOwned,
487 {
488 let request_factory =
489 self.build_request_factory(reqwest::Method::DELETE, path, request_options);
490 self.execute(request_factory).await
491 }
492
493 #[allow(unused)]
495 pub(crate) async fn get_raw(
496 &self,
497 path: &str,
498 request_options: &RequestOptions,
499 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
500 let request_factory =
501 self.build_request_factory(reqwest::Method::GET, path, request_options);
502 self.execute_raw(request_factory).await
503 }
504
505 #[allow(unused)]
507 pub(crate) async fn post_raw<I>(
508 &self,
509 path: &str,
510 request: I,
511 request_options: &RequestOptions,
512 ) -> Result<(Bytes, HeaderMap), OpenAIError>
513 where
514 I: Serialize,
515 {
516 let request_factory = self.build_request_factory_with_json(
517 reqwest::Method::POST,
518 path,
519 request,
520 request_options,
521 )?;
522 self.execute_raw(request_factory).await
523 }
524
525 #[allow(unused)]
527 pub(crate) async fn post<I, O>(
528 &self,
529 path: &str,
530 request: I,
531 request_options: &RequestOptions,
532 ) -> Result<O, OpenAIError>
533 where
534 I: Serialize,
535 O: DeserializeOwned,
536 {
537 let request_factory = self.build_request_factory_with_json(
538 reqwest::Method::POST,
539 path,
540 request,
541 request_options,
542 )?;
543 self.execute(request_factory).await
544 }
545
546 #[allow(unused)]
548 pub(crate) async fn post_form_raw<F>(
549 &self,
550 path: &str,
551 form: F,
552 request_options: &RequestOptions,
553 ) -> Result<(Bytes, HeaderMap), OpenAIError>
554 where
555 F: Clone + crate::traits::MaybeSend + 'static,
556 Form: AsyncTryFrom<F, Error = OpenAIError>,
557 {
558 let request_factory = self.build_request_factory_with_form(
559 reqwest::Method::POST,
560 path,
561 form,
562 request_options,
563 )?;
564 self.execute_raw(request_factory).await
565 }
566
567 #[allow(unused)]
569 pub(crate) async fn post_form<O, F>(
570 &self,
571 path: &str,
572 form: F,
573 request_options: &RequestOptions,
574 ) -> Result<O, OpenAIError>
575 where
576 O: DeserializeOwned,
577 F: Clone + crate::traits::MaybeSend + 'static,
578 Form: AsyncTryFrom<F, Error = OpenAIError>,
579 {
580 let request_factory = self.build_request_factory_with_form(
581 reqwest::Method::POST,
582 path,
583 form,
584 request_options,
585 )?;
586 self.execute(request_factory).await
587 }
588
589 #[allow(unused)]
590 pub(crate) async fn post_form_stream<O, F>(
591 &self,
592 path: &str,
593 form: F,
594 request_options: &RequestOptions,
595 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
596 where
597 F: Clone + crate::traits::MaybeSend + 'static,
598 Form: AsyncTryFrom<F, Error = OpenAIError>,
599 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
600 {
601 let request_factory = self.build_request_factory_with_form(
602 reqwest::Method::POST,
603 path,
604 form,
605 request_options,
606 )?;
607
608 self.execute_stream(request_factory).await
609 }
610
611 async fn execute_raw(
612 &self,
613 request_factory: HttpRequestFactory,
614 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
615 let response = self.execute_response(request_factory).await?;
616 read_response(response).await
617 }
618
619 async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
620 where
621 O: DeserializeOwned,
622 {
623 let (bytes, _headers) = self.execute_raw(request_factory).await?;
624
625 let response: O = serde_json::from_slice(bytes.as_ref())
626 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
627
628 Ok(response)
629 }
630
631 async fn execute_response(
632 &self,
633 request_factory: HttpRequestFactory,
634 ) -> Result<Response, OpenAIError> {
635 self.executor.execute(request_factory).await
636 }
637
638 async fn execute_stream<O>(
639 &self,
640 request_factory: HttpRequestFactory,
641 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
642 where
643 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
644 {
645 let response = self.execute_response(request_factory).await?;
646 Ok(stream(response).await)
647 }
648
649 async fn execute_stream_mapped_raw_events<O>(
650 &self,
651 request_factory: HttpRequestFactory,
652 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
653 + crate::traits::MaybeSend
654 + 'static,
655 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
656 where
657 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
658 {
659 let response = self.execute_response(request_factory).await?;
660 Ok(stream_mapped_raw_events(response, event_mapper).await)
661 }
662
663 #[allow(unused)]
665 pub(crate) async fn post_stream<I, O>(
666 &self,
667 path: &str,
668 request: I,
669 request_options: &RequestOptions,
670 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
671 where
672 I: Serialize,
673 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
674 {
675 let request_factory = self.build_request_factory_with_json(
676 reqwest::Method::POST,
677 path,
678 request,
679 request_options,
680 )?;
681 self.execute_stream(request_factory).await
684 }
685
686 #[allow(unused)]
687 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
688 &self,
689 path: &str,
690 request: I,
691 request_options: &RequestOptions,
692 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
693 + crate::traits::MaybeSend
694 + 'static,
695 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
696 where
697 I: Serialize,
698 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
699 {
700 let request_factory = self.build_request_factory_with_json(
701 reqwest::Method::POST,
702 path,
703 request,
704 request_options,
705 )?;
706 self.execute_stream_mapped_raw_events(request_factory, event_mapper)
707 .await
708 }
709
710 #[allow(unused)]
712 pub(crate) async fn get_stream<O>(
713 &self,
714 path: &str,
715 request_options: &RequestOptions,
716 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
717 where
718 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
719 {
720 let request_factory =
721 self.build_request_factory(reqwest::Method::GET, path, request_options);
722 self.execute_stream(request_factory).await
723 }
724}
725
726async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
727 let status = response.status();
728 let headers = response.headers().clone();
729 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
730
731 if status.is_server_error() {
732 let message: String = String::from_utf8_lossy(&bytes).into_owned();
734 tracing::warn!("Server error: {status} - {message}");
735 return Err(OpenAIError::ApiError(ApiError {
736 message,
737 r#type: None,
738 param: None,
739 code: None,
740 }));
741 }
742
743 if !status.is_success() {
745 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
746 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
747
748 return Err(OpenAIError::ApiError(wrapped_error.error));
749 }
750
751 Ok((bytes, headers))
752}
753
754pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
757where
758 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
759{
760 stream_mapped_raw_events(response, |event| {
761 serde_json::from_str::<O>(&event.data)
762 .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
763 })
764 .await
765}
766
767#[cfg(target_family = "wasm")]
768pub(crate) async fn stream_mapped_raw_events<O>(
769 response: Response,
770 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
771) -> crate::types::stream::StreamResponse<O>
772where
773 O: DeserializeOwned + 'static,
774{
775 if !response.status().is_success() {
776 return Box::pin(futures::stream::once(async move {
777 match read_response(response).await {
778 Ok(_) => Err(OpenAIError::InvalidArgument(
779 "stream request failed without an error body".into(),
780 )),
781 Err(error) => Err(error),
782 }
783 }));
784 }
785
786 let byte_stream = response
787 .bytes_stream()
788 .map(|result| result.map_err(std::io::Error::other));
789 let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
790
791 Box::pin(futures::stream::unfold(
792 (event_stream, event_mapper),
793 |(mut event_stream, event_mapper)| async move {
794 loop {
795 let event = match event_stream.next().await {
796 Some(Ok(event)) => event,
797 Some(Err(error)) => {
798 return Some((
799 Err(OpenAIError::StreamError(Box::new(
800 StreamError::EventStream(error.to_string()),
801 ))),
802 (event_stream, event_mapper),
803 ));
804 }
805 None => return None,
806 };
807
808 if event.data == "[DONE]" {
809 return None;
810 }
811
812 if event.event == "keepalive" {
813 continue;
814 }
815
816 let response = event_mapper(event);
817 return Some((response, (event_stream, event_mapper)));
818 }
819 },
820 ))
821}
822
823#[cfg(not(target_family = "wasm"))]
824pub(crate) async fn stream_mapped_raw_events<O>(
825 response: Response,
826 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
827) -> crate::types::stream::StreamResponse<O>
828where
829 O: DeserializeOwned + std::marker::Send + 'static,
830{
831 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
832
833 tokio::spawn(async move {
834 if !response.status().is_success() {
835 if let Err(e) = read_response(response).await {
836 let _ = tx.send(Err(e));
837 }
838 return;
839 }
840 let byte_stream = response
841 .bytes_stream()
842 .map(|r| r.map_err(std::io::Error::other));
843 let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
844
845 while let Some(ev) = event_stream.next().await {
846 let event = match ev {
847 Ok(e) => e,
848 Err(e) => {
849 let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
850 StreamError::EventStream(e.to_string()),
851 ))));
852 break;
853 }
854 };
855 if event.data == "[DONE]" {
856 break;
857 }
858
859 if event.event == "keepalive" {
860 continue;
861 }
862
863 let response = event_mapper(event);
864
865 if tx.send(response).is_err() {
866 break;
867 }
868 }
869 });
870
871 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
872}
873
874#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
875mod tests {
876 use std::sync::{
877 atomic::{AtomicUsize, Ordering},
878 Arc,
879 };
880
881 use futures::StreamExt;
882 use http::Response as HttpResponse;
883 use serde_json::json;
884 use tower::{service_fn, ServiceBuilder};
885
886 use super::Client;
887 use crate::{
888 config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
889 retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
890 };
891
892 #[tokio::test]
893 async fn unary_requests_dispatch_through_middleware_service() {
894 let request_count = Arc::new(AtomicUsize::new(0));
895 let service = {
896 let request_count = request_count.clone();
897 ServiceBuilder::new()
898 .concurrency_limit(1)
899 .service(service_fn(move |factory: HttpRequestFactory| {
900 let request_count = request_count.clone();
901 async move {
902 let request = factory.build().await?;
903 assert_eq!(request.url().path(), "/models");
904 request_count.fetch_add(1, Ordering::SeqCst);
905 Ok::<reqwest::Response, OpenAIError>(
906 HttpResponse::builder()
907 .status(200)
908 .header("content-type", "application/json")
909 .body(reqwest::Body::from(
910 "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
911 ))
912 .unwrap()
913 .into(),
914 )
915 }
916 }))
917 };
918
919 let client = Client::with_config(
920 OpenAIConfig::new()
921 .with_api_base("http://example.test")
922 .with_api_key("test-key"),
923 )
924 .with_http_service(service);
925
926 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
927
928 assert_eq!(value["object"], "list");
929 assert_eq!(request_count.load(Ordering::SeqCst), 1);
930 }
931
932 #[tokio::test]
933 async fn stream_requests_open_through_middleware_service() {
934 let request_count = Arc::new(AtomicUsize::new(0));
935 let service = {
936 let request_count = request_count.clone();
937 ServiceBuilder::new()
938 .concurrency_limit(1)
939 .service(service_fn(move |factory: HttpRequestFactory| {
940 let request_count = request_count.clone();
941 async move {
942 let request = factory.build().await?;
943 assert_eq!(request.url().path(), "/responses");
944 request_count.fetch_add(1, Ordering::SeqCst);
945 Ok::<reqwest::Response, OpenAIError>(
946 HttpResponse::builder()
947 .status(200)
948 .header("content-type", "text/event-stream")
949 .body(reqwest::Body::from(
950 "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
951 ))
952 .unwrap()
953 .into(),
954 )
955 }
956 }))
957 };
958
959 let client = Client::with_config(
960 OpenAIConfig::new()
961 .with_api_base("http://example.test")
962 .with_api_key("test-key"),
963 )
964 .with_http_service(service);
965
966 let mut stream = client
967 .post_stream::<_, serde_json::Value>(
968 "/responses",
969 json!({ "stream": true }),
970 &RequestOptions::new(),
971 )
972 .await
973 .unwrap();
974
975 let first = stream.next().await.unwrap().unwrap();
976
977 assert_eq!(first, json!({ "ok": true }));
978 assert_eq!(request_count.load(Ordering::SeqCst), 1);
979 }
980
981 #[tokio::test]
982 async fn middleware_retry_policy_retries_429_responses() {
983 let request_count = Arc::new(AtomicUsize::new(0));
984 let service = {
985 let request_count = request_count.clone();
986 ServiceBuilder::new()
987 .retry(SimpleRetryPolicy::default())
988 .service(service_fn(move |factory: HttpRequestFactory| {
989 let request_count = request_count.clone();
990 async move {
991 let request = factory.build().await?;
992 assert_eq!(request.url().path(), "/models");
993 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
994
995 let response = if attempt == 0 {
996 HttpResponse::builder()
997 .status(429)
998 .header("content-type", "application/json")
999 .body(reqwest::Body::from(
1000 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1001 ))
1002 .unwrap()
1003 } else {
1004 HttpResponse::builder()
1005 .status(200)
1006 .header("content-type", "application/json")
1007 .body(reqwest::Body::from(
1008 r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1009 ))
1010 .unwrap()
1011 };
1012
1013 Ok::<reqwest::Response, OpenAIError>(response.into())
1014 }
1015 }))
1016 };
1017
1018 let client = Client::with_config(
1019 OpenAIConfig::new()
1020 .with_api_base("http://example.test")
1021 .with_api_key("test-key"),
1022 )
1023 .with_http_service(service);
1024
1025 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1026
1027 assert_eq!(value["data"][0]["id"], "retry-model");
1028 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1029 }
1030
1031 #[derive(Clone)]
1032 struct RetryableMultipartInput {
1033 conversions: Arc<AtomicUsize>,
1034 }
1035
1036 impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1037 type Error = OpenAIError;
1038
1039 async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1040 value.conversions.fetch_add(1, Ordering::SeqCst);
1041 Ok(reqwest::multipart::Form::new().text("field", "value"))
1042 }
1043 }
1044
1045 #[tokio::test]
1046 async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1047 let request_count = Arc::new(AtomicUsize::new(0));
1048 let conversion_count = Arc::new(AtomicUsize::new(0));
1049
1050 let service = {
1051 let request_count = request_count.clone();
1052 ServiceBuilder::new()
1053 .retry(SimpleRetryPolicy::default())
1054 .service(service_fn(move |factory: HttpRequestFactory| {
1055 let request_count = request_count.clone();
1056 async move {
1057 let request = factory.build().await?;
1058 assert_eq!(request.method(), reqwest::Method::POST);
1059 assert_eq!(request.url().path(), "/files");
1060 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1061
1062 let response = if attempt == 0 {
1063 HttpResponse::builder()
1064 .status(429)
1065 .header("content-type", "application/json")
1066 .body(reqwest::Body::from(
1067 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1068 ))
1069 .unwrap()
1070 } else {
1071 HttpResponse::builder()
1072 .status(200)
1073 .header("content-type", "application/json")
1074 .body(reqwest::Body::from(r#"{"ok":true}"#))
1075 .unwrap()
1076 };
1077
1078 Ok::<reqwest::Response, OpenAIError>(response.into())
1079 }
1080 }))
1081 };
1082
1083 let client = Client::with_config(
1084 OpenAIConfig::new()
1085 .with_api_base("http://example.test")
1086 .with_api_key("test-key"),
1087 )
1088 .with_http_service(service);
1089
1090 let value: serde_json::Value = client
1091 .post_form(
1092 "/files",
1093 RetryableMultipartInput {
1094 conversions: conversion_count.clone(),
1095 },
1096 &RequestOptions::new(),
1097 )
1098 .await
1099 .unwrap();
1100
1101 assert_eq!(value, json!({ "ok": true }));
1102 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1103 assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1104 }
1105}