use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use std::time::Duration;
use aes_gcm::Aes128Gcm;
use bytes::Bytes;
use tokio::sync::oneshot;
use crate::{
simulation::TimeSource,
tracing::TransferDirection,
transport::{
TransferStats, TransportError,
congestion_control::{CongestionControl, CongestionController},
metrics::{emit_transfer_completed, emit_transfer_failed, emit_transfer_started},
packet_data,
sent_packet_tracker::SentPacketTracker,
symmetric_message::{self},
},
};
use futures::StreamExt;
use super::StreamId;
use super::streaming::StreamHandle;
const CWND_WAIT_TIMEOUT: Duration = {
const MARGIN_SECS: u64 = 10;
const INACTIVITY_SECS: u64 = super::streaming::STREAM_INACTIVITY_TIMEOUT.as_secs();
assert!(
INACTIVITY_SECS > MARGIN_SECS,
"STREAM_INACTIVITY_TIMEOUT must be > 10s"
);
Duration::from_secs(INACTIVITY_SECS - MARGIN_SECS)
};
pub(crate) type SerializedStream = Bytes;
const MAX_DATA_SIZE: usize = packet_data::MAX_DATA_SIZE - 41;
#[allow(clippy::too_many_arguments)]
pub(super) async fn send_stream<S: super::super::Socket, T: TimeSource>(
stream_id: StreamId,
last_packet_id: Arc<AtomicU32>,
socket: Arc<S>,
destination_addr: SocketAddr,
mut stream_to_send: SerializedStream,
outbound_symmetric_key: Aes128Gcm,
sent_packet_tracker: Arc<parking_lot::Mutex<SentPacketTracker<T>>>,
token_bucket: Arc<super::super::token_bucket::TokenBucket<T>>,
congestion_controller: Arc<CongestionController<T>>,
time_source: T,
metadata: Option<Bytes>,
completion_tx: Option<oneshot::Sender<()>>,
) -> Result<TransferStats, TransportError> {
let start_time = time_source.now();
let bytes_to_send = stream_to_send.len() as u64;
emit_transfer_started(
stream_id.0 as u64,
destination_addr,
bytes_to_send,
TransferDirection::Send,
);
tracing::debug!(
stream_id = %stream_id.0,
length_bytes = stream_to_send.len(),
initial_rate_bytes_per_sec = token_bucket.rate(),
cwnd = congestion_controller.current_cwnd(),
"Sending stream"
);
let total_length_bytes = stream_to_send.len() as u32;
let total_packets = if let Some(ref meta) = metadata {
let meta_overhead = 1 + 8 + meta.len();
let first_frag_capacity = MAX_DATA_SIZE.saturating_sub(meta_overhead);
if stream_to_send.len() <= first_frag_capacity {
1
} else {
let remaining = stream_to_send.len() - first_frag_capacity;
1 + remaining.div_ceil(MAX_DATA_SIZE)
}
} else {
stream_to_send.len().div_ceil(MAX_DATA_SIZE)
};
let mut sent_so_far = 0;
let mut next_fragment_number = 1; let mut pending_metadata = metadata;
loop {
if sent_so_far == total_packets {
break;
}
let packet_size = stream_to_send.len().min(MAX_DATA_SIZE);
let cwnd_wait_start = time_source.now();
let mut cwnd_wait_iterations = 0;
loop {
let flightsize = congestion_controller.flightsize();
let cwnd = congestion_controller.current_cwnd();
if flightsize + packet_size <= cwnd {
break; }
cwnd_wait_iterations += 1;
if cwnd_wait_iterations == 1 {
tracing::trace!(
stream_id = %stream_id.0,
flightsize_kb = flightsize / 1024,
cwnd_kb = cwnd / 1024,
packet_size,
"Waiting for cwnd space (ensure recv() is being called to process ACKs)"
);
}
let cwnd_elapsed = time_source.now().saturating_sub(cwnd_wait_start);
if cwnd_elapsed >= CWND_WAIT_TIMEOUT {
let elapsed = time_source.now().saturating_sub(start_time);
tracing::warn!(
stream_id = %stream_id.0,
destination = %destination_addr,
sent_so_far,
total_packets,
flightsize_kb = flightsize / 1024,
cwnd_kb = cwnd / 1024,
cwnd_wait_ms = cwnd_elapsed.as_millis(),
elapsed_ms = elapsed.as_millis(),
"send_stream cwnd wait timed out — ACKs likely stopped arriving"
);
emit_transfer_failed(
stream_id.0 as u64,
destination_addr,
sent_so_far as u64,
format!(
"cwnd wait timeout after {}s (sent {sent_so_far}/{total_packets} packets, \
flightsize={flightsize}B, cwnd={cwnd}B)",
cwnd_elapsed.as_secs()
),
elapsed.as_millis() as u64,
TransferDirection::Send,
);
return Err(TransportError::ConnectionClosed(destination_addr));
}
if cwnd_wait_iterations <= 10 {
tokio::task::yield_now().await;
} else if cwnd_wait_iterations <= 100 {
time_source.sleep(Duration::from_micros(100)).await;
} else {
time_source.sleep(Duration::from_millis(1)).await;
}
}
if cwnd_wait_iterations > 0 {
tracing::trace!(
stream_id = %stream_id.0,
wait_iterations = cwnd_wait_iterations,
"Acquired cwnd space"
);
}
let wait_time = token_bucket.reserve(packet_size);
if !wait_time.is_zero() {
tracing::trace!(
stream_id = %stream_id.0,
wait_time_ms = wait_time.as_millis(),
packet_size,
"Rate limiting stream transmission"
);
time_source.sleep(wait_time).await;
}
let metadata_bytes = if next_fragment_number == 1 {
pending_metadata.take()
} else {
None
};
let available_payload = if let Some(ref meta) = metadata_bytes {
let meta_overhead = 1 + 8 + meta.len();
MAX_DATA_SIZE.saturating_sub(meta_overhead)
} else {
MAX_DATA_SIZE
};
let fragment = {
if stream_to_send.len() > available_payload {
let fragment = stream_to_send.slice(..available_payload);
stream_to_send = stream_to_send.slice(available_payload..);
fragment
} else {
std::mem::take(&mut stream_to_send)
}
};
let packet_id = last_packet_id.fetch_add(1, std::sync::atomic::Ordering::Release);
let token = congestion_controller.on_send_with_token(packet_size);
if let Err(e) = super::packet_sending(
destination_addr,
&socket,
packet_id,
&outbound_symmetric_key,
vec![],
symmetric_message::StreamFragment {
stream_id,
total_length_bytes: total_length_bytes as u64,
fragment_number: next_fragment_number,
payload: fragment,
metadata_bytes,
},
sent_packet_tracker.as_ref(),
token,
)
.await
{
let bytes_sent = (sent_so_far * MAX_DATA_SIZE) as u64;
let elapsed = time_source.now().saturating_sub(start_time);
emit_transfer_failed(
stream_id.0 as u64,
destination_addr,
bytes_sent.min(bytes_to_send),
e.to_string(),
elapsed.as_millis() as u64,
TransferDirection::Send,
);
if let Some(tx) = completion_tx {
let _ignored = tx.send(());
}
return Err(e);
}
next_fragment_number += 1;
sent_so_far += 1;
}
let generic_stats = congestion_controller.stats();
let ledbat_stats = congestion_controller.ledbat_stats();
let bbr_stats = congestion_controller.bbr_stats();
let elapsed = time_source.now().saturating_sub(start_time);
tracing::debug!(
stream_id = %stream_id.0,
total_packets = %sent_so_far,
bytes = bytes_to_send,
elapsed_ms = elapsed.as_millis(),
peak_cwnd_kb = generic_stats.peak_cwnd / 1024,
final_cwnd_kb = generic_stats.cwnd / 1024,
slowdowns = ledbat_stats.as_ref().map(|s| s.periodic_slowdowns).unwrap_or(0),
bbr_state = ?bbr_stats.as_ref().map(|s| s.state),
"Stream sent"
);
emit_transfer_completed(
stream_id.0 as u64,
destination_addr,
bytes_to_send,
elapsed.as_millis() as u64,
if elapsed.as_secs() > 0 {
bytes_to_send / elapsed.as_secs()
} else {
bytes_to_send * 1000 / elapsed.as_millis().max(1) as u64
},
Some(generic_stats.peak_cwnd as u32),
Some(generic_stats.cwnd as u32),
ledbat_stats.as_ref().map(|s| s.periodic_slowdowns as u32),
Some(generic_stats.base_delay.as_millis() as u32),
Some(generic_stats.ssthresh as u32),
ledbat_stats.as_ref().map(|s| s.min_ssthresh_floor as u32),
Some(generic_stats.total_timeouts as u32),
TransferDirection::Send,
);
if let Some(tx) = completion_tx {
let _ignored = tx.send(());
}
Ok(TransferStats {
stream_id: stream_id.0 as u64,
remote_addr: destination_addr,
bytes_transferred: bytes_to_send,
elapsed,
peak_cwnd_bytes: generic_stats.peak_cwnd as u32,
final_cwnd_bytes: generic_stats.cwnd as u32,
slowdowns_triggered: ledbat_stats
.as_ref()
.map(|s| s.periodic_slowdowns as u32)
.unwrap_or(0),
base_delay: generic_stats.base_delay,
final_ssthresh_bytes: generic_stats.ssthresh as u32,
min_ssthresh_floor_bytes: ledbat_stats
.as_ref()
.map(|s| s.min_ssthresh_floor as u32)
.unwrap_or(0),
total_timeouts: generic_stats.total_timeouts as u32,
final_flightsize: generic_stats.flightsize as u32,
configured_rate: congestion_controller.configured_rate() as u32,
})
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn pipe_stream<S: super::super::Socket, T: TimeSource>(
inbound_handle: StreamHandle,
outbound_stream_id: StreamId,
last_packet_id: Arc<AtomicU32>,
socket: Arc<S>,
destination_addr: SocketAddr,
outbound_symmetric_key: Aes128Gcm,
sent_packet_tracker: Arc<parking_lot::Mutex<SentPacketTracker<T>>>,
token_bucket: Arc<super::super::token_bucket::TokenBucket<T>>,
congestion_controller: Arc<CongestionController<T>>,
time_source: T,
metadata: Option<Bytes>,
) -> Result<TransferStats, TransportError> {
let start_time = time_source.now();
let total_bytes = inbound_handle.total_bytes();
emit_transfer_started(
outbound_stream_id.0 as u64,
destination_addr,
total_bytes,
TransferDirection::Send,
);
tracing::debug!(
stream_id = %outbound_stream_id.0,
total_bytes,
"Piping stream to next hop"
);
let mut stream = inbound_handle.stream();
let mut sent_so_far = 0u64;
let mut fragment_number = 1u32;
let mut pending_metadata = metadata;
use super::streaming::STREAM_INACTIVITY_TIMEOUT;
let inactivity_timeout = STREAM_INACTIVITY_TIMEOUT;
loop {
let next_fragment = tokio::select! {
result = stream.next() => {
match result {
Some(r) => r,
None => break, }
}
_ = time_source.sleep(inactivity_timeout) => {
let elapsed = time_source.now().saturating_sub(start_time);
tracing::warn!(
stream_id = %outbound_stream_id.0,
destination = %destination_addr,
sent_so_far,
total_bytes,
fragment_number,
elapsed_ms = elapsed.as_millis(),
"pipe_stream stalled: no fragment received within {}s",
inactivity_timeout.as_secs()
);
emit_transfer_failed(
outbound_stream_id.0 as u64,
destination_addr,
sent_so_far,
format!(
"pipe stalled: no fragment for {}s (sent {sent_so_far}/{total_bytes} bytes)",
inactivity_timeout.as_secs()
),
elapsed.as_millis() as u64,
TransferDirection::Send,
);
return Err(TransportError::ConnectionClosed(destination_addr));
}
};
let payload = match next_fragment {
Ok(data) => data,
Err(e) => {
let elapsed = time_source.now().saturating_sub(start_time);
emit_transfer_failed(
outbound_stream_id.0 as u64,
destination_addr,
sent_so_far,
format!("inbound stream error: {e}"),
elapsed.as_millis() as u64,
TransferDirection::Send,
);
return Err(TransportError::ConnectionClosed(destination_addr));
}
};
let packet_size = payload.len();
let cwnd_wait_start = time_source.now();
let mut cwnd_wait_iterations = 0;
loop {
let flightsize = congestion_controller.flightsize();
let cwnd = congestion_controller.current_cwnd();
if flightsize + packet_size <= cwnd {
break;
}
cwnd_wait_iterations += 1;
if cwnd_wait_iterations == 1 {
tracing::trace!(
stream_id = %outbound_stream_id.0,
fragment_number,
flightsize_kb = flightsize / 1024,
cwnd_kb = cwnd / 1024,
"Waiting for cwnd space in pipe_stream"
);
}
let cwnd_elapsed = time_source.now().saturating_sub(cwnd_wait_start);
if cwnd_elapsed >= CWND_WAIT_TIMEOUT {
let elapsed = time_source.now().saturating_sub(start_time);
tracing::warn!(
stream_id = %outbound_stream_id.0,
destination = %destination_addr,
fragment_number,
sent_so_far,
total_bytes,
flightsize_kb = flightsize / 1024,
cwnd_kb = cwnd / 1024,
cwnd_wait_ms = cwnd_elapsed.as_millis(),
elapsed_ms = elapsed.as_millis(),
"pipe_stream cwnd wait timed out — ACKs likely stopped arriving"
);
emit_transfer_failed(
outbound_stream_id.0 as u64,
destination_addr,
sent_so_far,
format!(
"cwnd wait timeout after {}s (sent {sent_so_far}/{total_bytes} bytes, \
flightsize={flightsize}B, cwnd={cwnd}B)",
cwnd_elapsed.as_secs()
),
elapsed.as_millis() as u64,
TransferDirection::Send,
);
return Err(TransportError::ConnectionClosed(destination_addr));
}
if cwnd_wait_iterations <= 10 {
tokio::task::yield_now().await;
} else if cwnd_wait_iterations <= 100 {
time_source.sleep(Duration::from_micros(100)).await;
} else {
time_source.sleep(Duration::from_millis(1)).await;
}
}
let wait_time = token_bucket.reserve(packet_size);
if !wait_time.is_zero() {
time_source.sleep(wait_time).await;
}
let metadata_bytes = if fragment_number == 1 {
if let Some(meta) = pending_metadata.take() {
let required_size = payload.len() + 41 + meta.len();
if required_size <= packet_data::MAX_DATA_SIZE {
Some(meta)
} else {
tracing::debug!(
stream_id = %outbound_stream_id.0,
payload_len = payload.len(),
meta_len = meta.len(),
required_size,
max_size = packet_data::MAX_DATA_SIZE,
"Skipping metadata embedding in piped fragment #1 - would exceed MAX_DATA_SIZE"
);
None
}
} else {
None
}
} else {
None
};
let packet_id = last_packet_id.fetch_add(1, std::sync::atomic::Ordering::Release);
let token = congestion_controller.on_send_with_token(packet_size);
if let Err(e) = super::packet_sending(
destination_addr,
&socket,
packet_id,
&outbound_symmetric_key,
vec![],
symmetric_message::StreamFragment {
stream_id: outbound_stream_id,
total_length_bytes: total_bytes,
fragment_number,
payload,
metadata_bytes,
},
sent_packet_tracker.as_ref(),
token,
)
.await
{
let elapsed = time_source.now().saturating_sub(start_time);
emit_transfer_failed(
outbound_stream_id.0 as u64,
destination_addr,
sent_so_far,
e.to_string(),
elapsed.as_millis() as u64,
TransferDirection::Send,
);
return Err(e);
}
sent_so_far += packet_size as u64;
fragment_number += 1;
}
let generic_stats = congestion_controller.stats();
let ledbat_stats = congestion_controller.ledbat_stats();
let elapsed = time_source.now().saturating_sub(start_time);
tracing::debug!(
stream_id = %outbound_stream_id.0,
fragments = fragment_number - 1,
bytes = sent_so_far,
elapsed_ms = elapsed.as_millis(),
"Pipe stream complete"
);
emit_transfer_completed(
outbound_stream_id.0 as u64,
destination_addr,
sent_so_far,
elapsed.as_millis() as u64,
if elapsed.as_secs() > 0 {
sent_so_far / elapsed.as_secs()
} else {
sent_so_far * 1000 / elapsed.as_millis().max(1) as u64
},
Some(generic_stats.peak_cwnd as u32),
Some(generic_stats.cwnd as u32),
ledbat_stats.as_ref().map(|s| s.periodic_slowdowns as u32),
Some(generic_stats.base_delay.as_millis() as u32),
Some(generic_stats.ssthresh as u32),
ledbat_stats.as_ref().map(|s| s.min_ssthresh_floor as u32),
Some(generic_stats.total_timeouts as u32),
TransferDirection::Send,
);
Ok(TransferStats {
stream_id: outbound_stream_id.0 as u64,
remote_addr: destination_addr,
bytes_transferred: sent_so_far,
elapsed,
peak_cwnd_bytes: generic_stats.peak_cwnd as u32,
final_cwnd_bytes: generic_stats.cwnd as u32,
slowdowns_triggered: ledbat_stats
.as_ref()
.map(|s| s.periodic_slowdowns as u32)
.unwrap_or(0),
base_delay: generic_stats.base_delay,
final_ssthresh_bytes: generic_stats.ssthresh as u32,
min_ssthresh_floor_bytes: ledbat_stats
.as_ref()
.map(|s| s.min_ssthresh_floor as u32)
.unwrap_or(0),
total_timeouts: generic_stats.total_timeouts as u32,
final_flightsize: generic_stats.flightsize as u32,
configured_rate: congestion_controller.configured_rate() as u32,
})
}
#[cfg(test)]
mod tests {
use aes_gcm::KeyInit;
use std::net::Ipv4Addr;
use tests::packet_data::MAX_PACKET_SIZE;
use tracing::debug;
use super::{
symmetric_message::{SymmetricMessage, SymmetricMessagePayload},
*,
};
use crate::config::GlobalExecutor;
use crate::simulation::{RealTime, VirtualTime};
use crate::transport::congestion_control::CongestionControlConfig;
use crate::transport::fast_channel::{self, FastSender};
use crate::transport::ledbat::LedbatConfig;
use crate::transport::packet_data::PacketData;
use crate::transport::token_bucket::TokenBucket;
struct TestSocket {
sender: fast_channel::FastSender<(SocketAddr, Arc<[u8]>)>,
}
impl TestSocket {
fn new(sender: fast_channel::FastSender<(SocketAddr, Arc<[u8]>)>) -> Self {
Self { sender }
}
}
impl crate::transport::Socket for TestSocket {
async fn bind(_addr: SocketAddr) -> std::io::Result<Self> {
unimplemented!()
}
async fn recv_from(&self, _buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
unimplemented!()
}
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
self.sender
.send_async((target, buf.into()))
.await
.map_err(|_| std::io::ErrorKind::ConnectionAborted)?;
Ok(buf.len())
}
fn send_to_blocking(&self, buf: &[u8], target: SocketAddr) -> std::io::Result<usize> {
self.sender
.send((target, buf.into()))
.map_err(|_| std::io::ErrorKind::ConnectionAborted)?;
Ok(buf.len())
}
}
#[tokio::test]
async fn test_send_stream_success() -> Result<(), Box<dyn std::error::Error>> {
let (outbound_sender, outbound_receiver) = fast_channel::bounded(1);
let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080);
let mut message = vec![0u8; 100_000];
crate::config::GlobalRng::fill_bytes(&mut message);
let cipher = {
let mut key = [0u8; 16];
crate::config::GlobalRng::fill_bytes(&mut key);
Aes128Gcm::new(&key.into())
};
let time_source = VirtualTime::new();
let sent_tracker = Arc::new(parking_lot::Mutex::new(
SentPacketTracker::new_with_time_source(time_source.clone()),
));
let congestion_controller = CongestionControlConfig::from_ledbat_config(LedbatConfig {
initial_cwnd: 1_000_000,
min_cwnd: 1_000_000,
max_cwnd: 1_000_000_000,
..Default::default()
})
.build_arc_with_time_source(time_source.clone());
let token_bucket = Arc::new(TokenBucket::new_with_time_source(
1_000_000,
10_000_000,
time_source.clone(),
));
let background_task = GlobalExecutor::spawn(send_stream(
StreamId::next(),
Arc::new(AtomicU32::new(0)),
Arc::new(TestSocket::new(outbound_sender)),
remote_addr,
Bytes::from(message.clone()),
cipher.clone(),
sent_tracker,
token_bucket,
congestion_controller,
time_source,
None,
None,
));
let mut inbound_bytes = Vec::with_capacity(message.len());
while let Ok((_, packet)) = outbound_receiver.recv_async().await {
let decrypted_packet = PacketData::<_, MAX_PACKET_SIZE>::from_buf(packet.as_ref())
.try_decrypt_sym(&cipher)
.map_err(|e| e.to_string())?;
let deserialized = SymmetricMessage::deser(decrypted_packet.data())?;
let SymmetricMessagePayload::StreamFragment { payload, .. } = deserialized.payload
else {
panic!("Expected a StreamFragment, got {:?}", deserialized.payload);
};
inbound_bytes.extend_from_slice(payload.as_ref());
}
let result = background_task.await?;
assert!(result.is_ok());
assert_eq!(&message[..10], &inbound_bytes[..10]);
assert_eq!(inbound_bytes.len(), 100_000);
assert_eq!(&message[99_990..], &inbound_bytes[99_990..]);
Ok(())
}
#[tokio::test]
async fn test_send_stream_with_bandwidth_limit() -> Result<(), Box<dyn std::error::Error>> {
let (outbound_sender, outbound_receiver) = fast_channel::bounded(100);
let destination_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 1234));
let key = Aes128Gcm::new_from_slice(&[0u8; 16])?;
let last_packet_id = Arc::new(AtomicU32::new(0));
let stream_id = StreamId::next();
let message = vec![0u8; 10_000];
let bandwidth_limit = 100_000;
let time_source = RealTime::new();
let sent_tracker = Arc::new(parking_lot::Mutex::new(
SentPacketTracker::new_with_time_source(time_source.clone()),
));
let congestion_controller = CongestionControlConfig::from_ledbat_config(LedbatConfig {
initial_cwnd: 1_000_000,
min_cwnd: 1_000_000,
max_cwnd: 1_000_000_000,
..Default::default()
})
.build_arc();
let token_bucket = Arc::new(TokenBucket::new_with_time_source(
1_000, bandwidth_limit,
time_source.clone(),
));
let start_time = tokio::time::Instant::now();
let sender_clone: FastSender<(SocketAddr, Arc<[u8]>)> = outbound_sender.clone();
let key_clone = key.clone();
let receiver_task = GlobalExecutor::spawn(async move {
let mut packet_count = 0;
let mut total_bytes = 0;
while let Ok((addr, packet)) = outbound_receiver.recv_async().await {
assert_eq!(addr, destination_addr);
packet_count += 1;
total_bytes += packet.len();
let packet_data = PacketData::<_, MAX_PACKET_SIZE>::from_buf(&packet);
let decrypted = packet_data.try_decrypt_sym(&key_clone).unwrap();
let msg = SymmetricMessage::deser(decrypted.data()).unwrap();
match msg.payload {
SymmetricMessagePayload::StreamFragment { .. } => {
}
SymmetricMessagePayload::AckConnection { .. }
| SymmetricMessagePayload::ShortMessage { .. }
| SymmetricMessagePayload::NoOp
| SymmetricMessagePayload::Ping { .. }
| SymmetricMessagePayload::Pong { .. } => panic!("Expected stream fragment"),
}
}
(packet_count, total_bytes)
});
let send_task = GlobalExecutor::spawn(send_stream(
stream_id,
last_packet_id.clone(),
Arc::new(TestSocket::new(outbound_sender)),
destination_addr,
Bytes::from(message.clone()),
key.clone(),
sent_tracker.clone(),
token_bucket,
congestion_controller,
time_source,
None,
None,
));
send_task.await??;
let elapsed = start_time.elapsed();
drop(sender_clone);
let (packet_count, _total_bytes) = receiver_task.await?;
let expected_packets = message.len().div_ceil(MAX_DATA_SIZE);
assert_eq!(packet_count, expected_packets);
debug!(
"Transfer took: {elapsed:?}, packets sent: {packet_count}, expected: {expected_packets}"
);
debug!("Bytes per packet: ~{MAX_DATA_SIZE}");
assert!(
elapsed.as_millis() >= 50,
"Transfer completed too quickly: {elapsed:?}"
);
Ok(())
}
#[tokio::test]
async fn test_send_stream_without_bandwidth_limit() -> Result<(), Box<dyn std::error::Error>> {
let (outbound_sender, outbound_receiver) = fast_channel::bounded(100);
let destination_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 1234));
let key = Aes128Gcm::new_from_slice(&[0u8; 16])?;
let last_packet_id = Arc::new(AtomicU32::new(0));
let stream_id = StreamId::next();
let message = vec![0u8; 10_000];
let time_source = RealTime::new();
let sent_tracker = Arc::new(parking_lot::Mutex::new(
SentPacketTracker::new_with_time_source(time_source.clone()),
));
let congestion_controller = CongestionControlConfig::from_ledbat_config(LedbatConfig {
initial_cwnd: 1_000_000,
min_cwnd: 1_000_000,
max_cwnd: 1_000_000_000,
..Default::default()
})
.build_arc();
let token_bucket = Arc::new(TokenBucket::new_with_time_source(
100_000, 1_000_000_000, time_source.clone(),
));
let start_time = tokio::time::Instant::now();
let sender_clone: FastSender<(SocketAddr, Arc<[u8]>)> = outbound_sender.clone();
let receiver_task = GlobalExecutor::spawn(async move {
let mut packet_count = 0;
while let Ok((addr, _packet)) = outbound_receiver.recv_async().await {
assert_eq!(addr, destination_addr);
packet_count += 1;
}
packet_count
});
let send_task = GlobalExecutor::spawn(send_stream(
stream_id,
last_packet_id.clone(),
Arc::new(TestSocket::new(outbound_sender)),
destination_addr,
Bytes::from(message.clone()),
key.clone(),
sent_tracker.clone(),
token_bucket,
congestion_controller,
time_source,
None,
None,
));
send_task.await??;
let elapsed = start_time.elapsed();
drop(sender_clone);
let packet_count = receiver_task.await?;
let expected_packets = message.len().div_ceil(MAX_DATA_SIZE);
assert_eq!(packet_count, expected_packets);
assert!(
elapsed.as_millis() < 50,
"Transfer took too long without rate limit: {elapsed:?}"
);
Ok(())
}
#[tokio::test(start_paused = true)]
async fn test_send_stream_cwnd_wait_timeout() -> Result<(), Box<dyn std::error::Error>> {
let (outbound_sender, _outbound_receiver) = fast_channel::bounded(100);
let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080);
let message = vec![0u8; 10_000];
let cipher = {
let mut key = [0u8; 16];
crate::config::GlobalRng::fill_bytes(&mut key);
Aes128Gcm::new(&key.into())
};
let time_source = RealTime::new();
let congestion_controller = CongestionControlConfig::from_ledbat_config(LedbatConfig {
initial_cwnd: 1,
min_cwnd: 1,
max_cwnd: 1,
..Default::default()
})
.build_arc();
let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new()));
let token_bucket = Arc::new(TokenBucket::new(1_000_000, 100_000_000));
let send_task = GlobalExecutor::spawn(send_stream(
StreamId::next(),
Arc::new(AtomicU32::new(0)),
Arc::new(TestSocket::new(outbound_sender)),
remote_addr,
Bytes::from(message),
cipher,
sent_tracker,
token_bucket,
congestion_controller,
time_source,
None,
None,
));
let result = send_task.await.expect("join error");
assert!(
matches!(result, Err(TransportError::ConnectionClosed(_))),
"Expected ConnectionClosed after cwnd wait timeout, got: {result:?}",
);
Ok(())
}
#[test]
fn test_fragment_1_with_metadata_respects_max_size() {
use crate::transport::symmetric_message::{StreamFragment, SymmetricMessage};
let typical_metadata = bytes::Bytes::from(vec![0u8; 200]);
let meta_overhead = 1 + 8 + typical_metadata.len(); let available_payload = MAX_DATA_SIZE.saturating_sub(meta_overhead);
let fragment_with_typical_meta = StreamFragment {
stream_id: StreamId::next_operations(),
total_length_bytes: 10000,
fragment_number: 1,
payload: bytes::Bytes::from(vec![0u8; available_payload]),
metadata_bytes: Some(typical_metadata),
};
let msg = SymmetricMessage {
packet_id: 1,
confirm_receipt: vec![],
payload: SymmetricMessagePayload::from(fragment_with_typical_meta),
};
let serialized = bincode::serialize(&msg).expect("serialization should succeed");
assert!(
serialized.len() <= packet_data::MAX_DATA_SIZE,
"Fragment #1 with typical metadata ({} bytes) exceeds MAX_DATA_SIZE: {} > {}",
200,
serialized.len(),
packet_data::MAX_DATA_SIZE
);
let large_metadata = bytes::Bytes::from(vec![0u8; 500]);
let large_meta_overhead = 1 + 8 + large_metadata.len();
let available_payload_large = MAX_DATA_SIZE.saturating_sub(large_meta_overhead);
let fragment_with_large_meta = StreamFragment {
stream_id: StreamId::next_operations(),
total_length_bytes: 10000,
fragment_number: 1,
payload: bytes::Bytes::from(vec![0u8; available_payload_large]),
metadata_bytes: Some(large_metadata),
};
let msg_large = SymmetricMessage {
packet_id: 2,
confirm_receipt: vec![],
payload: SymmetricMessagePayload::from(fragment_with_large_meta),
};
let serialized_large =
bincode::serialize(&msg_large).expect("serialization should succeed");
assert!(
serialized_large.len() <= packet_data::MAX_DATA_SIZE,
"Fragment #1 with large metadata ({} bytes) exceeds MAX_DATA_SIZE: {} > {}",
500,
serialized_large.len(),
packet_data::MAX_DATA_SIZE
);
let fragment_2 = StreamFragment {
stream_id: StreamId::next_operations(),
total_length_bytes: 10000,
fragment_number: 2,
payload: bytes::Bytes::from(vec![0u8; MAX_DATA_SIZE]),
metadata_bytes: None,
};
let msg_frag2 = SymmetricMessage {
packet_id: 3,
confirm_receipt: vec![],
payload: SymmetricMessagePayload::from(fragment_2),
};
let serialized_frag2 =
bincode::serialize(&msg_frag2).expect("serialization should succeed");
assert!(
serialized_frag2.len() <= packet_data::MAX_DATA_SIZE,
"Fragment #2 (no metadata, full payload) exceeds MAX_DATA_SIZE: {} > {}",
serialized_frag2.len(),
packet_data::MAX_DATA_SIZE
);
}
#[tokio::test(start_paused = true)]
async fn test_pipe_stream_cwnd_wait_timeout() -> Result<(), Box<dyn std::error::Error>> {
use crate::transport::peer_connection::streaming::StreamHandle;
let (outbound_sender, _outbound_receiver) = fast_channel::bounded(100);
let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080);
let cipher = {
let mut key = [0u8; 16];
crate::config::GlobalRng::fill_bytes(&mut key);
Aes128Gcm::new(&key.into())
};
let time_source = RealTime::new();
let stream_id = StreamId::next();
let handle = StreamHandle::new(stream_id, 10_000);
handle
.push_fragment(1, Bytes::from(vec![0u8; 1000]))
.unwrap();
let congestion_controller = CongestionControlConfig::default().build_arc();
congestion_controller.on_send(1_000_000);
congestion_controller.on_timeout();
let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new()));
let token_bucket = Arc::new(TokenBucket::new(1_000_000, 100_000_000));
let pipe_task = GlobalExecutor::spawn(pipe_stream(
handle,
StreamId::next(),
Arc::new(AtomicU32::new(0)),
Arc::new(TestSocket::new(outbound_sender)),
remote_addr,
cipher,
sent_tracker,
token_bucket,
congestion_controller,
time_source,
None,
));
let result = pipe_task.await.expect("join error");
assert!(
matches!(result, Err(TransportError::ConnectionClosed(_))),
"Expected ConnectionClosed after pipe_stream cwnd wait timeout, got: {:?}",
result
);
Ok(())
}
}