1#[cfg(feature = "jsonrpc")]
2use crate::jsonrpc::{JsonRpcError, JsonRpcRequest, JsonRpcResult};
3#[cfg(feature = "jsonrpc")]
4use crate::target::JsonRpcTarget;
5#[cfg(feature = "jsonrpc")]
6use futures::future::join_all;
7
8use crate::{
9 http::{AuthMethod, HTTPBody, HTTPResponse},
10 target::Target,
11};
12use core::future::Future;
13use reqwest::{Client, Error};
14use serde::de::DeserializeOwned;
15
16pub trait ProviderType<T: Target>: Send {
17 fn request(&self, target: T) -> impl Future<Output = Result<HTTPResponse, Error>>;
19}
20
21pub trait JsonProviderType<T: Target>: ProviderType<T> {
22 fn request_json<U: DeserializeOwned>(
24 &self,
25 target: T,
26 ) -> impl Future<Output = Result<U, Error>>;
27}
28
29#[cfg(feature = "jsonrpc")]
30
31pub trait JsonRpcProviderType<T: Target>: ProviderType<T> {
32 fn batch<U: DeserializeOwned>(
34 &self,
35 targets: Vec<T>,
36 ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
37
38 fn batch_chunk_by<U: DeserializeOwned>(
39 &self,
40 targets: Vec<T>,
41 chunk_size: usize,
42 ) -> impl Future<Output = Result<Vec<JsonRpcResult<U>>, JsonRpcError>>;
43}
44
45pub type EndpointFn<T> = fn(target: &T) -> String;
46pub type RequestBuilderFn<T> =
47 fn(request_builder: &reqwest::RequestBuilder, target: &T) -> reqwest::RequestBuilder;
48pub struct Provider<T: Target> {
49 endpoint_fn: Option<EndpointFn<T>>,
51 request_fn: Option<RequestBuilderFn<T>>,
52 client: Client,
53}
54
55impl<T> ProviderType<T> for Provider<T>
56where
57 T: Target + Send,
58{
59 async fn request(&self, target: T) -> Result<HTTPResponse, Error> {
60 let mut request = self.request_builder(&target);
61 request = request.body(target.body().inner);
62 if let Some(timeout) = target.timeout() {
63 request = request.timeout(timeout);
64 }
65 request.send().await
66 }
67}
68
69impl<T> JsonProviderType<T> for Provider<T>
70where
71 T: Target + Send,
72{
73 async fn request_json<U: DeserializeOwned>(&self, target: T) -> Result<U, Error> {
74 let response = self.request(target).await?;
75 let body = response.json::<U>().await?;
76 Ok(body)
77 }
78}
79
80#[cfg(feature = "jsonrpc")]
81impl<T> JsonRpcProviderType<T> for Provider<T>
82where
83 T: JsonRpcTarget + Send,
84{
85 async fn batch<U: DeserializeOwned>(
86 &self,
87 targets: Vec<T>,
88 ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
89 if targets.is_empty() {
90 return Err(JsonRpcError {
91 code: -32600,
92 message: "Invalid Request".into(),
93 });
94 }
95
96 let target = &targets[0];
97 let mut request = self.request_builder(target);
98 let mut requests = Vec::<JsonRpcRequest>::new();
99 for (k, v) in targets.iter().enumerate() {
100 let request = JsonRpcRequest::new(v.method_name(), v.params(), (k + 1) as u64);
101 requests.push(request);
102 }
103
104 request = request.body(HTTPBody::from_array(&requests).inner);
105 let response = request.send().await?;
106 let body = response.json::<Vec<JsonRpcResult<U>>>().await?;
107 Ok(body)
108 }
109
110 async fn batch_chunk_by<U: DeserializeOwned>(
111 &self,
112 targets: Vec<T>,
113 chunk_size: usize,
114 ) -> Result<Vec<JsonRpcResult<U>>, JsonRpcError> {
115 if targets.is_empty() || chunk_size == 0 {
116 return Err(JsonRpcError {
117 code: -32600,
118 message: "Invalid Request".into(),
119 });
120 }
121
122 let chunk_targets = targets.chunks(chunk_size).collect::<Vec<_>>();
123 let mut rpc_requests = Vec::<reqwest::RequestBuilder>::new();
124
125 for (chunk_idx, chunk) in chunk_targets.into_iter().enumerate() {
126 let target = &chunk[0];
127 let mut request = self.request_builder(target);
128 let mut requests = Vec::<JsonRpcRequest>::new();
129 for (k, v) in chunk.iter().enumerate() {
130 let request = JsonRpcRequest::new(
131 v.method_name(),
132 v.params(),
133 (chunk_idx * chunk_size + k + 1) as u64,
134 );
135 requests.push(request);
136 }
137
138 request = request.body(HTTPBody::from_array(&requests).inner);
139 rpc_requests.push(request);
140 }
141 let bodies = join_all(rpc_requests.into_iter().map(|request| async move {
142 let response = request.send().await?;
143 let body = response.json::<Vec<JsonRpcResult<U>>>().await?;
144 Ok(body)
145 }))
146 .await;
147
148 let mut results = Vec::<JsonRpcResult<U>>::new();
149 let mut error: Option<JsonRpcError> = None;
150
151 for result in bodies {
152 match result {
153 Ok(body) => {
154 results.extend(body);
155 }
156 Err(err) => {
157 error = Some(err);
158 }
159 }
160 }
161 if let Some(err) = error {
162 return Err(err);
163 }
164 Ok(results)
165 }
166}
167
168impl<T> Provider<T>
169where
170 T: Target,
171{
172 pub fn new(
173 endpoint_fn: Option<EndpointFn<T>>,
174 request_fn: Option<RequestBuilderFn<T>>,
175 ) -> Self {
176 let client = reqwest::Client::new();
177 Self {
178 client,
179 endpoint_fn,
180 request_fn,
181 }
182 }
183
184 pub(crate) fn request_url(&self, target: &T) -> String {
185 let mut url = format!("{}{}", target.base_url(), target.path());
186 if let Some(func) = &self.endpoint_fn {
187 url = func(target);
188 }
189 url
190 }
191
192 pub(crate) fn request_builder(&self, target: &T) -> reqwest::RequestBuilder {
193 let url = self.request_url(target);
194 let mut request = self.client.request(target.method().into(), url);
195 let query_map = target.query();
196 if !query_map.is_empty() {
197 request = request.query(&query_map);
198 }
199 if !target.headers().is_empty() {
200 for (k, v) in target.headers() {
201 request = request.header(k, v);
202 }
203 }
204 if let Some(auth) = target.authentication() {
205 match auth {
206 AuthMethod::Basic(username, password) => {
207 request = request.basic_auth(username, Some(password));
208 }
209 AuthMethod::Bearer(token) => {
210 request = request.bearer_auth(token);
211 }
212 }
213 }
214 if let Some(request_fn) = &self.request_fn {
215 request = request_fn(&mut request, target);
216 }
217 request
218 }
219}
220
221impl<T> Default for Provider<T>
222where
223 T: Target,
224{
225 fn default() -> Self {
226 Self {
227 client: reqwest::Client::new(),
228 endpoint_fn: None,
229 request_fn: None,
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use crate::{
237 http::{AuthMethod, HTTPBody, HTTPMethod},
238 provider::{JsonProviderType, Provider},
239 target::Target,
240 };
241 use serde::{Deserialize, Serialize};
242 use std::collections::hash_map::DefaultHasher;
243 use std::collections::HashMap;
244 use std::hash::{Hash, Hasher};
245 use std::time::{Duration, UNIX_EPOCH};
246 use tokio_test::block_on;
247
248 #[derive(Serialize, Deserialize)]
249 struct Person {
250 name: String,
251 age: u8,
252 phones: Vec<String>,
253 }
254
255 enum HttpBin {
256 Get,
257 Post,
258 Bearer,
259 }
260
261 impl Target for HttpBin {
262 fn base_url(&self) -> &'static str {
263 "https://httpbin.org"
264 }
265
266 fn method(&self) -> HTTPMethod {
267 match self {
268 HttpBin::Get => HTTPMethod::GET,
269 HttpBin::Post => HTTPMethod::POST,
270 HttpBin::Bearer => HTTPMethod::GET,
271 }
272 }
273
274 fn path(&self) -> String {
275 let ts = UNIX_EPOCH + Duration::from_secs(1728044812);
276 match self {
277 HttpBin::Get => format!(
278 "/get?ts={}",
279 ts.duration_since(UNIX_EPOCH).unwrap().as_secs(),
280 ),
281 HttpBin::Post => "/post".into(),
282 HttpBin::Bearer => "/bearer".into(),
283 }
284 }
285
286 fn query(&self) -> HashMap<&'static str, &'static str> {
287 HashMap::from([("foo", "bar")])
288 }
289
290 fn headers(&self) -> HashMap<&'static str, &'static str> {
291 HashMap::default()
292 }
293
294 fn authentication(&self) -> Option<AuthMethod> {
295 match self {
296 HttpBin::Bearer => Some(AuthMethod::Bearer("token")),
297 _ => None,
298 }
299 }
300
301 fn body(&self) -> HTTPBody {
302 match self {
303 HttpBin::Get | HttpBin::Bearer => HTTPBody::default(),
304 HttpBin::Post => HTTPBody::from(&Person {
305 name: "test".to_string(),
306 age: 20,
307 phones: vec!["1234567890".to_string()],
308 }),
309 }
310 }
311 fn timeout(&self) -> Option<Duration> {
312 None
313 }
314 }
315
316 #[test]
317 fn test_test_endpoint_closure() {
318 let provider = Provider::<HttpBin>::default();
319 assert_eq!(
320 provider.request_url(&HttpBin::Get),
321 "https://httpbin.org/get?ts=1728044812"
322 );
323
324 let provider =
325 Provider::<HttpBin>::new(Some(|_: &HttpBin| "http://httpbin.org".to_string()), None);
326 assert_eq!(provider.request_url(&HttpBin::Post), "http://httpbin.org");
327 }
328
329 #[test]
330 fn test_request_fn() {
331 let provider = Provider::<HttpBin>::new(
332 None,
333 Some(|builder: &reqwest::RequestBuilder, target: &HttpBin| {
334 let mut hasher = DefaultHasher::new();
335 target.query_string().hash(&mut hasher);
336 let hash = hasher.finish();
337
338 let mut req = builder.try_clone().expect("trying to clone request");
339 req = req.header("X-test", "test");
340 req = req.header("X-hash", format!("{}", hash));
341 req
342 }),
343 );
344
345 let request = provider.request_builder(&HttpBin::Get).build().unwrap();
346 let headers = request.headers();
347
348 assert_eq!(request.method().to_string(), "GET");
349 assert_eq!(headers.get("X-test").unwrap(), "test");
350 assert_eq!(headers.get("X-hash").unwrap(), "3270317559611782182");
351 }
352
353 #[test]
354 fn test_authentication() {
355 let provider = Provider::<HttpBin>::default();
356 block_on(async {
357 let response: serde_json::Value = provider
358 .request_json(HttpBin::Bearer)
359 .await
360 .expect("request error");
361
362 assert!(response["authenticated"].as_bool().unwrap());
363 });
364 }
365}