1use std::sync::Arc;
2#[cfg(not(target_family = "wasm"))]
3use std::sync::Mutex;
4
5use bytes::Bytes;
6use futures::stream::StreamExt;
7use reqwest::{header::HeaderMap, multipart::Form, Response};
8use serde::{de::DeserializeOwned, Serialize};
9
10use crate::error::StreamError;
11#[cfg(feature = "middleware")]
12use crate::executor::TowerExecutor;
13use crate::{
14 config::{Config, OpenAIConfig},
15 error::{map_deserialization_error, ApiError, ApiErrorResponse, OpenAIError, WrappedError},
16 executor::{HttpRequestFactory, ReqwestExecutor, SharedExecutor},
17 traits::AsyncTryFrom,
18 RequestOptions,
19};
20
21struct RequestParts {
22 request_client: reqwest::Client,
23 method: reqwest::Method,
24 url: String,
25 headers: HeaderMap,
26 query: Vec<(String, String)>,
27}
28
29impl RequestParts {
30 fn build_request_builder(&self) -> reqwest::RequestBuilder {
31 self.request_client
32 .request(self.method.clone(), self.url.clone())
33 .query(&self.query)
34 .headers(self.headers.clone())
35 }
36}
37
38#[cfg(feature = "administration")]
39use crate::admin::Admin;
40#[cfg(feature = "chatkit")]
41use crate::chatkit::Chatkit;
42#[cfg(feature = "file")]
43use crate::file::Files;
44#[cfg(feature = "image")]
45use crate::image::Images;
46#[cfg(feature = "moderation")]
47use crate::moderation::Moderations;
48#[cfg(feature = "assistant")]
49#[allow(deprecated)]
50use crate::Assistants;
51#[cfg(feature = "audio")]
52use crate::Audio;
53#[cfg(feature = "batch")]
54use crate::Batches;
55#[cfg(feature = "chat-completion")]
56use crate::Chat;
57#[cfg(feature = "completions")]
58use crate::Completions;
59#[cfg(feature = "container")]
60use crate::Containers;
61#[cfg(feature = "responses")]
62use crate::Conversations;
63#[cfg(feature = "embedding")]
64use crate::Embeddings;
65#[cfg(feature = "evals")]
66use crate::Evals;
67#[cfg(feature = "finetuning")]
68use crate::FineTuning;
69#[cfg(feature = "model")]
70use crate::Models;
71#[cfg(feature = "realtime")]
72use crate::Realtime;
73#[cfg(feature = "responses")]
74use crate::Responses;
75#[cfg(feature = "skill")]
76use crate::Skills;
77#[cfg(feature = "assistant")]
78#[allow(deprecated)]
79use crate::Threads;
80#[cfg(feature = "upload")]
81use crate::Uploads;
82#[cfg(feature = "vectorstore")]
83use crate::VectorStores;
84#[cfg(feature = "video")]
85use crate::Videos;
86
87#[derive(Clone)]
88pub struct Client<C: Config> {
91 request_client: reqwest::Client,
92 executor: SharedExecutor,
93 config: C,
94}
95
96impl<C> std::fmt::Debug for Client<C>
97where
98 C: Config + std::fmt::Debug,
99{
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("Client")
102 .field("request_client", &self.request_client)
103 .field("config", &self.config)
104 .finish()
105 }
106}
107
108impl<C: Config> Default for Client<C>
109where
110 C: Default,
111{
112 fn default() -> Self {
113 let request_client = reqwest::Client::new();
114 Self {
115 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
116 request_client,
117 config: C::default(),
118 }
119 }
120}
121
122impl Client<OpenAIConfig> {
123 pub fn new() -> Self {
125 Self::default()
126 }
127}
128
129impl<C: Config> Client<C> {
130 pub fn build(http_client: reqwest::Client, config: C) -> Self {
132 Self {
133 executor: Arc::new(ReqwestExecutor::new(http_client.clone())),
134 request_client: http_client,
135 config,
136 }
137 }
138
139 pub fn with_config(config: C) -> Self {
141 let request_client = reqwest::Client::new();
142 Self {
143 executor: Arc::new(ReqwestExecutor::new(request_client.clone())),
144 request_client,
145 config,
146 }
147 }
148
149 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
153 self.executor = Arc::new(ReqwestExecutor::new(http_client.clone()));
154 self.request_client = http_client;
155 self
156 }
157
158 #[cfg(all(feature = "middleware", not(target_family = "wasm")))]
160 pub fn with_http_service<S>(mut self, service: S) -> Self
161 where
162 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
163 S::Future: Send + 'static,
164 S::Error: Into<OpenAIError> + Send + Sync + 'static,
165 {
166 self.executor = Arc::new(TowerExecutor::new(service));
171 self
172 }
173
174 #[cfg(all(feature = "middleware", target_family = "wasm"))]
176 pub fn with_http_service<S>(mut self, service: S) -> Self
177 where
178 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
179 S::Future: 'static,
180 S::Error: Into<OpenAIError> + 'static,
181 {
182 self.executor = Arc::new(TowerExecutor::new(service));
187 self
188 }
189
190 #[cfg(feature = "model")]
194 pub fn models(&self) -> Models<'_, C> {
195 Models::new(self)
196 }
197
198 #[cfg(feature = "completions")]
200 pub fn completions(&self) -> Completions<'_, C> {
201 Completions::new(self)
202 }
203
204 #[cfg(feature = "chat-completion")]
206 pub fn chat(&self) -> Chat<'_, C> {
207 Chat::new(self)
208 }
209
210 #[cfg(feature = "image")]
212 pub fn images(&self) -> Images<'_, C> {
213 Images::new(self)
214 }
215
216 #[cfg(feature = "moderation")]
218 pub fn moderations(&self) -> Moderations<'_, C> {
219 Moderations::new(self)
220 }
221
222 #[cfg(feature = "file")]
224 pub fn files(&self) -> Files<'_, C> {
225 Files::new(self)
226 }
227
228 #[cfg(feature = "upload")]
230 pub fn uploads(&self) -> Uploads<'_, C> {
231 Uploads::new(self)
232 }
233
234 #[cfg(feature = "finetuning")]
236 pub fn fine_tuning(&self) -> FineTuning<'_, C> {
237 FineTuning::new(self)
238 }
239
240 #[cfg(feature = "embedding")]
242 pub fn embeddings(&self) -> Embeddings<'_, C> {
243 Embeddings::new(self)
244 }
245
246 #[cfg(feature = "audio")]
248 pub fn audio(&self) -> Audio<'_, C> {
249 Audio::new(self)
250 }
251
252 #[cfg(feature = "video")]
254 pub fn videos(&self) -> Videos<'_, C> {
255 Videos::new(self)
256 }
257
258 #[cfg(feature = "assistant")]
260 #[deprecated(
261 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
262 )]
263 #[allow(deprecated)]
264 pub fn assistants(&self) -> Assistants<'_, C> {
265 Assistants::new(self)
266 }
267
268 #[cfg(feature = "assistant")]
270 #[deprecated(
271 note = "Assistants API is deprecated and will be removed in August 2026. Use the Responses API."
272 )]
273 #[allow(deprecated)]
274 pub fn threads(&self) -> Threads<'_, C> {
275 Threads::new(self)
276 }
277
278 #[cfg(feature = "vectorstore")]
280 pub fn vector_stores(&self) -> VectorStores<'_, C> {
281 VectorStores::new(self)
282 }
283
284 #[cfg(feature = "batch")]
286 pub fn batches(&self) -> Batches<'_, C> {
287 Batches::new(self)
288 }
289
290 #[cfg(feature = "administration")]
293 pub fn admin(&self) -> Admin<'_, C> {
294 Admin::new(self)
295 }
296
297 #[cfg(feature = "responses")]
299 pub fn responses(&self) -> Responses<'_, C> {
300 Responses::new(self)
301 }
302
303 #[cfg(feature = "responses")]
305 pub fn conversations(&self) -> Conversations<'_, C> {
306 Conversations::new(self)
307 }
308
309 #[cfg(feature = "container")]
311 pub fn containers(&self) -> Containers<'_, C> {
312 Containers::new(self)
313 }
314
315 #[cfg(feature = "skill")]
317 pub fn skills(&self) -> Skills<'_, C> {
318 Skills::new(self)
319 }
320
321 #[cfg(feature = "evals")]
323 pub fn evals(&self) -> Evals<'_, C> {
324 Evals::new(self)
325 }
326
327 #[cfg(feature = "chatkit")]
328 pub fn chatkit(&self) -> Chatkit<'_, C> {
329 Chatkit::new(self)
330 }
331
332 #[cfg(feature = "realtime")]
334 pub fn realtime(&self) -> Realtime<'_, C> {
335 Realtime::new(self)
336 }
337
338 pub fn config(&self) -> &C {
339 &self.config
340 }
341
342 fn build_request_parts(
343 &self,
344 method: reqwest::Method,
345 path: &str,
346 request_options: &RequestOptions,
347 ) -> Arc<RequestParts> {
348 let url = if let Some(path) = request_options.path() {
349 self.config.url(path.as_str())
350 } else {
351 self.config.url(path)
352 };
353 let mut headers = self.config.headers();
354 if let Some(request_headers) = request_options.headers() {
355 headers.extend(request_headers.clone());
356 }
357
358 let mut query = self
359 .config
360 .query()
361 .into_iter()
362 .map(|(key, value)| (key.to_string(), value.to_string()))
363 .collect::<Vec<_>>();
364 query.extend_from_slice(request_options.query());
365
366 Arc::new(RequestParts {
367 request_client: self.request_client.clone(),
368 method,
369 url,
370 headers,
371 query,
372 })
373 }
374
375 fn build_request_factory(
376 &self,
377 method: reqwest::Method,
378 path: &str,
379 request_options: &RequestOptions,
380 ) -> HttpRequestFactory {
381 let request_parts = self.build_request_parts(method, path, request_options);
382
383 HttpRequestFactory::new(move || {
384 let request_parts = request_parts.clone();
385
386 async move {
387 let request = request_parts.build_request_builder().build()?;
388 Ok(request)
389 }
390 })
391 }
392
393 fn build_request_factory_with_json<I>(
394 &self,
395 method: reqwest::Method,
396 path: &str,
397 request: I,
398 request_options: &RequestOptions,
399 ) -> Result<HttpRequestFactory, OpenAIError>
400 where
401 I: Serialize,
402 {
403 let request = Bytes::from(serde_json::to_vec(&request).map_err(|error| {
406 OpenAIError::InvalidArgument(format!("failed to serialize request: {error}"))
407 })?);
408 let request_parts = self.build_request_parts(method, path, request_options);
409
410 Ok(HttpRequestFactory::new(move || {
411 let request_parts = request_parts.clone();
412 let request = request.clone();
413
414 async move {
415 let request_builder = request_parts
416 .build_request_builder()
417 .header(reqwest::header::CONTENT_TYPE, "application/json")
418 .body(request.clone());
419
420 Ok(request_builder.build()?)
421 }
422 }))
423 }
424
425 fn build_request_factory_with_form<F>(
426 &self,
427 method: reqwest::Method,
428 path: &str,
429 form: F,
430 request_options: &RequestOptions,
431 ) -> Result<HttpRequestFactory, OpenAIError>
432 where
433 F: Clone + crate::traits::MaybeSend + 'static,
434 Form: AsyncTryFrom<F, Error = OpenAIError>,
435 {
436 #[cfg(not(target_family = "wasm"))]
440 let form = Arc::new(Mutex::new(form));
441 let request_parts = self.build_request_parts(method, path, request_options);
442
443 Ok(HttpRequestFactory::new(move || {
444 let request_parts = request_parts.clone();
445 let form = form.clone();
446
447 async move {
448 #[cfg(not(target_family = "wasm"))]
449 let form = form
450 .lock()
451 .expect("multipart request factory mutex poisoned")
452 .clone();
453 #[cfg(target_family = "wasm")]
454 let form = form.clone();
455 let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
456 let request_builder = request_parts.build_request_builder().multipart(form);
457
458 Ok(request_builder.build()?)
459 }
460 }))
461 }
462
463 #[allow(unused)]
465 pub(crate) async fn get<O>(
466 &self,
467 path: &str,
468 request_options: &RequestOptions,
469 ) -> Result<O, OpenAIError>
470 where
471 O: DeserializeOwned,
472 {
473 let request_factory =
474 self.build_request_factory(reqwest::Method::GET, path, request_options);
475 self.execute(request_factory).await
476 }
477
478 #[allow(unused)]
480 pub(crate) async fn delete<O>(
481 &self,
482 path: &str,
483 request_options: &RequestOptions,
484 ) -> Result<O, OpenAIError>
485 where
486 O: DeserializeOwned,
487 {
488 let request_factory =
489 self.build_request_factory(reqwest::Method::DELETE, path, request_options);
490 self.execute(request_factory).await
491 }
492
493 #[allow(unused)]
495 pub(crate) async fn get_raw(
496 &self,
497 path: &str,
498 request_options: &RequestOptions,
499 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
500 let request_factory =
501 self.build_request_factory(reqwest::Method::GET, path, request_options);
502 self.execute_raw(request_factory).await
503 }
504
505 #[allow(unused)]
507 pub(crate) async fn post_raw<I>(
508 &self,
509 path: &str,
510 request: I,
511 request_options: &RequestOptions,
512 ) -> Result<(Bytes, HeaderMap), OpenAIError>
513 where
514 I: Serialize,
515 {
516 let request_factory = self.build_request_factory_with_json(
517 reqwest::Method::POST,
518 path,
519 request,
520 request_options,
521 )?;
522 self.execute_raw(request_factory).await
523 }
524
525 #[allow(unused)]
527 pub(crate) async fn post<I, O>(
528 &self,
529 path: &str,
530 request: I,
531 request_options: &RequestOptions,
532 ) -> Result<O, OpenAIError>
533 where
534 I: Serialize,
535 O: DeserializeOwned,
536 {
537 let request_factory = self.build_request_factory_with_json(
538 reqwest::Method::POST,
539 path,
540 request,
541 request_options,
542 )?;
543 self.execute(request_factory).await
544 }
545
546 #[allow(unused)]
548 pub(crate) async fn post_form_raw<F>(
549 &self,
550 path: &str,
551 form: F,
552 request_options: &RequestOptions,
553 ) -> Result<(Bytes, HeaderMap), OpenAIError>
554 where
555 F: Clone + crate::traits::MaybeSend + 'static,
556 Form: AsyncTryFrom<F, Error = OpenAIError>,
557 {
558 let request_factory = self.build_request_factory_with_form(
559 reqwest::Method::POST,
560 path,
561 form,
562 request_options,
563 )?;
564 self.execute_raw(request_factory).await
565 }
566
567 #[allow(unused)]
569 pub(crate) async fn post_form<O, F>(
570 &self,
571 path: &str,
572 form: F,
573 request_options: &RequestOptions,
574 ) -> Result<O, OpenAIError>
575 where
576 O: DeserializeOwned,
577 F: Clone + crate::traits::MaybeSend + 'static,
578 Form: AsyncTryFrom<F, Error = OpenAIError>,
579 {
580 let request_factory = self.build_request_factory_with_form(
581 reqwest::Method::POST,
582 path,
583 form,
584 request_options,
585 )?;
586 self.execute(request_factory).await
587 }
588
589 #[allow(unused)]
590 pub(crate) async fn post_form_stream<O, F>(
591 &self,
592 path: &str,
593 form: F,
594 request_options: &RequestOptions,
595 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
596 where
597 F: Clone + crate::traits::MaybeSend + 'static,
598 Form: AsyncTryFrom<F, Error = OpenAIError>,
599 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
600 {
601 let request_factory = self.build_request_factory_with_form(
602 reqwest::Method::POST,
603 path,
604 form,
605 request_options,
606 )?;
607
608 self.execute_stream(request_factory).await
609 }
610
611 async fn execute_raw(
612 &self,
613 request_factory: HttpRequestFactory,
614 ) -> Result<(Bytes, HeaderMap), OpenAIError> {
615 let response = self.execute_response(request_factory).await?;
616 read_response(response).await
617 }
618
619 async fn execute<O>(&self, request_factory: HttpRequestFactory) -> Result<O, OpenAIError>
620 where
621 O: DeserializeOwned,
622 {
623 let (bytes, _headers) = self.execute_raw(request_factory).await?;
624
625 let response: O = serde_json::from_slice(bytes.as_ref())
626 .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
627
628 Ok(response)
629 }
630
631 async fn execute_response(
632 &self,
633 request_factory: HttpRequestFactory,
634 ) -> Result<Response, OpenAIError> {
635 let response = self.executor.execute(request_factory).await?;
636 if !response.status().is_success() {
637 return Err(read_error_response(response).await);
638 }
639 Ok(response)
640 }
641
642 async fn execute_stream<O>(
643 &self,
644 request_factory: HttpRequestFactory,
645 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
646 where
647 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
648 {
649 let response = self.execute_response(request_factory).await?;
650 Ok(stream(response).await)
651 }
652
653 async fn execute_stream_mapped_raw_events<O>(
654 &self,
655 request_factory: HttpRequestFactory,
656 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
657 + crate::traits::MaybeSend
658 + 'static,
659 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
660 where
661 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
662 {
663 let response = self.execute_response(request_factory).await?;
664 Ok(stream_mapped_raw_events(response, event_mapper).await)
665 }
666
667 #[allow(unused)]
669 pub(crate) async fn post_stream<I, O>(
670 &self,
671 path: &str,
672 request: I,
673 request_options: &RequestOptions,
674 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
675 where
676 I: Serialize,
677 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
678 {
679 let request_factory = self.build_request_factory_with_json(
680 reqwest::Method::POST,
681 path,
682 request,
683 request_options,
684 )?;
685 self.execute_stream(request_factory).await
688 }
689
690 #[allow(unused)]
691 pub(crate) async fn post_stream_mapped_raw_events<I, O>(
692 &self,
693 path: &str,
694 request: I,
695 request_options: &RequestOptions,
696 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError>
697 + crate::traits::MaybeSend
698 + 'static,
699 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
700 where
701 I: Serialize,
702 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
703 {
704 let request_factory = self.build_request_factory_with_json(
705 reqwest::Method::POST,
706 path,
707 request,
708 request_options,
709 )?;
710 self.execute_stream_mapped_raw_events(request_factory, event_mapper)
711 .await
712 }
713
714 #[allow(unused)]
716 pub(crate) async fn get_stream<O>(
717 &self,
718 path: &str,
719 request_options: &RequestOptions,
720 ) -> Result<crate::types::stream::StreamResponse<O>, OpenAIError>
721 where
722 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
723 {
724 let request_factory =
725 self.build_request_factory(reqwest::Method::GET, path, request_options);
726 self.execute_stream(request_factory).await
727 }
728}
729
730async fn read_response(response: Response) -> Result<(Bytes, HeaderMap), OpenAIError> {
731 let headers = response.headers().clone();
732 let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
733 Ok((bytes, headers))
734}
735
736async fn read_error_response(response: Response) -> OpenAIError {
737 let status = response.status();
738 let bytes = match response.bytes().await {
739 Ok(b) => b,
740 Err(e) => return OpenAIError::Reqwest(e),
741 };
742
743 if status.is_server_error() {
744 let message: String = String::from_utf8_lossy(&bytes).into_owned();
746 tracing::warn!("Server error: {status} - {message}");
747 return OpenAIError::ApiError(ApiErrorResponse {
748 status_code: status,
749 api_error: ApiError {
750 message,
751 r#type: None,
752 param: None,
753 code: None,
754 },
755 });
756 }
757
758 match serde_json::from_slice::<WrappedError>(bytes.as_ref()) {
760 Ok(wrapped) => OpenAIError::ApiError(ApiErrorResponse {
761 status_code: status,
762 api_error: wrapped.error,
763 }),
764 Err(e) => map_deserialization_error(e, bytes.as_ref()),
765 }
766}
767
768pub(crate) async fn stream<O>(response: Response) -> crate::types::stream::StreamResponse<O>
771where
772 O: DeserializeOwned + crate::traits::MaybeSend + 'static,
773{
774 stream_mapped_raw_events(response, |event| {
775 serde_json::from_str::<O>(&event.data)
776 .map_err(|error| map_deserialization_error(error, event.data.as_bytes()))
777 })
778 .await
779}
780
781#[cfg(target_family = "wasm")]
782pub(crate) async fn stream_mapped_raw_events<O>(
783 response: Response,
784 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + 'static,
785) -> crate::types::stream::StreamResponse<O>
786where
787 O: DeserializeOwned + 'static,
788{
789 let byte_stream = response
790 .bytes_stream()
791 .map(|result| result.map_err(std::io::Error::other));
792 let event_stream = Box::pin(eventsource_stream::EventStream::new(byte_stream));
793
794 Box::pin(futures::stream::unfold(
795 (event_stream, event_mapper),
796 |(mut event_stream, event_mapper)| async move {
797 loop {
798 let event = match event_stream.next().await {
799 Some(Ok(event)) => event,
800 Some(Err(error)) => {
801 return Some((
802 Err(OpenAIError::StreamError(Box::new(
803 StreamError::EventStream(error.to_string()),
804 ))),
805 (event_stream, event_mapper),
806 ));
807 }
808 None => return None,
809 };
810
811 if event.data == "[DONE]" {
812 return None;
813 }
814
815 if event.event == "keepalive" {
816 continue;
817 }
818
819 let response = event_mapper(event);
820 return Some((response, (event_stream, event_mapper)));
821 }
822 },
823 ))
824}
825
826#[cfg(not(target_family = "wasm"))]
827pub(crate) async fn stream_mapped_raw_events<O>(
828 response: Response,
829 event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
830) -> crate::types::stream::StreamResponse<O>
831where
832 O: DeserializeOwned + std::marker::Send + 'static,
833{
834 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
835
836 tokio::spawn(async move {
837 let byte_stream = response
838 .bytes_stream()
839 .map(|r| r.map_err(std::io::Error::other));
840 let mut event_stream = std::pin::pin!(eventsource_stream::EventStream::new(byte_stream));
841
842 while let Some(ev) = event_stream.next().await {
843 let event = match ev {
844 Ok(e) => e,
845 Err(e) => {
846 let _ = tx.send(Err(OpenAIError::StreamError(Box::new(
847 StreamError::EventStream(e.to_string()),
848 ))));
849 break;
850 }
851 };
852 if event.data == "[DONE]" {
853 break;
854 }
855
856 if event.event == "keepalive" {
857 continue;
858 }
859
860 let response = event_mapper(event);
861
862 if tx.send(response).is_err() {
863 break;
864 }
865 }
866 });
867
868 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
869}
870
871#[cfg(all(test, feature = "middleware", not(target_family = "wasm")))]
872mod tests {
873 use std::sync::{
874 atomic::{AtomicUsize, Ordering},
875 Arc,
876 };
877
878 use futures::StreamExt;
879 use http::Response as HttpResponse;
880 use serde_json::json;
881 use tower::{service_fn, ServiceBuilder};
882
883 use super::Client;
884 use crate::{
885 config::OpenAIConfig, error::OpenAIError, executor::HttpRequestFactory,
886 retry::SimpleRetryPolicy, traits::AsyncTryFrom, RequestOptions,
887 };
888
889 #[tokio::test]
890 async fn unary_requests_dispatch_through_middleware_service() {
891 let request_count = Arc::new(AtomicUsize::new(0));
892 let service = {
893 let request_count = request_count.clone();
894 ServiceBuilder::new()
895 .concurrency_limit(1)
896 .service(service_fn(move |factory: HttpRequestFactory| {
897 let request_count = request_count.clone();
898 async move {
899 let request = factory.build().await?;
900 assert_eq!(request.url().path(), "/models");
901 request_count.fetch_add(1, Ordering::SeqCst);
902 Ok::<reqwest::Response, OpenAIError>(
903 HttpResponse::builder()
904 .status(200)
905 .header("content-type", "application/json")
906 .body(reqwest::Body::from(
907 "{\"object\":\"list\",\"data\":[{\"id\":\"model\"}]}",
908 ))
909 .unwrap()
910 .into(),
911 )
912 }
913 }))
914 };
915
916 let client = Client::with_config(
917 OpenAIConfig::new()
918 .with_api_base("http://example.test")
919 .with_api_key("test-key"),
920 )
921 .with_http_service(service);
922
923 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
924
925 assert_eq!(value["object"], "list");
926 assert_eq!(request_count.load(Ordering::SeqCst), 1);
927 }
928
929 #[tokio::test]
930 async fn stream_requests_open_through_middleware_service() {
931 let request_count = Arc::new(AtomicUsize::new(0));
932 let service = {
933 let request_count = request_count.clone();
934 ServiceBuilder::new()
935 .concurrency_limit(1)
936 .service(service_fn(move |factory: HttpRequestFactory| {
937 let request_count = request_count.clone();
938 async move {
939 let request = factory.build().await?;
940 assert_eq!(request.url().path(), "/responses");
941 request_count.fetch_add(1, Ordering::SeqCst);
942 Ok::<reqwest::Response, OpenAIError>(
943 HttpResponse::builder()
944 .status(200)
945 .header("content-type", "text/event-stream")
946 .body(reqwest::Body::from(
947 "data: {\"ok\":true}\n\ndata: [DONE]\n\n",
948 ))
949 .unwrap()
950 .into(),
951 )
952 }
953 }))
954 };
955
956 let client = Client::with_config(
957 OpenAIConfig::new()
958 .with_api_base("http://example.test")
959 .with_api_key("test-key"),
960 )
961 .with_http_service(service);
962
963 let mut stream = client
964 .post_stream::<_, serde_json::Value>(
965 "/responses",
966 json!({ "stream": true }),
967 &RequestOptions::new(),
968 )
969 .await
970 .unwrap();
971
972 let first = stream.next().await.unwrap().unwrap();
973
974 assert_eq!(first, json!({ "ok": true }));
975 assert_eq!(request_count.load(Ordering::SeqCst), 1);
976 }
977
978 #[tokio::test]
979 async fn middleware_retry_policy_retries_429_responses() {
980 let request_count = Arc::new(AtomicUsize::new(0));
981 let service = {
982 let request_count = request_count.clone();
983 ServiceBuilder::new()
984 .retry(SimpleRetryPolicy::default())
985 .service(service_fn(move |factory: HttpRequestFactory| {
986 let request_count = request_count.clone();
987 async move {
988 let request = factory.build().await?;
989 assert_eq!(request.url().path(), "/models");
990 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
991
992 let response = if attempt == 0 {
993 HttpResponse::builder()
994 .status(429)
995 .header("content-type", "application/json")
996 .body(reqwest::Body::from(
997 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
998 ))
999 .unwrap()
1000 } else {
1001 HttpResponse::builder()
1002 .status(200)
1003 .header("content-type", "application/json")
1004 .body(reqwest::Body::from(
1005 r#"{"object":"list","data":[{"id":"retry-model"}]}"#,
1006 ))
1007 .unwrap()
1008 };
1009
1010 Ok::<reqwest::Response, OpenAIError>(response.into())
1011 }
1012 }))
1013 };
1014
1015 let client = Client::with_config(
1016 OpenAIConfig::new()
1017 .with_api_base("http://example.test")
1018 .with_api_key("test-key"),
1019 )
1020 .with_http_service(service);
1021
1022 let value: serde_json::Value = client.get("/models", &RequestOptions::new()).await.unwrap();
1023
1024 assert_eq!(value["data"][0]["id"], "retry-model");
1025 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1026 }
1027
1028 #[derive(Clone)]
1029 struct RetryableMultipartInput {
1030 conversions: Arc<AtomicUsize>,
1031 }
1032
1033 impl AsyncTryFrom<RetryableMultipartInput> for reqwest::multipart::Form {
1034 type Error = OpenAIError;
1035
1036 async fn try_from(value: RetryableMultipartInput) -> Result<Self, Self::Error> {
1037 value.conversions.fetch_add(1, Ordering::SeqCst);
1038 Ok(reqwest::multipart::Form::new().text("field", "value"))
1039 }
1040 }
1041
1042 #[tokio::test]
1043 async fn middleware_retry_policy_rebuilds_multipart_form_per_attempt() {
1044 let request_count = Arc::new(AtomicUsize::new(0));
1045 let conversion_count = Arc::new(AtomicUsize::new(0));
1046
1047 let service = {
1048 let request_count = request_count.clone();
1049 ServiceBuilder::new()
1050 .retry(SimpleRetryPolicy::default())
1051 .service(service_fn(move |factory: HttpRequestFactory| {
1052 let request_count = request_count.clone();
1053 async move {
1054 let request = factory.build().await?;
1055 assert_eq!(request.method(), reqwest::Method::POST);
1056 assert_eq!(request.url().path(), "/files");
1057 let attempt = request_count.fetch_add(1, Ordering::SeqCst);
1058
1059 let response = if attempt == 0 {
1060 HttpResponse::builder()
1061 .status(429)
1062 .header("content-type", "application/json")
1063 .body(reqwest::Body::from(
1064 r#"{"error":{"message":"retry me","type":"rate_limit_error","param":null,"code":null}}"#,
1065 ))
1066 .unwrap()
1067 } else {
1068 HttpResponse::builder()
1069 .status(200)
1070 .header("content-type", "application/json")
1071 .body(reqwest::Body::from(r#"{"ok":true}"#))
1072 .unwrap()
1073 };
1074
1075 Ok::<reqwest::Response, OpenAIError>(response.into())
1076 }
1077 }))
1078 };
1079
1080 let client = Client::with_config(
1081 OpenAIConfig::new()
1082 .with_api_base("http://example.test")
1083 .with_api_key("test-key"),
1084 )
1085 .with_http_service(service);
1086
1087 let value: serde_json::Value = client
1088 .post_form(
1089 "/files",
1090 RetryableMultipartInput {
1091 conversions: conversion_count.clone(),
1092 },
1093 &RequestOptions::new(),
1094 )
1095 .await
1096 .unwrap();
1097
1098 assert_eq!(value, json!({ "ok": true }));
1099 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1100 assert_eq!(conversion_count.load(Ordering::SeqCst), 2);
1101 }
1102}