1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use crate::{common::*, pipeline::Pipeline};

use crate::reqs::*;

use async_net::TcpStream;
use dashmap::DashMap;
use lazy_static::lazy_static;

use serde::{de::DeserializeOwned, Serialize};
use smol::lock::Semaphore;
use smol_timeout::TimeoutExt;

use std::net::SocketAddr;
use std::time::{Duration, Instant};

lazy_static! {
    static ref CONN_POOL: Client = Client::default();
}

/// Does a melnet request to any given endpoint, using the global client.
pub async fn request<TInput: Serialize + Clone, TOutput: DeserializeOwned + std::fmt::Debug>(
    addr: SocketAddr,
    netname: &str,
    verb: &str,
    req: TInput,
) -> Result<TOutput> {
    match CONN_POOL
        .request(addr, netname, verb, req)
        .timeout(Duration::from_secs(60))
        .await
    {
        Some(v) => v,
        None => Err(MelnetError::Network(std::io::Error::new(
            std::io::ErrorKind::TimedOut,
            "long timeout at 60 seconds",
        ))),
    }
}

const POOL_SIZE: usize = 4;

/// Implements a thread-safe pool of connections to melnet, or any HTTP/1.1-style keepalive protocol, servers.
#[derive(Default)]
pub struct Client {
    pool: [DashMap<SocketAddr, (Pipeline, Instant)>; POOL_SIZE],
}

impl Client {
    /// Does a melnet request to any given endpoint.
    pub async fn request<TInput: Serialize + Clone, TOutput: DeserializeOwned + std::fmt::Debug>(
        &self,
        addr: SocketAddr,
        netname: &str,
        verb: &str,
        req: TInput,
    ) -> Result<TOutput> {
        for count in 0..5 {
            match self.request_inner(addr, netname, verb, req.clone()).await {
                Err(MelnetError::Network(err)) => {
                    log::debug!(
                        "retrying request {} to {} on transient network error {:?}",
                        verb,
                        addr,
                        err
                    );
                    smol::Timer::after(Duration::from_secs_f64(0.1 * 2.0f64.powi(count))).await;
                }
                x => return x,
            }
        }
        self.request_inner(addr, netname, verb, req).await
    }

    async fn request_inner<TInput: Serialize, TOutput: DeserializeOwned + std::fmt::Debug>(
        &self,
        addr: SocketAddr,
        netname: &str,
        verb: &str,
        req: TInput,
    ) -> Result<TOutput> {
        // // Semaphore
        static GLOBAL_LIMIT: Semaphore = Semaphore::new(256);
        let start = Instant::now();
        let _guard = GLOBAL_LIMIT.acquire().await;
        log::debug!("acquired semaphore by {:?}", start.elapsed());
        let start = Instant::now();
        let pool = &self.pool[fastrand::usize(0..self.pool.len())];
        let conn = if let Some(v) = pool.get(&addr).filter(|d| d.1.elapsed().as_secs() < 60) {
            v.0.clone()
        } else {
            let t = TcpStream::connect(addr)
                .await
                .map_err(MelnetError::Network)?;
            let pipe = Pipeline::new(t);
            pool.insert(addr, (pipe.clone(), Instant::now()));
            pipe
        };
        log::debug!("acquired connection by {:?}", start.elapsed());

        let res = async {
            // send a request
            let rr = stdcode::serialize(&RawRequest {
                proto_ver: PROTO_VER,
                netname: netname.to_owned(),
                verb: verb.to_owned(),
                payload: stdcode::serialize(&req).unwrap(),
            })
            .unwrap();
            // read the response length
            let response: RawResponse =
                stdcode::deserialize(&conn.request(rr).await?).map_err(|e| {
                    MelnetError::Network(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
                })?;
            let response = match response.kind.as_ref() {
                "Ok" => stdcode::deserialize::<TOutput>(&response.body)
                    .map_err(|_| MelnetError::Custom("stdcode error".to_owned()))?,
                "NoVerb" => return Err(MelnetError::VerbNotFound),
                _ => {
                    return Err(MelnetError::Custom(
                        String::from_utf8_lossy(&response.body).to_string(),
                    ))
                }
            };
            let elapsed = start.elapsed();
            if elapsed.as_secs_f64() > 3.0 {
                log::warn!(
                    "melnet req of verb {}/{} to {} took {:?}",
                    netname,
                    verb,
                    addr,
                    elapsed
                )
            }
            Ok::<_, crate::MelnetError>(response)
        };
        match res.await {
            Ok(v) => Ok(v),
            Err(err) => {
                pool.remove(&addr);
                Err(err)
            }
        }
    }
}