1use std::{future::Future, pin::Pin, sync::Arc};
2
3use reqwest::{Request, Response};
4
5use crate::{error::OpenAIError, retry::OpenAIRetry};
6
7#[cfg(not(target_family = "wasm"))]
8type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + Send + 'static>>;
9#[cfg(target_family = "wasm")]
10type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + 'static>>;
11
12#[cfg(not(target_family = "wasm"))]
13pub(crate) type HttpFuture =
14 Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16pub(crate) type HttpFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18#[cfg(not(target_family = "wasm"))]
19type RequestFn = dyn Fn() -> RequestFuture + Send + Sync + 'static;
20#[cfg(target_family = "wasm")]
21type RequestFn = dyn Fn() -> RequestFuture + 'static;
22
23#[derive(Clone)]
36pub struct HttpRequestFactory {
37 make_request: Arc<RequestFn>,
38}
39
40impl std::fmt::Debug for HttpRequestFactory {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("HttpRequestFactory").finish_non_exhaustive()
43 }
44}
45
46impl HttpRequestFactory {
47 #[cfg(not(target_family = "wasm"))]
52 pub fn new<F, Fut>(make_request: F) -> Self
53 where
54 F: Fn() -> Fut + Send + Sync + 'static,
55 Fut: Future<Output = Result<Request, OpenAIError>> + Send + 'static,
56 {
57 Self {
58 make_request: Arc::new(move || Box::pin(make_request())),
59 }
60 }
61
62 #[cfg(target_family = "wasm")]
63 pub fn new<F, Fut>(make_request: F) -> Self
64 where
65 F: Fn() -> Fut + 'static,
66 Fut: Future<Output = Result<Request, OpenAIError>> + 'static,
67 {
68 Self {
69 make_request: Arc::new(move || Box::pin(make_request())),
70 }
71 }
72
73 pub async fn build(&self) -> Result<Request, OpenAIError> {
79 (self.make_request)().await
80 }
81}
82
83#[cfg(not(target_family = "wasm"))]
90pub trait HttpExecutor: Send + Sync {
91 fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
92}
93
94#[cfg(target_family = "wasm")]
95pub trait HttpExecutor {
96 fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
97}
98
99#[derive(Clone, Debug, Default)]
105pub struct ReqwestService {
106 client: reqwest::Client,
107}
108
109impl ReqwestService {
110 pub fn new(client: reqwest::Client) -> Self {
111 Self { client }
112 }
113}
114
115impl tower::Service<HttpRequestFactory> for ReqwestService {
116 type Response = Response;
117 type Error = OpenAIError;
118 type Future = HttpFuture;
119
120 fn poll_ready(
121 &mut self,
122 _cx: &mut std::task::Context<'_>,
123 ) -> std::task::Poll<Result<(), Self::Error>> {
124 std::task::Poll::Ready(Ok(()))
125 }
126
127 fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
128 let client = self.client.clone();
129 Box::pin(async move {
130 let request = request.build().await?;
133 client.execute(request).await.map_err(OpenAIError::Reqwest)
134 })
135 }
136}
137
138#[derive(Clone, Debug)]
139pub(crate) struct ReqwestExecutor {
140 service: OpenAIRetry<ReqwestService>,
141}
142
143impl ReqwestExecutor {
144 pub(crate) fn new(client: reqwest::Client) -> Self {
145 Self {
146 service: tower::ServiceBuilder::new()
147 .layer(crate::retry::OpenAIRetryLayer::default())
148 .service(ReqwestService::new(client)),
149 }
150 }
151}
152
153impl HttpExecutor for ReqwestExecutor {
154 fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
155 use tower::ServiceExt;
156
157 let service = self.service.clone();
158 Box::pin(async move { service.oneshot(request).await })
159 }
160}
161
162#[cfg(feature = "middleware")]
163pub(crate) struct TowerExecutor<S> {
164 service: S,
165}
166
167#[cfg(feature = "middleware")]
168impl<S> TowerExecutor<S> {
169 pub(crate) fn new(service: S) -> Self {
170 Self { service }
174 }
175}
176
177#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
178impl<S> HttpExecutor for TowerExecutor<S>
179where
180 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
181 S::Future: Send + 'static,
182 S::Error: Into<OpenAIError> + Send + Sync + 'static,
183{
184 fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
185 use tower::ServiceExt;
186
187 let service = self.service.clone();
188 Box::pin(async move {
189 service.oneshot(request).await.map_err(Into::into)
193 })
194 }
195}
196
197#[cfg(all(feature = "middleware", target_family = "wasm"))]
198impl<S> HttpExecutor for TowerExecutor<S>
199where
200 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
201 S::Future: 'static,
202 S::Error: Into<OpenAIError> + 'static,
203{
204 fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
205 use tower::ServiceExt;
206
207 let service = self.service.clone();
208 Box::pin(async move { service.oneshot(request).await.map_err(Into::into) })
209 }
210}
211
212pub(crate) type SharedExecutor = Arc<dyn HttpExecutor>;