use bytes::Bytes;
use futures_util::stream::FuturesUnordered;
use futures_util::StreamExt;
use octseq::Octets;
use rand::{random, random_range};
use std::boxed::Box;
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::vec::Vec;
use tokio::sync::{mpsc, oneshot};
use tokio::time::{sleep_until, Duration, Instant};
use crate::base::iana::OptRcode;
use crate::base::Message;
use crate::net::client::request::{Error, GetResponse, SendRequest};
const DEF_CHAN_CAP: usize = 8;
const DEFAULT_RT_MS: u64 = 300;
const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS);
const SMOOTH_N: f64 = 8.;
const PROBE_P: f64 = 0.05;
#[derive(Clone, Copy, Debug, Default)]
pub struct Config {
defer_transport_error: bool,
defer_refused: bool,
defer_servfail: bool,
}
impl Config {
pub fn defer_transport_error(&self) -> bool {
self.defer_transport_error
}
pub fn set_defer_transport_error(&mut self, value: bool) {
self.defer_transport_error = value
}
pub fn defer_refused(&self) -> bool {
self.defer_refused
}
pub fn set_defer_refused(&mut self, value: bool) {
self.defer_refused = value
}
pub fn defer_servfail(&self) -> bool {
self.defer_servfail
}
pub fn set_defer_servfail(&mut self, value: bool) {
self.defer_servfail = value
}
}
#[derive(Debug)]
pub struct Connection<Req>
where
Req: Send + Sync,
{
config: Config,
sender: mpsc::Sender<ChanReq<Req>>,
}
impl<Req: Clone + Debug + Send + Sync + 'static> Connection<Req> {
pub fn new() -> (Self, Transport<Req>) {
Self::with_config(Default::default())
}
pub fn with_config(config: Config) -> (Self, Transport<Req>) {
let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
(Self { config, sender }, Transport::new(receiver))
}
pub async fn add(
&self,
conn: Box<dyn SendRequest<Req> + Send + Sync>,
) -> Result<(), Error> {
let (tx, rx) = oneshot::channel();
self.sender
.send(ChanReq::Add(AddReq { conn, tx }))
.await
.expect("send should not fail");
rx.await.expect("receive should not fail")
}
async fn request_impl(
self,
request_msg: Req,
) -> Result<Message<Bytes>, Error> {
let (tx, rx) = oneshot::channel();
self.sender
.send(ChanReq::GetRT(RTReq { tx }))
.await
.expect("send should not fail");
let conn_rt = rx.await.expect("receive should not fail")?;
Query::new(self.config, request_msg, conn_rt, self.sender.clone())
.get_response()
.await
}
}
impl<Req> Clone for Connection<Req>
where
Req: Send + Sync,
{
fn clone(&self) -> Self {
Self {
config: self.config,
sender: self.sender.clone(),
}
}
}
impl<Req: Clone + Debug + Send + Sync + 'static> SendRequest<Req>
for Connection<Req>
{
fn send_request(
&self,
request_msg: Req,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(Request {
fut: Box::pin(self.clone().request_impl(request_msg)),
})
}
}
struct Request {
fut: Pin<
Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
>,
}
impl Request {
async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
(&mut self.fut).await
}
}
impl GetResponse for Request {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(self.get_response_impl())
}
}
impl Debug for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request")
.field("fut", &format_args!("_"))
.finish()
}
}
#[derive(Debug)]
struct Query<Req>
where
Req: Send + Sync,
{
config: Config,
state: QueryState,
request_msg: Req,
conn_rt: Vec<ConnRT>,
sender: mpsc::Sender<ChanReq<Req>>,
fut_list: FuturesUnordered<
Pin<Box<dyn Future<Output = FutListOutput> + Send + Sync>>,
>,
deferred_transport_error: Option<Error>,
deferred_reply: Option<Message<Bytes>>,
result: Option<Result<Message<Bytes>, Error>>,
res_index: usize,
}
#[derive(Debug)]
enum QueryState {
Init,
Probe(usize),
Report(usize),
Wait,
}
enum ChanReq<Req>
where
Req: Send + Sync,
{
Add(AddReq<Req>),
GetRT(RTReq),
Query(RequestReq<Req>),
Report(TimeReport),
Failure(TimeReport),
}
impl<Req> Debug for ChanReq<Req>
where
Req: Send + Sync,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
f.debug_struct("ChanReq").finish()
}
}
struct AddReq<Req> {
conn: Box<dyn SendRequest<Req> + Send + Sync>,
tx: oneshot::Sender<AddReply>,
}
type AddReply = Result<(), Error>;
struct RTReq {
tx: oneshot::Sender<RTReply>,
}
type RTReply = Result<Vec<ConnRT>, Error>;
struct RequestReq<Req>
where
Req: Send + Sync,
{
id: u64,
request_msg: Req,
tx: oneshot::Sender<RequestReply>,
}
impl<Req: Debug> Debug for RequestReq<Req>
where
Req: Send + Sync,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
f.debug_struct("RequestReq")
.field("id", &self.id)
.field("request_msg", &self.request_msg)
.finish()
}
}
type RequestReply = Result<Box<dyn GetResponse + Send + Sync>, Error>;
#[derive(Debug)]
struct TimeReport {
id: u64,
elapsed: Duration,
}
struct ConnStats {
mean: f64,
mean_sq: f64,
}
#[derive(Clone, Debug)]
struct ConnRT {
est_rt: Duration,
id: u64,
start: Option<Instant>,
}
type FutListOutput = (usize, Result<Message<Bytes>, Error>);
impl<Req: Clone + Send + Sync + 'static> Query<Req> {
fn new(
config: Config,
request_msg: Req,
mut conn_rt: Vec<ConnRT>,
sender: mpsc::Sender<ChanReq<Req>>,
) -> Self {
let conn_rt_len = conn_rt.len();
conn_rt.sort_unstable_by(conn_rt_cmp);
if conn_rt_len > 1 && random::<f64>() < PROBE_P {
let index = random_range(1..=conn_rt_len - 1);
let min_rt = conn_rt.iter().map(|e| e.est_rt).min().unwrap();
let mut e = conn_rt.remove(index);
e.est_rt = min_rt;
conn_rt.insert(0, e);
}
Self {
config,
request_msg,
conn_rt,
sender,
state: QueryState::Init,
fut_list: FuturesUnordered::new(),
deferred_transport_error: None,
deferred_reply: None,
result: None,
res_index: 0,
}
}
async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
loop {
match self.state {
QueryState::Init => {
if self.conn_rt.is_empty() {
return Err(Error::NoTransportAvailable);
}
self.state = QueryState::Probe(0);
continue;
}
QueryState::Probe(ind) => {
self.conn_rt[ind].start = Some(Instant::now());
let fut = start_request(
ind,
self.conn_rt[ind].id,
self.sender.clone(),
self.request_msg.clone(),
);
self.fut_list.push(Box::pin(fut));
let timeout = Instant::now() + self.conn_rt[ind].est_rt;
loop {
tokio::select! {
res = self.fut_list.next() => {
let res = res.expect("res should not be empty");
match res.1 {
Err(ref err) => {
if self.config.defer_transport_error {
if self.deferred_transport_error.is_none() {
self.deferred_transport_error = Some(err.clone());
}
if res.0 == ind {
self.state =
if ind+1 < self.conn_rt.len() {
QueryState::Probe(ind+1)
}
else
{
QueryState::Wait
};
break;
}
continue;
}
}
Ok(ref msg) => {
if skip(msg, &self.config) {
if self.deferred_reply.is_none() {
self.deferred_reply = Some(msg.clone());
}
if res.0 == ind {
self.state =
if ind+1 < self.conn_rt.len() {
QueryState::Probe(ind+1)
}
else
{
QueryState::Wait
};
break;
}
continue;
}
}
}
self.result = Some(res.1);
self.res_index= res.0;
self.state = QueryState::Report(0);
break;
}
_ = sleep_until(timeout) => {
self.state =
if ind+1 < self.conn_rt.len() {
QueryState::Probe(ind+1)
}
else {
QueryState::Wait
};
break;
}
}
}
continue;
}
QueryState::Report(ind) => {
if ind >= self.conn_rt.len()
|| self.conn_rt[ind].start.is_none()
{
let res = self
.result
.take()
.expect("result should not be empty");
return res;
}
let start = self.conn_rt[ind]
.start
.expect("start time should not be empty");
let elapsed = start.elapsed();
let time_report = TimeReport {
id: self.conn_rt[ind].id,
elapsed,
};
let report = if ind == self.res_index {
ChanReq::Report(time_report)
} else {
ChanReq::Failure(time_report)
};
let _ = self.sender.send(report).await;
self.state = QueryState::Report(ind + 1);
continue;
}
QueryState::Wait => {
loop {
if self.fut_list.is_empty() {
if self.deferred_reply.is_some() {
let msg = self
.deferred_reply
.take()
.expect("just checked for Some");
return Ok(msg);
}
if self.deferred_transport_error.is_some() {
let err = self
.deferred_transport_error
.take()
.expect("just checked for Some");
return Err(err);
}
panic!("either deferred_reply or deferred_error should be present");
}
let res = self.fut_list.next().await;
let res = res.expect("res should not be empty");
match res.1 {
Err(ref err) => {
if self.config.defer_transport_error {
if self.deferred_transport_error.is_none()
{
self.deferred_transport_error =
Some(err.clone());
}
continue;
}
}
Ok(ref msg) => {
if skip(msg, &self.config) {
if self.deferred_reply.is_none() {
self.deferred_reply =
Some(msg.clone());
}
continue;
}
}
}
self.result = Some(res.1);
self.res_index = res.0;
self.state = QueryState::Report(0);
break;
}
continue;
}
}
}
}
}
#[derive(Debug)]
pub struct Transport<Req>
where
Req: Send + Sync,
{
receiver: mpsc::Receiver<ChanReq<Req>>,
}
impl<Req: Clone + Send + Sync + 'static> Transport<Req> {
fn new(receiver: mpsc::Receiver<ChanReq<Req>>) -> Self {
Self { receiver }
}
pub async fn run(mut self) {
let mut next_id: u64 = 10;
let mut conn_stats: Vec<ConnStats> = Vec::new();
let mut conn_rt: Vec<ConnRT> = Vec::new();
let mut conns: Vec<Box<dyn SendRequest<Req> + Send + Sync>> =
Vec::new();
loop {
let req = match self.receiver.recv().await {
Some(req) => req,
None => break, };
match req {
ChanReq::Add(add_req) => {
let id = next_id;
next_id += 1;
conn_stats.push(ConnStats {
mean: (DEFAULT_RT_MS as f64) / 1000.,
mean_sq: 0.,
});
conn_rt.push(ConnRT {
id,
est_rt: DEFAULT_RT,
start: None,
});
conns.push(add_req.conn);
let _ = add_req.tx.send(Ok(()));
}
ChanReq::GetRT(rt_req) => {
let _ = rt_req.tx.send(Ok(conn_rt.clone()));
}
ChanReq::Query(request_req) => {
let opt_ind =
conn_rt.iter().position(|e| e.id == request_req.id);
match opt_ind {
Some(ind) => {
let query = conns[ind]
.send_request(request_req.request_msg);
let _ = request_req.tx.send(Ok(query));
}
None => {
let _ = request_req
.tx
.send(Err(Error::RedundantTransportNotFound));
}
}
}
ChanReq::Report(time_report) => {
let opt_ind =
conn_rt.iter().position(|e| e.id == time_report.id);
if let Some(ind) = opt_ind {
let elapsed = time_report.elapsed.as_secs_f64();
conn_stats[ind].mean +=
(elapsed - conn_stats[ind].mean) / SMOOTH_N;
let elapsed_sq = elapsed * elapsed;
conn_stats[ind].mean_sq +=
(elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N;
let mean = conn_stats[ind].mean;
let var = conn_stats[ind].mean_sq - mean * mean;
let std_dev =
if var < 0. { 0. } else { f64::sqrt(var) };
let est_rt = mean + 3. * std_dev;
conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
}
}
ChanReq::Failure(time_report) => {
let opt_ind =
conn_rt.iter().position(|e| e.id == time_report.id);
if let Some(ind) = opt_ind {
let elapsed = time_report.elapsed.as_secs_f64();
if elapsed < conn_stats[ind].mean {
continue;
}
conn_stats[ind].mean +=
(elapsed - conn_stats[ind].mean) / SMOOTH_N;
let elapsed_sq = elapsed * elapsed;
conn_stats[ind].mean_sq +=
(elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N;
let mean = conn_stats[ind].mean;
let var = conn_stats[ind].mean_sq - mean * mean;
let std_dev =
if var < 0. { 0. } else { f64::sqrt(var) };
let est_rt = mean + 3. * std_dev;
conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
}
}
}
}
}
}
async fn start_request<Req>(
index: usize,
id: u64,
sender: mpsc::Sender<ChanReq<Req>>,
request_msg: Req,
) -> (usize, Result<Message<Bytes>, Error>)
where
Req: Send + Sync,
{
let (tx, rx) = oneshot::channel();
sender
.send(ChanReq::Query(RequestReq {
id,
request_msg,
tx,
}))
.await
.expect("send is expected to work");
let mut request = match rx.await.expect("receive is expected to work") {
Err(err) => return (index, Err(err)),
Ok(request) => request,
};
let reply = request.get_response().await;
(index, reply)
}
fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering {
e1.est_rt.cmp(&e2.est_rt)
}
fn skip<Octs: Octets>(msg: &Message<Octs>, config: &Config) -> bool {
if !config.defer_refused && !config.defer_servfail {
return false;
}
let opt_rcode = msg.opt_rcode();
if let OptRcode::REFUSED = opt_rcode {
if config.defer_refused {
return true;
}
}
if let OptRcode::SERVFAIL = opt_rcode {
if config.defer_servfail {
return true;
}
}
false
}