1use std::{time::Duration, future::Future};
2use serde::de::DeserializeOwned;
3use crate::{
4 http::{AuthMethod, HTTPBody, HTTPResponse},
5 error::Error,
6 target::Target,
7};
8
9#[cfg(feature = "jsonrpc")]
10use crate::{
11 jsonrpc::{JsonRpcError, JsonRpcRequest, JsonRpcResult},
12 target::JsonRpcTarget,
13};
14#[cfg(feature = "jsonrpc")]
15use futures::future::join_all;
16
17#[cfg(not(feature = "middleware"))]
18pub type ProviderRequestBuilder = reqwest::RequestBuilder;
19#[cfg(feature = "middleware")]
20pub type ProviderRequestBuilder = reqwest_middleware::RequestBuilder;
21#[cfg(feature = "middleware")]
22use reqwest_middleware::{ClientBuilder as MiddlewareClientBuilder, ClientWithMiddleware};
23
24pub trait ProviderType<T: Target>: Send {
26 fn request(&self, target: T) -> impl Future<Output = Result<HTTPResponse, Error>>;
28}
29
30pub trait JsonProviderType<T: Target>: ProviderType<T> {
32 fn request_json<U: DeserializeOwned>(
34 &self,
35 target: T,
36 ) -> impl Future<Output = Result<U, Error>>;
37}
38
39#[cfg(feature = "jsonrpc")]
40pub trait JsonRpcProviderType<T: Target>: ProviderType<T> {
41 fn batch<U: DeserializeOwned>(
43 &self,
44 targets: Vec<T>,
45 ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
46
47 fn batch_chunk_by<U: DeserializeOwned>(
48 &self,
49 targets: Vec<T>,
50 chunk_size: usize,
51 ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
52}
53
54pub type EndpointFn<T> = fn(target: &T) -> String;
55pub type RequestBuilderFn<T> =
56 Box<dyn Fn(&ProviderRequestBuilder, &T) -> ProviderRequestBuilder + Send + Sync>;
57
58pub struct Provider<T: Target> {
60 endpoint_fn: Option<EndpointFn<T>>,
62 request_fn: Option<RequestBuilderFn<T>>,
63 timeout: Option<Duration>,
66 #[cfg(not(feature = "middleware"))]
67 client: reqwest::Client,
68 #[cfg(feature = "middleware")]
69 client: ClientWithMiddleware,
70}
71
72impl<T: Target> std::fmt::Debug for Provider<T> {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("Provider")
75 .field("endpoint_fn", &self.endpoint_fn.map(|_| "<function>")) .field("request_fn", &self.request_fn.as_ref().map(|_| "<function>")) .field("timeout", &self.timeout)
78 .field("client", &self.client) .finish()
80 }
81}
82
83impl<T> ProviderType<T> for Provider<T>
84where
85 T: Target + Send,
86{
87 async fn request(&self, target: T) -> Result<HTTPResponse, Error> {
89 let req = self.request_builder(&target)?.build()?;
90 self.client.execute(req).await.map_err(Error::from)
91 }
92}
93
94impl<T> JsonProviderType<T> for Provider<T>
95where
96 T: Target + Send,
97{
98 async fn request_json<U: DeserializeOwned>(&self, target: T) -> Result<U, Error> {
99 let response = self.request(target).await?;
100
101 let response = response.error_for_status()?;
103
104 let body: U = response.json().await?;
106
107 Ok(body)
108 }
109}
110
111#[cfg(feature = "jsonrpc")]
112impl<T> JsonRpcProviderType<T> for Provider<T>
113where
114 T: JsonRpcTarget + Send,
115{
116 async fn batch<U: DeserializeOwned>(
117 &self,
118 targets: Vec<T>,
119 ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
120 if targets.is_empty() {
121 return Err(JsonRpcError {
122 code: -32600,
123 message: "Invalid Request".into(),
124 });
125 }
126
127 let representative_target = &targets[0];
128
129 let mut builder = self.request_builder(representative_target)?;
130
131 let mut rpc_payload = Vec::new();
132 for (k, individual_target) in targets.iter().enumerate() {
133 let req = JsonRpcRequest::new(individual_target.method_name(), individual_target.params(), (k + 1) as u64);
134 rpc_payload.push(req);
135 }
136 let body = HTTPBody::from_array(&rpc_payload).map_err(|e| JsonRpcError { code: -32700, message: format!("Failed to serialize batch request: {}", e) })?;
137
138 builder = builder.body(body.inner);
139
140 let final_request = builder.build().map_err(|e| JsonRpcError { code: -32603, message: format!("Failed to build batch request: {}", e) })?;
142
143 let response = self.client.execute(final_request).await.map_err(|e| JsonRpcError { code: -32603, message: format!("Batch request execution failed: {}", e) })?;
145
146 let response_body = response.json::<Vec<JsonRpcResult<U>>>().await.map_err(|e| JsonRpcError { code: -32700, message: format!("Failed to parse batch JSON response: {}", e) })?;
148 Ok(response_body)
149 }
150
151 async fn batch_chunk_by<U: DeserializeOwned>(
152 &self,
153 targets: Vec<T>,
154 chunk_size: usize,
155 ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
156 if targets.is_empty() || chunk_size == 0 {
157 return Err(JsonRpcError {
158 code: -32600,
159 message: "Invalid Request".into(),
160 });
161 }
162
163 let chunk_targets = targets.chunks(chunk_size).collect::<Vec<_>>();
164 let mut rpc_requests = Vec::<ProviderRequestBuilder>::new();
165
166 for (chunk_idx, chunk) in chunk_targets.into_iter().enumerate() {
167 let target = &chunk[0];
168 let mut request = self.request_builder(target);
169 let mut requests = Vec::<JsonRpcRequest>::new();
170 for (k, v) in chunk.iter().enumerate() {
171 let request = JsonRpcRequest::new(
172 v.method_name(),
173 v.params(),
174 (chunk_idx * chunk_size + k + 1) as u64,
175 );
176 requests.push(request);
177 }
178
179 let http_body = HTTPBody::from_array(&requests).map_err(|e| JsonRpcError { code: -32700, message: format!("Failed to serialize batch chunk: {}", e) })?;
180 request = Ok(request?.body(http_body.inner));
181 rpc_requests.push(request?);
182 }
183 let bodies = join_all(rpc_requests.into_iter().map(|request| async move {
184 #[cfg(feature = "middleware")]
185 let response = request.send().await.map_err(crate::Error::ReqwestMiddleware)?;
186 #[cfg(not(feature = "middleware"))]
187 let response = request.send().await?;
188 let body = response.json::<Vec<JsonRpcResult<U>>>().await?;
189 Ok(body)
190 }))
191 .await;
192
193 let mut results = Vec::<JsonRpcResult<U>>::new();
194 let mut error: Option<JsonRpcError> = None;
195
196 for result in bodies {
197 match result {
198 Ok(body) => {
199 results.extend(body);
200 }
201 Err(err) => {
202 error = Some(err);
203 }
204 }
205 }
206 if let Some(err) = error {
207 return Err(err);
208 }
209 Ok(results)
210 }
211}
212
213impl<T> Provider<T>
214where
215 T: Target,
216{
217 pub fn new(
219 endpoint_fn: Option<EndpointFn<T>>,
220 request_fn: Option<RequestBuilderFn<T>>,
221 timeout: Option<Duration>,
222 ) -> Self {
223 #[cfg(not(feature = "middleware"))]
224 let client = reqwest::Client::new();
225 #[cfg(feature = "middleware")]
226 let client = {
227 MiddlewareClientBuilder::new(reqwest::Client::new()).build()
228 };
229 Self {
230 client,
231 endpoint_fn,
232 request_fn,
233 timeout,
234 }
235 }
236
237 #[cfg(not(feature = "middleware"))]
238 pub fn with_client(
239 client: reqwest::Client,
240 endpoint_fn: Option<EndpointFn<T>>,
241 request_fn: Option<RequestBuilderFn<T>>,
242 ) -> Self {
243 Self {
244 endpoint_fn,
245 request_fn,
246 client,
247 timeout: None,
248 }
249 }
250
251 #[cfg(feature = "middleware")]
252 pub fn with_client(
253 client: ClientWithMiddleware,
254 endpoint_fn: Option<EndpointFn<T>>,
255 request_fn: Option<RequestBuilderFn<T>>,
256 ) -> Self {
257 Self {
258 endpoint_fn,
259 request_fn,
260 client,
261 timeout: None,
262 }
263 }
264
265 pub fn request_url(&self, target: &T) -> String {
266 let mut url = format!("{}{}", target.base_url(), target.path());
267 if let Some(func) = &self.endpoint_fn {
268 url = func(target);
269 }
270 url
271 }
272
273 pub(crate) fn request_builder(&self, target: &T) -> Result<ProviderRequestBuilder, Error> {
275 let url = self.request_url(target);
276 let mut request_builder = self.client.request(target.method().into(), url.as_str());
277
278 request_builder = request_builder.query(&target.query());
280
281 for (key, value) in target.headers() {
283 request_builder = request_builder.header(key, value);
284 }
285
286 if let Some(auth) = target.authentication() {
288 request_builder = match auth {
289 AuthMethod::Bearer(token) => request_builder.bearer_auth(token),
290 AuthMethod::Basic(username, password) => request_builder.basic_auth(username, password),
291 AuthMethod::Custom(auth_fn) => auth_fn(request_builder),
292 };
293 }
294
295 let body = target.body()?;
297 request_builder = request_builder.body(body.inner);
298
299 if let Some(provider_timeout) = self.timeout {
301 request_builder = request_builder.timeout(provider_timeout);
302 }
303
304 if let Some(r_fn) = &self.request_fn {
306 request_builder = r_fn(&request_builder, target);
307 }
308
309 Ok(request_builder)
310 }
311}
312
313impl<T> Default for Provider<T>
314where
315 T: Target,
316{
317 fn default() -> Self {
318 #[cfg(not(feature = "middleware"))]
319 let client = reqwest::Client::new();
320 #[cfg(feature = "middleware")]
321 let client = {
322 MiddlewareClientBuilder::new(reqwest::Client::new()).build()
323 };
324 Self {
325 client,
326 endpoint_fn: None,
327 request_fn: None,
328 timeout: None,
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::{
337 http::{AuthMethod, HTTPBody, HTTPMethod},
338 provider::{JsonProviderType, Provider},
339 target::Target,
340 };
341 use serde::{Deserialize, Serialize};
342 use std::{borrow::Cow, collections::hash_map::DefaultHasher, collections::HashMap, hash::{Hash, Hasher}, time::{Duration, UNIX_EPOCH}};
343
344 #[derive(Serialize, Deserialize)]
345 struct Person {
346 name: String,
347 age: u8,
348 phones: Vec<String>,
349 }
350
351 enum HttpBin {
352 Get,
353 Post,
354 Bearer,
355 HeaderAuth,
356 }
357
358 impl Target for HttpBin {
359 fn base_url(&self) -> Cow<'_, str> {
360 Cow::Borrowed("https://httpbin.org")
361 }
362
363 fn method(&self) -> HTTPMethod {
364 match self {
365 HttpBin::Get => HTTPMethod::GET,
366 HttpBin::Post => HTTPMethod::POST,
367 HttpBin::Bearer => HTTPMethod::GET,
368 HttpBin::HeaderAuth => HTTPMethod::GET,
369 }
370 }
371
372 fn path(&self) -> String {
373 let ts = UNIX_EPOCH + Duration::from_secs(1728044812);
374 match self {
375 HttpBin::Get => format!(
376 "/get?ts={}",
377 ts.duration_since(UNIX_EPOCH).unwrap().as_secs(),
378 ),
379 HttpBin::Post => "/post".into(),
380 HttpBin::Bearer => "/bearer".into(),
381 HttpBin::HeaderAuth => "/headers".into(),
382 }
383 }
384
385 fn query(&self) -> HashMap<String, String> {
386 HashMap::from([("foo".to_string(), "bar".to_string())])
387 }
388
389 fn headers(&self) -> HashMap<String, String> {
390 HashMap::default()
391 }
392
393 fn authentication(&self) -> Option<AuthMethod> {
394 match self {
395 HttpBin::Bearer => Some(AuthMethod::Bearer("token".to_string())),
396 HttpBin::HeaderAuth => Some(AuthMethod::header_api_key(
397 "X-Test-Api-Key".to_string(),
398 "my-secret-key".to_string(),
399 )),
400 _ => None,
401 }
402 }
403
404 fn body(&self) -> Result<HTTPBody, crate::Error> {
405 match self {
406 HttpBin::Get | HttpBin::Bearer | HttpBin::HeaderAuth => Ok(HTTPBody::default()),
407 HttpBin::Post => {
408 let person = Person {
409 name: "test".to_string(),
410 age: 20,
411 phones: vec!["1234567890".to_string()],
412 };
413 Ok(HTTPBody::from(&person)?)
414 }
415 }
416 }
417 }
418
419 #[test]
420 fn test_test_endpoint_closure() {
421 let provider = Provider::<HttpBin>::default();
422 assert_eq!(
423 provider.request_url(&HttpBin::Get),
424 "https://httpbin.org/get?ts=1728044812"
425 );
426
427 let provider =
428 Provider::<HttpBin>::new(Some(|_: &HttpBin| "http://httpbin.org".to_string()), None, None);
429 assert_eq!(provider.request_url(&HttpBin::Post), "http://httpbin.org");
430 }
431
432 #[test]
433 fn test_request_fn() {
434 let provider = Provider::<HttpBin>::new(
435 None,
436 Some(Box::new(|builder: &ProviderRequestBuilder, target: &HttpBin| {
437 let mut hasher = DefaultHasher::new();
438 target.query_string().hash(&mut hasher);
439 let hash = hasher.finish();
440
441 let mut req = builder.try_clone().expect("trying to clone request");
442 req = req.header("X-test", "test");
443 req = req.header("X-hash", format!("{}", hash));
444 req
445 })),
446 None,
447 );
448
449 let request = provider.request_builder(&HttpBin::Get).unwrap().build().unwrap();
450 let headers = request.headers();
451
452 assert_eq!(request.method().to_string(), "GET");
453 assert_eq!(headers.get("X-test").unwrap(), "test");
454 assert_eq!(headers.get("X-hash").unwrap(), "3270317559611782182");
455 }
456
457 #[tokio::test]
458 async fn test_authentication() {
459 let provider = Provider::<HttpBin>::default();
460 let response: serde_json::Value = provider
461 .request_json(HttpBin::Bearer)
462 .await
463 .expect("request error");
464
465 assert!(response["authenticated"].as_bool().unwrap());
466 }
467
468 #[tokio::test]
469 async fn test_header_api_key_auth() {
470 let provider = Provider::<HttpBin>::default();
471 let response: serde_json::Value = provider
472 .request_json(HttpBin::HeaderAuth)
473 .await
474 .expect("request error");
475
476 let headers_map = response.get("headers").unwrap().as_object().unwrap();
478 assert_eq!(
479 headers_map.get("X-Test-Api-Key").unwrap().as_str().unwrap(),
480 "my-secret-key"
481 );
482 }
483}