use dashmap::DashMap;
use std::net::SocketAddr;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use tokio::sync::mpsc;
use crate::tracing::{TransferDirection, TransferEvent};
pub static TRANSPORT_METRICS: LazyLock<TransportMetrics> = LazyLock::new(TransportMetrics::new);
#[derive(Debug)]
pub struct TransportMetrics {
transfers_completed: AtomicU32,
transfers_failed: AtomicU32,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
cumulative_bytes_sent: AtomicU64,
cumulative_bytes_received: AtomicU64,
total_transfer_time_ms: AtomicU64,
peak_throughput_bps: AtomicU64,
peak_cwnd_bytes: AtomicU32,
min_cwnd_bytes: AtomicU32,
cwnd_sum: AtomicU64, cwnd_samples: AtomicU32,
slowdowns_triggered: AtomicU32,
per_peer_stats: DashMap<SocketAddr, PeerTransferStats>,
min_rtt_us: AtomicU64,
max_rtt_us: AtomicU64,
rtt_sum_us: AtomicU64,
rtt_samples: AtomicU32,
}
const MAX_TRACKED_PEERS: usize = 256;
#[derive(Debug)]
pub struct PeerTransferStats {
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
}
impl Default for PeerTransferStats {
fn default() -> Self {
Self {
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
}
}
}
impl Default for TransportMetrics {
fn default() -> Self {
Self::new()
}
}
impl TransportMetrics {
pub fn new() -> Self {
Self {
transfers_completed: AtomicU32::new(0),
transfers_failed: AtomicU32::new(0),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
cumulative_bytes_sent: AtomicU64::new(0),
cumulative_bytes_received: AtomicU64::new(0),
total_transfer_time_ms: AtomicU64::new(0),
peak_throughput_bps: AtomicU64::new(0),
peak_cwnd_bytes: AtomicU32::new(0),
min_cwnd_bytes: AtomicU32::new(u32::MAX),
cwnd_sum: AtomicU64::new(0),
cwnd_samples: AtomicU32::new(0),
slowdowns_triggered: AtomicU32::new(0),
min_rtt_us: AtomicU64::new(u64::MAX),
max_rtt_us: AtomicU64::new(0),
rtt_sum_us: AtomicU64::new(0),
rtt_samples: AtomicU32::new(0),
per_peer_stats: DashMap::new(),
}
}
pub fn record_transfer_completed(&self, stats: &super::TransferStats) {
self.transfers_completed
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(1))
})
.ok();
self.bytes_sent
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(stats.bytes_transferred))
})
.ok();
self.cumulative_bytes_sent
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(stats.bytes_transferred))
})
.ok();
self.total_transfer_time_ms
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(stats.elapsed.as_millis() as u64))
})
.ok();
let throughput = stats.avg_throughput_bps();
self.peak_throughput_bps
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if throughput > current {
Some(throughput)
} else {
None
}
})
.ok();
self.record_cwnd_sample(stats.final_cwnd_bytes);
self.peak_cwnd_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if stats.peak_cwnd_bytes > current {
Some(stats.peak_cwnd_bytes)
} else {
None
}
})
.ok();
self.slowdowns_triggered
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(stats.slowdowns_triggered))
})
.ok();
let rtt_us = stats.base_delay.as_micros() as u64;
if rtt_us > 0 {
self.record_rtt_sample(rtt_us);
}
self.record_per_peer(stats.remote_addr, stats.bytes_transferred, |s| {
&s.bytes_sent
});
}
fn record_cwnd_sample(&self, cwnd_bytes: u32) {
self.cwnd_sum
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(cwnd_bytes as u64))
})
.ok();
self.cwnd_samples
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(1))
})
.ok();
self.min_cwnd_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if cwnd_bytes < current {
Some(cwnd_bytes)
} else {
None
}
})
.ok();
}
fn record_rtt_sample(&self, rtt_us: u64) {
self.rtt_sum_us
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(rtt_us))
})
.ok();
self.rtt_samples
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_add(1))
})
.ok();
self.min_rtt_us
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if rtt_us < current { Some(rtt_us) } else { None }
})
.ok();
self.max_rtt_us
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
if rtt_us > current { Some(rtt_us) } else { None }
})
.ok();
}
pub fn cumulative_bytes_sent(&self) -> u64 {
self.cumulative_bytes_sent.load(Ordering::Relaxed)
}
pub fn record_inbound_completed(&self, remote_addr: SocketAddr, bytes: u64) {
self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
self.cumulative_bytes_received
.fetch_add(bytes, Ordering::Relaxed);
self.record_per_peer(remote_addr, bytes, |s| &s.bytes_received);
}
pub fn cumulative_bytes_received(&self) -> u64 {
self.cumulative_bytes_received.load(Ordering::Relaxed)
}
fn record_per_peer(
&self,
addr: SocketAddr,
bytes: u64,
field: impl Fn(&PeerTransferStats) -> &AtomicU64,
) {
if let Some(entry) = self.per_peer_stats.get(&addr) {
field(&entry).fetch_add(bytes, Ordering::Relaxed);
} else if self.per_peer_stats.len() < MAX_TRACKED_PEERS {
let entry = self.per_peer_stats.entry(addr).or_default();
field(&entry).fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn per_peer_snapshot(&self) -> Vec<(SocketAddr, u64, u64)> {
self.per_peer_stats
.iter()
.map(|entry| {
let addr = *entry.key();
let sent = entry.value().bytes_sent.load(Ordering::Relaxed);
let recv = entry.value().bytes_received.load(Ordering::Relaxed);
(addr, sent, recv)
})
.collect()
}
pub fn take_snapshot(&self) -> Option<TransportSnapshot> {
let transfers_completed = self.transfers_completed.swap(0, Ordering::Relaxed);
let transfers_failed = self.transfers_failed.swap(0, Ordering::Relaxed);
if transfers_completed == 0 && transfers_failed == 0 {
self.bytes_sent.store(0, Ordering::Relaxed);
self.bytes_received.store(0, Ordering::Relaxed);
self.total_transfer_time_ms.store(0, Ordering::Relaxed);
self.peak_throughput_bps.store(0, Ordering::Relaxed);
self.peak_cwnd_bytes.store(0, Ordering::Relaxed);
self.min_cwnd_bytes.store(u32::MAX, Ordering::Relaxed);
self.cwnd_sum.store(0, Ordering::Relaxed);
self.cwnd_samples.store(0, Ordering::Relaxed);
self.slowdowns_triggered.store(0, Ordering::Relaxed);
self.min_rtt_us.store(u64::MAX, Ordering::Relaxed);
self.max_rtt_us.store(0, Ordering::Relaxed);
self.rtt_sum_us.store(0, Ordering::Relaxed);
self.rtt_samples.store(0, Ordering::Relaxed);
return None;
}
let bytes_sent = self.bytes_sent.swap(0, Ordering::Relaxed);
let bytes_received = self.bytes_received.swap(0, Ordering::Relaxed);
let total_transfer_time_ms = self.total_transfer_time_ms.swap(0, Ordering::Relaxed);
let peak_throughput_bps = self.peak_throughput_bps.swap(0, Ordering::Relaxed);
let peak_cwnd_bytes = self.peak_cwnd_bytes.swap(0, Ordering::Relaxed);
let min_cwnd_bytes = self.min_cwnd_bytes.swap(u32::MAX, Ordering::Relaxed);
let cwnd_sum = self.cwnd_sum.swap(0, Ordering::Relaxed);
let cwnd_samples = self.cwnd_samples.swap(0, Ordering::Relaxed);
let slowdowns_triggered = self.slowdowns_triggered.swap(0, Ordering::Relaxed);
let min_rtt_us = self.min_rtt_us.swap(u64::MAX, Ordering::Relaxed);
let max_rtt_us = self.max_rtt_us.swap(0, Ordering::Relaxed);
let rtt_sum_us = self.rtt_sum_us.swap(0, Ordering::Relaxed);
let rtt_samples = self.rtt_samples.swap(0, Ordering::Relaxed);
let avg_cwnd_bytes = if cwnd_samples > 0 {
(cwnd_sum / cwnd_samples as u64) as u32
} else {
0
};
let avg_transfer_time_ms = if transfers_completed > 0 {
total_transfer_time_ms / transfers_completed as u64
} else {
0
};
let avg_rtt_us = if rtt_samples > 0 {
rtt_sum_us / rtt_samples as u64
} else {
0
};
Some(TransportSnapshot {
transfers_completed,
transfers_failed,
bytes_sent,
bytes_received,
avg_transfer_time_ms,
peak_throughput_bps,
avg_cwnd_bytes,
peak_cwnd_bytes,
min_cwnd_bytes: if min_cwnd_bytes == u32::MAX {
0
} else {
min_cwnd_bytes
},
slowdowns_triggered,
avg_rtt_us,
min_rtt_us: if min_rtt_us == u64::MAX {
0
} else {
min_rtt_us
},
max_rtt_us,
})
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(arbitrary::Arbitrary))]
pub struct TransportSnapshot {
pub transfers_completed: u32,
pub transfers_failed: u32,
pub bytes_sent: u64,
pub bytes_received: u64,
pub avg_transfer_time_ms: u64,
pub peak_throughput_bps: u64,
pub avg_cwnd_bytes: u32,
pub peak_cwnd_bytes: u32,
pub min_cwnd_bytes: u32,
pub slowdowns_triggered: u32,
pub avg_rtt_us: u64,
pub min_rtt_us: u64,
pub max_rtt_us: u64,
}
const TRANSFER_EVENT_CHANNEL_CAPACITY: usize = 1000;
static TRANSFER_EVENT_SENDER: LazyLock<parking_lot::RwLock<Option<mpsc::Sender<TransferEvent>>>> =
LazyLock::new(|| parking_lot::RwLock::new(None));
pub fn init_transfer_event_channel() -> mpsc::Receiver<TransferEvent> {
let (tx, rx) = mpsc::channel(TRANSFER_EVENT_CHANNEL_CAPACITY);
*TRANSFER_EVENT_SENDER.write() = Some(tx);
rx
}
pub fn emit_transfer_started(
stream_id: u64,
peer_addr: SocketAddr,
expected_bytes: u64,
direction: TransferDirection,
) {
let sender_guard = TRANSFER_EVENT_SENDER.read();
if let Some(sender) = sender_guard.as_ref() {
let event = TransferEvent::Started {
stream_id,
peer_addr,
expected_bytes,
direction,
tx_id: None, timestamp: crate::tracing::telemetry::current_timestamp_ms(),
};
#[allow(clippy::let_underscore_must_use)]
let _ = sender.try_send(event);
}
}
#[allow(clippy::too_many_arguments)]
pub fn emit_transfer_completed(
stream_id: u64,
peer_addr: SocketAddr,
bytes_transferred: u64,
elapsed_ms: u64,
avg_throughput_bps: u64,
peak_cwnd_bytes: Option<u32>,
final_cwnd_bytes: Option<u32>,
slowdowns_triggered: Option<u32>,
final_srtt_ms: Option<u32>,
final_ssthresh_bytes: Option<u32>,
min_ssthresh_floor_bytes: Option<u32>,
total_timeouts: Option<u32>,
direction: TransferDirection,
) {
let sender_guard = TRANSFER_EVENT_SENDER.read();
if let Some(sender) = sender_guard.as_ref() {
let event = TransferEvent::Completed {
stream_id,
peer_addr,
bytes_transferred,
elapsed_ms,
avg_throughput_bps,
peak_cwnd_bytes,
final_cwnd_bytes,
slowdowns_triggered,
final_srtt_ms,
final_ssthresh_bytes,
min_ssthresh_floor_bytes,
total_timeouts,
direction,
timestamp: crate::tracing::telemetry::current_timestamp_ms(),
};
#[allow(clippy::let_underscore_must_use)]
let _ = sender.try_send(event);
}
}
pub fn emit_transfer_failed(
stream_id: u64,
peer_addr: SocketAddr,
bytes_transferred: u64,
reason: String,
elapsed_ms: u64,
direction: TransferDirection,
) {
let sender_guard = TRANSFER_EVENT_SENDER.read();
if let Some(sender) = sender_guard.as_ref() {
let event = TransferEvent::Failed {
stream_id,
peer_addr,
bytes_transferred,
reason,
elapsed_ms,
direction,
timestamp: crate::tracing::telemetry::current_timestamp_ms(),
};
#[allow(clippy::let_underscore_must_use)]
let _ = sender.try_send(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_empty_snapshot_returns_none() {
let metrics = TransportMetrics::new();
assert!(metrics.take_snapshot().is_none());
}
#[test]
fn test_snapshot_after_transfer() {
let metrics = TransportMetrics::new();
let stats = crate::transport::TransferStats {
stream_id: 1,
remote_addr: "127.0.0.1:8080".parse().unwrap(),
bytes_transferred: 1024,
elapsed: Duration::from_millis(100),
peak_cwnd_bytes: 50000,
final_cwnd_bytes: 40000,
slowdowns_triggered: 1,
base_delay: Duration::from_millis(10),
final_ssthresh_bytes: 100000,
min_ssthresh_floor_bytes: 5696,
total_timeouts: 0,
final_flightsize: 0,
configured_rate: 0,
};
metrics.record_transfer_completed(&stats);
let snapshot = metrics.take_snapshot().expect("should have snapshot");
assert_eq!(snapshot.transfers_completed, 1);
assert_eq!(snapshot.bytes_sent, 1024);
assert_eq!(snapshot.peak_cwnd_bytes, 50000);
assert_eq!(snapshot.slowdowns_triggered, 1);
}
#[test]
fn test_snapshot_resets_counters() {
let metrics = TransportMetrics::new();
let stats = crate::transport::TransferStats {
stream_id: 1,
remote_addr: "127.0.0.1:8080".parse().unwrap(),
bytes_transferred: 1024,
elapsed: Duration::from_millis(100),
peak_cwnd_bytes: 50000,
final_cwnd_bytes: 40000,
slowdowns_triggered: 1,
base_delay: Duration::from_millis(10),
final_ssthresh_bytes: 100000,
min_ssthresh_floor_bytes: 5696,
total_timeouts: 0,
final_flightsize: 0,
configured_rate: 0,
};
metrics.record_transfer_completed(&stats);
let _ = metrics.take_snapshot();
assert!(metrics.take_snapshot().is_none());
}
#[test]
fn test_multiple_transfers_aggregate() {
let metrics = TransportMetrics::new();
for i in 0..5 {
let stats = crate::transport::TransferStats {
stream_id: i,
remote_addr: "127.0.0.1:8080".parse().unwrap(),
bytes_transferred: 1000,
elapsed: Duration::from_millis(100),
peak_cwnd_bytes: 40000 + (i as u32 * 1000),
final_cwnd_bytes: 35000,
slowdowns_triggered: 1,
base_delay: Duration::from_millis(10),
final_ssthresh_bytes: 100000,
min_ssthresh_floor_bytes: 5696,
total_timeouts: 0,
final_flightsize: 0,
configured_rate: 0,
};
metrics.record_transfer_completed(&stats);
}
let snapshot = metrics.take_snapshot().expect("should have snapshot");
assert_eq!(snapshot.transfers_completed, 5);
assert_eq!(snapshot.bytes_sent, 5000);
assert_eq!(snapshot.peak_cwnd_bytes, 44000); assert_eq!(snapshot.slowdowns_triggered, 5);
}
#[test]
fn test_inbound_completed_tracking() {
let metrics = TransportMetrics::new();
let addr: std::net::SocketAddr = "10.0.0.1:5000".parse().unwrap();
metrics.record_inbound_completed(addr, 2048);
metrics.record_inbound_completed(addr, 1024);
assert_eq!(metrics.cumulative_bytes_received(), 3072);
let peers = metrics.per_peer_snapshot();
let peer = peers.iter().find(|(a, _, _)| *a == addr);
assert!(peer.is_some(), "peer should be tracked");
let (_, sent, recv) = peer.unwrap();
assert_eq!(*sent, 0);
assert_eq!(*recv, 3072);
let stats = crate::transport::TransferStats {
stream_id: 1,
remote_addr: addr,
bytes_transferred: 100,
elapsed: Duration::from_millis(10),
peak_cwnd_bytes: 1000,
final_cwnd_bytes: 1000,
slowdowns_triggered: 0,
base_delay: Duration::from_millis(5),
final_ssthresh_bytes: 100000,
min_ssthresh_floor_bytes: 5696,
total_timeouts: 0,
final_flightsize: 0,
configured_rate: 0,
};
metrics.record_transfer_completed(&stats);
let snapshot = metrics.take_snapshot().unwrap();
assert_eq!(snapshot.bytes_received, 3072);
assert_eq!(metrics.cumulative_bytes_received(), 3072);
}
#[test]
fn test_per_peer_capacity_bound() {
let metrics = TransportMetrics::new();
for i in 0..MAX_TRACKED_PEERS {
let addr: std::net::SocketAddr = format!("10.0.{}.{}:{}", i / 256, i % 256, 5000 + i)
.parse()
.unwrap();
metrics.record_inbound_completed(addr, 100);
}
assert_eq!(metrics.per_peer_snapshot().len(), MAX_TRACKED_PEERS);
let extra: std::net::SocketAddr = "192.168.1.1:9999".parse().unwrap();
metrics.record_inbound_completed(extra, 500);
assert_eq!(metrics.per_peer_snapshot().len(), MAX_TRACKED_PEERS);
let snapshot = metrics.per_peer_snapshot();
let extra_entry = snapshot.iter().find(|(a, _, _)| *a == extra);
assert!(extra_entry.is_none(), "257th peer should not be tracked");
let total: u64 = (MAX_TRACKED_PEERS as u64 * 100) + 500;
assert_eq!(metrics.cumulative_bytes_received(), total);
}
}