Skip to main content

lb_rs/io/
network.rs

1use web_time::{Duration, Instant};
2
3#[cfg(not(target_family = "wasm"))]
4use bytes::Bytes;
5#[cfg(not(target_family = "wasm"))]
6use futures::stream;
7use reqwest::{Body, Client};
8
9use crate::get_code_version;
10use crate::model::account::Account;
11use crate::model::api::*;
12use crate::model::clock::{Timestamp, get_time};
13use crate::model::errors::LbErr;
14use crate::model::pubkey;
15use crate::model::wire::{WIRE_FORMAT_HEADER, WireFormat};
16
17const STREAM_CHUNK_BYTES: usize = 4 * 1024 * 1024;
18
19const STREAM_BODY_THRESHOLD: usize = 1024 * 1024 * 1024;
20
21impl<E> From<ErrorWrapper<E>> for ApiError<E> {
22    fn from(err: ErrorWrapper<E>) -> Self {
23        match err {
24            ErrorWrapper::Endpoint(e) => ApiError::Endpoint(e),
25            ErrorWrapper::ClientUpdateRequired => ApiError::ClientUpdateRequired,
26            ErrorWrapper::InvalidAuth => ApiError::InvalidAuth,
27            ErrorWrapper::ExpiredAuth => ApiError::ExpiredAuth,
28            ErrorWrapper::InternalError => ApiError::InternalError,
29            ErrorWrapper::BadRequest => ApiError::BadRequest,
30        }
31    }
32}
33
34// #[derive(Debug, PartialEq, Eq)]
35#[derive(Debug)]
36pub enum ApiError<E> {
37    Endpoint(E),
38    ClientUpdateRequired,
39    InvalidAuth,
40    ExpiredAuth,
41    InternalError,
42    BadRequest,
43    Sign(LbErr),
44    Serialize(String),
45    SendFailed(String),
46    ReceiveFailed(String),
47    Deserialize(String),
48}
49
50#[derive(Debug, Clone)]
51#[repr(C)]
52pub struct Network {
53    pub client: Client,
54    pub get_code_version: fn() -> &'static str,
55    pub get_time: fn() -> Timestamp,
56}
57
58impl Default for Network {
59    fn default() -> Self {
60        Self { client: Default::default(), get_code_version, get_time }
61    }
62}
63
64impl Network {
65    #[instrument(level = "debug", skip(self, account, request), fields(route=T::ROUTE), err(Debug))]
66    pub async fn request<T: Request>(
67        &self, account: &Account, request: T,
68    ) -> Result<T::Response, ApiError<T::Error>> {
69        let signed_request =
70            pubkey::sign(&account.private_key, &account.public_key(), request, self.get_time)
71                .map_err(ApiError::Sign)?;
72
73        let client_version = String::from((self.get_code_version)());
74
75        let wire_format = WireFormat::CLIENT_DEFAULT;
76        let serialized_request = wire_format
77            .serialize(&RequestWrapper { signed_request, client_version: client_version.clone() })
78            .map_err(|err| ApiError::Serialize(err.to_string()))?;
79
80        if serialized_request.len() > 10 * 1024 * 1024 {
81            warn!(
82                "making network request with {} bytes ({:?})",
83                serialized_request.len(),
84                wire_format
85            );
86        }
87
88        let url = &account.api_url;
89        let start = Instant::now();
90        let body = body_for(serialized_request);
91        let sent = self
92            .client
93            .request(T::METHOD, format!("{}{}", url, T::ROUTE).as_str())
94            .body(body)
95            .header("Accept-Version", client_version)
96            .header(WIRE_FORMAT_HEADER, wire_format.as_str())
97            .send()
98            .await
99            .map_err(|e| {
100                warn!("send failed: {e:?}");
101                ApiError::SendFailed(e.to_string())
102            })?;
103        if start.elapsed() > Duration::from_millis(1000) {
104            warn!("network request took {:?}", start.elapsed());
105        }
106
107        let serialized_response = sent
108            .bytes()
109            .await
110            .map_err(|err| ApiError::ReceiveFailed(err.to_string()))?;
111        let response: Result<T::Response, ErrorWrapper<T::Error>> = wire_format
112            .deserialize(&serialized_response)
113            .map_err(|err| ApiError::Deserialize(err.to_string()))?;
114        response.map_err(ApiError::from)
115    }
116}
117
118#[cfg(not(target_family = "wasm"))]
119fn body_for(serialized_request: Vec<u8>) -> Body {
120    if serialized_request.len() < STREAM_BODY_THRESHOLD {
121        return Body::from(serialized_request);
122    }
123    let mut buf = Bytes::from(serialized_request);
124    let mut chunks: Vec<Result<Bytes, std::io::Error>> =
125        Vec::with_capacity(buf.len().div_ceil(STREAM_CHUNK_BYTES));
126    while !buf.is_empty() {
127        let n = buf.len().min(STREAM_CHUNK_BYTES);
128        chunks.push(Ok(buf.split_to(n)));
129    }
130    Body::wrap_stream(stream::iter(chunks))
131}
132
133#[cfg(target_family = "wasm")]
134fn body_for(serialized_request: Vec<u8>) -> Body {
135    Body::from(serialized_request)
136}