1use std::{future::Future, pin::Pin, sync::Arc};
2
3use reqwest::{Request, Response};
4
5use crate::{error::OpenAIError, retry::OpenAIRetry};
6
7#[cfg(not(target_family = "wasm"))]
8type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + Send + 'static>>;
9#[cfg(target_family = "wasm")]
10type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + 'static>>;
11
12#[cfg(not(target_family = "wasm"))]
13pub(crate) type HttpFuture =
14 Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16pub(crate) type HttpFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18#[cfg(not(target_family = "wasm"))]
19type RequestFn = dyn Fn() -> RequestFuture + Send + Sync + 'static;
20#[cfg(target_family = "wasm")]
21type RequestFn = dyn Fn() -> RequestFuture + 'static;
22
23#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
25pub trait MiddlewareInput: Send + Sync + 'static {}
26#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
27impl<T> MiddlewareInput for T where T: Send + Sync + 'static {}
28
29#[cfg(all(feature = "middleware", target_family = "wasm"))]
30pub trait MiddlewareInput: 'static {}
31#[cfg(all(feature = "middleware", target_family = "wasm"))]
32impl<T> MiddlewareInput for T where T: 'static {}
33
34#[derive(Clone)]
47pub struct HttpRequestFactory {
48 make_request: Arc<RequestFn>,
49}
50
51impl std::fmt::Debug for HttpRequestFactory {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("HttpRequestFactory").finish_non_exhaustive()
54 }
55}
56
57impl HttpRequestFactory {
58 #[cfg(not(target_family = "wasm"))]
63 pub fn new<F, Fut>(make_request: F) -> Self
64 where
65 F: Fn() -> Fut + Send + Sync + 'static,
66 Fut: Future<Output = Result<Request, OpenAIError>> + Send + 'static,
67 {
68 Self {
69 make_request: Arc::new(move || Box::pin(make_request())),
70 }
71 }
72
73 #[cfg(target_family = "wasm")]
74 pub fn new<F, Fut>(make_request: F) -> Self
75 where
76 F: Fn() -> Fut + 'static,
77 Fut: Future<Output = Result<Request, OpenAIError>> + 'static,
78 {
79 Self {
80 make_request: Arc::new(move || Box::pin(make_request())),
81 }
82 }
83
84 pub async fn build(&self) -> Result<Request, OpenAIError> {
90 (self.make_request)().await
91 }
92}
93
94#[cfg(not(target_family = "wasm"))]
101pub trait HttpExecutor: Send + Sync {
102 fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
103}
104
105#[cfg(target_family = "wasm")]
106pub trait HttpExecutor {
107 fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
108}
109
110#[derive(Clone, Debug, Default)]
116pub struct ReqwestService {
117 client: reqwest::Client,
118}
119
120impl ReqwestService {
121 pub fn new(client: reqwest::Client) -> Self {
122 Self { client }
123 }
124}
125
126impl tower::Service<HttpRequestFactory> for ReqwestService {
127 type Response = Response;
128 type Error = OpenAIError;
129 type Future = HttpFuture;
130
131 fn poll_ready(
132 &mut self,
133 _cx: &mut std::task::Context<'_>,
134 ) -> std::task::Poll<Result<(), Self::Error>> {
135 std::task::Poll::Ready(Ok(()))
136 }
137
138 fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
139 let client = self.client.clone();
140 Box::pin(async move {
141 let request = request.build().await?;
144 client.execute(request).await.map_err(OpenAIError::Reqwest)
145 })
146 }
147}
148
149#[derive(Clone, Debug)]
150pub(crate) struct ReqwestExecutor {
151 service: OpenAIRetry<ReqwestService>,
152}
153
154impl ReqwestExecutor {
155 pub(crate) fn new(client: reqwest::Client) -> Self {
156 Self {
157 service: tower::ServiceBuilder::new()
158 .layer(crate::retry::OpenAIRetryLayer::default())
159 .service(ReqwestService::new(client)),
160 }
161 }
162}
163
164impl HttpExecutor for ReqwestExecutor {
165 fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
166 use tower::ServiceExt;
167
168 let service = self.service.clone();
169 Box::pin(async move { service.oneshot(request).await })
170 }
171}
172
173#[cfg(feature = "middleware")]
174pub(crate) struct TowerExecutor<S> {
175 service: S,
176}
177
178#[cfg(feature = "middleware")]
179impl<S> TowerExecutor<S> {
180 pub(crate) fn new(service: S) -> Self {
181 Self { service }
185 }
186}
187
188#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
189impl<S> HttpExecutor for TowerExecutor<S>
190where
191 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
192 S::Future: Send + 'static,
193 S::Error: Into<OpenAIError> + Send + Sync + 'static,
194{
195 fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
196 use tower::ServiceExt;
197
198 let service = self.service.clone();
199 Box::pin(async move {
200 service.oneshot(request).await.map_err(Into::into)
204 })
205 }
206}
207
208#[cfg(all(feature = "middleware", target_family = "wasm"))]
209impl<S> HttpExecutor for TowerExecutor<S>
210where
211 S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
212 S::Future: 'static,
213 S::Error: Into<OpenAIError> + 'static,
214{
215 fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
216 use tower::ServiceExt;
217
218 let service = self.service.clone();
219 Box::pin(async move { service.oneshot(request).await.map_err(Into::into) })
220 }
221}
222
223pub(crate) type SharedExecutor = Arc<dyn HttpExecutor>;