Skip to main content

aelf_client/
provider.rs

1#[cfg(feature = "native-http")]
2use crate::config::{ClientConfig, RetryPolicy};
3use crate::error::AElfError;
4use async_trait::async_trait;
5use http::Method;
6#[cfg(feature = "native-http")]
7use reqwest::header::{ACCEPT, CONTENT_TYPE};
8#[cfg(feature = "native-http")]
9use reqwest::StatusCode;
10use serde_json::Value;
11#[cfg(feature = "native-http")]
12use std::time::Duration;
13
14/// Abstract transport used by the SDK client.
15#[async_trait]
16pub trait Provider: Send + Sync {
17    /// Sends a request and parses the response body as JSON.
18    async fn request_json(
19        &self,
20        method: Method,
21        path: &str,
22        query: &[(&str, String)],
23        body: Option<Value>,
24    ) -> Result<Value, AElfError>;
25
26    /// Sends a request and returns the raw response text.
27    async fn request_text(
28        &self,
29        method: Method,
30        path: &str,
31        query: &[(&str, String)],
32        body: Option<Value>,
33    ) -> Result<String, AElfError>;
34}
35
36/// Default HTTP transport backed by `reqwest`.
37#[cfg(feature = "native-http")]
38#[derive(Clone, Debug)]
39pub struct HttpProvider {
40    config: ClientConfig,
41    client: reqwest::Client,
42}
43
44#[cfg(feature = "native-http")]
45impl HttpProvider {
46    /// Creates a new HTTP transport from client configuration.
47    pub fn new(config: ClientConfig) -> Result<Self, AElfError> {
48        let client = reqwest::Client::builder()
49            .timeout(config.timeout)
50            .build()
51            .map_err(AElfError::Http)?;
52        Ok(Self { config, client })
53    }
54
55    fn make_url(&self, path: &str) -> String {
56        let endpoint = self.config.endpoint.trim_end_matches('/');
57        let path = path.trim_start_matches('/');
58        format!("{endpoint}/{path}")
59    }
60
61    fn versioned_content_type(&self) -> String {
62        match &self.config.api_version {
63            Some(version) => format!("application/json;v={version}"),
64            None => "application/json".to_owned(),
65        }
66    }
67
68    fn build_request(
69        &self,
70        method: &Method,
71        url: &str,
72        query: &[(&str, String)],
73        body: Option<&Value>,
74    ) -> reqwest::RequestBuilder {
75        let mut request = self.client.request(method.clone(), url);
76
77        request = request.header(ACCEPT, self.versioned_content_type());
78        if *method != Method::GET {
79            request = request.header(CONTENT_TYPE, self.versioned_content_type());
80        }
81
82        for (name, value) in &self.config.headers {
83            request = request.header(name, value);
84        }
85
86        if let Some(auth) = &self.config.basic_auth {
87            request = request.basic_auth(&auth.username, Some(&auth.password));
88        }
89
90        if !query.is_empty() {
91            request = request.query(&query);
92        }
93
94        if let Some(body) = body {
95            request = request.json(body);
96        }
97
98        request
99    }
100
101    fn should_retry_status(status: StatusCode) -> bool {
102        status.is_server_error()
103    }
104
105    fn should_retry_error(error: &reqwest::Error) -> bool {
106        error.is_connect() || error.is_timeout() || error.is_request() || error.is_body()
107    }
108
109    async fn sleep_before_retry(backoff: Duration) {
110        if !backoff.is_zero() {
111            tokio::time::sleep(backoff).await;
112        }
113    }
114
115    async fn retry_after(backoff: &mut Duration, retries_remaining: &mut usize) {
116        let current = *backoff;
117        *retries_remaining -= 1;
118        *backoff = backoff.saturating_mul(2);
119        Self::sleep_before_retry(current).await;
120    }
121}
122
123#[cfg(feature = "native-http")]
124#[async_trait]
125impl Provider for HttpProvider {
126    async fn request_json(
127        &self,
128        method: Method,
129        path: &str,
130        query: &[(&str, String)],
131        body: Option<Value>,
132    ) -> Result<Value, AElfError> {
133        let text = self.request_text(method, path, query, body).await?;
134        serde_json::from_str(&text).map_err(AElfError::Json)
135    }
136
137    async fn request_text(
138        &self,
139        method: Method,
140        path: &str,
141        query: &[(&str, String)],
142        body: Option<Value>,
143    ) -> Result<String, AElfError> {
144        let url = self.make_url(path);
145        let RetryPolicy {
146            mut max_retries,
147            mut initial_backoff,
148        } = self.config.retry_policy;
149
150        loop {
151            let request = self.build_request(&method, &url, query, body.as_ref());
152            let response = match request.send().await {
153                Ok(response) => response,
154                Err(error) if max_retries > 0 && Self::should_retry_error(&error) => {
155                    Self::retry_after(&mut initial_backoff, &mut max_retries).await;
156                    continue;
157                }
158                Err(error) => return Err(AElfError::Http(error)),
159            };
160
161            let status = response.status();
162            let request_id = response
163                .headers()
164                .get("x-request-id")
165                .or_else(|| response.headers().get("request-id"))
166                .and_then(|value| value.to_str().ok())
167                .map(ToOwned::to_owned);
168            let text = match response.text().await {
169                Ok(text) => text,
170                Err(error) if max_retries > 0 && Self::should_retry_error(&error) => {
171                    Self::retry_after(&mut initial_backoff, &mut max_retries).await;
172                    continue;
173                }
174                Err(error) => return Err(AElfError::Http(error)),
175            };
176
177            if status.is_success() {
178                return Ok(text);
179            }
180
181            if max_retries > 0 && Self::should_retry_status(status) {
182                Self::retry_after(&mut initial_backoff, &mut max_retries).await;
183                continue;
184            }
185
186            return Err(AElfError::from_response(
187                url.clone(),
188                status,
189                request_id,
190                &text,
191            ));
192        }
193    }
194}
195
196#[cfg(test)]
197pub(crate) use test_support::{MockCallKind, MockProvider, MockRecordedRequest, MockResponse};
198
199#[cfg(test)]
200mod test_support {
201    use super::*;
202    use std::collections::VecDeque;
203    use std::sync::{Arc, Mutex};
204
205    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
206    pub(crate) enum MockCallKind {
207        Json,
208        Text,
209    }
210
211    #[derive(Debug)]
212    pub(crate) enum MockResponse {
213        Json(Result<Value, AElfError>),
214        Text(Result<String, AElfError>),
215    }
216
217    impl MockResponse {
218        pub(crate) fn json(value: Value) -> Self {
219            Self::Json(Ok(value))
220        }
221
222        pub(crate) fn json_error(error: AElfError) -> Self {
223            Self::Json(Err(error))
224        }
225
226        pub(crate) fn text(value: impl Into<String>) -> Self {
227            Self::Text(Ok(value.into()))
228        }
229
230        fn kind(&self) -> MockResponseKind {
231            match self {
232                Self::Json(_) => MockResponseKind::Json,
233                Self::Text(_) => MockResponseKind::Text,
234            }
235        }
236    }
237
238    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
239    pub(crate) enum MockResponseKind {
240        Json,
241        Text,
242    }
243
244    #[derive(Clone, Debug, PartialEq)]
245    pub(crate) struct MockRecordedRequest {
246        pub kind: MockCallKind,
247        pub method: Method,
248        pub path: String,
249        pub query: Vec<(String, String)>,
250        pub body: Option<Value>,
251    }
252
253    #[derive(Default)]
254    struct MockProviderState {
255        responses: Mutex<VecDeque<MockResponse>>,
256        requests: Mutex<Vec<MockRecordedRequest>>,
257    }
258
259    /// FIFO provider double used for client/service unit tests.
260    #[derive(Clone, Default)]
261    pub(crate) struct MockProvider {
262        state: Arc<MockProviderState>,
263    }
264
265    impl MockProvider {
266        pub(crate) fn new(responses: Vec<MockResponse>) -> Self {
267            Self {
268                state: Arc::new(MockProviderState {
269                    responses: Mutex::new(responses.into()),
270                    requests: Mutex::new(Vec::new()),
271                }),
272            }
273        }
274
275        pub(crate) fn requests(&self) -> Vec<MockRecordedRequest> {
276            self.state
277                .requests
278                .lock()
279                .expect("recorded requests lock")
280                .clone()
281        }
282
283        fn record(
284            &self,
285            kind: MockCallKind,
286            method: Method,
287            path: &str,
288            query: &[(&str, String)],
289            body: Option<Value>,
290        ) {
291            self.state
292                .requests
293                .lock()
294                .expect("recorded requests lock")
295                .push(MockRecordedRequest {
296                    kind,
297                    method,
298                    path: path.to_owned(),
299                    query: query
300                        .iter()
301                        .map(|(name, value)| ((*name).to_owned(), value.clone()))
302                        .collect(),
303                    body,
304                });
305        }
306
307        fn pop_response(&self, expected: MockResponseKind) -> MockResponse {
308            let response = self
309                .state
310                .responses
311                .lock()
312                .expect("mock response lock")
313                .pop_front()
314                .unwrap_or_else(|| panic!("missing mock response for {expected:?} request"));
315            assert_eq!(
316                response.kind(),
317                expected,
318                "mock response kind mismatch: expected {expected:?}, got {:?}",
319                response.kind()
320            );
321            response
322        }
323    }
324
325    #[async_trait]
326    impl Provider for MockProvider {
327        async fn request_json(
328            &self,
329            method: Method,
330            path: &str,
331            query: &[(&str, String)],
332            body: Option<Value>,
333        ) -> Result<Value, AElfError> {
334            self.record(MockCallKind::Json, method, path, query, body);
335            match self.pop_response(MockResponseKind::Json) {
336                MockResponse::Json(result) => result,
337                MockResponse::Text(_) => unreachable!("json response kind already validated"),
338            }
339        }
340
341        async fn request_text(
342            &self,
343            method: Method,
344            path: &str,
345            query: &[(&str, String)],
346            body: Option<Value>,
347        ) -> Result<String, AElfError> {
348            self.record(MockCallKind::Text, method, path, query, body);
349            match self.pop_response(MockResponseKind::Text) {
350                MockResponse::Text(result) => result,
351                MockResponse::Json(_) => unreachable!("text response kind already validated"),
352            }
353        }
354    }
355}
356
357#[cfg(all(test, feature = "native-http"))]
358mod tests {
359    use super::*;
360    use base64::Engine;
361    use serde_json::json;
362    use std::sync::{
363        atomic::{AtomicUsize, Ordering},
364        Arc,
365    };
366    use std::time::Duration;
367    use wiremock::matchers::any;
368    use wiremock::{Mock, MockServer, ResponseTemplate};
369
370    fn header_value<'a>(request: &'a wiremock::Request, name: &str) -> Option<&'a str> {
371        request
372            .headers
373            .get(name)
374            .and_then(|value| value.to_str().ok())
375    }
376
377    #[tokio::test]
378    async fn get_request_uses_accept_header_without_content_type() {
379        let server = MockServer::start().await;
380        Mock::given(any())
381            .respond_with(ResponseTemplate::new(200).set_body_string("\"123\""))
382            .mount(&server)
383            .await;
384
385        let provider = HttpProvider::new(ClientConfig::new(server.uri())).expect("provider");
386        let response = provider
387            .request_text(
388                Method::GET,
389                "api/blockChain/blockHeight",
390                &[("includeTransactions", "true".to_owned())],
391                None,
392            )
393            .await
394            .expect("response");
395
396        assert_eq!(response, "\"123\"");
397        let requests = server.received_requests().await.expect("requests");
398        let request = requests.first().expect("request");
399        assert_eq!(request.method.as_str(), "GET");
400        assert_eq!(request.url.path(), "/api/blockChain/blockHeight");
401        assert_eq!(request.url.query(), Some("includeTransactions=true"));
402        assert_eq!(header_value(request, "accept"), Some("application/json"));
403        assert!(header_value(request, "content-type").is_none());
404    }
405
406    #[tokio::test]
407    async fn post_request_includes_version_auth_and_custom_headers() {
408        let server = MockServer::start().await;
409        Mock::given(any())
410            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "ok": true })))
411            .mount(&server)
412            .await;
413
414        let config = ClientConfig::new(server.uri())
415            .with_api_version("1.0")
416            .with_basic_auth("open", "sesame")
417            .with_header("x-sdk-test", "yes")
418            .expect("header");
419        let provider = HttpProvider::new(config).expect("provider");
420
421        let response = provider
422            .request_json(
423                Method::POST,
424                "api/blockChain/rawTransaction",
425                &[("verbose", "true".to_owned())],
426                Some(json!({ "From": "from-address" })),
427            )
428            .await
429            .expect("response");
430
431        assert_eq!(response, json!({ "ok": true }));
432        let requests = server.received_requests().await.expect("requests");
433        let request = requests.first().expect("request");
434        let auth = format!(
435            "Basic {}",
436            base64::engine::general_purpose::STANDARD.encode("open:sesame")
437        );
438
439        assert_eq!(request.method.as_str(), "POST");
440        assert_eq!(request.url.path(), "/api/blockChain/rawTransaction");
441        assert_eq!(request.url.query(), Some("verbose=true"));
442        assert_eq!(
443            header_value(request, "accept"),
444            Some("application/json;v=1.0")
445        );
446        assert_eq!(
447            header_value(request, "content-type"),
448            Some("application/json;v=1.0")
449        );
450        assert_eq!(header_value(request, "authorization"), Some(auth.as_str()));
451        assert_eq!(header_value(request, "x-sdk-test"), Some("yes"));
452        assert_eq!(
453            serde_json::from_slice::<Value>(&request.body).expect("body"),
454            json!({ "From": "from-address" })
455        );
456    }
457
458    #[tokio::test]
459    async fn request_error_uses_webapp_error_shape() {
460        let server = MockServer::start().await;
461        Mock::given(any())
462            .respond_with(
463                ResponseTemplate::new(400)
464                    .insert_header("x-request-id", "req-123")
465                    .set_body_json(json!({
466                        "error": {
467                            "code": "InvalidTransaction",
468                            "message": "bad transaction",
469                            "details": "signature mismatch"
470                        }
471                    })),
472            )
473            .mount(&server)
474            .await;
475
476        let provider = HttpProvider::new(ClientConfig::new(server.uri())).expect("provider");
477        let error = provider
478            .request_text(Method::POST, "api/blockChain/sendTransaction", &[], None)
479            .await
480            .expect_err("request should fail");
481        let expected_endpoint = format!("{}/api/blockChain/sendTransaction", server.uri());
482
483        match error {
484            AElfError::Request(error) => {
485                assert_eq!(error.message, "bad transaction");
486                assert_eq!(error.request_id.as_deref(), Some("req-123"));
487                assert_eq!(error.endpoint.as_deref(), Some(expected_endpoint.as_str()));
488                assert_eq!(error.status, Some(400));
489                assert_eq!(error.chain_code.as_deref(), Some("InvalidTransaction"));
490                assert_eq!(error.details.as_deref(), Some("signature mismatch"));
491            }
492            other => panic!("unexpected error: {other:?}"),
493        }
494    }
495
496    #[tokio::test]
497    async fn request_error_falls_back_to_plain_text_message() {
498        let server = MockServer::start().await;
499        Mock::given(any())
500            .respond_with(
501                ResponseTemplate::new(503)
502                    .insert_header("request-id", "req-plain")
503                    .set_body_string("service unavailable"),
504            )
505            .mount(&server)
506            .await;
507
508        let provider = HttpProvider::new(ClientConfig::new(server.uri())).expect("provider");
509        let error = provider
510            .request_text(Method::GET, "api/net/peers", &[], None)
511            .await
512            .expect_err("request should fail");
513        let expected_endpoint = format!("{}/api/net/peers", server.uri());
514
515        match error {
516            AElfError::Request(error) => {
517                assert_eq!(
518                    error.message,
519                    "request failed with status 503 Service Unavailable: service unavailable"
520                );
521                assert_eq!(error.request_id.as_deref(), Some("req-plain"));
522                assert_eq!(error.endpoint.as_deref(), Some(expected_endpoint.as_str()));
523                assert_eq!(error.status, Some(503));
524                assert_eq!(error.chain_code, None);
525                assert_eq!(error.details, None);
526            }
527            other => panic!("unexpected error: {other:?}"),
528        }
529    }
530
531    #[tokio::test]
532    async fn request_retries_server_errors_until_success() {
533        let server = MockServer::start().await;
534        let attempts = Arc::new(AtomicUsize::new(0));
535        let responder_attempts = attempts.clone();
536
537        Mock::given(any())
538            .respond_with(move |_request: &wiremock::Request| {
539                let attempt = responder_attempts.fetch_add(1, Ordering::SeqCst);
540                if attempt < 2 {
541                    ResponseTemplate::new(503).set_body_string("temporary outage")
542                } else {
543                    ResponseTemplate::new(200).set_body_string("\"ok\"")
544                }
545            })
546            .mount(&server)
547            .await;
548
549        let provider = HttpProvider::new(
550            ClientConfig::new(server.uri())
551                .with_retry_policy(RetryPolicy::new(2, Duration::from_millis(1))),
552        )
553        .expect("provider");
554
555        let response = provider
556            .request_text(Method::GET, "api/blockChain/blockHeight", &[], None)
557            .await
558            .expect("response");
559
560        assert_eq!(response, "\"ok\"");
561        assert_eq!(attempts.load(Ordering::SeqCst), 3);
562        let requests = server.received_requests().await.expect("requests");
563        assert_eq!(requests.len(), 3);
564    }
565
566    #[tokio::test]
567    async fn request_does_not_retry_client_errors() {
568        let server = MockServer::start().await;
569        let attempts = Arc::new(AtomicUsize::new(0));
570        let responder_attempts = attempts.clone();
571
572        Mock::given(any())
573            .respond_with(move |_request: &wiremock::Request| {
574                responder_attempts.fetch_add(1, Ordering::SeqCst);
575                ResponseTemplate::new(400).set_body_json(json!({
576                    "error": {
577                        "code": "BadRequest",
578                        "message": "no retry",
579                        "details": null
580                    }
581                }))
582            })
583            .mount(&server)
584            .await;
585
586        let provider = HttpProvider::new(
587            ClientConfig::new(server.uri())
588                .with_retry_policy(RetryPolicy::new(2, Duration::from_millis(1))),
589        )
590        .expect("provider");
591
592        let error = provider
593            .request_text(Method::GET, "api/net/peers", &[], None)
594            .await
595            .expect_err("request should fail");
596
597        assert!(matches!(error, AElfError::Request(_)));
598        assert_eq!(attempts.load(Ordering::SeqCst), 1);
599        let requests = server.received_requests().await.expect("requests");
600        assert_eq!(requests.len(), 1);
601    }
602}