lb_rs/io/
network.rs

1use std::time::{Duration, Instant};
2
3use reqwest::Client;
4use tokio::time::sleep;
5
6use crate::get_code_version;
7use crate::model::account::Account;
8use crate::model::api::*;
9use crate::model::clock::{Timestamp, get_time};
10use crate::model::errors::LbErr;
11use crate::model::pubkey;
12
13impl<E> From<ErrorWrapper<E>> for ApiError<E> {
14    fn from(err: ErrorWrapper<E>) -> Self {
15        match err {
16            ErrorWrapper::Endpoint(e) => ApiError::Endpoint(e),
17            ErrorWrapper::ClientUpdateRequired => ApiError::ClientUpdateRequired,
18            ErrorWrapper::InvalidAuth => ApiError::InvalidAuth,
19            ErrorWrapper::ExpiredAuth => ApiError::ExpiredAuth,
20            ErrorWrapper::InternalError => ApiError::InternalError,
21            ErrorWrapper::BadRequest => ApiError::BadRequest,
22        }
23    }
24}
25
26// #[derive(Debug, PartialEq, Eq)]
27#[derive(Debug)]
28pub enum ApiError<E> {
29    Endpoint(E),
30    ClientUpdateRequired,
31    InvalidAuth,
32    ExpiredAuth,
33    InternalError,
34    BadRequest,
35    Sign(LbErr),
36    Serialize(String),
37    SendFailed(String),
38    ReceiveFailed(String),
39    Deserialize(String),
40}
41
42#[derive(Debug, Clone)]
43#[repr(C)]
44pub struct Network {
45    pub client: Client,
46    pub get_code_version: fn() -> &'static str,
47    pub get_time: fn() -> Timestamp,
48}
49
50impl Default for Network {
51    fn default() -> Self {
52        Self { client: Default::default(), get_code_version, get_time }
53    }
54}
55
56impl Network {
57    #[instrument(level = "debug", skip(self, account, request), fields(route=T::ROUTE), err(Debug))]
58    pub async fn request<T: Request>(
59        &self, account: &Account, request: T,
60    ) -> Result<T::Response, ApiError<T::Error>> {
61        let signed_request =
62            pubkey::sign(&account.private_key, &account.public_key(), request, self.get_time)
63                .map_err(ApiError::Sign)?;
64
65        let client_version = String::from((self.get_code_version)());
66
67        let serialized_request = serde_json::to_vec(&RequestWrapper {
68            signed_request,
69            client_version: client_version.clone(),
70        })
71        .map_err(|err| ApiError::Serialize(err.to_string()))?;
72
73        if serialized_request.len() > 10 * 1024 * 1024 {
74            warn!("making network request with {} bytes", serialized_request.len());
75        }
76
77        let mut retries = 0;
78        let start = Instant::now();
79        let sent = loop {
80            match self
81                .client
82                .request(T::METHOD, format!("{}{}", account.api_url, T::ROUTE).as_str())
83                .body(serialized_request.clone())
84                .header("Accept-Version", client_version.clone())
85                .send()
86                .await
87            {
88                Ok(o) => {
89                    if start.elapsed() > Duration::from_millis(1000) {
90                        warn!("network request took {:?}", start.elapsed());
91                    }
92                    break o;
93                }
94                Err(e) => {
95                    if retries < 3 {
96                        warn!(
97                            "network request send failed; retrying after {}ms; error = {:?}",
98                            retries * 100,
99                            e.to_string()
100                        );
101                        sleep(Duration::from_millis(retries * 100)).await;
102                        retries += 1;
103                        continue;
104                    } else {
105                        return Err(ApiError::SendFailed(e.to_string()));
106                    }
107                }
108            }
109        };
110        let serialized_response = sent
111            .bytes()
112            .await
113            .map_err(|err| ApiError::ReceiveFailed(err.to_string()))?;
114        let response: Result<T::Response, ErrorWrapper<T::Error>> =
115            serde_json::from_slice(&serialized_response)
116                .map_err(|err| ApiError::Deserialize(err.to_string()))?;
117        response.map_err(ApiError::from)
118    }
119}