use std::collections::HashMap;
use futures_concurrency::stream::stream_group;
use futures_lite::StreamExt;
use iroh_base::key::NodeId;
use iroh_metrics::inc;
use tokio::{
sync::{mpsc, Notify},
task::JoinHandle,
time::Duration,
};
use tracing::{debug, error, info_span, trace, warn, Instrument};
use crate::{
magicsock::{ConnectionType, ConnectionTypeStream},
metrics::MagicsockMetrics,
};
#[derive(Debug)]
pub(super) struct RttHandle {
pub(super) _handle: JoinHandle<()>,
pub(super) msg_tx: mpsc::Sender<RttMessage>,
}
impl RttHandle {
pub(super) fn new() -> Self {
let mut actor = RttActor {
connection_events: stream_group::StreamGroup::new().keyed(),
connections: HashMap::new(),
tick: Notify::new(),
};
let (msg_tx, msg_rx) = mpsc::channel(16);
let _handle = tokio::spawn(
async move {
actor.run(msg_rx).await;
}
.instrument(info_span!("rtt-actor")),
);
Self { _handle, msg_tx }
}
}
#[derive(Debug)]
pub(super) enum RttMessage {
NewConnection {
connection: quinn::WeakConnectionHandle,
conn_type_changes: ConnectionTypeStream,
node_id: NodeId,
},
}
#[derive(Debug)]
struct RttActor {
connection_events: stream_group::Keyed<ConnectionTypeStream>,
connections: HashMap<stream_group::Key, (quinn::WeakConnectionHandle, NodeId, bool)>,
tick: Notify,
}
impl RttActor {
async fn run(&mut self, mut msg_rx: mpsc::Receiver<RttMessage>) {
let mut cleanup_interval = tokio::time::interval(Duration::from_secs(5));
cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
Some(msg) = msg_rx.recv() => self.handle_msg(msg),
item = self.connection_events.next(),
if !self.connection_events.is_empty() => self.do_reset_rtt(item),
_ = cleanup_interval.tick() => self.do_connections_cleanup(),
() = self.tick.notified() => continue,
else => break,
}
}
debug!("rtt-actor finished");
}
fn handle_msg(&mut self, msg: RttMessage) {
match msg {
RttMessage::NewConnection {
connection,
conn_type_changes,
node_id,
} => {
self.handle_new_connection(connection, conn_type_changes, node_id);
}
}
}
fn handle_new_connection(
&mut self,
connection: quinn::WeakConnectionHandle,
conn_type_changes: ConnectionTypeStream,
node_id: NodeId,
) {
let key = self.connection_events.insert(conn_type_changes);
self.connections.insert(key, (connection, node_id, false));
self.tick.notify_one();
inc!(MagicsockMetrics, connection_handshake_success);
}
fn do_reset_rtt(&mut self, item: Option<(stream_group::Key, ConnectionType)>) {
match item {
Some((key, new_conn_type)) => match self.connections.get_mut(&key) {
Some((handle, node_id, was_direct_before)) => {
if handle.network_path_changed() {
debug!(
node_id = %node_id.fmt_short(),
new_type = ?new_conn_type,
"Congestion controller state reset",
);
if !*was_direct_before && matches!(new_conn_type, ConnectionType::Direct(_))
{
*was_direct_before = true;
inc!(MagicsockMetrics, connection_became_direct);
}
} else {
debug!(
node_id = %node_id.fmt_short(),
"removing dropped connection",
);
self.connection_events.remove(key);
}
}
None => error!("No connection found for stream item"),
},
None => {
warn!("self.conn_type_changes is empty but was polled");
}
}
}
fn do_connections_cleanup(&mut self) {
for (key, (handle, node_id, _)) in self.connections.iter() {
if !handle.is_alive() {
trace!(node_id = %node_id.fmt_short(), "removing stale connection");
self.connection_events.remove(*key);
}
}
}
}