use std::cell::Cell;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use std::time::Duration;
use parking_lot::RwLock;
use crate::config::GlobalExecutor;
use crate::simulation::{RealTime, TimeSource, TimeSourceInterval};
use crate::transport::connection_handler::NAT_TRAVERSAL_MAX_ATTEMPTS;
use crate::transport::crypto::TransportSecretKey;
use crate::transport::packet_data::UnknownEncryption;
use crate::transport::sent_packet_tracker::MESSAGE_CONFIRMATION_TIMEOUT;
use aes_gcm::Aes128Gcm;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tracing::{Instrument, instrument, span};
mod inbound_stream;
mod outbound_stream;
pub(crate) mod piped_stream;
pub(crate) mod streaming;
#[cfg(feature = "bench")]
pub mod streaming_buffer;
#[cfg(not(feature = "bench"))]
pub(crate) mod streaming_buffer;
use super::{
TransportError,
bbr::DeliveryRateToken,
congestion_control::{CongestionControl, CongestionController},
connection_handler::SerializedMessage,
global_bandwidth::GlobalBandwidthManager,
packet_data::{self, PacketData},
received_packet_tracker::ReceivedPacketTracker,
received_packet_tracker::ReportResult,
sent_packet_tracker::{ResendAction, SentPacketTracker},
symmetric_message::{self, SymmetricMessage, SymmetricMessagePayload},
token_bucket::TokenBucket,
};
use crate::operations::orphan_streams::OrphanStreamRegistry;
use crate::util::time_source::InstantTimeSrc;
type Result<T = (), E = TransportError> = std::result::Result<T, E>;
type OutboundStreamResult = std::result::Result<super::TransferStats, TransportError>;
const MAX_DATA_SIZE: usize = packet_data::MAX_DATA_SIZE - 41;
const ACK_CHECK_INTERVAL: Duration = Duration::from_millis(100);
const SIMULATION_ACK_CHECK_INTERVAL: Duration = Duration::from_millis(500);
const SIMULATION_RESEND_CHECK_INTERVAL: Duration = Duration::from_millis(50);
const SIMULATION_RESEND_YIELD_DELAY: Duration = Duration::from_millis(10);
#[must_use]
pub(crate) struct RemoteConnection<S = super::UdpSocket, T: TimeSource = RealTime> {
pub(super) outbound_symmetric_key: Aes128Gcm,
pub(super) remote_addr: SocketAddr,
pub(super) sent_tracker: Arc<parking_lot::Mutex<SentPacketTracker<T>>>,
pub(super) last_packet_id: Arc<AtomicU32>,
pub(super) inbound_packet_recv: mpsc::Receiver<PacketData<UnknownEncryption>>,
pub(super) inbound_symmetric_key: Aes128Gcm,
pub(super) inbound_symmetric_key_bytes: [u8; 16],
#[allow(dead_code)]
pub(super) my_address: Option<SocketAddr>,
pub(super) transport_secret_key: TransportSecretKey,
pub(super) congestion_controller: Arc<CongestionController<T>>,
pub(super) token_bucket: Arc<TokenBucket<T>>,
pub(super) socket: Arc<S>,
pub(super) global_bandwidth: Option<Arc<GlobalBandwidthManager>>,
pub(crate) rolling_rtt_stats: super::rolling_rtt_stats::RollingRttStatsHandle<T>,
pub(super) time_source: T,
}
impl<S, T: TimeSource> Drop for RemoteConnection<S, T> {
fn drop(&mut self) {
if let Some(ref global) = self.global_bandwidth {
global.unregister_connection();
tracing::debug!(
peer_addr = %self.remote_addr,
remaining_connections = global.connection_count(),
"Unregistered connection from global bandwidth pool"
);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(transparent)]
#[serde(transparent)]
pub struct StreamId(u32);
const STREAM_ID_BLOCK: u32 = 100_000;
thread_local! {
static STREAM_ID_COUNTER: Cell<u32> = {
let idx = crate::config::GlobalRng::thread_index();
Cell::new((idx as u32) * STREAM_ID_BLOCK)
};
}
impl StreamId {
const OPS_STREAM_BIT: u32 = 0x8000_0000;
pub fn next() -> Self {
Self(STREAM_ID_COUNTER.with(|c| {
let v = c.get();
c.set(v + 1);
v
}))
}
pub fn next_operations() -> Self {
let id = STREAM_ID_COUNTER.with(|c| {
let v = c.get();
c.set(v + 1);
v
});
Self(id | Self::OPS_STREAM_BIT)
}
pub fn is_operations_stream(&self) -> bool {
self.0 & Self::OPS_STREAM_BIT != 0
}
pub fn reset_counter() {
let idx = crate::config::GlobalRng::thread_index();
STREAM_ID_COUNTER.with(|c| c.set((idx as u32) * STREAM_ID_BLOCK));
}
}
impl std::fmt::Display for StreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[cfg(test)]
impl<'a> arbitrary::Arbitrary<'a> for StreamId {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
Ok(Self(u.arbitrary()?))
}
}
type InboundStreamResult = Result<(StreamId, SerializedMessage), StreamId>;
#[must_use = "call await on the `recv` function to start listening for incoming messages"]
pub struct PeerConnection<S = super::UdpSocket, T: TimeSource = RealTime> {
remote_conn: RemoteConnection<S, T>,
received_tracker: ReceivedPacketTracker<InstantTimeSrc>,
inbound_streams: HashMap<StreamId, mpsc::Sender<(u32, bytes::Bytes)>>,
inbound_stream_futures: FuturesUnordered<JoinHandle<InboundStreamResult>>,
outbound_stream_futures: FuturesUnordered<JoinHandle<OutboundStreamResult>>,
failure_count: usize,
first_failure_time_nanos: Option<u64>,
last_packet_report_time_nanos: u64,
last_rate_update_nanos: Option<u64>,
last_received_nanos: u64,
keep_alive_handle: Option<JoinHandle<()>>,
pending_pings: Arc<RwLock<BTreeMap<u64, u64>>>,
streaming_registry: Arc<streaming::StreamRegistry>,
streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)>,
time_source: T,
orphan_stream_registry:
Option<std::sync::Arc<crate::operations::orphan_streams::OrphanStreamRegistry>>,
dispatched_msg_hashes: lru::LruCache<u64, ()>,
}
impl<S, T: TimeSource> std::fmt::Debug for PeerConnection<S, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerConnection")
.field("remote_conn", &self.remote_conn.remote_addr)
.finish()
}
}
impl<S, T: TimeSource> Drop for PeerConnection<S, T> {
fn drop(&mut self) {
let stream_count = self.streaming_handles.len();
if stream_count > 0 {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
stream_count,
"Cancelling streaming handles on connection drop"
);
for (handle, _) in self.streaming_handles.values() {
handle.cancel();
}
}
if let Some(handle) = self.keep_alive_handle.take() {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
"Cancelling keep-alive task"
);
handle.abort();
}
}
}
const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(5);
const MAX_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60);
fn keepalive_interval_for_pending(pending_count: usize) -> Duration {
if pending_count > MAX_UNANSWERED_PINGS {
let extra = (pending_count - MAX_UNANSWERED_PINGS).min(4) as u32;
KEEP_ALIVE_INTERVAL
.saturating_mul(2u32.pow(extra))
.min(MAX_KEEPALIVE_INTERVAL)
} else {
KEEP_ALIVE_INTERVAL
}
}
const MAX_UNANSWERED_PINGS: usize = 5;
const DEDUP_CACHE_CAPACITY: NonZeroUsize = match NonZeroUsize::new(1000) {
Some(v) => v,
None => panic!("DEDUP_CACHE_CAPACITY must be non-zero"),
};
const STREAMING_HANDLE_IDLE_TIMEOUT_PROD: Duration =
Duration::from_secs(streaming::STREAM_INACTIVITY_TIMEOUT.as_secs() * 6);
const STREAMING_HANDLE_IDLE_TIMEOUT_SIM: Duration = Duration::from_secs(86_400);
fn streaming_handle_idle_timeout() -> Duration {
if crate::config::SimulationIdleTimeout::is_enabled() {
STREAMING_HANDLE_IDLE_TIMEOUT_SIM
} else {
STREAMING_HANDLE_IDLE_TIMEOUT_PROD
}
}
fn sweep_streaming_handles_inner(
streaming_handles: &mut HashMap<StreamId, (streaming::StreamHandle, u64)>,
streaming_registry: &streaming::StreamRegistry,
now_nanos: u64,
threshold: Duration,
peer_addr: SocketAddr,
) {
let threshold_nanos = threshold.as_nanos() as u64;
let mut swept_count: usize = 0;
streaming_handles.retain(|&stream_id, (handle, last_activity_nanos)| {
let idle_nanos = now_nanos.saturating_sub(*last_activity_nanos);
if idle_nanos > threshold_nanos {
tracing::debug!(
%peer_addr,
stream_id = %stream_id,
idle_secs = Duration::from_nanos(idle_nanos).as_secs_f64(),
"Sweeping idle streaming handle (issue #4079)"
);
handle.cancel();
streaming_registry.remove(stream_id);
swept_count += 1;
false
} else {
true
}
});
if swept_count > 0 {
tracing::info!(
%peer_addr,
swept_count,
remaining = streaming_handles.len(),
"Swept idle streaming handles"
);
}
}
#[allow(private_bounds)]
impl<S: super::Socket, T: TimeSource> PeerConnection<S, T> {
pub(super) fn new(remote_conn: RemoteConnection<S, T>) -> Self {
let remote_addr = remote_conn.remote_addr;
let socket = remote_conn.socket.clone();
let outbound_key = remote_conn.outbound_symmetric_key.clone();
let last_packet_id = remote_conn.last_packet_id.clone();
let time_source = remote_conn.time_source.clone();
let pending_pings: Arc<RwLock<BTreeMap<u64, u64>>> = Arc::new(RwLock::new(BTreeMap::new()));
let pending_pings_for_task = pending_pings.clone();
let task_time_source = time_source.clone();
let keepalive_enabled = time_source.supports_keepalive();
let keep_alive_handle = GlobalExecutor::spawn(async move {
if !keepalive_enabled {
tracing::debug!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
"Keep-alive task SKIPPED (time source does not support keepalive)"
);
return;
}
tracing::info!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
"Keep-alive task STARTED for connection"
);
let mut tick_count = 0u64;
let mut ping_seq = 0u64;
let task_start_nanos = task_time_source.now_nanos();
loop {
let pending_count = pending_pings_for_task.read().len();
let wait_duration = keepalive_interval_for_pending(pending_count);
task_time_source.sleep(wait_duration).await;
tick_count += 1;
let now_nanos = task_time_source.now_nanos();
let elapsed_since_start_nanos = now_nanos.saturating_sub(task_start_nanos);
let current_ping_seq = ping_seq;
ping_seq += 1;
tracing::debug!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
tick_count,
ping_sequence = current_ping_seq,
elapsed_since_start_secs = Duration::from_nanos(elapsed_since_start_nanos).as_secs_f64(),
wait_interval_secs = wait_duration.as_secs_f64(),
unanswered_pings = pending_count,
"Keep-alive tick - sending Ping"
);
let packet_id = last_packet_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let ping_packet = match SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::Ping {
sequence: current_ping_seq,
},
&outbound_key,
vec![], ) {
Ok(packet) => packet.prepared_send(),
Err(e) => {
tracing::error!(?e, "Failed to create keep-alive Ping packet");
break;
}
};
{
let mut pending = pending_pings_for_task.write();
pending.insert(current_ping_seq, task_time_source.now_nanos());
}
tracing::debug!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
packet_id,
ping_sequence = current_ping_seq,
"Sending keep-alive Ping packet"
);
match socket.send_to(&ping_packet, remote_addr).await {
Ok(_) => {
tracing::debug!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
packet_id,
ping_sequence = current_ping_seq,
"Keep-alive Ping packet sent successfully"
);
}
Err(e) => {
let elapsed = Duration::from_nanos(
task_time_source
.now_nanos()
.saturating_sub(task_start_nanos),
);
tracing::warn!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
error = ?e,
elapsed_since_start_secs = elapsed.as_secs_f64(),
total_ticks = tick_count,
"Keep-alive task STOPPING - socket error"
);
break;
}
}
}
let elapsed = Duration::from_nanos(
task_time_source
.now_nanos()
.saturating_sub(task_start_nanos),
);
tracing::warn!(
target: "freenet_core::transport::keepalive_lifecycle",
remote = ?remote_addr,
total_lifetime_secs = elapsed.as_secs_f64(),
total_ticks = tick_count,
"Keep-alive task EXITING"
);
});
tracing::debug!(
peer_addr = %remote_addr,
"PeerConnection created with persistent keep-alive task"
);
let now_nanos = time_source.now_nanos();
Self {
remote_conn,
received_tracker: ReceivedPacketTracker::new(),
inbound_streams: HashMap::new(),
inbound_stream_futures: FuturesUnordered::new(),
outbound_stream_futures: FuturesUnordered::new(),
failure_count: 0,
first_failure_time_nanos: None,
last_packet_report_time_nanos: now_nanos,
last_rate_update_nanos: None,
last_received_nanos: now_nanos,
keep_alive_handle: Some(keep_alive_handle),
pending_pings,
streaming_registry: Arc::new(streaming::StreamRegistry::new()),
streaming_handles: HashMap::new(),
time_source,
orphan_stream_registry: None,
dispatched_msg_hashes: lru::LruCache::new(DEDUP_CACHE_CAPACITY),
}
}
fn msg_hash(bytes: &[u8]) -> u64 {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut h);
h.finish()
}
fn is_duplicate_dispatch(&mut self, bytes: &[u8]) -> bool {
self.dispatched_msg_hashes
.put(Self::msg_hash(bytes), ())
.is_some()
}
pub fn set_orphan_stream_registry(
&mut self,
registry: std::sync::Arc<crate::operations::orphan_streams::OrphanStreamRegistry>,
) {
self.orphan_stream_registry = Some(registry);
}
fn sweep_idle_streaming_handles(&mut self) {
let now_nanos = self.time_source.now_nanos();
sweep_streaming_handles_inner(
&mut self.streaming_handles,
&self.streaming_registry,
now_nanos,
streaming_handle_idle_timeout(),
self.remote_conn.remote_addr,
);
}
#[instrument(name = "peer_connection", skip_all)]
pub async fn send<D>(&mut self, data: D) -> Result
where
D: Serialize + Send + std::fmt::Debug,
{
let data = bincode::serialize(&data)?;
if data.len() + SymmetricMessage::short_message_overhead() > MAX_DATA_SIZE {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
total_size_bytes = data.len(),
"Sending as stream"
);
self.outbound_stream(data).await;
} else {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
"Sending as short message"
);
self.outbound_short_message(data).await?;
}
Ok(())
}
#[instrument(name = "peer_connection", skip(self))]
pub async fn recv(&mut self) -> Result<Vec<u8>> {
let in_simulation = crate::config::SimulationTransportOpt::is_enabled();
let ack_interval = if in_simulation {
SIMULATION_ACK_CHECK_INTERVAL
} else {
ACK_CHECK_INTERVAL
};
let resend_initial = if in_simulation {
SIMULATION_RESEND_CHECK_INTERVAL
} else {
Duration::from_millis(10)
};
let resend_yield = if in_simulation {
SIMULATION_RESEND_YIELD_DELAY
} else {
Duration::from_millis(2)
};
let mut resend_check_sleep: Option<
std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
> = Some(self.time_source.sleep(resend_initial));
let kill_connection_after = self.time_source.connection_idle_timeout();
let kill_connection_after_nanos = kill_connection_after.as_nanos() as u64;
let mut timeout_check =
TimeSourceInterval::new(self.time_source.clone(), Duration::from_secs(5));
let ack_start_nanos = self.time_source.now_nanos() + ack_interval.as_nanos() as u64;
let mut ack_check =
TimeSourceInterval::new_at(self.time_source.clone(), ack_start_nanos, ack_interval);
let rate_start_nanos = self.time_source.now_nanos() + ack_interval.as_nanos() as u64;
let mut rate_update_check =
TimeSourceInterval::new_at(self.time_source.clone(), rate_start_nanos, ack_interval);
const FAILURE_TIME_WINDOW: Duration = Duration::from_secs(30);
const FAILURE_TIME_WINDOW_NANOS: u64 = FAILURE_TIME_WINDOW.as_nanos() as u64;
loop {
if resend_check_sleep.is_none() {
resend_check_sleep = Some(self.time_source.sleep(resend_yield));
}
crate::deterministic_select! {
inbound = self.remote_conn.inbound_packet_recv.recv() => {
let packet_data = inbound.ok_or_else(|| TransportError::ConnectionClosed(self.remote_addr()))?;
self.last_received_nanos = self.time_source.now_nanos();
if packet_data.is_intro_packet() {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
packet_bytes = ?&packet_data.data()[..std::cmp::min(32, packet_data.data().len())], packet_len = packet_data.data().len(),
"Received intro packet"
);
}
let Ok(decrypted) = packet_data.try_decrypt_sym(&self.remote_conn.inbound_symmetric_key).inspect_err(|error| {
tracing::debug!(
error = %error,
peer_addr = %self.remote_conn.remote_addr,
inbound_key = ?self.remote_conn.inbound_symmetric_key_bytes,
packet_len = packet_data.data().len(),
packet_first_bytes = ?&packet_data.data()[..std::cmp::min(32, packet_data.data().len())],
"Failed to decrypt packet, might be an intro packet or a partial packet"
);
}) else {
if packet_data.is_intro_packet() {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
"Attempting to decrypt intro packet"
);
match self.remote_conn.transport_secret_key.decrypt(packet_data.data()) {
Ok(_decrypted_intro) => {
super::TRANSPORT_METRICS.record_packet_received(
self.remote_conn.remote_addr,
packet_data.data().len() as u64,
);
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
"Successfully decrypted intro packet, sending ACK"
);
let ack_packet = SymmetricMessage::ack_ok(
&self.remote_conn.outbound_symmetric_key,
self.remote_conn.inbound_symmetric_key_bytes,
self.remote_conn.remote_addr,
);
if let Ok(ack) = ack_packet {
if let Err(send_err) = self.remote_conn
.socket
.send_to(ack.data(), self.remote_conn.remote_addr)
.await
{
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
error = ?send_err,
"Failed to send ACK for intro packet"
);
} else {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
"Successfully sent ACK for intro packet"
);
}
} else {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
"Failed to create ACK packet for intro"
);
}
continue;
}
Err(decrypt_err) => {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
error = ?decrypt_err,
"Packet with intro type marker failed decryption"
);
}
}
}
let now_nanos = self.time_source.now_nanos();
if let Some(first_failure_time_nanos) = self.first_failure_time_nanos {
if now_nanos.saturating_sub(first_failure_time_nanos) <= FAILURE_TIME_WINDOW_NANOS {
self.failure_count += 1;
} else {
self.failure_count = 1;
self.first_failure_time_nanos = Some(now_nanos);
}
} else {
self.failure_count = 1;
self.first_failure_time_nanos = Some(now_nanos);
}
if self.failure_count > NAT_TRAVERSAL_MAX_ATTEMPTS {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
failure_count = self.failure_count,
max_attempts = NAT_TRAVERSAL_MAX_ATTEMPTS,
"Dropping connection due to repeated decryption failures"
);
return Err(TransportError::ConnectionClosed(self.remote_addr()));
}
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
"Ignoring packet"
);
continue;
};
super::TRANSPORT_METRICS.record_packet_received(
self.remote_conn.remote_addr,
packet_data.data().len() as u64,
);
let msg = SymmetricMessage::deser(decrypted.data()).unwrap();
let SymmetricMessage {
packet_id,
confirm_receipt,
payload,
} = msg;
match &payload {
SymmetricMessagePayload::NoOp => {
if confirm_receipt.is_empty() {
let elapsed_nanos = self.time_source.now_nanos().saturating_sub(self.last_received_nanos);
tracing::debug!(
target: "freenet_core::transport::keepalive_received",
remote = ?self.remote_conn.remote_addr,
packet_id,
time_since_last_received_ms = Duration::from_nanos(elapsed_nanos).as_millis(),
"Received NoOp keep-alive packet (no receipts)"
);
} else {
tracing::debug!(
target: "freenet_core::transport::keepalive_received",
remote = ?self.remote_conn.remote_addr,
packet_id,
receipt_count = confirm_receipt.len(),
"Received NoOp receipt packet"
);
}
}
SymmetricMessagePayload::Ping { sequence } => {
tracing::debug!(
target: "freenet_core::transport::keepalive_received",
remote = ?self.remote_conn.remote_addr,
packet_id,
ping_sequence = sequence,
"Received Ping, sending Pong response"
);
if let Err(e) = self.send_pong(*sequence).await {
tracing::warn!(
target: "freenet_core::transport::keepalive_received",
remote = ?self.remote_conn.remote_addr,
ping_sequence = sequence,
error = ?e,
"Failed to send Pong response"
);
}
}
SymmetricMessagePayload::Pong { sequence } => {
tracing::debug!(
target: "freenet_core::transport::keepalive_received",
remote = ?self.remote_conn.remote_addr,
packet_id,
pong_sequence = sequence,
"Received Pong, confirming bidirectional liveness"
);
let mut pending = self.pending_pings.write();
if pending.remove(sequence).is_some() {
tracing::trace!(
target: "freenet_core::transport::keepalive_received",
remote = ?self.remote_conn.remote_addr,
pong_sequence = sequence,
remaining_pending = pending.len(),
"Removed acknowledged ping from pending set"
);
}
}
SymmetricMessagePayload::AckConnection { .. } | SymmetricMessagePayload::ShortMessage { .. } | SymmetricMessagePayload::StreamFragment { .. } => {}
}
{
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
packet_id,
confirm_receipts_count = confirm_receipt.len(),
"Received inbound packet with confirmations"
);
}
let current_time_nanos = self.time_source.now_nanos();
let message_confirmation_timeout_nanos = MESSAGE_CONFIRMATION_TIMEOUT.as_nanos() as u64;
let should_send_receipts = if current_time_nanos > self.last_packet_report_time_nanos + message_confirmation_timeout_nanos {
let elapsed_nanos = current_time_nanos.saturating_sub(self.last_packet_report_time_nanos);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
elapsed_ms = Duration::from_nanos(elapsed_nanos).as_millis(),
timeout_ms = MESSAGE_CONFIRMATION_TIMEOUT.as_millis(),
"Timeout reached, should send receipts"
);
self.last_packet_report_time_nanos = current_time_nanos;
true
} else {
false
};
let (ack_info, _loss_rate) = self.remote_conn
.sent_tracker
.lock()
.report_received_receipts(&confirm_receipt);
for (rtt_sample_opt, packet_size, token) in ack_info {
match rtt_sample_opt {
Some(rtt_sample) => {
self.remote_conn.rolling_rtt_stats.record(rtt_sample);
self.remote_conn.congestion_controller.on_ack_with_token(
rtt_sample,
packet_size,
token,
);
}
None => {
self.remote_conn.congestion_controller.on_ack_without_rtt(packet_size);
}
}
}
let report_result = self.received_tracker.report_received_packet(packet_id);
match (report_result, should_send_receipts) {
(ReportResult::QueueFull, _) | (_, true) => {
let receipts = self.received_tracker.get_receipts();
if !receipts.is_empty() {
if let Err(e) = self.noop(receipts).await {
if e.is_transient_send_failure() {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
"ACK noop send failed, will retry"
);
} else {
return Err(e);
}
}
}
},
(ReportResult::Ok, _) => {}
(ReportResult::AlreadyReceived, _) => {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
packet_id,
"Already received packet"
);
continue;
}
}
if let Some(msg) = self.process_inbound(payload).await.map_err(|error| {
tracing::error!(
error = %error,
packet_id,
peer_addr = %self.remote_conn.remote_addr,
"Error processing inbound packet"
);
error
})? {
super::TRANSPORT_METRICS.record_inbound_completed(msg.len() as u64);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
packet_id,
"Returning full stream message"
);
return Ok(msg);
}
},
inbound_stream = self.inbound_stream_futures.next(), if !self.inbound_stream_futures.is_empty() => {
let Some(res) = inbound_stream else {
tracing::error!(
peer_addr = %self.remote_conn.remote_addr,
"Unexpected no-stream from ongoing_inbound_streams"
);
continue
};
let Ok((stream_id, msg)) = res.map_err(|e| TransportError::Other(e.into()))? else {
tracing::error!(
peer_addr = %self.remote_conn.remote_addr,
"Unexpected error from ongoing_inbound_streams"
);
continue;
};
self.inbound_streams.remove(&stream_id);
self.streaming_handles.remove(&stream_id);
self.streaming_registry.remove(stream_id);
let bytes_received = msg.len() as u64;
super::TRANSPORT_METRICS.record_inbound_completed(bytes_received);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
bytes = bytes_received,
"Stream finished"
);
return Ok(msg);
},
outbound_stream = self.outbound_stream_futures.next(), if !self.outbound_stream_futures.is_empty() => {
let Some(res) = outbound_stream else {
tracing::error!(
peer_addr = %self.remote_conn.remote_addr,
"Unexpected no-stream from ongoing_outbound_streams"
);
continue
};
let transfer_result = res.map_err(|e| TransportError::Other(e.into()))?;
match transfer_result {
Ok(stats) => {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = stats.stream_id,
bytes = stats.bytes_transferred,
elapsed_ms = stats.elapsed.as_millis(),
throughput_kbps = stats.avg_throughput_bps() / 1024,
"Outbound stream completed with stats"
);
super::TRANSPORT_METRICS.record_transfer_completed(&stats);
}
Err(e) if e.is_transient_send_failure() => {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
error = %e,
"Outbound stream send failed, operation layer will timeout and retry"
);
}
Err(e) => return Err(e),
}
},
_ = timeout_check.tick() => {
let now_nanos = self.time_source.now_nanos();
let elapsed_nanos = now_nanos.saturating_sub(self.last_received_nanos);
let elapsed = Duration::from_nanos(elapsed_nanos);
if elapsed_nanos > kill_connection_after_nanos {
tracing::warn!(
target: "freenet_core::transport::keepalive_timeout",
remote = ?self.remote_conn.remote_addr,
elapsed_seconds = elapsed.as_secs_f64(),
timeout_threshold_secs = kill_connection_after.as_secs(),
"CONNECTION TIMEOUT - no packets received for {:.8}s",
elapsed.as_secs_f64()
);
if let Some(ref handle) = self.keep_alive_handle {
let task_state = if handle.is_finished() { "finished" } else { "running" };
tracing::debug!(
target: "freenet_core::transport::keepalive_timeout",
remote = ?self.remote_conn.remote_addr,
keepalive_task = task_state,
"Connection timed out, keepalive task was {task_state}"
);
}
return Err(TransportError::ConnectionClosed(self.remote_addr()));
}
{
let mut pending = self.pending_pings.write();
let stale_threshold_nanos = now_nanos.saturating_sub(kill_connection_after_nanos);
pending.retain(|_, sent_at_nanos| *sent_at_nanos > stale_threshold_nanos);
}
let pending_ping_count = self.pending_pings.read().len();
if pending_ping_count > MAX_UNANSWERED_PINGS {
tracing::debug!(
target: "freenet_core::transport::keepalive_health",
remote = ?self.remote_conn.remote_addr,
pending_pings = pending_ping_count,
max_unanswered = MAX_UNANSWERED_PINGS,
time_since_last_received_secs = elapsed.as_secs_f64(),
"Many unanswered pings ({} > {}), relying on idle timeout for liveness",
pending_ping_count,
MAX_UNANSWERED_PINGS
);
}
let remaining_nanos = kill_connection_after_nanos.saturating_sub(elapsed_nanos);
tracing::trace!(
target: "freenet_core::transport::keepalive_health",
remote = ?self.remote_conn.remote_addr,
elapsed_seconds = elapsed.as_secs_f64(),
remaining_seconds = Duration::from_nanos(remaining_nanos).as_secs_f64(),
pending_pings = pending_ping_count,
"Connection health check - still alive"
);
self.sweep_idle_streaming_handles();
},
_ = async { resend_check_sleep.take().unwrap_or_else(|| Box::pin(std::future::ready(()))).await } => {
const MAX_RESENDS_PER_ITERATION: usize = 4;
let mut resend_count = 0;
loop {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
"Checking for resends"
);
let maybe_resend = self.remote_conn
.sent_tracker
.lock()
.get_resend();
let (idx, packet) = match maybe_resend {
ResendAction::WaitUntil(deadline_nanos) => {
resend_check_sleep = Some(self.time_source.sleep_until(deadline_nanos));
break;
}
ResendAction::Resend(idx, packet) => {
self.remote_conn.congestion_controller.on_timeout();
(idx, packet)
}
ResendAction::TlpProbe(idx, packet) => {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
packet_id = idx,
"Sending TLP probe"
);
(idx, packet)
}
};
match self.remote_conn
.socket
.send_to(&packet, self.remote_conn.remote_addr)
.await
{
Ok(_) => {
self.remote_conn.sent_tracker.lock().report_sent_packet(idx, packet);
}
Err(e) => {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
packet_id = idx,
error = %e,
"Resend send failed, will retry on next RTO"
);
self.remote_conn
.sent_tracker
.lock()
.report_sent_packet(idx, packet);
break;
}
}
resend_count += 1;
if resend_count >= MAX_RESENDS_PER_ITERATION {
resend_check_sleep = Some(self.time_source.sleep(resend_yield));
break;
}
}
},
_ = ack_check.tick() => {
let receipts = self.received_tracker.get_receipts();
if !receipts.is_empty() {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
receipt_count = receipts.len(),
"Background ACK timer: sending pending receipts"
);
if let Err(e) = self.noop(receipts).await {
if e.is_transient_send_failure() {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
"Background ACK send failed, will retry next tick"
);
} else {
return Err(e);
}
}
}
},
_ = rate_update_check.tick() => {
let now_nanos = self.time_source.now_nanos();
let base_delay = self.remote_conn.congestion_controller.base_delay();
let rtt = if base_delay.is_zero() {
self.remote_conn.sent_tracker.lock().min_rtt()
} else {
base_delay
};
let should_update = match self.last_rate_update_nanos {
None => true, Some(last_update_nanos) => {
let elapsed_nanos = now_nanos.saturating_sub(last_update_nanos);
let min_interval = rtt.max(Duration::from_millis(50)).min(Duration::from_millis(500));
elapsed_nanos >= min_interval.as_nanos() as u64
}
};
if should_update {
let cc_rate = self.remote_conn.congestion_controller.current_rate(rtt) as u64;
let cwnd = self.remote_conn.congestion_controller.current_cwnd();
let queuing_delay = self.remote_conn.congestion_controller.queuing_delay();
let (new_rate, global_limit) = if let Some(ref global) =
self.remote_conn.global_bandwidth
{
let global_rate = global.current_per_connection_rate() as u64;
(cc_rate.min(global_rate), Some(global_rate))
} else {
(cc_rate, None)
};
let since_last_update_ms = self.last_rate_update_nanos
.map(|last| Duration::from_nanos(now_nanos.saturating_sub(last)).as_millis())
.unwrap_or(0);
self.remote_conn.token_bucket.set_rate(new_rate as usize);
self.last_rate_update_nanos = Some(now_nanos);
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
new_rate_bytes_per_sec = new_rate,
new_rate_mbps = (new_rate as f64) / 1_000_000.0,
cc_rate_mbps = (cc_rate as f64) / 1_000_000.0,
global_limit_mbps = global_limit.map(|r| (r as f64) / 1_000_000.0),
cwnd_bytes = cwnd,
cwnd_packets = cwnd / MAX_DATA_SIZE,
base_delay_ms = base_delay.as_millis(),
rtt_ms = rtt.as_millis(),
queuing_delay_ms = queuing_delay.as_millis(),
since_last_update_ms = since_last_update_ms,
"Congestion control metrics (RTT-adaptive rate update)"
);
}
}
}
}
}
#[allow(dead_code)]
pub fn my_address(&self) -> Option<SocketAddr> {
self.remote_conn.my_address
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_conn.remote_addr
}
#[allow(dead_code)]
pub(crate) fn recv_stream_handle(
&self,
stream_id: StreamId,
) -> Option<streaming::StreamHandle> {
self.streaming_handles
.get(&stream_id)
.map(|(handle, _)| handle.clone())
}
pub fn streaming_registry(&self) -> Arc<streaming::StreamRegistry> {
Arc::clone(&self.streaming_registry)
}
async fn process_inbound(
&mut self,
payload: SymmetricMessagePayload,
) -> Result<Option<Vec<u8>>> {
use SymmetricMessagePayload::*;
match payload {
ShortMessage { payload } => {
let bytes = payload.to_vec();
self.dispatched_msg_hashes.put(Self::msg_hash(&bytes), ());
Ok(Some(bytes))
}
AckConnection { result: Err(cause) } => {
Err(TransportError::ConnectionEstablishmentFailure { cause })
}
AckConnection { result: Ok(_) } => {
let packet = SymmetricMessage::ack_ok(
&self.remote_conn.outbound_symmetric_key,
self.remote_conn.inbound_symmetric_key_bytes,
self.remote_conn.remote_addr,
)?;
if let Err(e) = self
.remote_conn
.socket
.send_to(packet.data(), self.remote_conn.remote_addr)
.await
{
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
error = %e,
"AckOk send failed, peer will retransmit if needed"
);
}
Ok(None)
}
StreamFragment {
stream_id,
total_length_bytes,
fragment_number,
payload,
metadata_bytes,
} => {
let now_nanos = self.time_source.now_nanos();
if let Some((streaming_handle, last_activity_nanos)) =
self.streaming_handles.get_mut(&stream_id)
{
let push_result =
streaming_handle.push_fragment(fragment_number, payload.clone());
if matches!(push_result, Ok(true)) {
*last_activity_nanos = now_nanos;
}
if let Err(e) = push_result {
if matches!(e, streaming::StreamError::Cancelled) {
self.streaming_handles.remove(&stream_id);
self.streaming_registry.remove(stream_id);
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
"Stream cancelled, removed from handles and registry"
);
} else {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
error = %e,
"Failed to push fragment to streaming handle"
);
}
}
} else {
let streaming_handle = self
.streaming_registry
.register(stream_id, total_length_bytes);
if let Err(e) = streaming_handle.push_fragment(fragment_number, payload.clone())
{
if matches!(e, streaming::StreamError::Cancelled) {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
"New stream already cancelled, not registering"
);
} else {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
error = %e,
"Failed to push first fragment to streaming handle"
);
if let Some(orphan_registry) = &self.orphan_stream_registry {
orphan_registry.register_orphan(
self.remote_conn.remote_addr,
stream_id,
streaming_handle.clone(),
);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
"Registered stream as orphan for operations layer"
);
} else if stream_id.is_operations_stream() {
tracing::error!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
"Operations stream fragment arrived but orphan_stream_registry is None! \
This will cause claim_or_wait() to timeout. Check connection setup."
);
}
self.streaming_handles
.insert(stream_id, (streaming_handle, now_nanos));
}
} else {
if let Some(orphan_registry) = &self.orphan_stream_registry {
orphan_registry.register_orphan(
self.remote_conn.remote_addr,
stream_id,
streaming_handle.clone(),
);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
"Registered stream as orphan for operations layer"
);
} else if stream_id.is_operations_stream() {
tracing::error!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
"Operations stream fragment arrived but orphan_stream_registry is None! \
This will cause claim_or_wait() to timeout. Check connection setup."
);
}
self.streaming_handles
.insert(stream_id, (streaming_handle, now_nanos));
}
}
if stream_id.is_operations_stream() {
if let Some(meta) = metadata_bytes {
let bytes = meta.to_vec();
if self.is_duplicate_dispatch(&bytes) {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
"Suppressing duplicate embedded metadata (already dispatched via ShortMessage)"
);
} else {
tracing::debug!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
meta_len = bytes.len(),
"Dispatching embedded metadata from fragment #1"
);
return Ok(Some(bytes));
}
}
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
"Operations stream fragment - skipping legacy InboundStream path"
);
return Ok(None);
}
if let Some(sender) = self.inbound_streams.get(&stream_id) {
sender
.send((fragment_number, payload))
.await
.map_err(|_| TransportError::ConnectionClosed(self.remote_addr()))?;
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
"Fragment pushed to existing stream"
);
} else {
let (sender, receiver) = mpsc::channel(64);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
"New stream"
);
self.inbound_streams.insert(stream_id, sender);
let mut stream = inbound_stream::InboundStream::new(total_length_bytes);
if let Some(msg) = stream.push_fragment(fragment_number, payload) {
self.inbound_streams.remove(&stream_id);
self.streaming_handles.remove(&stream_id);
self.streaming_registry.remove(stream_id);
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
stream_id = %stream_id,
fragment_number,
"Stream finished"
);
return Ok(Some(msg));
}
self.inbound_stream_futures.push(GlobalExecutor::spawn(
inbound_stream::recv_stream(stream_id, receiver, stream),
));
}
Ok(None)
}
NoOp => Ok(None),
Ping { .. } | Pong { .. } => Ok(None),
}
}
#[inline]
async fn noop(&mut self, receipts: Vec<u32>) -> Result<()> {
let token = self
.remote_conn
.congestion_controller
.on_send_with_token(50);
packet_sending(
self.remote_conn.remote_addr,
&self.remote_conn.socket,
self.remote_conn
.last_packet_id
.fetch_add(1, std::sync::atomic::Ordering::Release),
&self.remote_conn.outbound_symmetric_key,
receipts,
(),
&self.remote_conn.sent_tracker,
token,
)
.await
}
async fn send_pong(&mut self, sequence: u64) -> Result<()> {
let packet_id = self
.remote_conn
.last_packet_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let pong_packet = SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::Pong { sequence },
&self.remote_conn.outbound_symmetric_key,
vec![], )?
.prepared_send();
match self
.remote_conn
.socket
.send_to(&pong_packet, self.remote_conn.remote_addr)
.await
{
Ok(_) => {
tracing::trace!(
peer_addr = %self.remote_conn.remote_addr,
packet_id,
pong_sequence = sequence,
"Pong packet sent"
);
}
Err(e) => {
tracing::warn!(
peer_addr = %self.remote_conn.remote_addr,
error = %e,
pong_sequence = sequence,
"Pong send failed, keepalive will handle liveness"
);
}
}
Ok(())
}
#[inline]
pub(crate) async fn outbound_short_message(&mut self, data: SerializedMessage) -> Result<()> {
let receipts = self.received_tracker.get_receipts();
let packet_id = self
.remote_conn
.last_packet_id
.fetch_add(1, std::sync::atomic::Ordering::Release);
let packet_size = data.len() + 40; let token = self
.remote_conn
.congestion_controller
.on_send_with_token(packet_size);
packet_sending(
self.remote_conn.remote_addr,
&self.remote_conn.socket,
packet_id,
&self.remote_conn.outbound_symmetric_key,
receipts,
symmetric_message::ShortMessage(data.into()),
&self.remote_conn.sent_tracker,
token,
)
.await?;
Ok(())
}
async fn outbound_stream(&mut self, data: SerializedMessage) {
let stream_id = StreamId::next();
self.outbound_stream_with_id(stream_id, data.into(), None, None)
.await;
}
async fn outbound_stream_with_id(
&mut self,
stream_id: StreamId,
data: bytes::Bytes,
metadata: Option<bytes::Bytes>,
completion_tx: Option<tokio::sync::oneshot::Sender<()>>,
) {
let task = GlobalExecutor::spawn(
outbound_stream::send_stream(
stream_id,
self.remote_conn.last_packet_id.clone(),
self.remote_conn.socket.clone(),
self.remote_conn.remote_addr,
data,
self.remote_conn.outbound_symmetric_key.clone(),
self.remote_conn.sent_tracker.clone(),
self.remote_conn.token_bucket.clone(),
self.remote_conn.congestion_controller.clone(),
self.time_source.clone(),
metadata,
completion_tx,
)
.instrument(span!(tracing::Level::DEBUG, "outbound_stream")),
);
self.outbound_stream_futures.push(task);
}
async fn pipe_stream_to_remote(
&mut self,
outbound_stream_id: StreamId,
inbound_handle: streaming::StreamHandle,
metadata: Option<bytes::Bytes>,
) {
let task = GlobalExecutor::spawn(
outbound_stream::pipe_stream(
inbound_handle,
outbound_stream_id,
self.remote_conn.last_packet_id.clone(),
self.remote_conn.socket.clone(),
self.remote_conn.remote_addr,
self.remote_conn.outbound_symmetric_key.clone(),
self.remote_conn.sent_tracker.clone(),
self.remote_conn.token_bucket.clone(),
self.remote_conn.congestion_controller.clone(),
self.time_source.clone(),
metadata,
)
.instrument(span!(tracing::Level::DEBUG, "pipe_stream")),
);
self.outbound_stream_futures.push(task);
}
#[allow(dead_code)] pub(crate) async fn send_fragment(
&mut self,
fragment: piped_stream::ForwardFragment,
) -> Result<()> {
let packet_size = fragment.payload.len();
let mut cwnd_wait_iterations = 0;
loop {
let flightsize = self.remote_conn.congestion_controller.flightsize();
let cwnd = self.remote_conn.congestion_controller.current_cwnd();
if flightsize + packet_size <= cwnd {
break;
}
cwnd_wait_iterations += 1;
if cwnd_wait_iterations == 1 {
tracing::trace!(
stream_id = %fragment.stream_id.0,
fragment_number = fragment.fragment_number,
flightsize_kb = flightsize / 1024,
cwnd_kb = cwnd / 1024,
"Waiting for cwnd space in send_fragment"
);
}
if cwnd_wait_iterations <= 10 {
tokio::task::yield_now().await;
} else if cwnd_wait_iterations <= 100 {
tokio::time::sleep(Duration::from_micros(100)).await;
} else {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
let wait_time = self.remote_conn.token_bucket.reserve(packet_size);
if !wait_time.is_zero() {
tracing::trace!(
stream_id = %fragment.stream_id.0,
fragment_number = fragment.fragment_number,
wait_time_ms = wait_time.as_millis(),
"Rate limiting fragment send"
);
tokio::time::sleep(wait_time).await;
}
let receipts = self.received_tracker.get_receipts();
let packet_id = self
.remote_conn
.last_packet_id
.fetch_add(1, std::sync::atomic::Ordering::Release);
let token = self
.remote_conn
.congestion_controller
.on_send_with_token(packet_size);
packet_sending(
self.remote_conn.remote_addr,
&self.remote_conn.socket,
packet_id,
&self.remote_conn.outbound_symmetric_key,
receipts,
symmetric_message::StreamFragment {
stream_id: fragment.stream_id,
total_length_bytes: fragment.total_bytes,
fragment_number: fragment.fragment_number,
payload: fragment.payload,
metadata_bytes: None,
},
&self.remote_conn.sent_tracker,
token,
)
.await?;
tracing::trace!(
stream_id = %fragment.stream_id.0,
fragment_number = fragment.fragment_number,
packet_id,
"Fragment sent"
);
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
async fn packet_sending<S: super::Socket, T: crate::simulation::TimeSource>(
remote_addr: SocketAddr,
socket: &Arc<S>,
packet_id: u32,
outbound_sym_key: &Aes128Gcm,
confirm_receipt: Vec<u32>,
payload: impl Into<SymmetricMessagePayload>,
sent_tracker: &parking_lot::Mutex<SentPacketTracker<T>>,
delivery_token: Option<DeliveryRateToken>,
) -> Result<()> {
let start_time = tokio::time::Instant::now();
tracing::trace!(
peer_addr = %remote_addr,
packet_id,
"Attempting to send packet"
);
match SymmetricMessage::try_serialize_msg_to_packet_data(
packet_id,
payload,
outbound_sym_key,
confirm_receipt,
)? {
either::Either::Left(packet) => {
let packet_size = packet.data().len();
tracing::trace!(
peer_addr = %remote_addr,
packet_id,
packet_size,
"Sending single packet"
);
let packet_data = packet.prepared_send();
match socket.send_to(&packet_data, remote_addr).await {
Ok(_) => {
let elapsed = start_time.elapsed();
tracing::trace!(
peer_addr = %remote_addr,
packet_id,
elapsed_ms = elapsed.as_millis(),
"Successfully sent packet"
);
sent_tracker.lock().report_sent_packet_with_token(
packet_id,
packet_data,
delivery_token,
);
Ok(())
}
Err(e) => {
tracing::warn!(
peer_addr = %remote_addr,
packet_id,
error = %e,
"Failed to send packet (transient)"
);
Err(TransportError::SendFailed(remote_addr, e.kind()))
}
}
}
either::Either::Right((payload, mut confirm_receipt)) => {
tracing::trace!(
peer_addr = %remote_addr,
packet_id,
"Sending multi-packet message"
);
macro_rules! send {
($packets:ident) => {{
for packet in $packets {
let packet_data = packet.prepared_send();
socket
.send_to(&packet_data, remote_addr)
.await
.map_err(|e| TransportError::SendFailed(remote_addr, e.kind()))?;
sent_tracker.lock().report_sent_packet_with_token(
packet_id,
packet_data,
delivery_token,
);
}
}};
}
let max_num = SymmetricMessage::max_num_of_confirm_receipts_of_noop_message();
let packet = SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
payload,
outbound_sym_key,
vec![],
)?;
if max_num > confirm_receipt.len() {
let packets = [
packet,
SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::NoOp,
outbound_sym_key,
confirm_receipt,
)?,
];
send!(packets);
return Ok(());
}
let mut packets = Vec::with_capacity(8);
packets.push(packet);
while !confirm_receipt.is_empty() {
let len = confirm_receipt.len();
if len <= max_num {
packets.push(SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::NoOp,
outbound_sym_key,
confirm_receipt,
)?);
break;
}
let receipts = confirm_receipt.split_off(max_num);
packets.push(SymmetricMessage::serialize_msg_to_packet_data(
packet_id,
SymmetricMessagePayload::NoOp,
outbound_sym_key,
receipts,
)?);
}
send!(packets);
Ok(())
}
}
}
impl<S: super::Socket> super::PeerConnectionApi for PeerConnection<S> {
fn remote_addr(&self) -> std::net::SocketAddr {
self.remote_conn.remote_addr
}
fn send_message(
&mut self,
msg: crate::message::NetMessage,
) -> std::pin::Pin<
Box<dyn futures::Future<Output = Result<(), super::TransportError>> + Send + '_>,
> {
Box::pin(async move { self.send(msg).await })
}
fn recv(
&mut self,
) -> std::pin::Pin<
Box<dyn futures::Future<Output = Result<Vec<u8>, super::TransportError>> + Send + '_>,
> {
Box::pin(async move { PeerConnection::recv(self).await })
}
fn set_orphan_stream_registry(&mut self, registry: std::sync::Arc<OrphanStreamRegistry>) {
PeerConnection::set_orphan_stream_registry(self, registry);
}
fn send_stream_data(
&mut self,
stream_id: StreamId,
data: bytes::Bytes,
metadata: Option<bytes::Bytes>,
completion_tx: Option<tokio::sync::oneshot::Sender<()>>,
) -> std::pin::Pin<
Box<dyn futures::Future<Output = Result<(), super::TransportError>> + Send + '_>,
> {
Box::pin(async move {
self.outbound_stream_with_id(stream_id, data, metadata, completion_tx)
.await;
Ok(())
})
}
fn pipe_stream_data(
&mut self,
outbound_stream_id: StreamId,
inbound_handle: streaming::StreamHandle,
metadata: Option<bytes::Bytes>,
) -> std::pin::Pin<
Box<dyn futures::Future<Output = Result<(), super::TransportError>> + Send + '_>,
> {
Box::pin(async move {
self.pipe_stream_to_remote(outbound_stream_id, inbound_handle, metadata)
.await;
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use aes_gcm::KeyInit;
use futures::TryFutureExt;
use std::net::Ipv4Addr;
use super::{
inbound_stream::{InboundStream, recv_stream},
outbound_stream::send_stream,
*,
};
use crate::transport::packet_data::MAX_PACKET_SIZE;
use crate::transport::received_packet_tracker::MAX_PENDING_RECEIPTS;
use crate::transport::sent_packet_tracker::MAX_CONFIRMATION_DELAY;
struct TestSocket {
sender: mpsc::Sender<(SocketAddr, Arc<[u8]>)>,
}
impl TestSocket {
fn new(sender: mpsc::Sender<(SocketAddr, Arc<[u8]>)>) -> Self {
Self { sender }
}
}
impl crate::transport::Socket for TestSocket {
async fn bind(_addr: SocketAddr) -> std::io::Result<Self> {
unimplemented!()
}
async fn recv_from(&self, _buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
unimplemented!()
}
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
self.sender
.send((target, buf.into()))
.await
.map_err(|_| std::io::ErrorKind::ConnectionAborted)?;
Ok(buf.len())
}
fn send_to_blocking(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
self.sender
.blocking_send((target, buf.into()))
.map_err(|_| std::io::ErrorKind::ConnectionAborted)?;
Ok(buf.len())
}
}
struct FailableTestSocket {
sender: mpsc::Sender<(SocketAddr, Arc<[u8]>)>,
fail_sends: Arc<std::sync::atomic::AtomicBool>,
}
impl FailableTestSocket {
fn new(
sender: mpsc::Sender<(SocketAddr, Arc<[u8]>)>,
fail_sends: Arc<std::sync::atomic::AtomicBool>,
) -> Self {
Self { sender, fail_sends }
}
}
impl crate::transport::Socket for FailableTestSocket {
async fn bind(_addr: SocketAddr) -> std::io::Result<Self> {
unimplemented!()
}
async fn recv_from(&self, _buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
unimplemented!()
}
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
if self.fail_sends.load(std::sync::atomic::Ordering::Relaxed) {
return Err(std::io::Error::new(
std::io::ErrorKind::NetworkUnreachable,
"simulated ENETUNREACH",
));
}
self.sender
.send((target, buf.into()))
.await
.map_err(|_| std::io::ErrorKind::ConnectionAborted)?;
Ok(buf.len())
}
fn send_to_blocking(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
if self.fail_sends.load(std::sync::atomic::Ordering::Relaxed) {
return Err(std::io::Error::new(
std::io::ErrorKind::NetworkUnreachable,
"simulated ENETUNREACH",
));
}
self.sender
.blocking_send((target, buf.into()))
.map_err(|_| std::io::ErrorKind::ConnectionAborted)?;
Ok(buf.len())
}
}
#[test]
fn send_failure_returns_transient_error() {
let fail_flag = Arc::new(std::sync::atomic::AtomicBool::new(true));
let (tx, _rx) = mpsc::channel(16);
let socket = Arc::new(FailableTestSocket::new(tx, fail_flag));
let remote_addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let outbound_key = {
use aes_gcm::KeyInit;
Aes128Gcm::new(&[0u8; 16].into())
};
let sent_tracker = Arc::new(parking_lot::Mutex::new(
crate::transport::sent_packet_tracker::tests::mock_sent_packet_tracker(),
));
let result = packet_sending(
remote_addr,
&socket,
1,
&outbound_key,
vec![],
(),
&sent_tracker,
None,
)
.await;
let err = result.expect_err("should fail");
assert!(
err.is_transient_send_failure(),
"expected SendFailed, got: {err:?}"
);
assert!(
!matches!(err, TransportError::ConnectionClosed(_)),
"should NOT be ConnectionClosed"
);
});
}
#[test]
fn ack_check_interval_is_within_confirmation_window() {
assert!(
ACK_CHECK_INTERVAL <= MAX_CONFIRMATION_DELAY,
"ACK_CHECK_INTERVAL ({:?}) must not exceed MAX_CONFIRMATION_DELAY ({:?})",
ACK_CHECK_INTERVAL,
MAX_CONFIRMATION_DELAY
);
}
#[test]
fn ack_check_interval_is_reasonable() {
assert!(
ACK_CHECK_INTERVAL >= Duration::from_millis(10),
"ACK_CHECK_INTERVAL ({:?}) should be at least 10ms to avoid excessive CPU usage",
ACK_CHECK_INTERVAL
);
assert!(
ACK_CHECK_INTERVAL <= Duration::from_millis(100),
"ACK_CHECK_INTERVAL ({:?}) should be at most 100ms to ensure timely ACK delivery",
ACK_CHECK_INTERVAL
);
}
#[test]
fn pending_receipts_buffer_size_documented() {
assert_eq!(
MAX_PENDING_RECEIPTS, 20,
"MAX_PENDING_RECEIPTS changed - verify ACK timing behavior is still correct"
);
}
#[tokio::test]
async fn test_inbound_outbound_interaction() -> Result<(), Box<dyn std::error::Error>> {
const MSG_LEN: usize = 1000;
let (sender, mut receiver) = mpsc::channel(1);
let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080);
let mut message = vec![0u8; MSG_LEN];
crate::config::GlobalRng::fill_bytes(&mut message);
let mut key = [0u8; 16];
crate::config::GlobalRng::fill_bytes(&mut key);
let cipher = Aes128Gcm::new(&key.into());
let time_source = crate::simulation::VirtualTime::new();
let sent_tracker = Arc::new(parking_lot::Mutex::new(
SentPacketTracker::new_with_time_source(time_source.clone()),
));
let congestion_controller =
crate::transport::congestion_control::CongestionControlConfig::default()
.with_initial_cwnd(2928)
.with_min_cwnd(2928)
.with_max_cwnd(1_000_000_000)
.build_arc_with_time_source(time_source.clone());
let token_bucket = Arc::new(TokenBucket::new_with_time_source(
10_000,
10_000_000,
time_source.clone(),
));
let stream_id = StreamId::next();
let outbound = GlobalExecutor::spawn(send_stream(
stream_id,
Arc::new(AtomicU32::new(0)),
Arc::new(TestSocket::new(sender)),
remote_addr,
bytes::Bytes::from(message.clone()),
cipher.clone(),
sent_tracker,
token_bucket,
congestion_controller,
time_source,
None,
None,
))
.map_err(|e| e.into());
let inbound = async {
let (tx, rx) = mpsc::channel(1);
let stream = InboundStream::new(MSG_LEN as u64);
let inbound_msg = GlobalExecutor::spawn(recv_stream(stream_id, rx, stream));
while let Some((_, network_packet)) = receiver.recv().await {
let decrypted = PacketData::<_, MAX_PACKET_SIZE>::from_buf(&network_packet)
.try_decrypt_sym(&cipher)
.map_err(|e| e.to_string())?;
let SymmetricMessage {
payload:
SymmetricMessagePayload::StreamFragment {
fragment_number,
payload,
..
},
..
} = SymmetricMessage::deser(decrypted.data()).expect("symmetric message")
else {
return Err("unexpected message".into());
};
tx.send((fragment_number, payload)).await?;
}
let (_, msg) = inbound_msg
.await?
.map_err(|_| anyhow::anyhow!("stream failed"))?;
Ok::<_, Box<dyn std::error::Error>>(msg)
};
let (out_res, inbound_msg) = tokio::try_join!(outbound, inbound)?;
out_res?;
assert_eq!(message, inbound_msg);
Ok(())
}
#[test]
fn bincode_serialization_is_fast_enough_for_async() {
use crate::message::{NeighborHostingMessage, NetMessage, NetMessageV1};
use freenet_stdlib::prelude::ContractInstanceId;
use std::time::Instant;
fn assert_fast_serialize<T: serde::Serialize>(name: &str, value: &T) {
let start = Instant::now();
let serialized = bincode::serialize(value).expect("serialization failed");
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 10,
"{} serialization ({} bytes) took {:?}, expected < 10ms. \
If this fails consistently, reconsider whether spawn_blocking is needed.",
name,
serialized.len(),
elapsed
);
}
for size in [100, 1000, MAX_DATA_SIZE / 2, MAX_DATA_SIZE] {
let payload: Vec<u8> = (0..size).map(|i| i as u8).collect();
assert_fast_serialize(&format!("Vec<u8>[{}]", size), &payload);
}
let cache_msg = NetMessage::V1(NetMessageV1::NeighborHosting {
message: NeighborHostingMessage::HostingAnnounce {
added: vec![ContractInstanceId::new([1u8; 32])],
removed: vec![],
is_response: false,
},
});
assert_fast_serialize("NeighborHostingMessage", &cache_msg);
let large_cache = NetMessage::V1(NetMessageV1::NeighborHosting {
message: NeighborHostingMessage::HostingStateResponse {
contracts: (0..100)
.map(|i| ContractInstanceId::new([i as u8; 32]))
.collect(),
},
});
assert_fast_serialize("Large HostingStateResponse", &large_cache);
}
#[test]
fn test_flightsize_accounting_for_retransmitted_packets() {
use crate::transport::bbr::{BbrConfig, BbrController};
use crate::transport::sent_packet_tracker::SentPacketTracker;
use std::sync::Arc;
let congestion = Arc::new(BbrController::new(BbrConfig {
initial_cwnd: 38_000,
min_cwnd: 2_000,
max_cwnd: 10_000_000,
..Default::default()
}));
let mut tracker = SentPacketTracker::new();
let packet_size = 1424;
for packet_id in 0..5u32 {
let payload: Box<[u8]> = vec![0u8; packet_size].into_boxed_slice();
tracker.report_sent_packet(packet_id, payload);
congestion.on_send(packet_size);
}
assert_eq!(
congestion.flightsize(),
5 * packet_size,
"Initial flightsize should be 5 * packet_size"
);
let (ack_info, _) = tracker.report_received_receipts(&[0, 1, 2]);
assert_eq!(ack_info.len(), 3, "Should have 3 ACK entries");
for (rtt_opt, size, _token) in &ack_info {
assert!(
rtt_opt.is_some(),
"Non-retransmitted packets should have RTT samples"
);
assert_eq!(*size, packet_size, "Packet size should match");
}
for (rtt_opt, size, _token) in ack_info {
match rtt_opt {
Some(rtt) => congestion.on_ack(rtt, size),
None => congestion.on_ack_without_rtt(size),
}
}
assert_eq!(
congestion.flightsize(),
2 * packet_size,
"Flightsize should be decremented to 2 * packet_size"
);
tracker.mark_retransmitted(3);
let payload: Box<[u8]> = vec![0u8; packet_size].into_boxed_slice();
tracker.report_sent_packet(3, payload);
let (ack_info, _) = tracker.report_received_receipts(&[3]);
assert_eq!(
ack_info.len(),
1,
"Should have 1 ACK entry for retransmitted packet"
);
let (rtt_opt, size, _token) = &ack_info[0];
assert!(
rtt_opt.is_none(),
"Retransmitted packets should NOT have RTT samples (Karn's algorithm)"
);
assert_eq!(
*size, packet_size,
"Retransmitted packet ACK MUST still return packet size for flightsize decrement"
);
for (rtt_opt, size, _token) in ack_info {
match rtt_opt {
Some(rtt) => congestion.on_ack(rtt, size),
None => congestion.on_ack_without_rtt(size),
}
}
assert_eq!(
congestion.flightsize(),
packet_size,
"Flightsize should be decremented even for retransmitted packet ACKs"
);
let (ack_info, _) = tracker.report_received_receipts(&[4]);
for (rtt_opt, size, _token) in ack_info {
match rtt_opt {
Some(rtt) => congestion.on_ack(rtt, size),
None => congestion.on_ack_without_rtt(size),
}
}
assert_eq!(
congestion.flightsize(),
0,
"All packets ACKed, flightsize should be 0"
);
}
#[test]
fn keepalive_interval_backs_off_for_unanswered_pings() {
use super::{
KEEP_ALIVE_INTERVAL, MAX_KEEPALIVE_INTERVAL, MAX_UNANSWERED_PINGS,
keepalive_interval_for_pending,
};
for count in 0..=MAX_UNANSWERED_PINGS {
assert_eq!(
keepalive_interval_for_pending(count),
KEEP_ALIVE_INTERVAL,
"pending_count={count} should use base interval"
);
}
let expected = [
(1, Duration::from_secs(10)), (2, Duration::from_secs(20)), (3, Duration::from_secs(40)), (4, MAX_KEEPALIVE_INTERVAL), (100, MAX_KEEPALIVE_INTERVAL), ];
for (above, interval) in expected {
assert_eq!(
keepalive_interval_for_pending(MAX_UNANSWERED_PINGS + above),
interval,
"pending_count={} should produce {:?}",
MAX_UNANSWERED_PINGS + above,
interval,
);
}
}
fn msg_hash(bytes: &[u8]) -> u64 {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut h);
h.finish()
}
fn is_duplicate_dispatch(cache: &mut lru::LruCache<u64, ()>, bytes: &[u8]) -> bool {
cache.put(msg_hash(bytes), ()).is_some()
}
fn dedup_cache(cap: usize) -> lru::LruCache<u64, ()> {
lru::LruCache::new(NonZeroUsize::new(cap).unwrap())
}
#[ignore]
#[test]
fn dispatched_short_message_always_recorded() {
let mut cache = dedup_cache(1000);
let bytes = b"metadata-payload";
assert!(
cache.put(msg_hash(bytes), ()).is_none(),
"first insert should return None (not present)"
);
assert!(
cache.put(msg_hash(bytes), ()).is_some(),
"second insert returns Some (was present)"
);
}
#[test]
fn duplicate_embedded_metadata_suppressed() {
let mut cache = dedup_cache(1000);
let bytes = b"metadata-payload";
assert!(
!is_duplicate_dispatch(&mut cache, bytes),
"first insert is not a duplicate"
);
assert!(
is_duplicate_dispatch(&mut cache, bytes),
"second insert is a duplicate"
);
}
#[test]
fn different_metadata_not_suppressed() {
let mut cache = dedup_cache(1000);
assert!(!is_duplicate_dispatch(&mut cache, b"first-metadata"));
assert!(!is_duplicate_dispatch(&mut cache, b"different-metadata"));
}
#[test]
fn lru_eviction_preserves_recent_entries() {
let mut cache = dedup_cache(3);
for i in 1..=4 {
cache.put(msg_hash(format!("msg-{i}").as_bytes()), ());
}
assert!(cache.contains(&msg_hash(b"msg-4")), "msg-4 should survive");
assert!(cache.contains(&msg_hash(b"msg-3")), "msg-3 should survive");
assert!(
!cache.contains(&msg_hash(b"msg-1")),
"msg-1 should be evicted"
);
}
#[tokio::test]
async fn last_received_nanos_is_persistent_field() {
use crate::transport::crypto::TransportKeypair;
use crate::util::time_source::SharedMockTimeSource;
let time_source = SharedMockTimeSource::new();
let (_inbound_tx, inbound_rx) = mpsc::channel(16);
let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 9999);
let mut key = [0u8; 16];
crate::config::GlobalRng::fill_bytes(&mut key);
let cipher = Aes128Gcm::new(&key.into());
let keypair = TransportKeypair::new();
let sent_tracker = Arc::new(parking_lot::Mutex::new(
SentPacketTracker::new_with_time_source(time_source.clone()),
));
let congestion_controller =
crate::transport::congestion_control::CongestionControlConfig::default()
.build_arc_with_time_source(time_source.clone());
let token_bucket = Arc::new(TokenBucket::new_with_time_source(
10_000,
10_000_000,
time_source.clone(),
));
let socket = Arc::new(TestSocket::new(
mpsc::channel::<(SocketAddr, Arc<[u8]>)>(16).0,
));
let rolling_rtt_stats = crate::transport::rolling_rtt_stats::RollingRttStatsHandle::new(
remote_addr,
time_source.clone(),
);
let remote_conn = RemoteConnection {
outbound_symmetric_key: cipher.clone(),
remote_addr,
sent_tracker,
last_packet_id: Arc::new(AtomicU32::new(0)),
inbound_packet_recv: inbound_rx,
inbound_symmetric_key: cipher,
inbound_symmetric_key_bytes: key,
my_address: None,
transport_secret_key: keypair.secret,
congestion_controller,
token_bucket,
socket,
global_bandwidth: None,
rolling_rtt_stats,
time_source: time_source.clone(),
};
let creation_time = time_source.now_nanos();
let conn = PeerConnection::new(remote_conn);
assert_eq!(
conn.last_received_nanos, creation_time,
"last_received_nanos should be set to creation time"
);
for i in 0..5 {
time_source.advance_time(Duration::from_secs(10));
let now = time_source.now_nanos();
assert_ne!(now, creation_time, "time should have advanced");
assert_eq!(
conn.last_received_nanos, creation_time,
"last_received_nanos must not change without inbound packets \
(iteration {i}, bug #3369). A local variable in recv() would \
return now()={now} instead of the stored creation_time={creation_time}"
);
}
}
#[tokio::test]
async fn recv_records_packet_metrics_post_authentication() {
use crate::transport::crypto::TransportKeypair;
use crate::transport::packet_data::PacketData;
use crate::transport::symmetric_message::SymmetricMessagePayload;
use crate::util::time_source::SharedMockTimeSource;
use bytes::Bytes;
let time_source = SharedMockTimeSource::new();
let (inbound_tx, inbound_rx) = mpsc::channel(16);
let remote_addr = SocketAddr::new(Ipv4Addr::new(10, 99, 99, 1).into(), 50001);
let mut key = [0u8; 16];
crate::config::GlobalRng::fill_bytes(&mut key);
let cipher = Aes128Gcm::new(&key.into());
let keypair = TransportKeypair::new();
let sent_tracker = Arc::new(parking_lot::Mutex::new(
SentPacketTracker::new_with_time_source(time_source.clone()),
));
let congestion_controller =
crate::transport::congestion_control::CongestionControlConfig::default()
.build_arc_with_time_source(time_source.clone());
let token_bucket = Arc::new(TokenBucket::new_with_time_source(
10_000,
10_000_000,
time_source.clone(),
));
let socket = Arc::new(TestSocket::new(
mpsc::channel::<(SocketAddr, Arc<[u8]>)>(16).0,
));
let rolling_rtt_stats = crate::transport::rolling_rtt_stats::RollingRttStatsHandle::new(
remote_addr,
time_source.clone(),
);
let remote_conn = RemoteConnection {
outbound_symmetric_key: cipher.clone(),
remote_addr,
sent_tracker,
last_packet_id: Arc::new(AtomicU32::new(0)),
inbound_packet_recv: inbound_rx,
inbound_symmetric_key: cipher.clone(),
inbound_symmetric_key_bytes: key,
my_address: None,
transport_secret_key: keypair.secret,
congestion_controller,
token_bucket,
socket,
global_bandwidth: None,
rolling_rtt_stats,
time_source,
};
let plaintext: Vec<u8> = (0..100u8).collect();
let payload = SymmetricMessagePayload::ShortMessage {
payload: Bytes::from(plaintext.clone()),
};
let encrypted = SymmetricMessage::serialize_msg_to_packet_data(1, payload, &cipher, vec![])
.expect("encrypt");
let encrypted_bytes = encrypted.data().to_vec();
let wire_size = encrypted_bytes.len() as u64;
assert!(
wire_size > plaintext.len() as u64,
"encrypted wire size ({wire_size}) must exceed plaintext size ({}) — \
AEAD tag + nonce + framing always grows the packet",
plaintext.len()
);
let packet = PacketData::<crate::transport::packet_data::UnknownEncryption>::from_buf(
&encrypted_bytes,
);
inbound_tx.send(packet).await.expect("send");
let metrics_before_recv = crate::transport::TRANSPORT_METRICS
.per_peer_snapshot()
.into_iter()
.find(|(a, _, _)| *a == remote_addr);
assert!(
metrics_before_recv.is_none(),
"test precondition: no entry for our unique remote_addr"
);
let cumulative_before = crate::transport::TRANSPORT_METRICS.cumulative_bytes_received();
let mut conn = PeerConnection::new(remote_conn);
let _msg = conn.recv().await.expect("recv");
let entry = crate::transport::TRANSPORT_METRICS
.per_peer_snapshot()
.into_iter()
.find(|(a, _, _)| *a == remote_addr)
.expect("post-auth recv must create a per-peer entry");
assert_eq!(
entry.2, wire_size,
"per-peer bytes_received must equal the encrypted wire size"
);
let cumulative_after = crate::transport::TRANSPORT_METRICS.cumulative_bytes_received();
assert!(
cumulative_after >= cumulative_before + wire_size,
"cumulative_bytes_received must include this packet (before={cumulative_before}, \
after={cumulative_after}, wire_size={wire_size})"
);
crate::transport::TRANSPORT_METRICS.remove_peer(remote_addr);
}
#[test]
fn sweep_streaming_handles_inner_drops_idle_entries() {
use streaming::{StreamError, StreamRegistry};
let now_nanos = STREAMING_HANDLE_IDLE_TIMEOUT_PROD.as_nanos() as u64 * 4;
let stale_nanos = 0u64; let fresh_nanos = now_nanos.saturating_sub(Duration::from_secs(1).as_nanos() as u64);
let registry = StreamRegistry::new();
let stale_id = StreamId::next();
let fresh_id = StreamId::next();
let stale_handle = registry.register(stale_id, 64 * 1024);
let fresh_handle = registry.register(fresh_id, 64 * 1024);
let stale_clone = stale_handle.clone();
let mut streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)> =
HashMap::new();
streaming_handles.insert(stale_id, (stale_handle, stale_nanos));
streaming_handles.insert(fresh_id, (fresh_handle, fresh_nanos));
assert_eq!(streaming_handles.len(), 2);
assert_eq!(registry.stream_count(), 2);
let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 9000);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert_eq!(
streaming_handles.len(),
1,
"stale entry should have been swept out of streaming_handles"
);
assert!(streaming_handles.contains_key(&fresh_id));
assert!(!streaming_handles.contains_key(&stale_id));
assert_eq!(
registry.stream_count(),
1,
"stale entry should have been removed from streaming_registry"
);
let push_result =
stale_clone.push_fragment(1, bytes::Bytes::from_static(b"would-leak-without-fix"));
assert!(
matches!(push_result, Err(StreamError::Cancelled)),
"stale handle should be cancelled; got {:?}",
push_result
);
}
#[test]
fn sweep_streaming_handles_inner_keeps_recent_entries() {
use streaming::StreamRegistry;
let registry = StreamRegistry::new();
let id = StreamId::next();
let handle = registry.register(id, 64 * 1024);
let now_nanos = Duration::from_secs(3600).as_nanos() as u64;
let last_activity_nanos =
now_nanos.saturating_sub(STREAMING_HANDLE_IDLE_TIMEOUT_PROD.as_nanos() as u64) + 1;
let mut streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)> =
HashMap::new();
streaming_handles.insert(id, (handle.clone(), last_activity_nanos));
let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 9001);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert_eq!(
streaming_handles.len(),
1,
"in-window entry must not be swept"
);
assert_eq!(registry.stream_count(), 1);
let result = handle.push_fragment(1, bytes::Bytes::from_static(b"still-alive"));
assert!(
result.is_ok(),
"in-window handle must not be cancelled by sweep, got {:?}",
result
);
}
#[test]
fn sweep_streaming_handles_inner_threshold_boundary_is_strict_greater_than() {
use streaming::StreamRegistry;
let registry = StreamRegistry::new();
let exact_id = StreamId::next();
let past_id = StreamId::next();
let exact_handle = registry.register(exact_id, 64 * 1024);
let past_handle = registry.register(past_id, 64 * 1024);
let threshold_nanos = STREAMING_HANDLE_IDLE_TIMEOUT_PROD.as_nanos() as u64;
let now_nanos = threshold_nanos * 10;
let exact_last_activity = now_nanos - threshold_nanos;
let past_last_activity = now_nanos - threshold_nanos - 1;
let mut streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)> =
HashMap::new();
streaming_handles.insert(exact_id, (exact_handle, exact_last_activity));
streaming_handles.insert(past_id, (past_handle, past_last_activity));
let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 9002);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert!(
streaming_handles.contains_key(&exact_id),
"entry at idle = threshold must be KEPT (strict greater-than)"
);
assert!(
!streaming_handles.contains_key(&past_id),
"entry at idle = threshold + 1 ns must be DROPPED"
);
}
#[test]
fn sweep_streaming_handles_inner_is_idempotent_on_repeat() {
use streaming::StreamRegistry;
let registry = StreamRegistry::new();
let id = StreamId::next();
let handle = registry.register(id, 64 * 1024);
let now_nanos = STREAMING_HANDLE_IDLE_TIMEOUT_PROD.as_nanos() as u64 * 4;
let mut streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)> =
HashMap::new();
streaming_handles.insert(id, (handle, 0));
let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 9003);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert_eq!(streaming_handles.len(), 0);
assert_eq!(registry.stream_count(), 0);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert_eq!(streaming_handles.len(), 0);
assert_eq!(registry.stream_count(), 0);
}
#[test]
fn sweep_streaming_handles_inner_releases_buffer_arc() {
use streaming::StreamRegistry;
let registry = StreamRegistry::new();
let id = StreamId::next();
let handle = registry.register(id, 64 * 1024);
let initial_strong = handle.buffer_strong_count();
assert_eq!(
initial_strong, 2,
"registry + local handle = 2 buffer clones at start"
);
let mut streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)> =
HashMap::new();
streaming_handles.insert(id, (handle.clone(), 0));
assert_eq!(handle.buffer_strong_count(), 3);
let now_nanos = STREAMING_HANDLE_IDLE_TIMEOUT_PROD.as_nanos() as u64 * 4;
let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 9004);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert_eq!(
handle.buffer_strong_count(),
1,
"post-sweep, only the test's local handle clone of \
Arc<LockFreeStreamBuffer> should remain — production code \
must release the slot-array allocation, see issue #4079"
);
}
#[test]
fn sweep_streaming_handles_inner_partitions_mixed_age_entries() {
use streaming::StreamRegistry;
let registry = StreamRegistry::new();
let threshold_nanos = STREAMING_HANDLE_IDLE_TIMEOUT_PROD.as_nanos() as u64;
let now_nanos = threshold_nanos * 100;
let mut streaming_handles: HashMap<StreamId, (streaming::StreamHandle, u64)> =
HashMap::new();
let mut expected_to_keep = Vec::new();
let mut expected_to_drop = Vec::new();
for i in 0..10u64 {
let id = StreamId::next();
let handle = registry.register(id, 64 * 1024);
let last_activity_nanos = if i < 5 {
now_nanos - (threshold_nanos / 2) } else {
now_nanos - threshold_nanos * 2 };
streaming_handles.insert(id, (handle, last_activity_nanos));
if i < 5 {
expected_to_keep.push(id);
} else {
expected_to_drop.push(id);
}
}
assert_eq!(streaming_handles.len(), 10);
assert_eq!(registry.stream_count(), 10);
let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 9005);
sweep_streaming_handles_inner(
&mut streaming_handles,
®istry,
now_nanos,
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
peer_addr,
);
assert_eq!(streaming_handles.len(), 5, "5 fresh entries should survive");
assert_eq!(
registry.stream_count(),
5,
"5 fresh registry rows should survive"
);
for id in &expected_to_keep {
assert!(streaming_handles.contains_key(id), "kept entry missing");
}
for id in &expected_to_drop {
assert!(
!streaming_handles.contains_key(id),
"dropped entry still present"
);
}
}
#[test]
fn sweep_idle_streaming_handles_is_invoked_from_timeout_check() {
let source = include_str!("peer_connection.rs");
assert!(
source.contains("self.sweep_idle_streaming_handles();"),
"PR #4083 / issue #4079: recv()'s timeout_check.tick() arm must \
call self.sweep_idle_streaming_handles() — without it the \
streaming_handles leak comes back. If you intentionally moved \
the call, update this pin (and ensure the new site fires at \
least as often as the 5 s timeout_check tick)."
);
}
#[test]
fn streaming_handle_idle_timeout_switches_to_sim_value_when_flag_enabled() {
use crate::config::SimulationIdleTimeout;
SimulationIdleTimeout::disable();
assert_eq!(
streaming_handle_idle_timeout(),
STREAMING_HANDLE_IDLE_TIMEOUT_PROD,
"production threshold required when flag is OFF"
);
SimulationIdleTimeout::enable();
assert_eq!(
streaming_handle_idle_timeout(),
STREAMING_HANDLE_IDLE_TIMEOUT_SIM,
"simulation threshold required when flag is ON"
);
SimulationIdleTimeout::disable();
}
#[test]
fn last_activity_refresh_strictly_gated_on_ok_true() {
use streaming::{StreamError, StreamRegistry};
use streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let registry = StreamRegistry::new();
let id = StreamId::next();
let handle = registry.register(id, (FRAGMENT_PAYLOAD_SIZE * 2) as u64);
let r_new = handle.push_fragment(1, bytes::Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]));
assert!(
matches!(r_new, Ok(true)),
"first push of fragment 1 must be Ok(true); got {:?}",
r_new
);
let r_dup = handle.push_fragment(1, bytes::Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]));
assert!(
matches!(r_dup, Ok(false)),
"duplicate push of fragment 1 must be Ok(false); got {:?}",
r_dup
);
let r_invalid = handle.push_fragment(99, bytes::Bytes::from_static(b"oob"));
assert!(
matches!(r_invalid, Err(StreamError::InvalidFragment { .. })),
"out-of-range push must be Err(InvalidFragment); got {:?}",
r_invalid
);
let source = include_str!("peer_connection.rs");
assert!(
source.contains("if matches!(push_result, Ok(true)) {"),
"PR #4083 / issue #4079: `last_activity_nanos` refresh in \
process_inbound must be gated on Ok(true) ONLY. Loosening \
this to Ok(_) or removing the guard re-opens the replay \
vector that lets a peer keep a dead stream alive \
indefinitely by replaying the same fragment every <30 s. \
See the comment block above the guard in process_inbound."
);
assert!(
source.contains("Refresh `last_activity_nanos` ONLY on real forward"),
"PR #4083 / issue #4079: the production code block explaining \
why the refresh is gated on Ok(true) must remain in place. \
If you intentionally rewrote it, update this pin to match \
the new comment."
);
}
}