use crate::{common::*, tcp_pool::TcpPool};
use crate::reqs::*;
use async_net::TcpStream;
use dashmap::DashMap;
use lazy_static::lazy_static;
use serde::{de::DeserializeOwned, Serialize};
use smol::lock::Semaphore;
use std::time::{Duration, Instant};
use std::{net::SocketAddr, sync::Arc};
lazy_static! {
static ref CONN_POOL: Client = Client::default();
}
pub async fn request<TInput: Serialize + Clone, TOutput: DeserializeOwned + std::fmt::Debug>(
addr: SocketAddr,
netname: &str,
verb: &str,
req: TInput,
) -> Result<TOutput> {
CONN_POOL.request(addr, netname, verb, req).await
}
#[derive(Default)]
pub struct Client {
pool: DashMap<SocketAddr, Arc<TcpPool>>,
}
impl Client {
pub async fn request<TInput: Serialize + Clone, TOutput: DeserializeOwned + std::fmt::Debug>(
&self,
addr: SocketAddr,
netname: &str,
verb: &str,
req: TInput,
) -> Result<TOutput> {
for count in 0..5 {
match self.request_inner(addr, netname, verb, req.clone()).await {
Err(MelnetError::Network(err)) => {
log::debug!(
"retrying request {} to {} on transient network error {:?}",
verb,
addr,
err
);
smol::Timer::after(Duration::from_secs_f64(0.1 * 2.0f64.powi(count))).await;
}
x => return x,
}
}
self.request_inner(addr, netname, verb, req).await
}
async fn request_inner<TInput: Serialize, TOutput: DeserializeOwned + std::fmt::Debug>(
&self,
addr: SocketAddr,
netname: &str,
verb: &str,
req: TInput,
) -> Result<TOutput> {
static GLOBAL_LIMIT: Semaphore = Semaphore::new(128);
let _guard = GLOBAL_LIMIT.acquire().await;
let start = Instant::now();
let pool = self
.pool
.entry(addr)
.or_insert_with(|| TcpPool::new(32, Duration::from_secs(5), addr).into())
.clone();
let mut conn = pool.connect().await.map_err(MelnetError::Network)?;
let res = async {
let rr = stdcode::serialize(&RawRequest {
proto_ver: PROTO_VER,
netname: netname.to_owned(),
verb: verb.to_owned(),
payload: stdcode::serialize(&req).unwrap(),
})
.unwrap();
write_len_bts(&mut conn, &rr).await?;
let response: RawResponse = stdcode::deserialize(&read_len_bts(&mut conn).await?)
.map_err(|e| {
MelnetError::Network(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
})?;
let response = match response.kind.as_ref() {
"Ok" => stdcode::deserialize::<TOutput>(&response.body)
.map_err(|_| MelnetError::Custom("stdcode error".to_owned()))?,
"NoVerb" => return Err(MelnetError::VerbNotFound),
_ => {
return Err(MelnetError::Custom(
String::from_utf8_lossy(&response.body).to_string(),
))
}
};
let elapsed = start.elapsed();
if elapsed.as_secs_f64() > 3.0 {
log::warn!(
"melnet req of verb {}/{} to {} took {:?}",
netname,
verb,
addr,
elapsed
)
}
self.pool.get(&addr).unwrap().replenish(conn);
Ok::<_, crate::MelnetError>(response)
};
res.await
}
}