reqwest_enum/
provider.rs

1#[cfg(feature = "jsonrpc")]
2use crate::jsonrpc::{JsonRpcError, JsonRpcRequest, JsonRpcResult};
3#[cfg(feature = "jsonrpc")]
4use crate::target::JsonRpcTarget;
5#[cfg(feature = "jsonrpc")]
6use futures::future::join_all;
7
8use crate::{
9    http::{AuthMethod, HTTPBody, HTTPResponse},
10    target::Target,
11};
12use core::future::Future;
13use reqwest::{Client, Error};
14use serde::de::DeserializeOwned;
15
16pub trait ProviderType<T: Target>: Send {
17    /// request to target and return http response
18    fn request(&self, target: T) -> impl Future<Output = Result<HTTPResponse, Error>>;
19}
20
21pub trait JsonProviderType<T: Target>: ProviderType<T> {
22    /// request and deserialize response to json using serde
23    fn request_json<U: DeserializeOwned>(
24        &self,
25        target: T,
26    ) -> impl Future<Output = Result<U, Error>>;
27}
28
29#[cfg(feature = "jsonrpc")]
30
31pub trait JsonRpcProviderType<T: Target>: ProviderType<T> {
32    /// batch isomorphic JSON-RPC requests
33    fn batch<U: DeserializeOwned>(
34        &self,
35        targets: Vec<T>,
36    ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
37
38    fn batch_chunk_by<U: DeserializeOwned>(
39        &self,
40        targets: Vec<T>,
41        chunk_size: usize,
42    ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
43}
44
45pub type EndpointFn<T> = fn(target: &T) -> String;
46pub type RequestBuilderFn<T> =
47    fn(request_builder: &reqwest::RequestBuilder, target: &T) -> reqwest::RequestBuilder;
48pub struct Provider<T: Target> {
49    /// endpoint closure to customize the endpoint (url / path)
50    endpoint_fn: Option<EndpointFn<T>>,
51    request_fn: Option<RequestBuilderFn<T>>,
52    client: Client,
53}
54
55impl<T> ProviderType<T> for Provider<T>
56where
57    T: Target + Send,
58{
59    async fn request(&self, target: T) -> Result<HTTPResponse, Error> {
60        let mut request = self.request_builder(&target);
61        request = request.body(target.body().inner);
62        if let Some(timeout) = target.timeout() {
63            request = request.timeout(timeout);
64        }
65        request.send().await
66    }
67}
68
69impl<T> JsonProviderType<T> for Provider<T>
70where
71    T: Target + Send,
72{
73    async fn request_json<U: DeserializeOwned>(&self, target: T) -> Result<U, Error> {
74        let response = self.request(target).await?;
75        let body = response.json::<U>().await?;
76        Ok(body)
77    }
78}
79
80#[cfg(feature = "jsonrpc")]
81impl<T> JsonRpcProviderType<T> for Provider<T>
82where
83    T: JsonRpcTarget + Send,
84{
85    async fn batch<U: DeserializeOwned>(
86        &self,
87        targets: Vec<T>,
88    ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
89        if targets.is_empty() {
90            return Err(JsonRpcError {
91                code: -32600,
92                message: "Invalid Request".into(),
93            });
94        }
95
96        let target = &targets[0];
97        let mut request = self.request_builder(target);
98        let mut requests = Vec::<JsonRpcRequest>::new();
99        for (k, v) in targets.iter().enumerate() {
100            let request = JsonRpcRequest::new(v.method_name(), v.params(), (k + 1) as u64);
101            requests.push(request);
102        }
103
104        request = request.body(HTTPBody::from_array(&requests).inner);
105        let response = request.send().await?;
106        let body = response.json::<Vec<JsonRpcResult<U>>>().await?;
107        Ok(body)
108    }
109
110    async fn batch_chunk_by<U: DeserializeOwned>(
111        &self,
112        targets: Vec<T>,
113        chunk_size: usize,
114    ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
115        if targets.is_empty() || chunk_size == 0 {
116            return Err(JsonRpcError {
117                code: -32600,
118                message: "Invalid Request".into(),
119            });
120        }
121
122        let chunk_targets = targets.chunks(chunk_size).collect::<Vec<_>>();
123        let mut rpc_requests = Vec::<reqwest::RequestBuilder>::new();
124
125        for (chunk_idx, chunk) in chunk_targets.into_iter().enumerate() {
126            let target = &chunk[0];
127            let mut request = self.request_builder(target);
128            let mut requests = Vec::<JsonRpcRequest>::new();
129            for (k, v) in chunk.iter().enumerate() {
130                let request = JsonRpcRequest::new(
131                    v.method_name(),
132                    v.params(),
133                    (chunk_idx * chunk_size + k + 1) as u64,
134                );
135                requests.push(request);
136            }
137
138            request = request.body(HTTPBody::from_array(&requests).inner);
139            rpc_requests.push(request);
140        }
141        let bodies = join_all(rpc_requests.into_iter().map(|request| async move {
142            let response = request.send().await?;
143            let body = response.json::<Vec<JsonRpcResult<U>>>().await?;
144            Ok(body)
145        }))
146        .await;
147
148        let mut results = Vec::<JsonRpcResult<U>>::new();
149        let mut error: Option<JsonRpcError> = None;
150
151        for result in bodies {
152            match result {
153                Ok(body) => {
154                    results.extend(body);
155                }
156                Err(err) => {
157                    error = Some(err);
158                }
159            }
160        }
161        if let Some(err) = error {
162            return Err(err);
163        }
164        Ok(results)
165    }
166}
167
168impl<T> Provider<T>
169where
170    T: Target,
171{
172    pub fn new(
173        endpoint_fn: Option<EndpointFn<T>>,
174        request_fn: Option<RequestBuilderFn<T>>,
175    ) -> Self {
176        let client = reqwest::Client::new();
177        Self {
178            client,
179            endpoint_fn,
180            request_fn,
181        }
182    }
183
184    pub(crate) fn request_url(&self, target: &T) -> String {
185        let mut url = format!("{}{}", target.base_url(), target.path());
186        if let Some(func) = &self.endpoint_fn {
187            url = func(target);
188        }
189        url
190    }
191
192    pub(crate) fn request_builder(&self, target: &T) -> reqwest::RequestBuilder {
193        let url = self.request_url(target);
194        let mut request = self.client.request(target.method().into(), url);
195        let query_map = target.query();
196        if !query_map.is_empty() {
197            request = request.query(&query_map);
198        }
199        if !target.headers().is_empty() {
200            for (k, v) in target.headers() {
201                request = request.header(k, v);
202            }
203        }
204        if let Some(auth) = target.authentication() {
205            match auth {
206                AuthMethod::Basic(username, password) => {
207                    request = request.basic_auth(username, Some(password));
208                }
209                AuthMethod::Bearer(token) => {
210                    request = request.bearer_auth(token);
211                }
212            }
213        }
214        if let Some(request_fn) = &self.request_fn {
215            request = request_fn(&mut request, target);
216        }
217        request
218    }
219}
220
221impl<T> Default for Provider<T>
222where
223    T: Target,
224{
225    fn default() -> Self {
226        Self {
227            client: reqwest::Client::new(),
228            endpoint_fn: None,
229            request_fn: None,
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use crate::{
237        http::{AuthMethod, HTTPBody, HTTPMethod},
238        provider::{JsonProviderType, Provider},
239        target::Target,
240    };
241    use serde::{Deserialize, Serialize};
242    use std::collections::hash_map::DefaultHasher;
243    use std::collections::HashMap;
244    use std::hash::{Hash, Hasher};
245    use std::time::{Duration, UNIX_EPOCH};
246    use tokio_test::block_on;
247
248    #[derive(Serialize, Deserialize)]
249    struct Person {
250        name: String,
251        age: u8,
252        phones: Vec<String>,
253    }
254
255    enum HttpBin {
256        Get,
257        Post,
258        Bearer,
259    }
260
261    impl Target for HttpBin {
262        fn base_url(&self) -> &'static str {
263            "https://httpbin.org"
264        }
265
266        fn method(&self) -> HTTPMethod {
267            match self {
268                HttpBin::Get => HTTPMethod::GET,
269                HttpBin::Post => HTTPMethod::POST,
270                HttpBin::Bearer => HTTPMethod::GET,
271            }
272        }
273
274        fn path(&self) -> String {
275            let ts = UNIX_EPOCH + Duration::from_secs(1728044812);
276            match self {
277                HttpBin::Get => format!(
278                    "/get?ts={}",
279                    ts.duration_since(UNIX_EPOCH).unwrap().as_secs(),
280                ),
281                HttpBin::Post => "/post".into(),
282                HttpBin::Bearer => "/bearer".into(),
283            }
284        }
285
286        fn query(&self) -> HashMap<&'static str, &'static str> {
287            HashMap::from([("foo", "bar")])
288        }
289
290        fn headers(&self) -> HashMap<&'static str, &'static str> {
291            HashMap::default()
292        }
293
294        fn authentication(&self) -> Option<AuthMethod> {
295            match self {
296                HttpBin::Bearer => Some(AuthMethod::Bearer("token")),
297                _ => None,
298            }
299        }
300
301        fn body(&self) -> HTTPBody {
302            match self {
303                HttpBin::Get | HttpBin::Bearer => HTTPBody::default(),
304                HttpBin::Post => HTTPBody::from(&Person {
305                    name: "test".to_string(),
306                    age: 20,
307                    phones: vec!["1234567890".to_string()],
308                }),
309            }
310        }
311        fn timeout(&self) -> Option<Duration> {
312            None
313        }
314    }
315
316    #[test]
317    fn test_test_endpoint_closure() {
318        let provider = Provider::<HttpBin>::default();
319        assert_eq!(
320            provider.request_url(&HttpBin::Get),
321            "https://httpbin.org/get?ts=1728044812"
322        );
323
324        let provider =
325            Provider::<HttpBin>::new(Some(|_: &HttpBin| "http://httpbin.org".to_string()), None);
326        assert_eq!(provider.request_url(&HttpBin::Post), "http://httpbin.org");
327    }
328
329    #[test]
330    fn test_request_fn() {
331        let provider = Provider::<HttpBin>::new(
332            None,
333            Some(|builder: &reqwest::RequestBuilder, target: &HttpBin| {
334                let mut hasher = DefaultHasher::new();
335                target.query_string().hash(&mut hasher);
336                let hash = hasher.finish();
337
338                let mut req = builder.try_clone().expect("trying to clone request");
339                req = req.header("X-test", "test");
340                req = req.header("X-hash", format!("{}", hash));
341                req
342            }),
343        );
344
345        let request = provider.request_builder(&HttpBin::Get).build().unwrap();
346        let headers = request.headers();
347
348        assert_eq!(request.method().to_string(), "GET");
349        assert_eq!(headers.get("X-test").unwrap(), "test");
350        assert_eq!(headers.get("X-hash").unwrap(), "3270317559611782182");
351    }
352
353    #[test]
354    fn test_authentication() {
355        let provider = Provider::<HttpBin>::default();
356        block_on(async {
357            let response: serde_json::Value = provider
358                .request_json(HttpBin::Bearer)
359                .await
360                .expect("request error");
361
362            assert!(response["authenticated"].as_bool().unwrap());
363        });
364    }
365}