use bytes::Bytes;
use futures::{future, future::poll_fn, FutureExt, Sink, SinkExt, Stream, StreamExt};
use std::{
collections::VecDeque,
fmt,
io::{self, Error, ErrorKind},
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tokio::{
select,
sync::{mpsc, watch},
};
use crate::{
cfg::{Cfg, ExchangedCfg},
control::{Direction, DisconnectReason, Link, LinkIntervalStats, LinkStats, NotWorkingReason},
exec::time::{sleep_until, Instant},
id::{ConnId, LinkId},
msg::LinkMsg,
seq::Seq,
};
#[derive(Debug)]
pub(crate) enum LinkIntEvent {
TxReady,
TxFlushed,
TxError(io::Error),
Rx {
msg: LinkMsg,
data: Option<Bytes>,
},
RxError(io::Error),
FlushDelayPassed,
Disconnect,
BlockedChanged,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum LinkTest {
Inactive,
InProgress,
Failed(Instant),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum DisconnectInitiator {
Local,
Remote,
}
pub(crate) struct LinkInt<TX, RX, TAG> {
tag: Arc<TAG>,
conn_id: ConnId,
link_id: LinkId,
direction: Direction,
cfg: Arc<Cfg>,
remote_cfg: Arc<ExchangedCfg>,
pub(crate) needs_tx_accepted: bool,
tx: TX,
tx_data: Option<Bytes>,
tx_error: Option<io::Error>,
tx_failed: bool,
tx_polling: Option<Instant>,
pub(crate) tx_pending: bool,
pub(crate) tx_last_msg: Option<Instant>,
txed_unacked: Option<Seq>,
tx_idle_since: Option<Instant>,
tx_flushing: bool,
tx_flushed: bool,
pub(crate) txed_unacked_data: usize,
pub(crate) txed_unacked_data_limit: usize,
pub(crate) txed_unacked_data_limit_increased: Option<Seq>,
pub(crate) txed_unacked_data_limit_increased_consecutively: usize,
pub(crate) tx_ack_queue: VecDeque<Seq>,
txed_acks_unflushed: usize,
rx: RX,
rxed_data_msg: Option<LinkMsg>,
disconnected_tx: watch::Sender<DisconnectReason>,
disconnect_tx: mpsc::Sender<()>,
disconnect_rx: mpsc::Receiver<()>,
pub(crate) blocked: Arc<AtomicBool>,
pub(crate) blocked_sent: bool,
pub(crate) blocked_changed_tx: mpsc::Sender<()>,
blocked_changed_rx: mpsc::Receiver<()>,
pub(crate) blocked_changed_out_tx: watch::Sender<()>,
blocked_changed_out_rx: watch::Receiver<()>,
pub(crate) remotely_blocked: Arc<AtomicBool>,
pub(crate) unconfirmed: Option<(Instant, NotWorkingReason)>,
unconfirmed_tx: watch::Sender<Option<(Instant, NotWorkingReason)>>,
unconfirmed_rx: watch::Receiver<Option<(Instant, NotWorkingReason)>>,
pub(crate) test: LinkTest,
pub(crate) roundtrip: Duration,
pub(crate) last_ping: Option<Instant>,
pub(crate) current_ping_sent: Option<Instant>,
pub(crate) send_ping: bool,
pub(crate) send_pong: bool,
pub(crate) disconnecting: Option<DisconnectInitiator>,
pub(crate) goodbye_sent: bool,
remote_user_data: Arc<Vec<u8>>,
stats: LinkStatistican,
}
impl<TX, RX, TAG> fmt::Debug for LinkInt<TX, RX, TAG> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("LinkInt")
.field("conn_id", &self.conn_id)
.field("link_id", &self.link_id)
.field("direction", &self.direction)
.finish_non_exhaustive()
}
}
impl<TX, RX, TAG> LinkInt<TX, RX, TAG> {
pub(crate) fn tag(&self) -> &TAG {
&self.tag
}
pub(crate) fn remote_user_data(&self) -> &[u8] {
&self.remote_user_data
}
pub(crate) fn remote_cfg(&self) -> Arc<ExchangedCfg> {
self.remote_cfg.clone()
}
}
impl<TX, RX, TAG> LinkInt<TX, RX, TAG>
where
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin,
TX: Sink<Bytes, Error = io::Error> + Unpin,
{
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
tag: TAG, conn_id: ConnId, tx: TX, rx: RX, cfg: Arc<Cfg>, remote_cfg: ExchangedCfg, direction: Direction,
roundtrip: Duration, remote_user_data: Vec<u8>,
) -> Self {
let (disconnected_tx, _) = watch::channel(DisconnectReason::TaskTerminated);
let (disconnect_tx, disconnect_rx) = mpsc::channel(1);
let (blocked_changed_tx, blocked_changed_rx) = mpsc::channel(2);
let stats = LinkStatistican::new(&cfg.stats_intervals, roundtrip);
let (unconfirmed_tx, unconfirmed_rx) = watch::channel(None);
let (blocked_changed_out_tx, blocked_changed_out_rx) = watch::channel(());
Self {
tag: Arc::new(tag),
conn_id,
link_id: LinkId::generate(),
direction,
tx,
tx_data: None,
tx_error: None,
tx_failed: false,
rx,
remote_cfg: Arc::new(remote_cfg),
needs_tx_accepted: direction == Direction::Incoming,
disconnected_tx,
disconnect_tx,
disconnect_rx,
stats,
goodbye_sent: false,
tx_polling: None,
blocked: Arc::new(AtomicBool::new(false)),
blocked_sent: false,
blocked_changed_tx,
blocked_changed_rx,
blocked_changed_out_tx,
blocked_changed_out_rx,
remotely_blocked: Arc::new(AtomicBool::new(false)),
unconfirmed: None,
unconfirmed_tx,
unconfirmed_rx,
test: LinkTest::Inactive,
tx_flushing: false,
tx_flushed: true,
rxed_data_msg: None,
tx_last_msg: None,
txed_unacked: None,
last_ping: None,
current_ping_sent: None,
send_ping: false,
send_pong: false,
roundtrip,
disconnecting: None,
txed_unacked_data: 0,
txed_unacked_data_limit: cfg.link_unacked_init.get(),
txed_unacked_data_limit_increased: None,
txed_unacked_data_limit_increased_consecutively: 45,
txed_acks_unflushed: 0,
tx_ack_queue: VecDeque::new(),
tx_idle_since: None,
tx_pending: false,
cfg,
remote_user_data: Arc::new(remote_user_data),
}
}
pub(crate) fn link_id(&self) -> LinkId {
self.link_id
}
fn check_tx_failed(&self) -> Result<(), io::Error> {
match self.tx_failed {
true => Err(Error::new(ErrorKind::ConnectionAborted, "link has failed")),
false => Ok(()),
}
}
pub(crate) async fn event(&mut self) -> LinkIntEvent {
let link_id = self.link_id();
if let Some(err) = self.tx_error.take() {
return LinkIntEvent::TxError(err);
}
if let Err(err) = self.check_tx_failed() {
return LinkIntEvent::TxError(err);
}
self.unconfirmed_tx.send_if_modified(|m| {
if *m != self.unconfirmed {
m.clone_from(&self.unconfirmed);
true
} else {
false
}
});
let flushable = !(self.tx_flushing || self.tx_flushed);
let tx_task = async {
loop {
if self.tx_polling.is_none() {
assert!(self.tx_data.is_none());
future::pending().await
} else if self.tx_flushing && self.tx_data.is_none() {
match self.tx.flush().await {
Ok(()) => {
self.tx_flushing = false;
self.tx_flushed = true;
break LinkIntEvent::TxFlushed;
}
Err(err) => {
self.tx_failed = true;
break LinkIntEvent::TxError(err);
}
}
} else {
let tx_ready = |cx: &mut Context| {
let res = self.tx.poll_ready_unpin(cx);
match &res {
Poll::Pending => self.tx_pending = true,
Poll::Ready(_) => self.tx_pending = false,
}
res
};
match poll_fn(tx_ready).await {
Ok(()) => match self.tx_data.take() {
Some(data) => {
self.tx_flushed = false;
if let Err(err) = self.tx.start_send_unpin(data) {
self.tx_failed = true;
break LinkIntEvent::TxError(err);
}
}
None => {
self.tx_polling = None;
break LinkIntEvent::TxReady;
}
},
Err(err) => {
tracing::debug!(?link_id, %err, "link poll ready failure");
self.tx_failed = true;
break LinkIntEvent::TxError(err);
}
}
}
}
};
let rx_task = async {
loop {
match self.rx.next().await {
Some(Ok(buf)) => {
self.stats.record(0, buf.len());
match self.rxed_data_msg.take() {
Some(msg) => {
break LinkIntEvent::Rx { msg, data: Some(buf) };
}
None => {
let cursor = io::Cursor::new(buf);
match LinkMsg::read(cursor) {
Ok(msg) => {
match (&msg, self.txed_unacked) {
(LinkMsg::Ack { received }, Some(sent)) if *received >= sent => {
self.txed_unacked = None
}
_ => (),
}
if let LinkMsg::Data { .. } = &msg {
self.rxed_data_msg = Some(msg);
} else {
break LinkIntEvent::Rx { msg, data: None };
}
}
Err(err) => break LinkIntEvent::RxError(err),
}
}
}
}
Some(Err(err)) => {
tracing::debug!(?link_id, %err, "link receive failure");
break LinkIntEvent::RxError(err);
}
None => {
tracing::debug!(?link_id, "link receive end");
break LinkIntEvent::RxError(io::ErrorKind::BrokenPipe.into());
}
}
}
};
let flush_req_task = async {
match self.tx_idle_since {
Some(idle_since) if flushable => sleep_until(idle_since + self.cfg.link_flush_delay).await,
_ => future::pending().await,
}
};
select! {
tx_event = tx_task => tx_event,
rx_event = rx_task => rx_event,
() = flush_req_task => LinkIntEvent::FlushDelayPassed,
Some(()) = self.disconnect_rx.recv() => LinkIntEvent::Disconnect,
Some(()) = self.blocked_changed_rx.recv() => LinkIntEvent::BlockedChanged,
}
}
pub(crate) async fn send_msg_and_flush(&mut self, msg: LinkMsg) -> Result<(), io::Error> {
self.check_tx_failed()?;
self.tx_polling = Some(Instant::now());
self.tx.send(msg.encode()).await.inspect_err(|_| self.tx_failed = true)?;
self.tx_flushed = true;
Ok(())
}
pub(crate) fn start_send_msg(&mut self, msg: LinkMsg, data: Option<Bytes>) {
assert!(self.tx_polling.is_none());
assert!(self.tx_data.is_none());
if let Err(err) = self.check_tx_failed() {
if self.tx_error.is_none() {
self.tx_error = Some(err);
}
return;
}
self.tx_polling = Some(Instant::now());
self.tx_flushed = false;
self.tx_idle_since = None;
let encoded = msg.encode();
let msg_len = encoded.len();
let data_len = data.as_ref().map(|data| data.len()).unwrap_or_default();
if let Err(err) = self.tx.start_send_unpin(encoded) {
tracing::debug!(link_id =? self.link_id, %err, "link send failure");
self.tx_error = Some(err);
self.tx_failed = true;
return;
}
self.stats.record(msg_len + data_len, 0);
self.tx_data = data;
self.tx_last_msg = Some(Instant::now());
match &msg {
LinkMsg::Ack { .. } | LinkMsg::Consumed { .. } => self.txed_acks_unflushed += 1,
LinkMsg::Data { seq } => match self.txed_unacked {
Some(txed_unacked) if txed_unacked > *seq => (),
_ => self.txed_unacked = Some(*seq),
},
LinkMsg::Accepted
| LinkMsg::Ping
| LinkMsg::Pong
| LinkMsg::SendFinish { .. }
| LinkMsg::ReceiveClose { .. }
| LinkMsg::ReceiveFinish { .. }
| LinkMsg::Goodbye => self.start_flush(),
_ => (),
}
}
pub(crate) fn start_flush(&mut self) {
self.txed_acks_unflushed = 0;
self.tx_flushing = true;
self.tx_polling = Some(Instant::now());
}
pub(crate) fn need_ack_flush(&self) -> bool {
self.txed_acks_unflushed != 0
}
pub(crate) fn needs_flush(&self) -> bool {
!self.tx_flushed && !self.tx_flushing
}
pub(crate) fn has_outstanding_ack(&self) -> bool {
self.txed_unacked.is_some()
}
pub(crate) fn report_ready(&mut self) {
self.tx_polling = Some(Instant::now());
}
pub(crate) fn send_test_data(&mut self, packet_size: usize, data_limit: usize) -> usize {
assert!(self.tx_data.is_none());
self.tx_polling = Some(Instant::now());
self.tx_flushed = false;
self.tx_idle_since = None;
if let Err(err) = self.check_tx_failed() {
if self.tx_error.is_none() {
self.tx_error = Some(err);
}
return 0;
}
let mut sent = 0;
while sent < data_limit {
match poll_fn(|cx| self.tx.poll_ready_unpin(cx)).now_or_never() {
Some(Ok(())) => (),
Some(Err(err)) => {
self.tx_error = Some(err);
self.tx_failed = true;
break;
}
None => break,
}
let size = packet_size.min(data_limit - sent);
if let Err(err) = self.tx.start_send_unpin(LinkMsg::TestData { size }.encode()) {
self.tx_error = Some(err);
self.tx_failed = true;
break;
}
sent += size;
}
sent
}
pub(crate) fn notify_disconnected(mut self, reason: DisconnectReason) {
self.disconnected_tx.send_replace(reason);
self.disconnect_rx.close();
}
pub(crate) async fn terminate_connection(&mut self, mut expect_reply: bool) {
let link_id = self.link_id();
tracing::debug!(?link_id, "waiting for link to become ready for termination");
self.report_ready();
loop {
match self.event().await {
LinkIntEvent::TxReady | LinkIntEvent::TxError(_) => break,
LinkIntEvent::Rx { msg: LinkMsg::Terminate, .. } => expect_reply = false,
_ => (),
}
}
tracing::debug!(?link_id, "sending forceful connection termination");
match self.send_msg_and_flush(LinkMsg::Terminate).await {
Ok(()) => tracing::debug!(?link_id, "forceful connection termination sent"),
Err(err) => {
tracing::warn!(?link_id, %err, "sending forceful connection termination failed");
}
}
if expect_reply {
tracing::debug!(?link_id, "waiting for forceful connection termination reply");
loop {
match self.event().await {
LinkIntEvent::RxError(err) => {
tracing::warn!(?link_id, %err, "receiving forceful connection termination reply failed");
break;
}
LinkIntEvent::Rx { msg: LinkMsg::Terminate, .. } => {
tracing::debug!(?link_id, "forceful connection termination reply received");
break;
}
_ => (),
}
}
}
}
pub(crate) fn mark_idle(&mut self) {
self.tx_idle_since = Some(Instant::now());
self.stats.mark_idle();
}
pub(crate) fn is_sendable(&self) -> bool {
self.txed_unacked_data < self.txed_unacked_data_limit
}
pub(crate) fn tx_polling(&self) -> Option<Instant> {
self.tx_polling
}
pub(crate) fn reset(&mut self) {
self.stats.current.hangs += 1;
self.txed_unacked_data_limit = self.txed_unacked_data_limit.clamp(128, self.cfg.link_unacked_init.get());
self.txed_unacked_data_limit_increased = None;
self.txed_unacked_data_limit_increased_consecutively = 0;
}
pub(crate) fn is_blocked(&self) -> bool {
self.blocked.load(Ordering::SeqCst) || self.remotely_blocked.load(Ordering::SeqCst)
}
pub(crate) fn publish_stats(&mut self) {
self.stats.current.sent_unacked = self.txed_unacked_data as _;
self.stats.current.unacked_limit = self.txed_unacked_data_limit as _;
self.stats.current.roundtrip = self.roundtrip;
self.stats.publish();
}
}
impl<TX, RX, TAG> From<&LinkInt<TX, RX, TAG>> for Link<TAG> {
fn from(link_int: &LinkInt<TX, RX, TAG>) -> Self {
Self {
conn_id: link_int.conn_id,
link_id: link_int.link_id,
direction: link_int.direction,
tag: link_int.tag.clone(),
cfg: link_int.cfg.clone(),
disconnected_rx: link_int.disconnected_tx.subscribe(),
disconnect_tx: link_int.disconnect_tx.clone(),
stats_rx: link_int.stats.subscribe(),
remote_user_data: link_int.remote_user_data.clone(),
blocked: link_int.blocked.clone(),
blocked_changed_tx: link_int.blocked_changed_tx.clone(),
blocked_changed_rx: link_int.blocked_changed_out_rx.clone(),
not_working_rx: link_int.unconfirmed_rx.clone(),
remotely_blocked: link_int.remotely_blocked.clone(),
}
}
}
struct LinkStatistican {
tx: watch::Sender<LinkStats>,
current: LinkStats,
running_stats: Vec<LinkIntervalStats>,
}
impl LinkStatistican {
fn new(intervals: &[Duration], roundtrip: Duration) -> Self {
let running_stats: Vec<_> = intervals.iter().map(|interval| LinkIntervalStats::new(*interval)).collect();
let current = LinkStats {
established: Instant::now(),
total_sent: 0,
total_recved: 0,
sent_unacked: 0,
unacked_limit: 0,
roundtrip,
hangs: 0,
time_stats: running_stats.clone(),
};
Self { tx: watch::channel(current.clone()).0, current, running_stats }
}
fn subscribe(&self) -> watch::Receiver<LinkStats> {
self.tx.subscribe()
}
fn publish(&mut self) {
let mut modified = false;
for (rs, ts) in self.running_stats.iter_mut().zip(self.current.time_stats.iter_mut()) {
if rs.start.elapsed() > rs.interval {
if rs.sent == 0 {
rs.busy = false;
}
*ts = mem::replace(rs, LinkIntervalStats::new(rs.interval));
modified = true;
}
}
if modified {
self.tx.send_replace(self.current.clone());
}
}
fn record(&mut self, sent: usize, received: usize) {
self.current.total_sent = self.current.total_sent.wrapping_add(sent as _);
self.current.total_recved = self.current.total_recved.wrapping_add(received as _);
for ts in &mut self.running_stats {
ts.sent = ts.sent.wrapping_add(sent as _);
ts.recved = ts.recved.wrapping_add(received as _);
}
}
fn mark_idle(&mut self) {
for ts in &mut self.running_stats {
ts.busy = false;
}
}
}
#[cfg(feature = "dump")]
impl<TX, RX, TAG> From<&LinkInt<TX, RX, TAG>> for super::dump::LinkDump {
fn from(link: &LinkInt<TX, RX, TAG>) -> Self {
Self {
present: true,
link_id: link.link_id.0,
unconfirmed: link.unconfirmed.is_some(),
tx_flushing: link.tx_flushing,
tx_flushed: link.tx_flushed,
roundtrip: link.roundtrip.as_secs_f32(),
tx_ack_queue: link.tx_ack_queue.len(),
txed_unacked_data: link.txed_unacked_data,
txed_unacked_data_limit: link.txed_unacked_data_limit,
txed_unacked_data_limit_increased_consecutively: link.txed_unacked_data_limit_increased_consecutively,
tx_idle: link.tx_idle_since.is_some(),
tx_pending: link.tx_pending,
total_sent: link.stats.current.total_sent,
total_recved: link.stats.current.total_recved,
}
}
}