use crate::base::Message;
use crate::net::client::protocol::AsyncConnect;
use crate::net::client::request::{
ComposeRequest, Error, GetResponse, RequestMessageMulti, SendRequest,
};
use crate::net::client::stream;
use crate::utils::config::DefMinMax;
use bytes::Bytes;
use futures_util::stream::FuturesUnordered;
use futures_util::StreamExt;
use rand::random;
use std::boxed::Box;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use std::vec::Vec;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use tokio::time::{sleep_until, Instant};
const DEF_CHAN_CAP: usize = 8;
const ERR_CONN_CLOSED: &str = "connection closed";
const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(30),
Duration::from_millis(1),
Duration::from_secs(600),
);
#[derive(Clone, Debug)]
pub struct Config {
response_timeout: Duration,
stream: stream::Config,
}
impl Config {
pub fn response_timeout(&self) -> Duration {
self.response_timeout
}
pub fn set_response_timeout(&mut self, timeout: Duration) {
self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
}
pub fn stream(&self) -> &stream::Config {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut stream::Config {
&mut self.stream
}
}
impl From<stream::Config> for Config {
fn from(stream: stream::Config) -> Self {
Self {
stream,
response_timeout: RESPONSE_TIMEOUT.default(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
stream: Default::default(),
response_timeout: RESPONSE_TIMEOUT.default(),
}
}
}
#[derive(Debug)]
pub struct Connection<Req> {
sender: mpsc::Sender<ChanReq<Req>>,
response_timeout: Duration,
}
impl<Req> Connection<Req> {
pub fn new<Remote>(remote: Remote) -> (Self, Transport<Remote, Req>) {
Self::with_config(remote, Default::default())
}
pub fn with_config<Remote>(
remote: Remote,
config: Config,
) -> (Self, Transport<Remote, Req>) {
let response_timeout = config.response_timeout;
let (sender, transport) = Transport::new(remote, config);
(
Self {
sender,
response_timeout,
},
transport,
)
}
}
impl<Req: ComposeRequest + Clone + 'static> Connection<Req> {
pub async fn request(
&self,
request: Req,
) -> Result<Message<Bytes>, Error> {
Request::new(self.clone(), request).get_response().await
}
async fn _send_request(
&self,
request: &Req,
) -> Result<Box<dyn GetResponse + Send>, Error>
where
Req: 'static,
{
let gr = Request::new(self.clone(), request.clone());
Ok(Box::new(gr))
}
async fn new_conn(
&self,
opt_id: Option<u64>,
) -> Result<oneshot::Receiver<ChanResp<Req>>, Error> {
let (sender, receiver) = oneshot::channel();
let req = ChanReq {
cmd: ReqCmd::NewConn(opt_id, sender),
};
self.sender
.send(req)
.await
.map_err(|_| Error::ConnectionClosed)?;
Ok(receiver)
}
pub async fn shutdown(&self) -> Result<(), &'static str> {
let req = ChanReq {
cmd: ReqCmd::Shutdown,
};
match self.sender.send(req).await {
Err(_) =>
{
Err(ERR_CONN_CLOSED)
}
Ok(_) => Ok(()),
}
}
}
impl<Req> Clone for Connection<Req> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
response_timeout: self.response_timeout,
}
}
}
impl<Req> SendRequest<Req> for Connection<Req>
where
Req: ComposeRequest + Clone + 'static,
{
fn send_request(
&self,
request: Req,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(Request::new(self.clone(), request))
}
}
#[derive(Debug)]
struct Request<Req> {
request_msg: Req,
start: Instant,
state: QueryState<Req>,
conn: Connection<Req>,
conn_id: Option<u64>,
delayed_retry_count: u64,
}
#[derive(Debug)]
enum QueryState<Req> {
RequestConn,
ReceiveConn(oneshot::Receiver<ChanResp<Req>>),
StartQuery(Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>),
GetResult(stream::Request),
Delay(Instant, Duration),
Done,
}
type ChanResp<Req> = Result<ChanRespOk<Req>, Arc<std::io::Error>>;
#[derive(Debug)]
struct ChanRespOk<Req> {
id: u64,
conn: Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>,
}
impl<Req> Request<Req> {
fn new(conn: Connection<Req>, request_msg: Req) -> Self {
Self {
conn,
request_msg,
start: Instant::now(),
state: QueryState::RequestConn,
conn_id: None,
delayed_retry_count: 0,
}
}
}
impl<Req: ComposeRequest + Clone + 'static> Request<Req> {
pub async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
loop {
let elapsed = self.start.elapsed();
if elapsed >= self.conn.response_timeout {
return Err(Error::StreamReadTimeout);
}
let remaining = self.conn.response_timeout - elapsed;
match self.state {
QueryState::RequestConn => {
let to =
timeout(remaining, self.conn.new_conn(self.conn_id))
.await
.map_err(|_| Error::StreamReadTimeout)?;
let rx = match to {
Ok(rx) => rx,
Err(err) => {
self.state = QueryState::Done;
return Err(err);
}
};
self.state = QueryState::ReceiveConn(rx);
}
QueryState::ReceiveConn(ref mut receiver) => {
let to = timeout(remaining, receiver)
.await
.map_err(|_| Error::StreamReadTimeout)?;
let res = match to {
Ok(res) => res,
Err(_) => {
self.state = QueryState::Done;
return Err(Error::StreamReceiveError);
}
};
match res {
Err(_) => {
self.delayed_retry_count += 1;
let retry_time =
retry_time(self.delayed_retry_count);
self.state =
QueryState::Delay(Instant::now(), retry_time);
continue;
}
Ok(ok_res) => {
let id = ok_res.id;
let conn = ok_res.conn;
self.conn_id = Some(id);
self.state = QueryState::StartQuery(conn);
continue;
}
}
}
QueryState::StartQuery(ref mut conn) => {
self.state = QueryState::GetResult(
conn.get_request(self.request_msg.clone()),
);
continue;
}
QueryState::GetResult(ref mut query) => {
let to = timeout(remaining, query.get_response())
.await
.map_err(|_| Error::StreamReadTimeout)?;
match to {
Ok(reply) => {
return Ok(reply);
}
Err(Error::WrongReplyForQuery) => {
return Err(Error::WrongReplyForQuery)
}
Err(Error::ConnectionClosed) => {
self.delayed_retry_count += 1;
if self.delayed_retry_count == 1 {
self.state = QueryState::RequestConn;
} else {
let retry_time =
retry_time(self.delayed_retry_count);
self.state = QueryState::Delay(
Instant::now(),
retry_time,
);
}
}
Err(_) => {
self.delayed_retry_count += 1;
let retry_time =
retry_time(self.delayed_retry_count);
self.state =
QueryState::Delay(Instant::now(), retry_time);
}
}
}
QueryState::Delay(instant, duration) => {
if timeout(remaining, sleep_until(instant + duration))
.await
.is_err()
{
return Err(Error::StreamReadTimeout);
};
self.state = QueryState::RequestConn;
}
QueryState::Done => {
panic!("Already done");
}
}
}
}
}
impl<Req: ComposeRequest + Clone + 'static> GetResponse for Request<Req> {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(Self::get_response(self))
}
}
#[derive(Debug)]
pub struct Transport<Remote, Req> {
config: Config,
stream: Remote,
conn_state: SingleConnState3<Req>,
conn_id: u64,
receiver: mpsc::Receiver<ChanReq<Req>>,
}
#[derive(Debug)]
struct ChanReq<Req> {
cmd: ReqCmd<Req>,
}
#[derive(Debug)]
enum ReqCmd<Req> {
NewConn(Option<u64>, ReplySender<Req>),
Shutdown,
}
type ReplySender<Req> = oneshot::Sender<ChanResp<Req>>;
#[derive(Debug)]
enum SingleConnState3<Req> {
None,
Some(Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>),
Err(ErrorState),
}
#[derive(Clone, Debug)]
struct ErrorState {
error: Arc<std::io::Error>,
retries: u64,
timer: Instant,
timeout: Duration,
}
impl<Remote, Req> Transport<Remote, Req> {
fn new(
stream: Remote,
config: Config,
) -> (mpsc::Sender<ChanReq<Req>>, Self) {
let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
(
sender,
Self {
config,
stream,
conn_state: SingleConnState3::None,
conn_id: 0,
receiver,
},
)
}
}
impl<Remote, Req: ComposeRequest> Transport<Remote, Req>
where
Remote: AsyncConnect,
Remote::Connection: AsyncRead + AsyncWrite,
Req: ComposeRequest,
{
pub async fn run(mut self) {
let mut curr_cmd: Option<ReqCmd<Req>> = None;
let mut do_stream = false;
let mut runners = FuturesUnordered::new();
let mut stream_fut: Pin<
Box<
dyn Future<
Output = Result<Remote::Connection, std::io::Error>,
> + Send,
>,
> = Box::pin(stream_nop());
let mut opt_chan = None;
loop {
if let Some(req) = curr_cmd {
assert!(!do_stream);
curr_cmd = None;
match req {
ReqCmd::NewConn(opt_id, chan) => {
if let SingleConnState3::Err(error_state) =
&self.conn_state
{
if error_state.timer.elapsed()
< error_state.timeout
{
let resp =
ChanResp::Err(error_state.error.clone());
_ = chan.send(resp);
continue;
}
}
if let Some(id) = opt_id {
if id >= self.conn_id {
self.conn_id += 1;
self.conn_state = SingleConnState3::None;
}
}
if let SingleConnState3::Some(conn) = &self.conn_state
{
let resp = ChanResp::Ok(ChanRespOk {
id: self.conn_id,
conn: conn.clone(),
});
_ = chan.send(resp);
} else {
opt_chan = Some(chan);
stream_fut = Box::pin(self.stream.connect());
do_stream = true;
}
}
ReqCmd::Shutdown => break,
}
}
if do_stream {
let runners_empty = runners.is_empty();
loop {
tokio::select! {
res_conn = stream_fut.as_mut() => {
do_stream = false;
stream_fut = Box::pin(stream_nop());
let stream = match res_conn {
Ok(stream) => stream,
Err(error) => {
let error = Arc::new(error);
match self.conn_state {
SingleConnState3::None =>
self.conn_state =
SingleConnState3::Err(ErrorState {
error: error.clone(),
retries: 0,
timer: Instant::now(),
timeout: retry_time(0),
}),
SingleConnState3::Some(_) =>
panic!("Illegal Some state"),
SingleConnState3::Err(error_state) => {
self.conn_state =
SingleConnState3::Err(ErrorState {
error:
error_state.error.clone(),
retries: error_state.retries+1,
timer: Instant::now(),
timeout: retry_time(
error_state.retries+1),
});
}
}
let resp = ChanResp::Err(error);
let loc_opt_chan = opt_chan.take();
_ = loc_opt_chan.expect("weird, no channel?")
.send(resp);
break;
}
};
let (conn, tran) = stream::Connection::with_config(
stream, self.config.stream.clone()
);
let conn = Arc::new(conn);
runners.push(Box::pin(tran.run()));
let resp = ChanResp::Ok(ChanRespOk {
id: self.conn_id,
conn: conn.clone(),
});
self.conn_state = SingleConnState3::Some(conn);
let loc_opt_chan = opt_chan.take();
_ = loc_opt_chan.expect("weird, no channel?")
.send(resp);
break;
}
_ = runners.next(), if !runners_empty => {
}
}
}
continue;
}
assert!(curr_cmd.is_none());
let recv_fut = self.receiver.recv();
let runners_empty = runners.is_empty();
tokio::select! {
msg = recv_fut => {
if msg.is_none() {
break;
}
curr_cmd = Some(msg.expect("None is checked before").cmd);
}
_ = runners.next(), if !runners_empty => {
}
}
}
drop(self.receiver);
while !runners.is_empty() {
runners.next().await;
}
}
}
fn retry_time(retries: u64) -> Duration {
let to_secs = if retries > 6 { 60 } else { 1 << retries };
let to_usecs = to_secs * 1000000;
let rnd: f64 = random();
let to_usecs = to_usecs as f64 * rnd;
Duration::from_micros(to_usecs as u64)
}
async fn stream_nop<IO>() -> Result<IO, std::io::Error> {
Err(io::Error::other("nop"))
}