1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14 config::{Config, OpenAIConfig},
15 error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
16 executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17 traits::AsyncTryFrom,
18 RequestOptions,
19};
20
21struct RequestParts {
22 request_client: reqwest::Client,
23 method: reqwest::Method,
24 url: String,
25 headers: HeaderMap,
26 query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30 fn build_request_builder(&self) -> reqwest::RequestBuilder {
31 self.request_client
32 .request(self.method.clone(), self.url.clone())
33 .query(&self.query)
34 .headers(self.headers.clone())
35 }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88pub struct Client<C: Config> {
91 request_client: reqwest::Client,
92 executor: SharedExecutor,
93 config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98 C: Config + std::fmt::Debug,
99{
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("Client")
102 .field("request_client", &self.request_client)
103 .field("config", &self.config)
104 .finish()
105 }
106}
107
108impl<C: Config> Default for Client<C>
109where
110 C: Default,
111{
112 fn default() -> Self {
113 let request_client = reqwest::Client::new();
114 Self {
115 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116 request_client,
117 config: C::default(),
118 }
119 }
120}
121
122impl Client<OpenAIConfig> {
123 pub fn new() -> Self {
125 Self::default()
126 }
127}
128
129impl<C: Config> Client<C> {
130 pub fn build(http_client: reqwest::Client, config: C) -> Self {
132 Self {
133 executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134 request_client: http_client,
135 config,
136 }
137 }
138
139 pub fn with_config(config: C) -> Self {
141 let request_client = reqwest::Client::new();
142 Self {
143 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144 request_client,
145 config,
146 }
147 }
148
149 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153 self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154 self.request_client = http_client;
155 self
156 }
157
158 #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160 pub fn with_http_service<S>(mut self, service: S) -> Self
161 where
162 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163 S::Future: Send + 'static,
164 S::Error: Into<OpenAIError> + Send + Sync + 'static,
165 {
166 self.executor = Arc::new(TowerExecutor::new(service));
171 self
172 }
173
174 #[cfg(all(feature = "middleware", target_family = "wasm"))]
176 pub fn with_http_service<S>(mut self, service: S) -> Self
177 where
178 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179 S::Future: 'static,
180 S::Error: Into<OpenAIError> + 'static,
181 {
182 self.executor = Arc::new(TowerExecutor::new(service));
187 self
188 }
189
190 #[cfg(feature = "model")]
194 pub fn models(&self) -> Models<'_, C> {
195 Models::new(self)
196 }
197
198 #[cfg(feature = "completions")]
200 pub fn completions(&self) -> Completions<'_, C> {
201 Completions::new(self)
202 }
203
204 #[cfg(feature = "chat-completion")]
206 pub fn chat(&self) -> Chat<'_, C> {
207 Chat::new(self)
208 }
209
210 #[cfg(feature = "image")]
212 pub fn images(&self) -> Images<'_, C> {
213 Images::new(self)
214 }
215
216 #[cfg(feature = "moderation")]
218 pub fn moderations(&self) -> Moderations<'_, C> {
219 Moderations::new(self)
220 }
221
222 #[cfg(feature = "file")]
224 pub fn files(&self) -> Files<'_, C> {
225 Files::new(self)
226 }
227
228 #[cfg(feature = "upload")]
230 pub fn uploads(&self) -> Uploads<'_, C> {
231 Uploads::new(self)
232 }
233
234 #[cfg(feature = "finetuning")]
236 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237 FineTuning::new(self)
238 }
239
240 #[cfg(feature = "embedding")]
242 pub fn embeddings(&self) -> Embeddings<'_, C> {
243 Embeddings::new(self)
244 }
245
246 #[cfg(feature = "audio")]
248 pub fn audio(&self) -> Audio<'_, C> {
249 Audio::new(self)
250 }
251
252 #[cfg(feature = "video")]
254 pub fn videos(&self) -> Videos<'_, C> {
255 Videos::new(self)
256 }
257
258 #[cfg(feature = "assistant")]
260 #[deprecated(
261 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262 )]
263 #[allow(deprecated)]
264 pub fn assistants(&self) -> Assistants<'_, C> {
265 Assistants::new(self)
266 }
267
268 #[cfg(feature = "assistant")]
270 #[deprecated(
271 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272 )]
273 #[allow(deprecated)]
274 pub fn threads(&self) -> Threads<'_, C> {
275 Threads::new(self)
276 }
277
278 #[cfg(feature = "vectorstore")]
280 pub fn vector_stores(&self) -> VectorStores<'_, C> {
281 VectorStores::new(self)
282 }
283
284 #[cfg(feature = "batch")]
286 pub fn batches(&self) -> Batches<'_, C> {
287 Batches::new(self)
288 }
289
290 #[cfg(feature = "administration")]
293 pub fn admin(&self) -> Admin<'_, C> {
294 Admin::new(self)
295 }
296
297 #[cfg(feature = "responses")]
299 pub fn responses(&self) -> Responses<'_, C> {
300 Responses::new(self)
301 }
302
303 #[cfg(feature = "responses")]
305 pub fn conversations(&self) -> Conversations<'_, C> {
306 Conversations::new(self)
307 }
308
309 #[cfg(feature = "container")]
311 pub fn containers(&self) -> Containers<'_, C> {
312 Containers::new(self)
313 }
314
315 #[cfg(feature = "skill")]
317 pub fn skills(&self) -> Skills<'_, C> {
318 Skills::new(self)
319 }
320
321 #[cfg(feature = "evals")]
323 pub fn evals(&self) -> Evals<'_, C> {
324 Evals::new(self)
325 }
326
327 #[cfg(feature = "chatkit")]
328 pub fn chatkit(&self) -> Chatkit<'_, C> {
329 Chatkit::new(self)
330 }
331
332 #[cfg(feature = "realtime")]
334 pub fn realtime(&self) -> Realtime<'_, C> {
335 Realtime::new(self)
336 }
337
338 pub fn config(&self) -> &C {
339 &self.config
340 }
341
342 fn build_request_parts(
343 &self,
344 method: reqwest::Method,
345 path: &str,
346 request_options: &RequestOptions,
347 ) -> Arc<RequestParts> {
348 let url = if let Some(path) = request_options.path() {
349 self.config.url(path.as_str())
350 } else {
351 self.config.url(path)
352 };
353 let mut headers = self.config.headers();
354 if let Some(request_headers) = request_options.headers() {
355 headers.extend(request_headers.clone());
356 }
357
358 let mut query = self
359 .config
360 .query()
361 .into_iter()
362 .map(|(key, value)| (key.to_string(), value.to_string()))
363 .collect::<Vec<_>>();
364 query.extend_from_slice(request_options.query());
365
366 Arc::new(RequestParts {
367 request_client: self.request_client.clone(),
368 method,
369 url,
370 headers,
371 query,
372 })
373 }
374
375 fn build_request_factory(
376 &self,
377 method: reqwest::Method,
378 path: &str,
379 request_options: &RequestOptions,
380 ) -> HttpRequestFactory {
381 let request_parts = self.build_request_parts(method, path, request_options);
382
383 HttpRequestFactory::new(move || {
384 let request_parts = request_parts.clone();
385
386 async move {
387 let request = request_parts.build_request_builder().build()?;
388 Ok(request)
389 }
390 })
391 }
392
393 fn build_request_factory_with_json<I>(
394 &self,
395 method: reqwest::Method,
396 path: &str,
397 request: I,
398 request_options: &RequestOptions,
399 ) -> Result<HttpRequestFactory, OpenAIError>
400 where
401 I: Serialize,
402 {
403 let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
408 OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
409 })?);
410 let request_parts = self.build_request_parts(method, path, request_options);
411
412 Ok(HttpRequestFactory::new(move || {
413 let request_parts = request_parts.clone();
414 let request = request.clone();
415
416 async move {
417 let request_builder = request_parts
418 .build_request_builder()
419 .header(reqwest::header::CONTENT_TYPE, "application/json")
420 .body(request.clone());
421
422 Ok(request_builder.build()?)
423 }
424 }))
425 }
426
427 fn build_request_factory_with_form<F>(
428 &self,
429 method: reqwest::Method,
430 path: &str,
431 form: F,
432 request_options: &RequestOptions,
433 ) -> Result<HttpRequestFactory, OpenAIError>
434 where
435 F: Clone + crate::traits::MaybeSend + 'static,
436 Form: AsyncTryFrom<F, Error = OpenAIError>,
437 {
438 #[cfg(not(target_family = "wasm"))]
442 let form = Arc::new(Mutex::new(form));
443 let request_parts = self.build_request_parts(method, path, request_options);
444
445 Ok(HttpRequestFactory::new(move || {
446 let request_parts = request_parts.clone();
447 let form = form.clone();
448
449 async move {
450 #[cfg(not(target_family = "wasm"))]
451 let form = form
452 .lock()
453 .expect("multipart request factory mutex poisoned")
454 .clone();
455 #[cfg(target_family = "wasm")]
456 let form = form.clone();
457 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
458 let request_builder = request_parts.build_request_builder().multipart(form);
459
460 Ok(request_builder.build()?)
461 }
462 }))
463 }
464
465 #[allow(unused)]
467 pub(crate) async fn get<O>(
468 &self,
469 path: &str,
470 request_options: &RequestOptions,
471 ) -> Result<O, OpenAIError>
472 where
473 O: DeserializeOwned,
474 {
475 let request_factory =
476 self.build_request_factory(reqwest::Method::GET, path, request_options);
477 self.execute(request_factory).await
478 }
479
480 #[allow(unused)]
482 pub(crate) async fn delete<O>(
483 &self,
484 path: &str,
485 request_options: &RequestOptions,
486 ) -> Result<O, OpenAIError>
487 where
488 O: DeserializeOwned,
489 {
490 let request_factory =
491 self.build_request_factory(reqwest::Method::DELETE, path, request_options);
492 self.execute(request_factory).await
493 }
494
495 #[allow(unused)]
497 pub(crate) async fn get_raw(
498 &self,
499 path: &str,
500 request_options: &RequestOptions,
501 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
502 let request_factory =
503 self.build_request_factory(reqwest::Method::GET, path, request_options);
504 self.execute_raw(request_factory).await
505 }
506
507 #[allow(unused)]
509 pub(crate) async fn post_raw<I>(
510 &self,
511 path: &str,
512 request: I,
513 request_options: &RequestOptions,
514 ) -> Result<(Bytes, HeaderMap), OpenAIError>
515 where
516 I: Serialize,
517 {
518 let request_factory = self.build_request_factory_with_json(
519 reqwest::Method::POST,
520 path,
521 request,
522 request_options,
523 )?;
524 self.execute_raw(request_factory).await
525 }
526
527 #[allow(unused)]
529 pub(crate) async fn post<I, O>(
530 &self,
531 path: &str,
532 request: I,
533 request_options: &RequestOptions,
534 ) -> Result<O, OpenAIError>
535 where
536 I: Serialize,
537 O: DeserializeOwned,
538 {
539 let request_factory = self.build_request_factory_with_json(
540 reqwest::Method::POST,
541 path,
542 request,
543 request_options,
544 )?;
545 self.execute(request_factory).await
546 }
547
548 #[allow(unused)]
550 pub(crate) async fn post_form_raw<F>(
551 &self,
552 path: &str,
553 form: F,
554 request_options: &RequestOptions,
555 ) -> Result<(Bytes, HeaderMap), OpenAIError>
556 where
557 F: Clone + crate::traits::MaybeSend + 'static,
558 Form: AsyncTryFrom<F, Error = OpenAIError>,
559 {
560 let request_factory = self.build_request_factory_with_form(
561 reqwest::Method::POST,
562 path,
563 form,
564 request_options,
565 )?;
566 self.execute_raw(request_factory).await
567 }
568
569 #[allow(unused)]
571 pub(crate) async fn post_form<O, F>(
572 &self,
573 path: &str,
574 form: F,
575 request_options: &RequestOptions,
576 ) -> Result<O, OpenAIError>
577 where
578 O: DeserializeOwned,
579 F: Clone + crate::traits::MaybeSend + 'static,
580 Form: AsyncTryFrom<F, Error = OpenAIError>,
581 {
582 let request_factory = self.build_request_factory_with_form(
583 reqwest::Method::POST,
584 path,
585 form,
586 request_options,
587 )?;
588 self.execute(request_factory).await
589 }
590
591 #[allow(unused)]
592 pub(crate) async fn post_form_stream<O, F>(
593 &self,
594 path: &str,
595 form: F,
596 request_options: &RequestOptions,
597 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
598 where
599 F: Clone + crate::traits::MaybeSend + 'static,
600 Form: AsyncTryFrom<F, Error = OpenAIError>,
601 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
602 {
603 let request_factory = self.build_request_factory_with_form(
604 reqwest::Method::POST,
605 path,
606 form,
607 request_options,
608 )?;
609
610 self.execute_stream(request_factory).await
611 }
612
613 async fn execute_raw(
614 &self,
615 request_factory: HttpRequestFactory,
616 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
617 let response = self.execute_response(request_factory).await?;
618 read_response(response).await
619 }
620
621 async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
622 where
623 O: DeserializeOwned,
624 {
625 let (bytes, _headers) = self.execute_raw(request_factory).await?;
626
627 let response: O = serde_json::from_slice(bytes.as_ref())
628 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
629
630 Ok(response)
631 }
632
633 async fn execute_response(
634 &self,
635 request_factory: HttpRequestFactory,
636 ) -> Result<Response, OpenAIError> {
637 self.executor.execute(request_factory).await
638 }
639
640 async fn execute_stream<O>(
641 &self,
642 request_factory: HttpRequestFactory,
643 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
644 where
645 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
646 {
647 let response = self.execute_response(request_factory).await?;
648 Ok(stream(response).await)
649 }
650
651 async fn execute_stream_mapped_raw_events<O>(
652 &self,
653 request_factory: HttpRequestFactory,
654 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
655 + crate::traits::MaybeSend
656 + 'static,
657 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
658 where
659 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
660 {
661 let response = self.execute_response(request_factory).await?;
662 Ok(stream_mapped_raw_events(response, event_mapper).await)
663 }
664
665 #[allow(unused)]
667 pub(crate) async fn post_stream<I, O>(
668 &self,
669 path: &str,
670 request: I,
671 request_options: &RequestOptions,
672 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
673 where
674 I: Serialize,
675 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
676 {
677 let request_factory = self.build_request_factory_with_json(
678 reqwest::Method::POST,
679 path,
680 request,
681 request_options,
682 )?;
683 self.execute_stream(request_factory).await
686 }
687
688 #[allow(unused)]
689 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
690 &self,
691 path: &str,
692 request: I,
693 request_options: &RequestOptions,
694 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
695 + crate::traits::MaybeSend
696 + 'static,
697 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
698 where
699 I: Serialize,
700 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
701 {
702 let request_factory = self.build_request_factory_with_json(
703 reqwest::Method::POST,
704 path,
705 request,
706 request_options,
707 )?;
708 self.execute_stream_mapped_raw_events(request_factory, event_mapper)
709 .await
710 }
711
712 #[allow(unused)]
714 pub(crate) async fn get_stream<O>(
715 &self,
716 path: &str,
717 request_options: &RequestOptions,
718 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
719 where
720 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
721 {
722 let request_factory =
723 self.build_request_factory(reqwest::Method::GET, path, request_options);
724 self.execute_stream(request_factory).await
725 }
726}
727
728async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
729 let status = response.status();
730 let headers = response.headers().clone();
731 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
732
733 if status.is_server_error() {
734 let message: String = String::from_utf8_lossy(&bytes).into_owned();
736 tracing::warn!("Server error: {status} - {message}");
737 return Err(OpenAIError::ApiError(ApiError {
738 message,
739 r#type: None,
740 param: None,
741 code: None,
742 }));
743 }
744
745 if !status.is_success() {
747 let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
748 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
749
750 return Err(OpenAIError::ApiError(wrapped_error.error));
751 }
752
753 Ok((bytes, headers))
754}
755
756pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
759where
760 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
761{
762 stream_mapped_raw_events(response, |event| {
763 serde_json::from_str::<O>(&event.data)
764 .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
765 })
766 .await
767}
768
769#[cfg(target_family = "wasm")]
770pub(crate) async fn stream_mapped_raw_events<O>(
771 response: Response,
772 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
773) -> crate::types::stream::StreamResponse<O>
774where
775 O: DeserializeOwned + 'static,
776{
777 if !response.status().is_success() {
778 return Box::pin(futures::stream::once(async move {
779 match read_response(response).await {
780 Ok(_) => Err(OpenAIError::InvalidArgument(
781 "stream request failed without an error body".into(),
782 )),
783 Err(error) => Err(error),
784 }
785 }));
786 }
787
788 let byte_stream = response
789 .bytes_stream()
790 .map(|result| result.map_err(std::io::Error::other));
791 let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
792
793 Box::pin(futures::stream::unfold(
794 (event_stream, event_mapper, false),
795 |(mut event_stream, event_mapper, finished)| async move {
796 if finished {
797 return None;
798 }
799
800 loop {
801 let event = match event_stream.next().await {
802 Some(Ok(event)) => event,
803 Some(Err(error)) => {
804 return Some((
805 Err(OpenAIError::StreamError(Box::new(
806 StreamError::EventStream(error.to_string()),
807 ))),
808 (event_stream, event_mapper, true),
809 ));
810 }
811 None => return None,
812 };
813
814 let done = event.data == "[DONE]";
815
816 if event.event == "keepalive" {
817 continue;
818 }
819
820 let response = event_mapper(event);
821 return Some((response, (event_stream, event_mapper, done)));
822 }
823 },
824 ))
825}
826
827#[cfg(not(target_family = "wasm"))]
828pub(crate) async fn stream_mapped_raw_events<O>(
829 response: Response,
830 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
831) -> crate::types::stream::StreamResponse<O>
832where
833 O: DeserializeOwned + std::marker::Send + 'static,
834{
835 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
836
837 tokio::spawn(async move {
838 if !response.status().is_success() {
839 if let Err(e) = read_response(response).await {
840 let _ = tx.send(Err(e));
841 }
842 return;
843 }
844 let byte_stream = response
845 .bytes_stream()
846 .map(|r| r.map_err(std::io::Error::other));
847 let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
848
849 while let Some(ev) = event_stream.next().await {
850 let event = match ev {
851 Ok(e) => e,
852 Err(e) => {
853 let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
854 StreamError::EventStream(e.to_string()),
855 ))));
856 break;
857 }
858 };
859 let done = event.data == "[DONE]";
860
861 if event.event == "keepalive" {
862 continue;
863 }
864
865 let response = event_mapper(event);
866
867 if tx.send(response).is_err() {
868 break;
869 }
870
871 if done {
872 break;
873 }
874 }
875 });
876
877 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
878}
879
880#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
881mod tests {
882 use std::sync::{
883 atomic::{AtomicUsize, Ordering},
884 Arc,
885 };
886
887 use futures::StreamExt;
888 use http::Response as HttpResponse;
889 use serde_json::json;
890 use tower::{service_fn, ServiceBuilder};
891
892 use super::Client;
893 use crate::{
894 config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
895 retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
896 };
897
898 #[tokio::test]
899 async fn unary_requests_dispatch_through_middleware_service() {
900 let request_count = Arc::new(AtomicUsize::new(0));
901 let service = {
902 let request_count = request_count.clone();
903 ServiceBuilder::new()
904 .concurrency_limit(1)
905 .service(service_fn(move |factory: HttpRequestFactory| {
906 let request_count = request_count.clone();
907 async move {
908 let request = factory.build().await?;
909 assert_eq!(request.url().path(), "/models");
910 request_count.fetch_add(1, Ordering::SeqCst);
911 Ok::<reqwest::Response, OpenAIError>(
912 HttpResponse::builder()
913 .status(200)
914 .header("content-type", "application/json")
915 .body(reqwest::Body::from(
916 "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
917 ))
918 .unwrap()
919 .into(),
920 )
921 }
922 }))
923 };
924
925 let client = Client::with_config(
926 OpenAIConfig::new()
927 .with_api_base("http://example.test")
928 .with_api_key("test-key"),
929 )
930 .with_http_service(service);
931
932 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
933
934 assert_eq!(value["object"], "list");
935 assert_eq!(request_count.load(Ordering::SeqCst), 1);
936 }
937
938 #[tokio::test]
939 async fn stream_requests_open_through_middleware_service() {
940 let request_count = Arc::new(AtomicUsize::new(0));
941 let service = {
942 let request_count = request_count.clone();
943 ServiceBuilder::new()
944 .concurrency_limit(1)
945 .service(service_fn(move |factory: HttpRequestFactory| {
946 let request_count = request_count.clone();
947 async move {
948 let request = factory.build().await?;
949 assert_eq!(request.url().path(), "/responses");
950 request_count.fetch_add(1, Ordering::SeqCst);
951 Ok::<reqwest::Response, OpenAIError>(
952 HttpResponse::builder()
953 .status(200)
954 .header("content-type", "text/event-stream")
955 .body(reqwest::Body::from(
956 "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
957 ))
958 .unwrap()
959 .into(),
960 )
961 }
962 }))
963 };
964
965 let client = Client::with_config(
966 OpenAIConfig::new()
967 .with_api_base("http://example.test")
968 .with_api_key("test-key"),
969 )
970 .with_http_service(service);
971
972 let mut stream = client
973 .post_stream::<_, serde_json::Value>(
974 "/responses",
975 json!({ "stream": true }),
976 &RequestOptions::new(),
977 )
978 .await
979 .unwrap();
980
981 let first = stream.next().await.unwrap().unwrap();
982
983 assert_eq!(first, json!({ "ok": true }));
984 assert_eq!(request_count.load(Ordering::SeqCst), 1);
985 }
986
987 #[tokio::test]
988 async fn middleware_retry_policy_retries_429_responses() {
989 let request_count = Arc::new(AtomicUsize::new(0));
990 let service = {
991 let request_count = request_count.clone();
992 ServiceBuilder::new()
993 .retry(SimpleRetryPolicy::default())
994 .service(service_fn(move |factory: HttpRequestFactory| {
995 let request_count = request_count.clone();
996 async move {
997 let request = factory.build().await?;
998 assert_eq!(request.url().path(), "/models");
999 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1000
1001 let response = if attempt == 0 {
1002 HttpResponse::builder()
1003 .status(429)
1004 .header("content-type", "application/json")
1005 .body(reqwest::Body::from(
1006 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1007 ))
1008 .unwrap()
1009 } else {
1010 HttpResponse::builder()
1011 .status(200)
1012 .header("content-type", "application/json")
1013 .body(reqwest::Body::from(
1014 r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1015 ))
1016 .unwrap()
1017 };
1018
1019 Ok::<reqwest::Response, OpenAIError>(response.into())
1020 }
1021 }))
1022 };
1023
1024 let client = Client::with_config(
1025 OpenAIConfig::new()
1026 .with_api_base("http://example.test")
1027 .with_api_key("test-key"),
1028 )
1029 .with_http_service(service);
1030
1031 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1032
1033 assert_eq!(value["data"][0]["id"], "retry-model");
1034 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1035 }
1036
1037 #[derive(Clone)]
1038 struct RetryableMultipartInput {
1039 conversions: Arc<AtomicUsize>,
1040 }
1041
1042 impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1043 type Error = OpenAIError;
1044
1045 async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1046 value.conversions.fetch_add(1, Ordering::SeqCst);
1047 Ok(reqwest::multipart::Form::new().text("field", "value"))
1048 }
1049 }
1050
1051 #[tokio::test]
1052 async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1053 let request_count = Arc::new(AtomicUsize::new(0));
1054 let conversion_count = Arc::new(AtomicUsize::new(0));
1055
1056 let service = {
1057 let request_count = request_count.clone();
1058 ServiceBuilder::new()
1059 .retry(SimpleRetryPolicy::default())
1060 .service(service_fn(move |factory: HttpRequestFactory| {
1061 let request_count = request_count.clone();
1062 async move {
1063 let request = factory.build().await?;
1064 assert_eq!(request.method(), reqwest::Method::POST);
1065 assert_eq!(request.url().path(), "/files");
1066 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1067
1068 let response = if attempt == 0 {
1069 HttpResponse::builder()
1070 .status(429)
1071 .header("content-type", "application/json")
1072 .body(reqwest::Body::from(
1073 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1074 ))
1075 .unwrap()
1076 } else {
1077 HttpResponse::builder()
1078 .status(200)
1079 .header("content-type", "application/json")
1080 .body(reqwest::Body::from(r#"{"ok":true}"#))
1081 .unwrap()
1082 };
1083
1084 Ok::<reqwest::Response, OpenAIError>(response.into())
1085 }
1086 }))
1087 };
1088
1089 let client = Client::with_config(
1090 OpenAIConfig::new()
1091 .with_api_base("http://example.test")
1092 .with_api_key("test-key"),
1093 )
1094 .with_http_service(service);
1095
1096 let value: serde_json::Value = client
1097 .post_form(
1098 "/files",
1099 RetryableMultipartInput {
1100 conversions: conversion_count.clone(),
1101 },
1102 &RequestOptions::new(),
1103 )
1104 .await
1105 .unwrap();
1106
1107 assert_eq!(value, json!({ "ok": true }));
1108 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1109 assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1110 }
1111}