use atomic_refcell::AtomicRefCell;
use bytes::Bytes;
use futures::{
future, future::BoxFuture, stream, stream::FuturesUnordered, Future, FutureExt, Sink, Stream, StreamExt,
};
use rand::{prelude::*, rngs::SmallRng};
use std::{
collections::{HashSet, VecDeque},
error::Error,
fmt,
future::IntoFuture,
io,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::{
select,
sync::{mpsc, oneshot, watch},
};
use crate::{
agg::link_int::{DisconnectInitiator, LinkInt, LinkIntEvent, LinkTest},
alc::{RecvError, SendError},
cfg::{Cfg, ExchangedCfg, LinkPing},
control::{Direction, DisconnectReason, Link, NotWorkingReason, Stats},
exec::time::{interval_stream, sleep_until, timeout, Instant},
id::{ConnId, LinkId, OwnedConnId},
msg::{LinkMsg, RefusedReason, ReliableMsg},
peekable_mpsc::{PeekableReceiver, RecvIfError},
protocol_err,
seq::Seq,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskError {
AllUnconfirmedTimeout,
NoLinksTimeout,
ProtocolError {
link_id: LinkId,
error: String,
},
ServerIdMismatch,
Terminated,
}
impl fmt::Display for TaskError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::AllUnconfirmedTimeout => write!(f, "all links unconfirmed timeout"),
Self::NoLinksTimeout => write!(f, "no links available timeout"),
Self::ProtocolError { link_id, error } => write!(f, "protocol error on link {link_id}: {error}"),
Self::ServerIdMismatch => write!(f, "a new link connected to another server"),
Self::Terminated => write!(f, "connection forcefully terminated"),
}
}
}
impl Error for TaskError {}
impl From<TaskError> for std::io::Error {
fn from(err: TaskError) -> Self {
io::Error::new(io::ErrorKind::ConnectionReset, err)
}
}
#[derive(Debug)]
pub(crate) enum SendReq {
Send(Bytes),
Flush(oneshot::Sender<()>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SendOverrun {
Armed,
Soft,
Hard,
}
#[derive(Clone)]
struct SentReliable {
seq: Seq,
status: AtomicRefCell<SentReliableStatus>,
}
impl fmt::Debug for SentReliable {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SentReliable")
.field("seq", &self.seq)
.field("status", &self.status.try_borrow().map(|b| (*b).clone()))
.finish()
}
}
#[derive(Debug, Clone)]
enum SentReliableStatus {
Sent {
sent: Instant,
link_id: usize,
msg: ReliableMsg,
resent: bool,
},
Received {
size: usize,
},
ResendQueued {
msg: ReliableMsg,
},
}
#[derive(Debug, Clone)]
struct ReceivedReliableMsg {
seq: Seq,
msg: ReliableMsg,
}
enum TaskEvent<TX, RX, TAG> {
Terminate,
NewLink(Box<LinkInt<TX, RX, TAG>>),
NoNewLinks,
LinkEvent { id: usize, event: LinkIntEvent },
WriteRx { id: usize, data: Bytes },
WriteEnd,
Flush(oneshot::Sender<()>),
ConfirmTimedOut(usize),
Resend(Arc<SentReliable>),
ReadDropped,
ReadClosed,
ConsumeReceived { received: ReceivedReliableMsg, permit: Option<mpsc::OwnedPermit<Bytes>> },
SendConsumed,
PingLink(usize),
LinkUnconfirmedTimeout(usize),
LinkSendTimeout(usize),
LinkPingTimeout(usize),
LinkTesting,
NoLinksTimeout,
PublishLinkStats,
RefusedLinkTask,
ServerChanged,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SendTerminate {
None,
Initiate,
Reply,
}
type LinkFilterFn<TAG> = Box<dyn FnMut(Link<TAG>, Vec<Link<TAG>>) -> BoxFuture<'static, bool> + Send>;
#[must_use = "the link aggregator task must be run for the connection to work"]
pub struct Task<TX, RX, TAG> {
cfg: Arc<Cfg>,
remote_cfg: Option<Arc<ExchangedCfg>>,
conn_id: OwnedConnId,
direction: Direction,
terminate_rx: mpsc::Receiver<()>,
links: Vec<Option<LinkInt<TX, RX, TAG>>>,
link_rx: Option<mpsc::Receiver<LinkInt<TX, RX, TAG>>>,
links_tx: watch::Sender<Vec<Link<TAG>>>,
links_not_working_since: Option<Instant>,
connected_tx: Option<oneshot::Sender<Arc<ExchangedCfg>>>,
read_tx: Option<mpsc::Sender<Bytes>>,
read_closed_rx: Option<mpsc::Receiver<()>>,
receive_close_sent: bool,
receive_finish_sent: bool,
write_rx: Option<PeekableReceiver<SendReq>>,
write_closed: Arc<AtomicBool>,
send_finish_sent: bool,
read_error_tx: watch::Sender<Option<RecvError>>,
write_error_tx: watch::Sender<SendError>,
tx_seq: Seq,
tx_overrun: SendOverrun,
tx_overrun_since: Option<Instant>,
txed_packets: VecDeque<Arc<SentReliable>>,
txed_unacked: usize,
txed_unconsumed: usize,
txed_unconsumable: usize,
txed_last_consumed: Seq,
resend_queue: VecDeque<Arc<SentReliable>>,
idle_links: Vec<usize>,
rx_seq: Seq,
rxed_reliable: VecDeque<Option<ReceivedReliableMsg>>,
rxed_reliable_consumable: VecDeque<ReceivedReliableMsg>,
rxed_reliable_size: usize,
rxed_reliable_consumed_since_last_ack: usize,
rxed_reliable_consumed_force_ack: bool,
unflushed_links: HashSet<usize>,
flushed_tx: Option<oneshot::Sender<()>>,
start_time: Instant,
read_write_closed: Option<Instant>,
established: Option<Instant>,
stats_tx: watch::Sender<Stats>,
stats_last_sent: Instant,
link_filter: LinkFilterFn<TAG>,
init_links: VecDeque<LinkInt<TX, RX, TAG>>,
refused_links_tasks: FuturesUnordered<BoxFuture<'static, ()>>,
server_changed_rx: mpsc::Receiver<()>,
result_tx: watch::Sender<Result<(), TaskError>>,
#[cfg(feature = "dump")]
dump_tx: Option<mpsc::Sender<super::dump::ConnDump>>,
}
impl<TX, RX, TAG> fmt::Debug for Task<TX, RX, TAG> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Task({}{:?})", self.direction.arrow(), &self.conn_id)
}
}
impl<TX, RX, TAG> Task<TX, RX, TAG>
where
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin + Send + 'static,
TX: Sink<Bytes, Error = io::Error> + Unpin + Send + 'static,
TAG: Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
cfg: Arc<Cfg>, remote_cfg: Option<Arc<ExchangedCfg>>, conn_id: OwnedConnId, direction: Direction,
terminate_rx: mpsc::Receiver<()>, links_tx: watch::Sender<Vec<Link<TAG>>>,
link_rx: mpsc::Receiver<LinkInt<TX, RX, TAG>>, connected_tx: oneshot::Sender<Arc<ExchangedCfg>>,
read_tx: mpsc::Sender<Bytes>, read_closed_rx: mpsc::Receiver<()>, write_rx: mpsc::Receiver<SendReq>,
read_error_tx: watch::Sender<Option<RecvError>>, write_error_tx: watch::Sender<SendError>,
stats_tx: watch::Sender<Stats>, server_changed_rx: mpsc::Receiver<()>,
result_tx: watch::Sender<Result<(), TaskError>>, links: Vec<LinkInt<TX, RX, TAG>>,
) -> Self {
Self {
cfg,
remote_cfg,
conn_id,
direction,
terminate_rx,
links: Vec::new(),
link_rx: Some(link_rx),
links_tx,
links_not_working_since: None,
connected_tx: Some(connected_tx),
read_tx: Some(read_tx),
read_closed_rx: Some(read_closed_rx),
receive_close_sent: false,
receive_finish_sent: false,
write_rx: Some(write_rx.into()),
write_closed: Arc::new(AtomicBool::new(false)),
send_finish_sent: false,
read_error_tx,
write_error_tx,
tx_seq: Seq::ZERO,
tx_overrun: SendOverrun::Armed,
tx_overrun_since: None,
txed_packets: VecDeque::new(),
txed_unacked: 0,
resend_queue: VecDeque::new(),
idle_links: Vec::new(),
rx_seq: Seq::ZERO,
rxed_reliable: VecDeque::new(),
rxed_reliable_consumable: VecDeque::new(),
rxed_reliable_consumed_since_last_ack: 0,
txed_unconsumed: 0,
txed_unconsumable: 0,
txed_last_consumed: Seq::MINUS_ONE,
rxed_reliable_size: 0,
rxed_reliable_consumed_force_ack: false,
unflushed_links: HashSet::new(),
flushed_tx: None,
start_time: Instant::now(),
read_write_closed: None,
established: None,
stats_tx,
stats_last_sent: Instant::now(),
link_filter: Box::new(|_, _| async { true }.boxed()),
init_links: links.into(),
refused_links_tasks: FuturesUnordered::new(),
server_changed_rx,
result_tx,
#[cfg(feature = "dump")]
dump_tx: None,
}
}
#[tracing::instrument(name = "aggligator::connection", level = "info", skip_all,
fields(conn_id =? self.conn_id, dir =% self.direction), ret)]
pub async fn run(mut self) -> Result<(), TaskError> {
tracing::debug!("link aggregator task starting");
self.start_time = Instant::now();
let mut stat_timers = stream::select_all(self.cfg.stats_intervals.iter().map(|t| interval_stream(*t)));
let mut fast_rng = SmallRng::seed_from_u64(1);
let read_term;
let write_term;
let link_term;
let mut send_terminate = SendTerminate::None;
let result;
loop {
let is_consume_ack_required = self.is_consume_ack_required();
let tx_seq_avail = self.tx_seq_avail();
let tx_space = self.tx_space();
let resending = !self.resend_queue.is_empty();
let links_idling = !self.idle_links.is_empty();
let links_available = self.links.iter().any(Option::is_some);
self.send_stats();
#[cfg(feature = "dump")]
self.send_dump();
if self.read_tx.is_none() && self.write_rx.is_none() {
let since = self.read_write_closed.get_or_insert_with(Instant::now);
if (self.txed_packets.is_empty()
&& self.txed_unconsumed == 0
&& self.rxed_reliable_size == 0
&& self.rxed_reliable_consumed_since_last_ack == 0
&& self.send_finish_sent
&& self.receive_finish_sent)
|| !links_available
|| since.elapsed() >= self.cfg.termination_timeout
{
tracing::info!("disconnecting because sender and receiver were dropped");
result = Ok(());
read_term = None;
write_term = SendError::Closed;
link_term = DisconnectReason::ConnectionClosed;
break;
}
}
if !links_available && self.link_rx.is_none() {
tracing::warn!("disconnecting because no links available and none can be added");
result = Err(TaskError::AllUnconfirmedTimeout);
read_term = Some(RecvError::AllLinksFailed);
write_term = SendError::AllLinksFailed;
link_term = DisconnectReason::AllUnconfirmedTimeout;
break;
}
if links_available {
if let Some(connected_tx) = self.connected_tx.take() {
tracing::debug!("sending connection established notification");
let _ = connected_tx.send(self.remote_cfg.clone().unwrap());
self.established = Some(Instant::now());
}
}
if self.unflushed_links.is_empty() {
if let Some(tx) = self.flushed_tx.take() {
tracing::trace!("flush request completed");
let _ = tx.send(());
}
}
self.adjust_link_tx_limits();
let no_link_since = self.links_not_working_since();
let no_link_timeout = self.cfg.no_link_timeout;
let links_timeout = async move {
match no_link_since {
Some(since) => sleep_until(since + no_link_timeout).await,
None => future::pending().await,
}
};
let next_link_ping = self.next_link_ping();
let next_ping_timeout = async move {
match next_link_ping {
Some((link_id, timeout)) => {
sleep_until(timeout).await;
link_id
}
None => future::pending().await,
}
};
let next_pong_timeout =
self.earliest_link_specific_timeout(self.cfg.link_ping_timeout, |link| link.current_ping_sent);
let next_unconfirmed_timeout = self
.earliest_link_specific_timeout(self.cfg.link_non_working_timeout, |link| {
link.unconfirmed.as_ref().map(|(since, _)| *since)
});
let next_send_timeout =
self.earliest_link_specific_timeout(self.cfg.link_ping_timeout, |link| link.tx_polling());
let next_link_testing = (0..self.links.len()).filter_map(|id| self.link_testing_step(id)).min();
let link_testing_timeout = async move {
match next_link_testing {
Some(timeout) => sleep_until(timeout).await,
None => future::pending().await,
}
};
let earliest_confirm_timeout = self.earliest_confirm_timeout();
let recv_confirm_timeout = async move {
match earliest_confirm_timeout {
Some((link_id, timeout)) => {
sleep_until(timeout).await;
link_id
}
None => future::pending().await,
}
};
let terminate_task = async {
match self.terminate_rx.recv().await {
Some(()) => TaskEvent::Terminate,
None => future::pending().await,
}
};
let new_link_task = async {
match &mut self.link_rx {
_ if !self.init_links.is_empty() => {
TaskEvent::NewLink(Box::new(self.init_links.pop_front().unwrap()))
}
Some(link_rx) => match link_rx.recv().await {
Some(link) => TaskEvent::NewLink(Box::new(link)),
None => TaskEvent::NoNewLinks,
},
None => future::pending().await,
}
};
let sendable_idle_link_id =
self.idle_links.iter().rev().cloned().find(|id| self.links[*id].as_ref().unwrap().is_sendable());
let write_rx_task = async {
if links_idling && is_consume_ack_required {
TaskEvent::SendConsumed
} else {
match &mut self.write_rx {
Some(write_rx) if tx_seq_avail && !resending => {
match write_rx
.recv_if(|msg| match msg {
SendReq::Send(data) => {
data.len() <= tx_space && sendable_idle_link_id.is_some()
}
SendReq::Flush(_) => true,
})
.await
{
Ok(SendReq::Send(data)) => {
TaskEvent::WriteRx { id: sendable_idle_link_id.unwrap(), data }
}
Ok(SendReq::Flush(flushed_tx)) => TaskEvent::Flush(flushed_tx),
Err(RecvIfError::NoMatch) => future::pending().await,
Err(RecvIfError::Disconnected) => TaskEvent::WriteEnd,
}
}
_ => future::pending().await,
}
}
};
let link_task = async {
if self.links.is_empty() {
future::pending().await
} else {
let mut tasks: Vec<_> = self
.links
.iter_mut()
.enumerate()
.filter_map(|(id, link_opt)| {
link_opt.as_mut().map(|link| async move { (id, link.event().await) }.boxed())
})
.collect();
tasks.shuffle(&mut fast_rng);
future::select_all(tasks).await
}
};
let read_closed_task = async {
match &mut self.read_closed_rx {
Some(read_closed_tx) => match read_closed_tx.recv().await {
Some(_) => TaskEvent::ReadClosed,
None => future::pending().await,
},
None => future::pending().await,
}
};
let resend_task = async {
if resending && sendable_idle_link_id.is_some() {
self.resend_queue.pop_front().unwrap()
} else {
future::pending().await
}
};
let consume_task = async {
if !self.rxed_reliable_consumable.is_empty() {
match self.read_tx.as_ref() {
Some(read_tx) => match read_tx.clone().reserve_owned().await {
Ok(permit) => TaskEvent::ConsumeReceived {
received: self.rxed_reliable_consumable.pop_front().unwrap(),
permit: Some(permit),
},
Err(_) => TaskEvent::ReadDropped,
},
None => TaskEvent::ConsumeReceived {
received: self.rxed_reliable_consumable.pop_front().unwrap(),
permit: None,
},
}
} else {
future::pending().await
}
};
let event = select! {
terminate_event = terminate_task => terminate_event,
new_link_event = new_link_task => new_link_event,
((id, event), _, _) = link_task => TaskEvent::LinkEvent { id, event },
write_event = write_rx_task => write_event,
link_id = recv_confirm_timeout => TaskEvent::ConfirmTimedOut(link_id),
link_id = next_ping_timeout => TaskEvent::PingLink(link_id),
link_id = next_pong_timeout => TaskEvent::LinkPingTimeout(link_id),
link_id = next_unconfirmed_timeout => TaskEvent::LinkUnconfirmedTimeout(link_id),
link_id = next_send_timeout => TaskEvent::LinkSendTimeout(link_id),
packet = resend_task => TaskEvent::Resend (packet),
consume_event = consume_task => consume_event,
event = read_closed_task => event,
() = link_testing_timeout => TaskEvent::LinkTesting,
() = links_timeout => TaskEvent::NoLinksTimeout,
Some(_) = stat_timers.next() => TaskEvent::PublishLinkStats,
Some(()) = self.refused_links_tasks.next(), if !self.refused_links_tasks.is_empty()
=> TaskEvent::RefusedLinkTask,
Some(()) = self.server_changed_rx.recv() => TaskEvent::ServerChanged,
};
match event {
TaskEvent::Terminate => {
tracing::info!("forceful connection termination by local request");
result = Err(TaskError::Terminated);
read_term = Some(RecvError::TaskTerminated);
write_term = SendError::TaskTerminated;
link_term = DisconnectReason::TaskTerminated;
send_terminate = SendTerminate::Initiate;
break;
}
TaskEvent::NewLink(mut link) => {
let link_id = link.link_id();
if self.remote_cfg.is_none() {
let remote_cfg = link.remote_cfg();
tracing::debug!(?remote_cfg, "obtained remote configuration");
self.remote_cfg = Some(remote_cfg);
}
let others =
self.links.iter().filter_map(|link_opt| link_opt.as_ref().map(Link::from)).collect();
if (self.link_filter)(Link::from(&*link), others).await {
self.add_link(*link);
tracing::info!(?link_id, "added new link");
} else {
tracing::debug!(?link_id, "link was refused by link filter");
let link_non_working_timeout = self.cfg.link_non_working_timeout;
if link.needs_tx_accepted {
self.refused_links_tasks.push(
async move {
let _ = timeout(
link_non_working_timeout,
link.send_msg_and_flush(LinkMsg::Refused {
reason: RefusedReason::LinkRefused,
}),
)
.await;
link.notify_disconnected(DisconnectReason::LinkFilter);
}
.boxed(),
);
} else {
link.notify_disconnected(DisconnectReason::LinkFilter);
}
}
}
TaskEvent::NoNewLinks => {
tracing::debug!("no new links can be added");
self.link_rx = None;
}
TaskEvent::LinkEvent { id, event } => {
let link_id = self.links[id].as_ref().unwrap().link_id();
match event {
LinkIntEvent::TxReady => {
let link = self.links[id].as_mut().unwrap();
let link_blocked = link.blocked.load(Ordering::SeqCst);
if link.needs_tx_accepted {
tracing::debug!(?link_id, "sending Accepted over link");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_send_msg(LinkMsg::Accepted, None);
link.needs_tx_accepted = false;
} else if link.send_pong {
tracing::trace!(?link_id, "sending Pong over link");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_send_msg(LinkMsg::Pong, None);
link.send_pong = false;
} else if let Some(initiator) = link.disconnecting {
if !link.goodbye_sent {
tracing::debug!(?link_id, "sending GoodBye over link");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_send_msg(LinkMsg::Goodbye, None);
link.goodbye_sent = true;
} else if initiator == DisconnectInitiator::Remote {
tracing::info!(?link_id, "removing link by remote request");
self.remove_link(id, DisconnectReason::RemotelyRequested);
}
} else if link.send_ping {
tracing::trace!(?link_id, "sending Ping over link");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_send_msg(LinkMsg::Ping, None);
link.current_ping_sent = Some(Instant::now());
link.send_ping = false;
} else if link_blocked != link.blocked_sent {
tracing::debug!(?link_id, %link_blocked, "local block status of link changed");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_send_msg(LinkMsg::SetBlock { blocked: link_blocked }, None);
link.blocked_sent = link_blocked;
} else if let Some(recved_seq) = link.tx_ack_queue.pop_front() {
tracing::trace!(?link_id, "acking sequence {recved_seq} over non-idle link");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_send_msg(LinkMsg::Ack { received: recved_seq }, None);
} else if link.unconfirmed.is_none() && !link.is_blocked() {
if is_consume_ack_required {
let consumed = self.rxed_reliable_consumed_since_last_ack as u32;
tracing::trace!(
?link_id,
"acking {consumed} consumed bytes over non-idle link"
);
self.idle_links.retain(|&idle_id| idle_id != id);
self.send_reliable_over_link(id, ReliableMsg::Consumed(consumed));
self.rxed_reliable_consumed_since_last_ack = 0;
self.rxed_reliable_consumed_force_ack = false;
} else if resending && link.is_sendable() {
let packet = self.resend_queue.pop_front().unwrap();
tracing::trace!(
?link_id,
"resending packet {} over non-idle link",
packet.seq
);
self.idle_links.retain(|idle_id| *idle_id != id);
self.resend_reliable_over_link(id, packet);
} else if self.read_closed_rx.is_none() && !self.receive_close_sent {
tracing::trace!(?link_id, "sending ReceiveClose over non-idle link");
self.idle_links.retain(|&idle_id| idle_id != id);
self.send_reliable_over_link(id, ReliableMsg::ReceiveClose);
self.receive_close_sent = true;
} else if self.read_tx.is_none() && !self.receive_finish_sent {
tracing::trace!(?link_id, "sending ReceiveFinish over non-idle link");
self.idle_links.retain(|&idle_id| idle_id != id);
self.send_reliable_over_link(id, ReliableMsg::ReceiveFinish);
self.receive_finish_sent = true;
} else if self.write_rx.is_none() && !self.send_finish_sent {
tracing::trace!(?link_id, "sending SendFinish over non-idle link");
self.idle_links.retain(|&idle_id| idle_id != id);
self.send_reliable_over_link(id, ReliableMsg::SendFinish);
self.send_finish_sent = true;
} else if let Some(SendReq::Send(data)) = self
.write_rx
.as_mut()
.filter(|_| tx_seq_avail && link.is_sendable())
.and_then(|rx| {
rx.try_recv_if(
|msg| matches!(msg, SendReq::Send(data) if data.len() <= tx_space),
)
.ok()
})
{
tracing::trace!(
?link_id,
"sending data of size {} over non-idle link",
data.len()
);
self.idle_links.retain(|idle_id| *idle_id != id);
self.send_reliable_over_link(id, ReliableMsg::Data(data));
} else if link.need_ack_flush() {
tracing::trace!(
?link_id,
"flushing link due to pending acks and no data to send"
);
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_flush();
} else if link.needs_flush() && !link.is_sendable() {
tracing::trace!(?link_id, "flushing link because it is not sendable");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_flush();
} else if !self.idle_links.contains(&id) {
tracing::trace!(?link_id, "link has become idle");
link.mark_idle();
self.idle_links.push(id);
}
} else {
if link.needs_flush() || link.need_ack_flush() {
tracing::trace!(?link_id, "flushing link because it is now unconfirmed");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_flush();
}
}
}
LinkIntEvent::TxFlushed => {
self.unflushed_links.remove(&id);
}
LinkIntEvent::Rx { msg, data } => {
match self.handle_received_msg(id, msg, data) {
Ok(false) => (),
Ok(true) => {
tracing::info!("forceful connection termination by remote endpoint");
result = Err(TaskError::Terminated);
read_term = Some(RecvError::TaskTerminated);
write_term = SendError::TaskTerminated;
link_term = DisconnectReason::TaskTerminated;
send_terminate = SendTerminate::Reply;
break;
}
Err(err) => {
tracing::warn!(?link_id, %err, "link caused protocol error");
result = Err(TaskError::ProtocolError { link_id, error: err.to_string() });
read_term = Some(RecvError::ProtocolError);
write_term = SendError::ProtocolError;
link_term = DisconnectReason::ProtocolError(err.to_string());
break;
}
}
}
LinkIntEvent::FlushDelayPassed => {
let link = self.links[id].as_mut().unwrap();
tracing::trace!(?link_id, "flushing link");
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_flush();
}
LinkIntEvent::TxError(err) | LinkIntEvent::RxError(err) => {
tracing::warn!(?link_id, %err, "disconnecting link due to IO error");
let reason = if self.read_tx.is_none() && self.write_rx.is_none() {
DisconnectReason::ConnectionClosed
} else {
DisconnectReason::IoError(Arc::new(err))
};
self.remove_link(id, reason);
}
LinkIntEvent::BlockedChanged => {
let link = self.links[id].as_mut().unwrap();
self.idle_links.retain(|&idle_id| idle_id != id);
link.report_ready();
link.blocked_changed_out_tx.send_replace(());
}
LinkIntEvent::Disconnect => {
let link = self.links[id].as_mut().unwrap();
if link.disconnecting.is_none() {
tracing::info!(?link_id, "starting disconnection of link by local request");
link.disconnecting = Some(DisconnectInitiator::Local);
self.idle_links.retain(|&idle_id| idle_id != id);
link.start_flush();
}
}
}
}
TaskEvent::WriteRx { id, data } => {
let link_id = self.links[id].as_ref().unwrap().link_id();
tracing::trace!(?link_id, "sending data of size {} bytes over idle link", data.len());
self.idle_links.retain(|&idle_id| idle_id != id);
self.send_reliable_over_link(id, ReliableMsg::Data(data));
}
TaskEvent::SendConsumed => {
let id = self.idle_links.pop().unwrap();
let link_id = self.links[id].as_ref().unwrap().link_id();
let consumed = self.rxed_reliable_consumed_since_last_ack as u32;
tracing::trace!(?link_id, "acking {consumed} consumed bytes over idle link");
self.send_reliable_over_link(id, ReliableMsg::Consumed(consumed));
self.rxed_reliable_consumed_since_last_ack = 0;
self.rxed_reliable_consumed_force_ack = false;
}
TaskEvent::WriteEnd => {
tracing::debug!("sender was dropped");
self.write_rx = None;
if let Some(id) = self.idle_links.pop() {
let link_id = self.links[id].as_mut().unwrap().link_id();
tracing::debug!(?link_id, "sending SendFinish over idle link");
self.send_reliable_over_link(id, ReliableMsg::SendFinish);
self.send_finish_sent = true;
} else {
tracing::debug!("queueing sending of SendFinish");
}
}
TaskEvent::Flush(tx) => {
tracing::trace!("starting flush of all links");
self.unflushed_links = self
.links
.iter_mut()
.enumerate()
.filter_map(|(id, link_opt)| {
link_opt.as_mut().and_then(|link| {
if link.unconfirmed.is_none() {
link.start_flush();
Some(id)
} else {
None
}
})
})
.collect();
self.idle_links.retain(|idle_id| !self.unflushed_links.contains(idle_id));
self.flushed_tx = Some(tx);
}
TaskEvent::ConfirmTimedOut(id) => {
tracing::debug!("acknowledgement timeout on link {id}");
self.unconfirm_link(id, NotWorkingReason::AckTimeout);
}
TaskEvent::Resend(packet) => {
let id = sendable_idle_link_id.unwrap();
let link_id = self.links[id].as_ref().unwrap().link_id();
self.idle_links.retain(|&idle_id| idle_id != id);
tracing::trace!(?link_id, "resending message {} over idle link", packet.seq);
self.resend_reliable_over_link(id, packet);
}
TaskEvent::ReadDropped => {
tracing::debug!("receiver was dropped");
self.read_tx = None;
self.read_closed_rx = None;
if let Some(id) = self.idle_links.pop() {
let link_id = self.links[id].as_mut().unwrap().link_id();
tracing::debug!(?link_id, "sending ReceiveFinish over idle link");
self.send_reliable_over_link(id, ReliableMsg::ReceiveFinish);
self.receive_finish_sent = true;
} else {
tracing::debug!("queueing sending of ReceiveFinish");
}
}
TaskEvent::ReadClosed => {
tracing::debug!("receiver was closed");
self.read_closed_rx = None;
if let Some(id) = self.idle_links.pop() {
self.send_reliable_over_link(id, ReliableMsg::ReceiveClose);
self.receive_close_sent = true;
}
}
TaskEvent::ConsumeReceived { received, permit } => {
tracing::trace!("consuming received data message {:?}", &received.msg);
match received.msg {
ReliableMsg::Data(data) => {
self.rxed_reliable_size -= data.len();
self.rxed_reliable_consumed_since_last_ack += data.len();
if let Some(permit) = permit {
permit.send(data);
}
}
ReliableMsg::SendFinish => {
self.read_error_tx.send_replace(None);
self.read_tx = None;
self.receive_finish_sent = true;
self.rxed_reliable_consumed_force_ack = true;
}
ReliableMsg::ReceiveClose | ReliableMsg::ReceiveFinish | ReliableMsg::Consumed(_) => {
unreachable!()
}
}
}
TaskEvent::PingLink(id) => {
let link = self.links[id].as_mut().unwrap();
tracing::trace!(link_id =? link.link_id(), "requesting ping of link");
link.send_ping = true;
self.flush_link(id);
}
TaskEvent::LinkPingTimeout(id) => {
let link_id = self.links[id].as_mut().unwrap().link_id();
tracing::warn!(?link_id, "removing link due to ping timeout");
self.remove_link(id, DisconnectReason::PingTimeout);
}
TaskEvent::LinkUnconfirmedTimeout(id) => {
let link_id = self.links[id].as_mut().unwrap().link_id();
tracing::warn!(?link_id, "removing link due to unconfirmed timeout");
self.remove_link(id, DisconnectReason::UnconfirmedTimeout);
}
TaskEvent::LinkSendTimeout(id) => {
let link_id = self.links[id].as_mut().unwrap().link_id();
tracing::warn!(?link_id, "removing link due to send timeout");
self.remove_link(id, DisconnectReason::SendTimeout);
}
TaskEvent::LinkTesting => (),
TaskEvent::NoLinksTimeout => {
tracing::warn!("disconnecting because no links are available for too long");
result = Err(TaskError::NoLinksTimeout);
read_term = Some(RecvError::AllLinksFailed);
write_term = SendError::AllLinksFailed;
link_term = DisconnectReason::AllUnconfirmedTimeout;
break;
}
TaskEvent::PublishLinkStats => {
for link_opt in &mut self.links {
if let Some(link) = link_opt.as_mut() {
link.publish_stats();
}
}
}
TaskEvent::RefusedLinkTask => (),
TaskEvent::ServerChanged => {
tracing::warn!("disconnecting because server id changed");
result = Err(TaskError::ServerIdMismatch);
read_term = Some(RecvError::ServerIdMismatch);
write_term = SendError::ServerIdMismatch;
link_term = DisconnectReason::ServerIdMismatch;
break;
}
}
if let Some(max_ping) = self.cfg.link_max_ping {
let all_links_slow = self.links.iter().all(|link_opt| {
link_opt
.as_ref()
.map(|link| link.unconfirmed.is_some() || link.is_blocked() || link.roundtrip > max_ping)
.unwrap_or(true)
});
if !all_links_slow {
let slow: Vec<_> = self
.links
.iter()
.enumerate()
.filter_map(|(id, link_opt)| match link_opt {
Some(link) if link.unconfirmed.is_none() && link.roundtrip > max_ping => {
tracing::debug!(
link_id =? link.link_id(),
"unconfirming link due to slow ping of {} ms",
link.roundtrip.as_millis()
);
Some(id)
}
_ => None,
})
.collect();
for id in slow {
self.unconfirm_link(id, NotWorkingReason::MaxPingExceeded);
}
}
}
}
if *self.read_error_tx.borrow() == Some(RecvError::TaskTerminated) {
self.read_error_tx.send_replace(read_term);
}
if *self.write_error_tx.borrow() == SendError::TaskTerminated {
self.write_error_tx.send_replace(write_term);
}
self.read_tx = None;
self.write_rx = None;
if send_terminate != SendTerminate::None {
let mut term_tasks = FuturesUnordered::new();
for link in &mut self.links {
let Some(link) = link.as_mut() else { continue };
term_tasks.push(link.terminate_connection(send_terminate == SendTerminate::Initiate));
}
let res =
timeout(self.cfg.termination_timeout, async move { while term_tasks.next().await.is_some() {} })
.await;
if res.is_err() {
tracing::warn!("forceful connection termination timed out");
}
}
for link in self.links.drain(..) {
let Some(link) = link else { continue };
link.notify_disconnected(link_term.clone());
}
let _ = self.result_tx.send_replace(result.clone());
#[allow(unused_assignments)]
{
self.link_rx = None;
}
result
}
fn add_link(&mut self, mut link: LinkInt<TX, RX, TAG>) -> usize {
link.report_ready();
link.unconfirmed = Some((Instant::now(), NotWorkingReason::New));
for (id, link_opt) in self.links.iter_mut().enumerate() {
if link_opt.is_none() {
*link_opt = Some(link);
self.publish_links();
return id;
}
}
self.links.push(Some(link));
self.publish_links();
self.links.len() - 1
}
fn remove_link(&mut self, id: usize, reason: DisconnectReason) {
let link_id = self.links[id].as_mut().unwrap().link_id();
tracing::debug!(?link_id, ?reason, "removing link");
self.unconfirm_link(id, NotWorkingReason::Disconnecting);
let link = self.links[id].take().unwrap();
link.notify_disconnected(reason);
while let Some(None) = self.links.last() {
self.links.pop();
}
self.publish_links();
}
fn publish_links(&self) {
let links = self.links.iter().filter_map(|link_opt| link_opt.as_ref().map(Link::from)).collect();
self.links_tx.send_replace(links);
}
fn links_not_working_since(&mut self) -> Option<Instant> {
let links_working = self
.links
.iter()
.any(|link_opt| link_opt.as_ref().map(|link| link.unconfirmed.is_none()).unwrap_or_default());
match (links_working, &self.links_not_working_since) {
(true, Some(_)) => self.links_not_working_since = None,
(false, None) => self.links_not_working_since = Some(Instant::now()),
_ => (),
}
self.links_not_working_since
}
fn remote_recv_buffer(&self) -> Option<usize> {
self.remote_cfg.as_ref().map(|cfg| cfg.recv_buffer.get() as usize)
}
fn tx_space(&self) -> usize {
let tx_local_space = (self.cfg.send_buffer.get() as usize).saturating_sub(self.txed_unacked);
let tx_remote_space = self.remote_recv_buffer().unwrap_or_default().saturating_sub(self.txed_unconsumed);
tx_local_space.min(tx_remote_space)
}
fn tx_seq_avail(&self) -> bool {
self.txed_packets.front().map(|p| self.tx_seq - p.seq <= Seq::USABLE_INTERVAL).unwrap_or(true)
}
fn adjust_link_tx_limits(&mut self) {
let Some(remote_recv_buffer) = self.remote_recv_buffer() else { return };
let coming_seq = match self.resend_queue.front() {
Some(packet) => packet.seq,
None => self.tx_seq,
};
let unconsumable_limit = (self.cfg.send_buffer.get() as usize).min(remote_recv_buffer);
let low_level = self.txed_unconsumable < unconsumable_limit / 4;
let soft_overrun = self.txed_unconsumable > unconsumable_limit / 3;
let hard_overrun = self.txed_unconsumable > unconsumable_limit * 3 / 4;
if (soft_overrun && self.tx_overrun == SendOverrun::Armed)
|| (hard_overrun && self.tx_overrun != SendOverrun::Hard)
{
if let Some(id) = self.txed_packets.iter().find_map(|p| {
if let SentReliableStatus::Sent { link_id, .. } = &*p.status.borrow() {
Some(*link_id)
} else {
None
}
}) {
let link = self.links[id].as_mut().unwrap();
let current = link.txed_unacked_data.min(link.txed_unacked_data_limit);
if hard_overrun {
link.txed_unacked_data_limit = current / 2;
self.tx_overrun = SendOverrun::Hard;
} else if soft_overrun {
link.txed_unacked_data_limit = current * 95 / 100;
self.tx_overrun = SendOverrun::Soft;
}
self.tx_overrun_since = Some(Instant::now());
tracing::trace!(link_id =? link.link_id(),
"decreasing unacked limit of link to {} bytes",
link.txed_unacked_data_limit
);
link.txed_unacked_data_limit_increased = Some(coming_seq);
link.txed_unacked_data_limit_increased_consecutively = 0;
}
} else if self.tx_overrun != SendOverrun::Armed && !soft_overrun && !hard_overrun {
tracing::trace!("re-arming send overrun handling");
self.tx_overrun = SendOverrun::Armed;
self.tx_overrun_since = None;
}
match self.tx_overrun_since {
Some(since) if since.elapsed() >= Duration::from_secs(1) => {
tracing::trace!("re-arming send overrun handling due to timeout");
self.tx_overrun = SendOverrun::Armed;
self.tx_overrun_since = None
}
_ => (),
}
let all_links_slow;
match self.cfg.link_max_ping {
Some(max_ping) => {
all_links_slow = self.links.iter().all(|link_opt| {
link_opt
.as_ref()
.map(|link| {
link.unconfirmed.is_some() || link.is_blocked() || link.roundtrip > max_ping / 2
})
.unwrap_or(true)
});
if !all_links_slow {
for link_opt in &mut self.links {
match link_opt {
Some(link)
if link.unconfirmed.is_none()
&& link.txed_unacked_data_limit_increased.is_none()
&& link.roundtrip > max_ping * 3 / 4 =>
{
let current = link.txed_unacked_data.min(link.txed_unacked_data_limit);
link.txed_unacked_data_limit = current * 95 / 100;
tracing::trace!(link_id =? link.link_id(),
"decreasing unacked limit of link to {} bytes due to ping",
link.txed_unacked_data_limit
);
link.txed_unacked_data_limit_increased = Some(coming_seq);
link.txed_unacked_data_limit_increased_consecutively = 0;
}
_ => (),
}
}
}
}
None => {
all_links_slow = true;
}
};
let send_data_avail = self.write_rx.as_mut().map(|rx| rx.try_peek().is_ok()).unwrap_or_default()
|| !self.resend_queue.is_empty();
let sendable_link_avail = self.links.iter().any(|link_opt| {
link_opt
.as_ref()
.map(|link| {
!link.tx_pending
&& link.unconfirmed.is_none()
&& !link.is_blocked()
&& link.txed_unacked_data < link.txed_unacked_data_limit
})
.unwrap_or_default()
});
if send_data_avail && !sendable_link_avail {
for link_opt in &mut self.links {
match link_opt {
Some(link)
if !link.tx_pending
&& link.unconfirmed.is_none()
&& !link.is_blocked()
&& link.txed_unacked_data >= link.txed_unacked_data_limit
&& link.txed_unacked_data_limit_increased.is_none()
&& link.txed_unacked_data_limit < self.cfg.link_unacked_limit.get()
&& self
.cfg
.link_max_ping
.map(|max_ping| link.roundtrip <= max_ping / 2 || all_links_slow)
.unwrap_or(true) =>
{
link.txed_unacked_data_limit =
if link.txed_unacked_data_limit_increased_consecutively >= 100 {
link.txed_unacked_data_limit * 120 / 100
} else if link.txed_unacked_data_limit_increased_consecutively >= 50 {
link.txed_unacked_data_limit * 110 / 100
} else if link.txed_unacked_data_limit_increased_consecutively >= 25 {
link.txed_unacked_data_limit * 105 / 100
} else if link.txed_unacked_data_limit_increased_consecutively >= 10 {
link.txed_unacked_data_limit * 102 / 100
} else {
link.txed_unacked_data_limit * 101 / 100
}
.max(100);
tracing::trace!(link_id =? link.link_id(),
"increasing unacked limit of link to {} bytes (done {} times without overrun)",
link.txed_unacked_data_limit,
link.txed_unacked_data_limit_increased_consecutively
);
link.txed_unacked_data_limit_increased = Some(coming_seq);
link.txed_unacked_data_limit_increased_consecutively =
link.txed_unacked_data_limit_increased_consecutively.saturating_add(1);
}
_ => (),
}
}
}
if !low_level {
for link_opt in self.links.iter_mut() {
if let Some(link) = link_opt.as_mut() {
link.txed_unacked_data_limit_increased_consecutively = 0;
}
}
}
}
fn earliest_link_specific_timeout(
&self, timeout: Duration, since_fn: impl Fn(&LinkInt<TX, RX, TAG>) -> Option<Instant>,
) -> impl Future<Output = usize> {
let earliest_timeout = self
.links
.iter()
.enumerate()
.filter_map(|(id, link_opt)| link_opt.as_ref().and_then(&since_fn).map(|sent| (id, sent + timeout)))
.min_by_key(|(_id, t)| *t);
async move {
match earliest_timeout {
Some((link_id, timeout)) => {
sleep_until(timeout).await;
link_id
}
None => future::pending().await,
}
}
}
fn earliest_confirm_timeout(&self) -> Option<(usize, Instant)> {
for p in &self.txed_packets {
if let SentReliableStatus::Sent { link_id, sent, resent, .. } = &*p.status.borrow() {
let link = self.links[*link_id].as_ref().unwrap();
let dur_factor = if *resent { 3 } else { 1 };
let dur = (link.roundtrip * self.cfg.link_ack_timeout_roundtrip_factor.get() * dur_factor)
.clamp(self.cfg.link_ack_timeout_min, self.cfg.link_ack_timeout_max);
return Some((*link_id, *sent + dur));
}
}
None
}
fn next_link_ping(&self) -> Option<(usize, Instant)> {
self.links
.iter()
.enumerate()
.filter_map(|(id, link_opt)| match &link_opt {
Some(link)
if link.current_ping_sent.is_none() && !link.send_ping && link.unconfirmed.is_none() =>
{
match self.cfg.link_ping {
LinkPing::Periodic(interval) => {
Some((id, link.last_ping.map(|last| last + interval).unwrap_or_else(Instant::now)))
}
LinkPing::WhenIdle(timeout) => {
let msg_timeout =
link.tx_last_msg.map(|last| last + timeout).unwrap_or_else(Instant::now);
let ping_timeout =
link.last_ping.map(|last| last + timeout).unwrap_or_else(Instant::now);
Some((id, msg_timeout.max(ping_timeout)))
}
LinkPing::WhenTimedOut => None,
}
}
_ => None,
})
.min_by_key(|(_id, next_ping)| *next_ping)
}
fn send_reliable_over_link(&mut self, id: usize, reliable_msg: ReliableMsg) -> Seq {
let seq = self.next_tx_seq();
let link = self.links[id].as_mut().unwrap();
tracing::trace!(link_id =? link.link_id(), "sending reliable message {seq} over link: {reliable_msg:?}");
let (msg, data) = reliable_msg.to_link_msg(seq);
link.start_send_msg(msg, data);
if let ReliableMsg::Data(data) = &reliable_msg {
self.txed_unacked += data.len();
self.txed_unconsumed += data.len();
link.txed_unacked_data += data.len();
}
let packet = SentReliable {
seq,
status: AtomicRefCell::new(SentReliableStatus::Sent {
sent: Instant::now(),
link_id: id,
msg: reliable_msg,
resent: false,
}),
};
self.txed_packets.push_back(Arc::new(packet));
seq
}
fn resend_reliable_over_link(&mut self, id: usize, packet: Arc<SentReliable>) {
let link = self.links[id].as_mut().unwrap();
let mut status = packet.status.borrow_mut();
let SentReliableStatus::ResendQueued { msg: reliable_msg } = &*status else {
unreachable!("message was not queued for resending")
};
tracing::trace!(link_id =? link.link_id(), "resending reliable message {} over link: {:?}", packet.seq, reliable_msg);
let (msg, data) = reliable_msg.to_link_msg(packet.seq);
link.start_send_msg(msg, data);
if let ReliableMsg::Data(data) = reliable_msg {
link.txed_unacked_data += data.len();
}
match &mut link.txed_unacked_data_limit_increased {
Some(last_increased) if packet.seq < *last_increased => {
*last_increased = packet.seq;
}
_ => (),
}
*status = SentReliableStatus::Sent {
sent: Instant::now(),
link_id: id,
msg: reliable_msg.clone(),
resent: true,
};
}
fn unconfirm_link(&mut self, id: usize, reason: NotWorkingReason) {
let link = self.links[id].as_mut().unwrap();
link.unconfirmed = Some((Instant::now(), reason));
self.idle_links.retain(|&idle_id| idle_id != id);
self.unflushed_links.remove(&id);
link.start_flush();
link.reset();
for p in &mut self.txed_packets {
let mut status = p.status.borrow_mut();
match &*status {
SentReliableStatus::Sent { link_id, msg, .. } if *link_id == id => {
if let ReliableMsg::Data(data) = &msg {
let old_link = self.links[*link_id].as_mut().unwrap();
old_link.txed_unacked_data -= data.len();
}
*status = SentReliableStatus::ResendQueued { msg: msg.clone() };
self.resend_queue.push_back(p.clone());
}
_ => (),
};
}
self.resend_queue.make_contiguous().sort_by_key(|packet| packet.seq);
for link in self.links.iter_mut().flatten() {
if let LinkTest::Failed(_) = link.test {
link.test = LinkTest::Inactive;
}
}
}
fn link_testing_step(&mut self, id: usize) -> Option<Instant> {
let others_slow = match self.cfg.link_max_ping {
Some(max_ping) => self.links.iter().enumerate().all(|(link_id, link_opt)| {
link_opt
.as_ref()
.map(|link| {
link_id == id
|| link.unconfirmed.is_some()
|| link.is_blocked()
|| link.roundtrip > max_ping
})
.unwrap_or(true)
}),
None => false,
};
if let Some(link) = self.links[id].as_mut() {
let link_id = link.link_id();
match link.test {
LinkTest::Failed(when) if when.elapsed() >= self.cfg.link_retest_interval => {
tracing::trace!("link {id} is ready for retry of test");
link.test = LinkTest::Inactive;
}
_ => (),
}
match link.test {
LinkTest::Inactive => {
if link.unconfirmed.is_some()
&& link.tx_polling().is_none()
&& link.current_ping_sent.is_none()
&& !link.has_outstanding_ack()
{
let test_data_limit = if self.cfg.link_max_ping.is_some() {
self.cfg.link_unacked_init.get()
} else {
self.cfg.link_unacked_limit.get().min(self.cfg.send_buffer.get() as usize)
}
.min(self.cfg.link_test_data_limit);
let test_data = link.send_test_data(self.cfg.io_write_size.get(), test_data_limit);
link.send_ping = true;
link.test = LinkTest::InProgress;
tracing::debug!(?link_id, "started test of link using {test_data} bytes of test data");
}
None
}
LinkTest::InProgress => {
if link.current_ping_sent.is_none() && !link.send_ping {
if link.roundtrip <= self.cfg.link_ack_timeout_max / 2
&& self
.cfg
.link_max_ping
.map(|max_ping| link.roundtrip <= max_ping || others_slow)
.unwrap_or(true)
{
tracing::debug!(
?link_id,
"link successfully completed test with ping {} ms",
link.roundtrip.as_millis()
);
link.unconfirmed = None;
link.test = LinkTest::Inactive;
self.idle_links.retain(|&idle_id| idle_id != id);
link.report_ready();
None
} else {
tracing::debug!(
?link_id,
"link failed test with ping {} ms, retrying in {} s",
link.roundtrip.as_millis(),
self.cfg.link_retest_interval.as_secs()
);
let when = Instant::now();
link.test = LinkTest::Failed(when);
match &mut link.unconfirmed {
Some((_since, reason)) => *reason = NotWorkingReason::TestFailed,
None => link.unconfirmed = Some((Instant::now(), NotWorkingReason::TestFailed)),
}
Some(when + self.cfg.link_retest_interval)
}
} else {
None
}
}
LinkTest::Failed(when) => Some(when + self.cfg.link_retest_interval),
}
} else {
None
}
}
fn next_tx_seq(&mut self) -> Seq {
let seq = self.tx_seq;
self.tx_seq += 1;
seq
}
fn flush_link(&mut self, id: usize) {
let link = self.links[id].as_mut().unwrap();
link.start_flush();
self.idle_links.retain(|&idle_id| idle_id != id);
}
fn handle_received_msg(&mut self, id: usize, msg: LinkMsg, data: Option<Bytes>) -> Result<bool, io::Error> {
let link = self.links[id].as_mut().unwrap();
let link_id = link.link_id();
match msg {
LinkMsg::Ping => {
tracing::trace!(?link_id, "ping received, requesting sending resposne");
link.send_pong = true;
self.flush_link(id);
}
LinkMsg::Pong => {
if let Some(current_ping_sent) = link.current_ping_sent.take() {
let elapsed = current_ping_sent.elapsed();
tracing::trace!(?link_id, "ping round-trip time is {} ms", elapsed.as_millis());
link.roundtrip = elapsed;
link.last_ping = Some(Instant::now());
self.link_testing_step(id);
}
}
msg @ (LinkMsg::Data { .. }
| LinkMsg::Consumed { .. }
| LinkMsg::SendFinish { .. }
| LinkMsg::ReceiveClose { .. }
| LinkMsg::ReceiveFinish { .. }) => {
let (reliable_msg, seq) = ReliableMsg::from_link_msg(msg, data);
tracing::trace!(?link_id, "received reliable message {seq}: {reliable_msg:?}");
self.handle_received_reliable_msg(id, seq, reliable_msg)?;
}
LinkMsg::Ack { received } => {
tracing::trace!(?link_id, "link acked reception up to {received}");
self.handle_ack(id, received);
}
LinkMsg::TestData { size } => {
tracing::trace!(?link_id, "link received {size} bytes of test data");
}
LinkMsg::SetBlock { blocked } => {
tracing::debug!(?link_id, %blocked, "remote block status of link changed");
link.remotely_blocked.store(blocked, Ordering::SeqCst);
self.idle_links.retain(|&idle_id| idle_id != id);
link.report_ready();
link.blocked_changed_out_tx.send_replace(());
}
LinkMsg::Goodbye => {
match link.disconnecting {
Some(DisconnectInitiator::Local) => {
if link.goodbye_sent {
tracing::info!(?link_id, "removing link due to local request");
self.remove_link(id, DisconnectReason::LocallyRequested);
}
}
Some(DisconnectInitiator::Remote) => {
return Err(protocol_err!("received Goodbye message more than once"));
}
None => {
tracing::debug!(?link_id, "remote requests disconnection of link");
link.disconnecting = Some(DisconnectInitiator::Remote);
}
}
}
LinkMsg::Terminate => {
tracing::trace!(?link_id, "link recevied forceful connection termination request");
return Ok(true);
}
LinkMsg::Welcome { .. } | LinkMsg::Connect { .. } | LinkMsg::Accepted | LinkMsg::Refused { .. } => {
return Err(protocol_err!("received unexpected message"))
}
}
Ok(false)
}
fn handle_received_reliable_msg(&mut self, id: usize, seq: Seq, msg: ReliableMsg) -> Result<(), io::Error> {
let link = self.links[id].as_mut().unwrap();
let link_id = link.link_id();
link.tx_ack_queue.push_back(seq);
self.idle_links.retain(|&idle_id| idle_id != id);
link.report_ready();
if seq < self.rx_seq {
tracing::trace!(?link_id, "rereceived consumed reliable message {}", seq);
} else {
let offset = (seq - self.rx_seq) as usize;
if self.rxed_reliable.len() <= offset {
self.rxed_reliable.resize(offset + 1, None);
}
if self.rxed_reliable[offset].is_none() {
tracing::trace!(?link_id, "received reliable message {}", seq);
match &msg {
ReliableMsg::Data(data) => {
self.rxed_reliable_size += data.len();
if self.rxed_reliable_size > self.cfg.recv_buffer.get() as usize {
return Err(protocol_err!("receive buffer overflow"));
}
}
ReliableMsg::SendFinish => {
}
ReliableMsg::Consumed(consumed) => {
tracing::trace!(?link_id, "remote consumed {consumed} bytes");
match self.txed_unconsumed.checked_sub(*consumed as usize) {
Some(txed_unconsumed) => self.txed_unconsumed = txed_unconsumed,
None => return Err(protocol_err!("txed_unconsumed underflow")),
}
}
ReliableMsg::ReceiveClose => {
self.write_error_tx.send_replace(SendError::Closed);
self.write_closed.store(true, Ordering::Relaxed);
self.rxed_reliable_consumed_force_ack = true;
}
ReliableMsg::ReceiveFinish => {
self.write_error_tx.send_replace(SendError::Dropped);
self.write_rx = None;
self.send_finish_sent = true;
self.rxed_reliable_consumed_force_ack = true;
}
}
self.rxed_reliable[offset] = Some(ReceivedReliableMsg { seq, msg });
} else {
tracing::trace!(?link_id, "rereceived unconsumed reliable message {}", seq);
}
}
while let Some(Some(_)) = self.rxed_reliable.front().as_ref() {
let msg = self.rxed_reliable.pop_front().unwrap().unwrap();
assert_eq!(msg.seq, self.rx_seq);
self.rx_seq += 1;
if matches!(&msg.msg, ReliableMsg::Data(_) | ReliableMsg::SendFinish) {
self.rxed_reliable_consumable.push_back(msg);
}
}
Ok(())
}
fn is_consume_ack_required(&self) -> bool {
self.rxed_reliable_consumed_since_last_ack > self.cfg.recv_buffer.get() as usize / 10
|| (self.rxed_reliable_size == 0 && self.rxed_reliable_consumed_since_last_ack > 0)
|| self.rxed_reliable_consumed_force_ack
}
fn handle_ack(&mut self, id: usize, rxed_seq: Seq) {
let link = self.links[id].as_mut().unwrap();
let link_id = link.link_id();
tracing::trace!(?link_id, "processing received ack for {rxed_seq} on link");
match link.txed_unacked_data_limit_increased {
Some(last_increased) if last_increased <= rxed_seq => {
tracing::trace!(?link_id, "re-allowing increase of send limit of link");
link.txed_unacked_data_limit_increased = None;
}
_ => (),
}
let back_idx = self.tx_seq - rxed_seq;
if 0 < back_idx && (back_idx as usize) <= self.txed_packets.len() {
let idx = self.txed_packets.len() - back_idx as usize;
let packet = &mut self.txed_packets[idx];
assert_eq!(packet.seq, rxed_seq);
let mut status = packet.status.borrow_mut();
match &*status {
SentReliableStatus::Sent { sent, link_id, msg, .. } if *link_id == id => {
let size = if let ReliableMsg::Data(data) = &msg { data.len() } else { 0 };
link.txed_unacked_data -= size;
self.txed_unacked -= size;
self.txed_unconsumable += size;
link.roundtrip = (99 * link.roundtrip + sent.elapsed()) / 100;
*status = SentReliableStatus::Received { size };
}
SentReliableStatus::ResendQueued { msg } => {
let size = if let ReliableMsg::Data(data) = &msg { data.len() } else { 0 };
self.txed_unacked -= size;
self.txed_unconsumable += size;
self.resend_queue.retain(|packet| packet.seq != rxed_seq);
*status = SentReliableStatus::Received { size };
}
_ => (),
}
}
while let Some(packet) = self.txed_packets.front() {
self.txed_last_consumed = packet.seq;
let status = packet.status.borrow();
if let SentReliableStatus::Received { size, .. } = &*status {
self.txed_unconsumable -= size;
drop(status);
self.txed_packets.pop_front();
} else {
break;
}
}
}
fn send_stats(&mut self) {
let Some(interval) = self.cfg.stats_intervals.iter().min() else { return };
if self.stats_last_sent.elapsed() >= *interval {
self.stats_last_sent = Instant::now();
self.stats_tx.send_replace(Stats {
established: self.established,
not_working_since: self.links_not_working_since,
send_space: self.tx_space(),
sent_unacked: self.txed_unacked,
sent_unconsumed: self.txed_unconsumed,
sent_unconsumed_count: self.txed_packets.len(),
sent_unconsumable: self.txed_unconsumable,
resend_queue_len: self.resend_queue.len(),
recved_unconsumed: self.rxed_reliable_size,
recved_unconsumed_count: self.rxed_reliable.len(),
});
}
}
pub fn id(&self) -> ConnId {
self.conn_id.get()
}
pub fn direction(&self) -> Direction {
self.direction
}
pub fn set_link_filter<F, Fut>(&mut self, mut link_filter: F)
where
F: FnMut(Link<TAG>, Vec<Link<TAG>>) -> Fut + Send + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
self.link_filter = Box::new(move |link, others| link_filter(link, others).boxed());
}
#[cfg(feature = "dump")]
#[cfg_attr(docsrs, doc(cfg(feature = "dump")))]
pub fn dump(&mut self, tx: mpsc::Sender<super::dump::ConnDump>) {
self.dump_tx = Some(tx);
}
#[cfg(feature = "dump")]
fn send_dump(&mut self) {
if let Some(tx) = &self.dump_tx {
let mut closed = false;
match tx.try_reserve() {
Ok(permit) => permit.send(super::dump::ConnDump::from(&*self)),
Err(mpsc::error::TrySendError::Full(_)) => (),
Err(mpsc::error::TrySendError::Closed(_)) => closed = true,
}
if closed {
self.dump_tx = None;
}
}
if !self.tx_seq_avail() {
tracing::warn!("no sequence number available for sending");
}
if self.read_tx.is_none() || self.write_rx.is_none() {
tracing::trace!("direction={:?} read_tx_none={} write_tx_none={} txed_packets={} txed_packets_front={:?} \
resend_queue={} resend_queue_front={:?} txed_unconsumed={} rxed_reliable={} rxed_reliable_front={:?} \
rxed_reliable_size={} rxed_reliable_consumed_since_last_ack={}, send_finish_sent={} receive_finish_sent={}",
&self.direction, self.read_tx.is_none(), self.write_rx.is_none(),
self.txed_packets.len(), self.txed_packets.front(), self.resend_queue.len(),
self.resend_queue.front(), self.txed_unconsumed, self.rxed_reliable.len(),
self.rxed_reliable.front(), self.rxed_reliable_size, self.rxed_reliable_consumed_since_last_ack,
self.send_finish_sent, self.receive_finish_sent,
);
}
}
}
impl<TX, RX, TAG> IntoFuture for Task<TX, RX, TAG>
where
RX: Stream<Item = Result<Bytes, io::Error>> + Unpin + Send + Sync + 'static,
TX: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync + 'static,
TAG: Send + Sync + 'static,
{
type Output = Result<(), TaskError>;
type IntoFuture = BoxFuture<'static, Result<(), TaskError>>;
fn into_future(self) -> Self::IntoFuture {
self.run().boxed()
}
}
#[cfg(feature = "dump")]
impl<TX, RX, TAG> From<&Task<TX, RX, TAG>> for super::dump::ConnDump {
fn from(task: &Task<TX, RX, TAG>) -> Self {
use super::dump::LinkDump;
let mut links: Vec<_> = task.links.iter().map(|opt| opt.as_ref().map(LinkDump::from)).collect();
Self {
conn_id: task.conn_id.get().0,
runtime: task.start_time.elapsed().as_secs_f32(),
txed_unacked: task.txed_unacked,
txed_unconsumable: task.txed_unconsumable,
txed_unconsumed: task.txed_unconsumed,
send_buffer: task.cfg.send_buffer.get(),
remote_receive_buffer: task.remote_cfg.as_ref().map(|cfg| cfg.recv_buffer.get()).unwrap_or_default(),
resend_queue: task.resend_queue.len(),
rxed_reliable_size: task.rxed_reliable_size,
rxed_reliable_consumed_since_last_ack: task.rxed_reliable_consumed_since_last_ack,
link0: links.get_mut(0).and_then(Option::take).unwrap_or_default(),
link1: links.get_mut(1).and_then(Option::take).unwrap_or_default(),
link2: links.get_mut(2).and_then(Option::take).unwrap_or_default(),
link3: links.get_mut(3).and_then(Option::take).unwrap_or_default(),
link4: links.get_mut(4).and_then(Option::take).unwrap_or_default(),
link5: links.get_mut(5).and_then(Option::take).unwrap_or_default(),
link6: links.get_mut(6).and_then(Option::take).unwrap_or_default(),
link7: links.get_mut(7).and_then(Option::take).unwrap_or_default(),
link8: links.get_mut(8).and_then(Option::take).unwrap_or_default(),
link9: links.get_mut(9).and_then(Option::take).unwrap_or_default(),
}
}
}