reqwest_enum/
provider.rs

1use std::{time::Duration, future::Future};
2use serde::de::DeserializeOwned;
3use crate::{
4    http::{AuthMethod, HTTPBody, HTTPResponse},
5    error::Error,
6    target::Target,
7};
8
9#[cfg(feature = "jsonrpc")]
10use crate::{
11    jsonrpc::{JsonRpcError, JsonRpcRequest, JsonRpcResult},
12    target::JsonRpcTarget,
13};
14#[cfg(feature = "jsonrpc")]
15use futures::future::join_all;
16
17#[cfg(not(feature = "middleware"))]
18pub type ProviderRequestBuilder = reqwest::RequestBuilder;
19#[cfg(feature = "middleware")]
20pub type ProviderRequestBuilder = reqwest_middleware::RequestBuilder;
21#[cfg(feature = "middleware")]
22use reqwest_middleware::{ClientBuilder as MiddlewareClientBuilder, ClientWithMiddleware};
23
24// Base trait for providers, defining the core request method.
25pub trait ProviderType<T: Target>: Send {
26    /// request to target and return http response
27    fn request(&self, target: T) -> impl Future<Output = Result<HTTPResponse, Error>>;
28}
29
30// Trait for providers that can handle JSON responses, deserializing them.
31pub trait JsonProviderType<T: Target>: ProviderType<T> {
32    /// request and deserialize response to json using serde
33    fn request_json<U: DeserializeOwned>(
34        &self,
35        target: T,
36    ) -> impl Future<Output = Result<U, Error>>;
37}
38
39#[cfg(feature = "jsonrpc")]
40pub trait JsonRpcProviderType<T: Target>: ProviderType<T> {
41    /// batch isomorphic JSON-RPC requests
42    fn batch<U: DeserializeOwned>(
43        &self,
44        targets: Vec<T>,
45    ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
46
47    fn batch_chunk_by<U: DeserializeOwned>(
48        &self,
49        targets: Vec<T>,
50        chunk_size: usize,
51    ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
52}
53
54pub type EndpointFn<T> = fn(target: &T) -> String;
55pub type RequestBuilderFn<T> =
56    Box<dyn Fn(&ProviderRequestBuilder, &T) -> ProviderRequestBuilder + Send + Sync>;
57
58/// Generic provider for HTTP requests to a `Target`. Handles construction, auth, and execution.
59pub struct Provider<T: Target> {
60    /// endpoint closure to customize the endpoint (url / path)
61    endpoint_fn: Option<EndpointFn<T>>,
62    request_fn: Option<RequestBuilderFn<T>>,
63    /// An optional default timeout for all requests made by this provider.
64    /// If set, this timeout is applied to each request unless overridden by more specific timeout logic.
65    timeout: Option<Duration>,
66    #[cfg(not(feature = "middleware"))]
67    client: reqwest::Client,
68    #[cfg(feature = "middleware")]
69    client: ClientWithMiddleware,
70}
71
72impl<T: Target> std::fmt::Debug for Provider<T> {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("Provider")
75            .field("endpoint_fn", &self.endpoint_fn.map(|_| "<function>")) // Print placeholder for fn pointer
76            .field("request_fn", &self.request_fn.as_ref().map(|_| "<function>")) // Print placeholder for Box<dyn Fn>
77            .field("timeout", &self.timeout)
78            .field("client", &self.client) // reqwest::Client and reqwest_middleware::ClientWithMiddleware implement Debug
79            .finish()
80    }
81}
82
83impl<T> ProviderType<T> for Provider<T>
84where
85    T: Target + Send,
86{
87    /// Builds and executes a request to `Target`, returning raw `HTTPResponse`.
88    async fn request(&self, target: T) -> Result<HTTPResponse, Error> {
89        let req = self.request_builder(&target)?.build()?;
90        self.client.execute(req).await.map_err(Error::from)
91    }
92}
93
94impl<T> JsonProviderType<T> for Provider<T>
95where
96    T: Target + Send,
97{
98    async fn request_json<U: DeserializeOwned>(&self, target: T) -> Result<U, Error> {
99        let response = self.request(target).await?;
100
101        // Check status and get Response or reqwest::Error
102        let response = response.error_for_status()?;
103
104        // If error_for_status succeeded, deserialize the JSON.
105        let body: U = response.json().await?;
106
107        Ok(body)
108    }
109}
110
111#[cfg(feature = "jsonrpc")]
112impl<T> JsonRpcProviderType<T> for Provider<T>
113where
114    T: JsonRpcTarget + Send,
115{
116    async fn batch<U: DeserializeOwned>(
117        &self,
118        targets: Vec<T>,
119    ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
120        if targets.is_empty() {
121            return Err(JsonRpcError {
122                code: -32600,
123                message: "Invalid Request".into(),
124            });
125        }
126
127        let representative_target = &targets[0];
128
129        let mut builder = self.request_builder(representative_target)?;
130
131        let mut rpc_payload = Vec::new();
132        for (k, individual_target) in targets.iter().enumerate() {
133            let req = JsonRpcRequest::new(individual_target.method_name(), individual_target.params(), (k + 1) as u64);
134            rpc_payload.push(req);
135        }
136        let body = HTTPBody::from_array(&rpc_payload).map_err(|e| JsonRpcError { code: -32700, message: format!("Failed to serialize batch request: {}", e) })?;
137
138        builder = builder.body(body.inner);
139
140        // Build the final reqwest::Request
141        let final_request = builder.build().map_err(|e| JsonRpcError { code: -32603, message: format!("Failed to build batch request: {}", e) })?;
142
143        // Execute the request using self.client
144        let response = self.client.execute(final_request).await.map_err(|e| JsonRpcError { code: -32603, message: format!("Batch request execution failed: {}", e) })?;
145        
146        // Deserialize the response
147        let response_body = response.json::<Vec<JsonRpcResult<U>>>().await.map_err(|e| JsonRpcError { code: -32700, message: format!("Failed to parse batch JSON response: {}", e) })?;
148        Ok(response_body)
149    }
150
151    async fn batch_chunk_by<U: DeserializeOwned>(
152        &self,
153        targets: Vec<T>,
154        chunk_size: usize,
155    ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
156        if targets.is_empty() || chunk_size == 0 {
157            return Err(JsonRpcError {
158                code: -32600,
159                message: "Invalid Request".into(),
160            });
161        }
162
163        let chunk_targets = targets.chunks(chunk_size).collect::<Vec<_>>();
164        let mut rpc_requests = Vec::<ProviderRequestBuilder>::new();
165
166        for (chunk_idx, chunk) in chunk_targets.into_iter().enumerate() {
167            let target = &chunk[0];
168            let mut request = self.request_builder(target);
169            let mut requests = Vec::<JsonRpcRequest>::new();
170            for (k, v) in chunk.iter().enumerate() {
171                let request = JsonRpcRequest::new(
172                    v.method_name(),
173                    v.params(),
174                    (chunk_idx * chunk_size + k + 1) as u64,
175                );
176                requests.push(request);
177            }
178
179            let http_body = HTTPBody::from_array(&requests).map_err(|e| JsonRpcError { code: -32700, message: format!("Failed to serialize batch chunk: {}", e) })?;
180            request = Ok(request?.body(http_body.inner));
181            rpc_requests.push(request?);
182        }
183        let bodies = join_all(rpc_requests.into_iter().map(|request| async move {
184            #[cfg(feature = "middleware")]
185            let response = request.send().await.map_err(crate::Error::ReqwestMiddleware)?;
186            #[cfg(not(feature = "middleware"))]
187            let response = request.send().await?;
188            let body = response.json::<Vec<JsonRpcResult<U>>>().await?;
189            Ok(body)
190        }))
191        .await;
192
193        let mut results = Vec::<JsonRpcResult<U>>::new();
194        let mut error: Option<JsonRpcError> = None;
195
196        for result in bodies {
197            match result {
198                Ok(body) => {
199                    results.extend(body);
200                }
201                Err(err) => {
202                    error = Some(err);
203                }
204            }
205        }
206        if let Some(err) = error {
207            return Err(err);
208        }
209        Ok(results)
210    }
211}
212
213impl<T> Provider<T>
214where
215    T: Target,
216{
217    /// Creates a new `Provider` with optional URL, request builder customization, and timeout.
218    pub fn new(
219        endpoint_fn: Option<EndpointFn<T>>,
220        request_fn: Option<RequestBuilderFn<T>>,
221        timeout: Option<Duration>,
222    ) -> Self {
223        #[cfg(not(feature = "middleware"))]
224        let client = reqwest::Client::new();
225        #[cfg(feature = "middleware")]
226        let client = {
227            MiddlewareClientBuilder::new(reqwest::Client::new()).build()
228        };
229        Self {
230            client,
231            endpoint_fn,
232            request_fn,
233            timeout,
234        }
235    }
236
237    #[cfg(not(feature = "middleware"))]
238    pub fn with_client(
239        client: reqwest::Client,
240        endpoint_fn: Option<EndpointFn<T>>,
241        request_fn: Option<RequestBuilderFn<T>>,
242    ) -> Self {
243        Self {
244            endpoint_fn,
245            request_fn,
246            client,
247            timeout: None,
248        }
249    }
250
251    #[cfg(feature = "middleware")]
252    pub fn with_client(
253        client: ClientWithMiddleware,
254        endpoint_fn: Option<EndpointFn<T>>,
255        request_fn: Option<RequestBuilderFn<T>>,
256    ) -> Self {
257        Self {
258            endpoint_fn,
259            request_fn,
260            client,
261            timeout: None,
262        }
263    }
264
265    pub fn request_url(&self, target: &T) -> String {
266        let mut url = format!("{}{}", target.base_url(), target.path());
267        if let Some(func) = &self.endpoint_fn {
268            url = func(target);
269        }
270        url
271    }
272
273    /// Constructs a `reqwest::RequestBuilder` for the `Target`, applying URL, method, query, headers, auth, body, timeout, and custom `request_fn`.
274    pub(crate) fn request_builder(&self, target: &T) -> Result<ProviderRequestBuilder, Error> {
275        let url = self.request_url(target);
276        let mut request_builder = self.client.request(target.method().into(), url.as_str());
277
278        // apply query params
279        request_builder = request_builder.query(&target.query());
280
281        // apply headers
282        for (key, value) in target.headers() {
283            request_builder = request_builder.header(key, value);
284        }
285
286        // apply authentication
287        if let Some(auth) = target.authentication() {
288            request_builder = match auth {
289                AuthMethod::Bearer(token) => request_builder.bearer_auth(token),
290                AuthMethod::Basic(username, password) => request_builder.basic_auth(username, password),
291                AuthMethod::Custom(auth_fn) => auth_fn(request_builder),
292            };
293        }
294
295        // apply body
296        let body = target.body()?;
297        request_builder = request_builder.body(body.inner);
298
299        // apply provider timeout
300        if let Some(provider_timeout) = self.timeout {
301            request_builder = request_builder.timeout(provider_timeout);
302        }
303
304        // apply request_fn closure
305        if let Some(r_fn) = &self.request_fn {
306            request_builder = r_fn(&request_builder, target);
307        }
308
309        Ok(request_builder)
310    }
311}
312
313impl<T> Default for Provider<T>
314where
315    T: Target,
316{
317    fn default() -> Self {
318        #[cfg(not(feature = "middleware"))]
319        let client = reqwest::Client::new();
320        #[cfg(feature = "middleware")]
321        let client = {
322            MiddlewareClientBuilder::new(reqwest::Client::new()).build()
323        };
324        Self {
325            client,
326            endpoint_fn: None,
327            request_fn: None,
328            timeout: None,
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::{
337        http::{AuthMethod, HTTPBody, HTTPMethod},
338        provider::{JsonProviderType, Provider},
339        target::Target,
340    };
341    use serde::{Deserialize, Serialize};
342    use std::{borrow::Cow, collections::hash_map::DefaultHasher, collections::HashMap, hash::{Hash, Hasher}, time::{Duration, UNIX_EPOCH}};
343
344    #[derive(Serialize, Deserialize)]
345    struct Person {
346        name: String,
347        age: u8,
348        phones: Vec<String>,
349    }
350
351    enum HttpBin {
352        Get,
353        Post,
354        Bearer,
355        HeaderAuth,
356    }
357
358    impl Target for HttpBin {
359        fn base_url(&self) -> Cow<'_, str> {
360            Cow::Borrowed("https://httpbin.org")
361        }
362
363        fn method(&self) -> HTTPMethod {
364            match self {
365                HttpBin::Get => HTTPMethod::GET,
366                HttpBin::Post => HTTPMethod::POST,
367                HttpBin::Bearer => HTTPMethod::GET,
368                HttpBin::HeaderAuth => HTTPMethod::GET,
369            }
370        }
371
372        fn path(&self) -> String {
373            let ts = UNIX_EPOCH + Duration::from_secs(1728044812);
374            match self {
375                HttpBin::Get => format!(
376                    "/get?ts={}",
377                    ts.duration_since(UNIX_EPOCH).unwrap().as_secs(),
378                ),
379                HttpBin::Post => "/post".into(),
380                HttpBin::Bearer => "/bearer".into(),
381                HttpBin::HeaderAuth => "/headers".into(),
382            }
383        }
384
385        fn query(&self) -> HashMap<String, String> {
386            HashMap::from([("foo".to_string(), "bar".to_string())])
387        }
388
389        fn headers(&self) -> HashMap<String, String> {
390            HashMap::default()
391        }
392
393        fn authentication(&self) -> Option<AuthMethod> {
394            match self {
395                HttpBin::Bearer => Some(AuthMethod::Bearer("token".to_string())),
396                HttpBin::HeaderAuth => Some(AuthMethod::header_api_key(
397                    "X-Test-Api-Key".to_string(),
398                    "my-secret-key".to_string(),
399                )),
400                _ => None,
401            }
402        }
403
404        fn body(&self) -> Result<HTTPBody, crate::Error> {
405            match self {
406                HttpBin::Get | HttpBin::Bearer | HttpBin::HeaderAuth => Ok(HTTPBody::default()),
407                HttpBin::Post => {
408                    let person = Person {
409                        name: "test".to_string(),
410                        age: 20,
411                        phones: vec!["1234567890".to_string()],
412                    };
413                    Ok(HTTPBody::from(&person)?)
414                }
415            }
416        }
417    }
418
419    #[test]
420    fn test_test_endpoint_closure() {
421        let provider = Provider::<HttpBin>::default();
422        assert_eq!(
423            provider.request_url(&HttpBin::Get),
424            "https://httpbin.org/get?ts=1728044812"
425        );
426
427        let provider =
428            Provider::<HttpBin>::new(Some(|_: &HttpBin| "http://httpbin.org".to_string()), None, None);
429        assert_eq!(provider.request_url(&HttpBin::Post), "http://httpbin.org");
430    }
431
432    #[test]
433    fn test_request_fn() {
434        let provider = Provider::<HttpBin>::new(
435            None,
436            Some(Box::new(|builder: &ProviderRequestBuilder, target: &HttpBin| {
437                let mut hasher = DefaultHasher::new();
438                target.query_string().hash(&mut hasher);
439                let hash = hasher.finish();
440
441                let mut req = builder.try_clone().expect("trying to clone request");
442                req = req.header("X-test", "test");
443                req = req.header("X-hash", format!("{}", hash));
444                req
445            })),
446            None,
447        );
448
449        let request = provider.request_builder(&HttpBin::Get).unwrap().build().unwrap();
450        let headers = request.headers();
451
452        assert_eq!(request.method().to_string(), "GET");
453        assert_eq!(headers.get("X-test").unwrap(), "test");
454        assert_eq!(headers.get("X-hash").unwrap(), "3270317559611782182");
455    }
456
457    #[tokio::test]
458    async fn test_authentication() {
459        let provider = Provider::<HttpBin>::default();
460        let response: serde_json::Value = provider
461            .request_json(HttpBin::Bearer)
462            .await
463            .expect("request error");
464
465        assert!(response["authenticated"].as_bool().unwrap());
466    }
467
468    #[tokio::test]
469    async fn test_header_api_key_auth() {
470        let provider = Provider::<HttpBin>::default();
471        let response: serde_json::Value = provider
472            .request_json(HttpBin::HeaderAuth)
473            .await
474            .expect("request error");
475
476        // httpbin /headers returns a JSON object like: {"headers": {"Header-Name": "Header-Value", ...}}
477        let headers_map = response.get("headers").unwrap().as_object().unwrap();
478        assert_eq!(
479            headers_map.get("X-Test-Api-Key").unwrap().as_str().unwrap(),
480            "my-secret-key"
481        );
482    }
483}