Skip to main content

lb_rs/io/
network.rs

1use web_time::{Duration, Instant};
2
3#[cfg(not(target_family = "wasm"))]
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::stream;
7use reqwest::{Body, Client};
8
9use crate::get_code_version;
10use crate::model::account::Account;
11use crate::model::api::*;
12use crate::model::clock::{Timestamp, get_time};
13use crate::model::core_config::ClientType;
14use crate::model::errors::LbErr;
15use crate::model::pubkey;
16use crate::model::wire::{CLIENT_HEADER, OS_HEADER, WIRE_FORMAT_HEADER, WireFormat};
17
18const STREAM_CHUNK_BYTES: usize = 4 * 1024 * 1024;
19
20const STREAM_BODY_THRESHOLD: usize = 1024 * 1024 * 1024;
21
22impl<E> From<ErrorWrapper<E>> for ApiError<E> {
23    fn from(err: ErrorWrapper<E>) -> Self {
24        match err {
25            ErrorWrapper::Endpoint(e) => ApiError::Endpoint(e),
26            ErrorWrapper::ClientUpdateRequired => ApiError::ClientUpdateRequired,
27            ErrorWrapper::InvalidAuth => ApiError::InvalidAuth,
28            ErrorWrapper::ExpiredAuth => ApiError::ExpiredAuth,
29            ErrorWrapper::InternalError => ApiError::InternalError,
30            ErrorWrapper::BadRequest => ApiError::BadRequest,
31        }
32    }
33}
34
35// #[derive(Debug, PartialEq, Eq)]
36#[derive(Debug)]
37pub enum ApiError<E> {
38    Endpoint(E),
39    ClientUpdateRequired,
40    InvalidAuth,
41    ExpiredAuth,
42    InternalError,
43    BadRequest,
44    Sign(LbErr),
45    Serialize(String),
46    SendFailed(String),
47    ReceiveFailed(String),
48    Deserialize(String),
49}
50
51#[derive(Debug, Clone)]
52#[repr(C)]
53pub struct Network {
54    pub client: Client,
55    pub get_code_version: fn() -> &'static str,
56    pub get_time: fn() -> Timestamp,
57    pub client_type: ClientType,
58}
59
60impl Default for Network {
61    fn default() -> Self {
62        Self {
63            client: Default::default(),
64            get_code_version,
65            get_time,
66            client_type: ClientType::Unknown,
67        }
68    }
69}
70
71impl Network {
72    #[instrument(level = "debug", skip(self, account, request), fields(route=T::ROUTE), err(Debug))]
73    pub async fn request<T: Request>(
74        &self, account: &Account, request: T,
75    ) -> Result<T::Response, ApiError<T::Error>> {
76        let signed_request =
77            pubkey::sign(&account.private_key, &account.public_key(), request, self.get_time)
78                .map_err(ApiError::Sign)?;
79
80        let client_version = String::from((self.get_code_version)());
81
82        let wire_format = WireFormat::CLIENT_DEFAULT;
83        let serialized_request = wire_format
84            .serialize(&RequestWrapper { signed_request, client_version: client_version.clone() })
85            .map_err(|err| ApiError::Serialize(err.to_string()))?;
86
87        if serialized_request.len() > 10 * 1024 * 1024 {
88            warn!(
89                "making network request with {} bytes ({:?})",
90                serialized_request.len(),
91                wire_format
92            );
93        }
94
95        let url = &account.api_url;
96        let start = Instant::now();
97        let body = body_for(serialized_request);
98        let sent = self
99            .client
100            .request(T::METHOD, format!("{}{}", url, T::ROUTE).as_str())
101            .body(body)
102            .header("Accept-Version", client_version)
103            .header(WIRE_FORMAT_HEADER, wire_format.as_str())
104            .header(OS_HEADER, client_os())
105            .header(CLIENT_HEADER, self.client_type.as_str())
106            .send()
107            .await
108            .map_err(|e| {
109                warn!("send failed: {e:?}");
110                ApiError::SendFailed(e.to_string())
111            })?;
112        if start.elapsed() > Duration::from_millis(1000) {
113            warn!("network request took {:?}", start.elapsed());
114        }
115
116        let serialized_response = sent
117            .bytes()
118            .await
119            .map_err(|err| ApiError::ReceiveFailed(err.to_string()))?;
120        let response: Result<T::Response, ErrorWrapper<T::Error>> = wire_format
121            .deserialize(&serialized_response)
122            .map_err(|err| ApiError::Deserialize(err.to_string()))?;
123        response.map_err(ApiError::from)
124    }
125}
126
127fn client_os() -> &'static str {
128    if cfg!(target_os = "windows") {
129        "windows"
130    } else if cfg!(target_os = "ios") {
131        "iOS"
132    } else if cfg!(target_os = "macos") {
133        "macOS"
134    } else if cfg!(target_os = "android") {
135        "android"
136    } else if cfg!(target_os = "linux") {
137        "linux"
138    } else {
139        "unknown"
140    }
141}
142
143#[cfg(not(target_family = "wasm"))]
144fn body_for(serialized_request: Vec<u8>) -> Body {
145    if serialized_request.len() < STREAM_BODY_THRESHOLD {
146        return Body::from(serialized_request);
147    }
148    let mut buf = Bytes::from(serialized_request);
149    let mut chunks: Vec<Result<Bytes, std::io::Error>> =
150        Vec::with_capacity(buf.len().div_ceil(STREAM_CHUNK_BYTES));
151    while !buf.is_empty() {
152        let n = buf.len().min(STREAM_CHUNK_BYTES);
153        chunks.push(Ok(buf.split_to(n)));
154    }
155    Body::wrap_stream(stream::iter(chunks))
156}
157
158#[cfg(target_family = "wasm")]
159fn body_for(serialized_request: Vec<u8>) -> Body {
160    Body::from(serialized_request)
161}