use crate::{
constants::ssu2::MIN_MTU,
i2np::{Message, MessageType},
primitives::RouterId,
runtime::{Counter, Histogram, Instant, MetricsHandle, Runtime},
transport::ssu2::{
message::data::{
DataMessageBuilder, MessageKind, PathValidationBlock, PeerTestBlock, RelayBlock,
},
metrics::*,
session::{
active::{ack::AckInfo, RemoteAckManager},
KeyContext,
},
},
};
use bytes::BytesMut;
use alloc::{
collections::{BTreeMap, VecDeque},
sync::Arc,
vec::Vec,
};
use core::{
cmp::{max, min},
fmt,
net::SocketAddr,
ops::Deref,
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
const LOG_TARGET: &str = "emissary::ssu2::active::transmission";
const SSU2_OVERHEAD: usize = 16usize + 1usize + 16usize;
const RI_BLOCK_OVERHEAD: usize = 1usize + 2usize + 1usize + 1usize;
const IMMEDIATE_ACK_INTERVAL: u32 = 10u32;
const RESEND_TERMINATION_THRESHOLD: usize = 7usize;
const INITIAL_RTO: Duration = Duration::from_millis(540);
const MIN_RTO: Duration = Duration::from_millis(100);
const MAX_RTO: Duration = Duration::from_millis(2500);
const RTT_DAMPENING_FACTOR: f64 = 0.125f64;
const RTTDEV_DAMPENING_FACTOR: f64 = 0.25;
const MIN_WINDOW_SIZE: usize = 16usize;
const MAX_WINDOW_SIZE: usize = 256usize;
enum RetransmissionTimeout {
Unsampled,
Sampled {
rto: Duration,
rtt: Duration,
rtt_var: Duration,
},
}
impl RetransmissionTimeout {
fn calculate_rto(&mut self, sample: Duration) {
let rtt = match self {
Self::Unsampled => sample,
Self::Sampled { rtt, .. } => Duration::from_millis(
((1f64 - RTT_DAMPENING_FACTOR) * rtt.as_millis() as f64
+ RTT_DAMPENING_FACTOR * sample.as_millis() as f64) as u64,
),
};
match self {
Self::Unsampled => {
*self = Self::Sampled {
rto: rtt * 2,
rtt,
rtt_var: rtt / 2,
};
}
Self::Sampled { rtt_var, .. } => {
let srtt = rtt.as_millis() as i64;
let abs = {
let sample = sample.as_millis() as i64;
RTTDEV_DAMPENING_FACTOR * i64::abs(srtt - sample) as f64
};
let rtt_var = rtt_var.as_millis() as f64;
let rtt_var = (1f64 - RTTDEV_DAMPENING_FACTOR) * rtt_var + abs;
let rto = Duration::from_millis((srtt as f64 + 4f64 * rtt_var) as u64);
*self = Self::Sampled {
rto: min(MAX_RTO, max(rto, MIN_RTO)),
rtt,
rtt_var: Duration::from_millis(rtt_var as u64),
};
}
}
}
}
impl Deref for RetransmissionTimeout {
type Target = Duration;
fn deref(&self) -> &Self::Target {
match self {
Self::Unsampled => &INITIAL_RTO,
Self::Sampled { rto, .. } => rto,
}
}
}
enum SegmentKind {
UnFragmented {
message: Vec<u8>,
},
FirstFragment {
fragment: Vec<u8>,
expiration: u32,
message_type: MessageType,
message_id: u32,
},
FollowOnFragment {
fragment: Vec<u8>,
fragment_num: u8,
last: bool,
message_id: u32,
},
PeerTest {
peer_test_block: PeerTestBlock,
router_info: Option<Vec<u8>>,
},
Relay {
relay_block: RelayBlock,
router_info: Option<Vec<u8>>,
},
RouterInfo {
router_info: Vec<u8>,
},
PathValidation {
path_validation_block: PathValidationBlock,
},
}
impl SegmentKind {
fn address(&self) -> Option<SocketAddr> {
match self {
SegmentKind::PathValidation {
path_validation_block,
} => path_validation_block.address(),
_ => None,
}
}
}
impl<'a> From<&'a SegmentKind> for MessageKind<'a> {
fn from(value: &'a SegmentKind) -> Self {
match value {
SegmentKind::UnFragmented { message } => Self::UnFragmented { message },
SegmentKind::FirstFragment {
fragment,
expiration,
message_type,
message_id,
} => Self::FirstFragment {
fragment,
expiration: *expiration,
message_type: *message_type,
message_id: *message_id,
},
SegmentKind::FollowOnFragment {
fragment,
fragment_num,
last,
message_id,
} => Self::FollowOnFragment {
fragment,
fragment_num: *fragment_num,
last: *last,
message_id: *message_id,
},
SegmentKind::PeerTest {
peer_test_block,
router_info,
} => Self::PeerTest {
peer_test_block,
router_info: router_info.as_deref(),
},
SegmentKind::Relay {
relay_block,
router_info,
} => Self::Relay {
relay_block,
router_info: router_info.as_deref(),
},
SegmentKind::RouterInfo { router_info } => Self::RouterInfo { router_info },
SegmentKind::PathValidation {
path_validation_block,
} => Self::PathValidation {
path_validation_block,
},
}
}
}
struct Segment<R: Runtime> {
num_sent: usize,
segment: SegmentKind,
sent: R::Instant,
}
enum Throttled {
No,
Yes {
max_payload_size: usize,
},
}
impl Throttled {
fn is_throttled(&self) -> bool {
core::matches!(self, Throttled::Yes { .. })
}
}
pub struct TransmissionManager<R: Runtime> {
throttle: Throttled,
dst_id: u64,
intro_key: [u8; 32],
last_immediate_ack: u32,
metrics: R::MetricsHandle,
max_payload_size: usize,
pending: VecDeque<SegmentKind>,
pkt_num: Arc<AtomicU32>,
remote_ack_manager: RemoteAckManager,
router_id: RouterId,
rto: RetransmissionTimeout,
segments: BTreeMap<u32, Segment<R>>,
send_key_ctx: KeyContext,
window_size: usize,
}
pub enum TransmissionMessage {
Message(Message),
PeerTest(PeerTestBlock),
Relay(RelayBlock),
PeerTestWithRouterInfo((PeerTestBlock, Vec<u8>)),
RelayWithRouterInfo((RelayBlock, Vec<u8>)),
PathValidation(PathValidationBlock),
}
impl From<Message> for TransmissionMessage {
fn from(value: Message) -> Self {
Self::Message(value)
}
}
impl From<PeerTestBlock> for TransmissionMessage {
fn from(value: PeerTestBlock) -> Self {
Self::PeerTest(value)
}
}
impl From<RelayBlock> for TransmissionMessage {
fn from(value: RelayBlock) -> Self {
Self::Relay(value)
}
}
impl From<(PeerTestBlock, Vec<u8>)> for TransmissionMessage {
fn from(value: (PeerTestBlock, Vec<u8>)) -> Self {
Self::PeerTestWithRouterInfo(value)
}
}
impl From<(RelayBlock, Vec<u8>)> for TransmissionMessage {
fn from(value: (RelayBlock, Vec<u8>)) -> Self {
Self::RelayWithRouterInfo(value)
}
}
impl From<PathValidationBlock> for TransmissionMessage {
fn from(value: PathValidationBlock) -> Self {
Self::PathValidation(value)
}
}
impl fmt::Debug for TransmissionMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Message(message) => f
.debug_struct("TransmissionMessage::Message")
.field("message", &message)
.finish(),
Self::PeerTest(block) =>
f.debug_struct("TransmissionMessage::PeerTest").field("block", &block).finish(),
Self::Relay(block) =>
f.debug_struct("TransmissionMessage::Relay").field("block", &block).finish(),
Self::PeerTestWithRouterInfo((block, _)) => f
.debug_struct("TransmissionMessage::PeerTestWithRouterInfo")
.field("block", &block)
.finish_non_exhaustive(),
Self::RelayWithRouterInfo((block, _)) => f
.debug_struct("TransmissionMessage::RelayWithRouterInfo")
.field("block", &block)
.finish_non_exhaustive(),
Self::PathValidation(block) => f
.debug_struct("TransmissionMessage::PathValidation")
.field("block", &block)
.finish(),
}
}
}
impl<R: Runtime> TransmissionManager<R> {
pub fn new(
dst_id: u64,
router_id: RouterId,
intro_key: [u8; 32],
send_key_ctx: KeyContext,
pkt_num: Arc<AtomicU32>,
metrics: R::MetricsHandle,
max_payload_size: usize,
) -> Self {
Self {
throttle: Throttled::No,
dst_id,
intro_key,
last_immediate_ack: 0u32,
metrics,
max_payload_size,
pending: VecDeque::new(),
pkt_num,
remote_ack_manager: RemoteAckManager::new(),
router_id,
rto: RetransmissionTimeout::Unsampled,
segments: BTreeMap::new(),
send_key_ctx,
window_size: MIN_WINDOW_SIZE,
}
}
pub fn next_pkt_num(&self) -> u32 {
self.pkt_num.fetch_add(1u32, Ordering::Relaxed)
}
pub fn has_capacity(&self) -> bool {
self.segments.len() + self.pending.len() < self.window_size
}
pub fn round_trip_time(&self) -> Duration {
match &self.rto {
RetransmissionTimeout::Unsampled => INITIAL_RTO,
RetransmissionTimeout::Sampled { rtt, .. } => *rtt,
}
}
pub fn fits_in_datagram(&self, size: usize) -> bool {
size + SSU2_OVERHEAD <= self.max_payload_size
}
pub fn register_remote_pkt(&mut self, pkt_num: u32) {
self.remote_ack_manager.register_pkt(pkt_num);
}
pub fn is_duplicate(&self, pkt_num: u32) -> bool {
self.remote_ack_manager.is_duplicate(pkt_num)
}
pub fn schedule(&mut self, message: impl Into<TransmissionMessage>) {
let message = message.into();
match message {
TransmissionMessage::Message(message) => {
if self.fits_in_datagram(message.serialized_len_short()) {
return self.pending.push_back(SegmentKind::UnFragmented {
message: message.serialize_short(),
});
}
let fragments = message.payload.chunks(1200).collect::<Vec<_>>();
let num_fragments = fragments.len();
self.metrics.histogram(OUTBOUND_FRAGMENT_COUNT).record(num_fragments as f64);
for (fragment_num, fragment) in fragments.into_iter().enumerate() {
let segment = match fragment_num {
0 => SegmentKind::FirstFragment {
fragment: fragment.to_vec(),
expiration: message.expiration.as_secs() as u32,
message_type: message.message_type,
message_id: message.message_id,
},
_ => SegmentKind::FollowOnFragment {
fragment: fragment.to_vec(),
fragment_num: fragment_num as u8,
last: fragment_num == num_fragments - 1,
message_id: message.message_id,
},
};
self.pending.push_back(segment);
}
}
TransmissionMessage::PeerTest(peer_test_block) => {
debug_assert!(self.fits_in_datagram(peer_test_block.serialized_len()));
self.pending.push_back(SegmentKind::PeerTest {
peer_test_block,
router_info: None,
});
}
TransmissionMessage::PeerTestWithRouterInfo((peer_test_block, router_info)) => {
if self.fits_in_datagram(
peer_test_block.serialized_len() + router_info.len() + RI_BLOCK_OVERHEAD,
) {
self.pending.push_back(SegmentKind::PeerTest {
peer_test_block,
router_info: Some(router_info),
});
} else {
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
"fragmenting peer test with router info into two packets",
);
self.pending.push_back(SegmentKind::RouterInfo { router_info });
self.pending.push_back(SegmentKind::PeerTest {
peer_test_block,
router_info: None,
});
}
}
TransmissionMessage::Relay(relay_block) => {
debug_assert!(self.fits_in_datagram(relay_block.serialized_len()));
self.pending.push_back(SegmentKind::Relay {
relay_block,
router_info: None,
});
}
TransmissionMessage::RelayWithRouterInfo((relay_block, router_info)) => {
if self.fits_in_datagram(
relay_block.serialized_len() + router_info.len() + RI_BLOCK_OVERHEAD,
) {
self.pending.push_back(SegmentKind::Relay {
relay_block,
router_info: Some(router_info),
});
} else {
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
"fragmenting relay with router info into two packets",
);
self.pending.push_back(SegmentKind::RouterInfo { router_info });
self.pending.push_back(SegmentKind::Relay {
relay_block,
router_info: None,
});
}
}
TransmissionMessage::PathValidation(block) => {
debug_assert!(self.fits_in_datagram(block.serialized_len()));
self.pending.push_front(SegmentKind::PathValidation {
path_validation_block: block,
});
}
}
}
pub fn register_ack(&mut self, ack_through: u32, num_acks: u8, ranges: &[(u8, u8)]) {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
?ack_through,
?num_acks,
?ranges,
num_segments = ?self.segments.len(),
"handle ack",
);
(0..=num_acks).for_each(|i| {
if let Some(Segment { num_sent, sent, .. }) =
self.segments.remove(&(ack_through.saturating_sub(i as u32)))
{
self.metrics
.histogram(ACK_RECEIVE_TIME)
.record(sent.elapsed().as_millis() as f64);
if num_sent == 1 {
self.rto.calculate_rto(sent.elapsed());
}
if !self.throttle.is_throttled() {
self.window_size += 1;
}
}
});
let mut next_pkt = ack_through.saturating_sub(num_acks as u32);
for (nack, ack) in ranges {
next_pkt = next_pkt.saturating_sub(*nack as u32);
for _ in 1..=*ack {
next_pkt = next_pkt.saturating_sub(1);
if let Some(Segment { num_sent, sent, .. }) = self.segments.remove(&next_pkt) {
self.metrics
.histogram(ACK_RECEIVE_TIME)
.record(sent.elapsed().as_millis() as f64);
if num_sent == 1 {
self.rto.calculate_rto(sent.elapsed());
}
if !self.throttle.is_throttled() {
self.window_size += 1;
}
}
}
}
if self.window_size > MAX_WINDOW_SIZE {
self.window_size = MAX_WINDOW_SIZE;
}
}
pub fn drain(&mut self) -> Option<Vec<(BytesMut, Option<SocketAddr>)>> {
if self.pending.is_empty() {
return None;
}
let pkts_to_send = (0..min(
self.pending.len(),
self.window_size.saturating_sub(self.segments.len()),
))
.filter_map(|_| {
let segment = self.pending.pop_front()?;
let pkt_num = self.next_pkt_num();
self.segments.insert(
pkt_num,
Segment {
num_sent: 1usize,
sent: R::now(),
segment,
},
);
Some(pkt_num)
})
.collect::<Vec<_>>();
let AckInfo {
highest_seen,
num_acks,
ranges,
} = self.remote_ack_manager.ack_info();
let num_pkts = pkts_to_send.len();
Some(
pkts_to_send
.into_iter()
.enumerate()
.map(|(i, pkt_num)| {
let segment = &self.segments.get(&pkt_num).expect("to exist").segment;
let address = segment.address();
let last_in_burst = num_pkts > 1 && i == num_pkts - 1;
let immediate_ack_threshold =
pkt_num.saturating_sub(self.last_immediate_ack) > IMMEDIATE_ACK_INTERVAL;
(
if last_in_burst || immediate_ack_threshold {
self.last_immediate_ack = pkt_num;
DataMessageBuilder::default().with_immediate_ack()
} else {
DataMessageBuilder::default()
}
.with_max_payload_size(self.max_payload_size)
.with_dst_id(self.dst_id)
.with_key_context(self.intro_key, &self.send_key_ctx)
.with_message(pkt_num, segment.into())
.with_ack(highest_seen, num_acks, ranges.as_deref())
.build::<R>(),
address,
)
})
.collect(),
)
}
pub fn drain_expired(&mut self) -> Option<Vec<(BytesMut, Option<SocketAddr>)>> {
let expired = self
.segments
.iter()
.filter_map(|(pkt_num, segment)| {
(segment.sent.elapsed() > (*self.rto * segment.num_sent as u32)).then_some(*pkt_num)
})
.collect::<Vec<_>>();
if expired.is_empty() {
return None;
}
let pkts_to_resend = expired
.into_iter()
.filter_map(|old_pkt_num| {
let Segment {
num_sent,
segment,
sent,
} = self.segments.remove(&old_pkt_num).expect("to exist");
if num_sent + 1 > RESEND_TERMINATION_THRESHOLD {
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
pkt_num = ?old_pkt_num,
"packet has been sent over {} times, dropping",
RESEND_TERMINATION_THRESHOLD,
);
self.metrics
.counter(DROPPED_PKTS)
.increment_with_label(1, "reason", "rt-limit");
return None;
}
let pkt_num = self.next_pkt_num();
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
?old_pkt_num,
new_pkt_num = ?pkt_num,
"resend packet",
);
self.segments.insert(
pkt_num,
Segment {
num_sent: num_sent + 1,
segment,
sent,
},
);
Some(pkt_num)
})
.collect::<Vec<_>>();
let pkts_to_resend = pkts_to_resend
.into_iter()
.take(self.window_size.saturating_sub(self.segments.len()))
.collect::<Vec<_>>();
if pkts_to_resend.is_empty() {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
"one or more packets need to be resent but no window",
);
return None;
}
{
self.window_size /= 2;
if self.window_size < MIN_WINDOW_SIZE {
self.window_size = MIN_WINDOW_SIZE;
}
}
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
num_pkts = ?pkts_to_resend.len(),
pkts = ?pkts_to_resend,
window = ?self.window_size,
"resend packets",
);
self.metrics.counter(RETRANSMISSION_COUNT).increment(pkts_to_resend.len());
self.last_immediate_ack = *pkts_to_resend.last().expect("to exist");
let AckInfo {
highest_seen,
num_acks,
ranges,
} = self.remote_ack_manager.ack_info();
Some(
pkts_to_resend
.into_iter()
.map(|pkt_num| {
let segment = &self.segments.get(&pkt_num).expect("to exist").segment;
let address = segment.address();
(
DataMessageBuilder::default()
.with_dst_id(self.dst_id)
.with_key_context(self.intro_key, &self.send_key_ctx)
.with_message(pkt_num, segment.into())
.with_immediate_ack()
.with_ack(highest_seen, num_acks, ranges.as_deref())
.build::<R>(),
address,
)
})
.collect(),
)
}
pub fn build_explicit_ack(&mut self) -> BytesMut {
let AckInfo {
highest_seen,
num_acks,
ranges,
} = self.remote_ack_manager.ack_info();
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
?highest_seen,
?num_acks,
?ranges,
"send explicit ack",
);
DataMessageBuilder::default()
.with_dst_id(self.dst_id)
.with_key_context(self.intro_key, &self.send_key_ctx)
.with_pkt_num(self.next_pkt_num())
.with_ack(highest_seen, num_acks, ranges.as_deref())
.build::<R>()
}
pub fn throttle(&mut self) {
self.throttle = Throttled::Yes {
max_payload_size: self.max_payload_size,
};
self.max_payload_size = MIN_MTU;
self.window_size = MIN_WINDOW_SIZE;
}
pub fn unthrottle(&mut self, reset_timers: bool) {
match core::mem::replace(&mut self.throttle, Throttled::No) {
Throttled::No => {}
Throttled::Yes { max_payload_size } => self.max_payload_size = max_payload_size,
}
if reset_timers {
self.rto = RetransmissionTimeout::Unsampled;
}
}
}
#[cfg(test)]
impl<R: Runtime> TransmissionManager<R> {
pub fn is_throttled(&self) -> bool {
self.throttle.is_throttled()
}
pub fn add_rtt_sample(&mut self, sample: Duration) {
self.rto.calculate_rto(sample);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
runtime::mock::MockRuntime,
transport::ssu2::message::{HeaderKind, HeaderReader},
};
#[tokio::test]
async fn ack_one_packet() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![1, 2, 3],
..Default::default()
});
assert_eq!(mgr.pending.len(), 1);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 1);
mgr.register_ack(1u32, 0u8, &[]);
assert!(mgr.segments.is_empty());
}
#[tokio::test]
async fn ack_multiple_packets_last_packet_missing() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 3 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 4);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 4);
assert_eq!(mgr.segments.len(), 4);
mgr.register_ack(4u32, 2u8, &[]);
assert_eq!(mgr.segments.len(), 1);
assert!(mgr.segments.contains_key(&1));
}
#[tokio::test]
async fn ack_multiple_packets_first_packet_missing() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 3 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 4);
assert_eq!(mgr.drain().unwrap().len(), 4);
mgr.register_ack(3u32, 2u8, &[]);
assert_eq!(mgr.segments.len(), 1);
assert!(mgr.segments.contains_key(&4));
}
#[tokio::test]
async fn ack_multiple_packets_middle_packets_nacked() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 3 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 4);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 4);
assert_eq!(mgr.segments.len(), 4);
mgr.register_ack(4u32, 0u8, &[(2, 1)]);
assert_eq!(mgr.segments.len(), 2);
assert!(mgr.segments.contains_key(&3));
assert!(mgr.segments.contains_key(&2));
}
#[tokio::test]
async fn multiple_ranges() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 10 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 11);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 11);
assert_eq!(mgr.segments.len(), 11);
mgr.register_ack(11u32, 2u8, &[(3, 2), (1, 2)]);
assert_eq!(mgr.segments.len(), 4);
assert!((6..=8).all(|i| mgr.segments.contains_key(&i)));
assert!(mgr.segments.contains_key(&3));
}
#[tokio::test]
async fn alternating() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 0u8, &[(1, 1), (1, 1), (1, 1), (1, 1), (1, 0)]);
assert_eq!(mgr.segments.len(), 5);
assert!((1..=9).all(|i| if i % 2 != 0 {
mgr.segments.contains_key(&i)
} else {
!mgr.segments.contains_key(&i)
}));
}
#[tokio::test]
async fn no_ranges() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 0u8, &[]);
assert_eq!(mgr.segments.len(), 9);
assert!((1..=9).all(|i| mgr.segments.contains_key(&i)));
}
#[tokio::test]
async fn highest_pkts_not_received() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(4u32, 0u8, &[(1, 2)]);
assert_eq!(mgr.segments.len(), 7);
assert!((5..=10).all(|i| mgr.segments.contains_key(&i)));
assert!(mgr.segments.contains_key(&3));
}
#[tokio::test]
async fn invalid_nack_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 0u8, &[(2, 0), (2, 0), (2, 0), (2, 0), (1, 0)]);
assert_eq!(mgr.segments.len(), 9);
assert!((1..=9).all(|i| mgr.segments.contains_key(&i)));
assert!(mgr.segments.contains_key(&3));
}
#[tokio::test]
async fn invalid_ack_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 0u8, &[(0, 2), (0, 2), (0, 2), (0, 2), (0, 1)]);
assert!(mgr.segments.is_empty());
}
#[tokio::test]
async fn num_acks_out_of_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 128u8, &[]);
assert!(mgr.segments.is_empty());
}
#[tokio::test]
async fn nacks_out_of_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 0u8, &[(128u8, 0)]);
assert_eq!(mgr.segments.len(), 9);
assert!((1..=9).all(|i| mgr.segments.contains_key(&i)));
}
#[tokio::test]
async fn acks_out_of_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(10u32, 0u8, &[(0, 128u8), (128u8, 0u8)]);
assert!(mgr.segments.is_empty());
}
#[tokio::test]
async fn highest_seen_out_of_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(1337u32, 10u8, &[]);
assert_eq!(mgr.segments.len(), 10);
}
#[tokio::test]
async fn num_ack_out_of_range() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
mgr.register_ack(15u32, 255, &[]);
assert!(mgr.segments.is_empty());
}
#[tokio::test]
async fn nothing_to_resend() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
assert!(mgr.drain_expired().is_none());
}
#[tokio::test]
async fn packets_resent() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
assert!(mgr.drain_expired().is_none());
tokio::time::sleep(INITIAL_RTO + Duration::from_millis(10)).await;
let pkt_nums = mgr
.drain_expired()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert!(pkt_nums
.into_iter()
.all(|pkt_num| mgr.segments.get(&pkt_num).unwrap().num_sent == 2));
}
#[tokio::test(start_paused = true)]
async fn some_packets_resent() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 8],
..Default::default()
});
assert_eq!(mgr.pending.len(), 8);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 8);
assert_eq!(mgr.segments.len(), 8);
assert!(mgr.drain_expired().is_none());
tokio::time::sleep(INITIAL_RTO + Duration::from_millis(10)).await;
let pkt_nums = mgr
.drain_expired()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert_eq!(pkt_nums.len(), 8);
assert!(pkt_nums.iter().all(|pkt_num| mgr.segments.get(&pkt_num).unwrap().num_sent == 2));
mgr.register_ack(20, 3, &[(2, 2), (2, 0)]);
tokio::time::sleep(2 * INITIAL_RTO + Duration::from_millis(10)).await;
let pkt_nums = mgr
.drain_expired()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert_eq!(pkt_nums.len(), 6);
assert!(pkt_nums.iter().all(|pkt_num| mgr.segments.get(&pkt_num).unwrap().num_sent == 3));
mgr.register_ack(24, 3, &[]);
tokio::time::sleep(2 * INITIAL_RTO + Duration::from_millis(10)).await;
let pkt_nums = mgr
.drain_expired()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert!(pkt_nums.iter().all(|pkt_num| mgr.segments.get(&pkt_num).unwrap().num_sent == 4));
mgr.register_ack(26, 4, &[]);
assert!(mgr.segments.is_empty());
}
#[tokio::test]
async fn window_size_increases() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 9 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 10);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), 10);
assert_eq!(mgr.segments.len(), 10);
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE);
mgr.register_ack(10, 3, &[(5, 1)]);
assert_eq!(mgr.segments.len(), 5);
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE + 5);
mgr.register_ack(6, 4, &[]);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE + 10);
}
#[tokio::test(start_paused = true)]
async fn window_size_decreases() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 15 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), MIN_WINDOW_SIZE);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), MIN_WINDOW_SIZE);
assert_eq!(mgr.segments.len(), MIN_WINDOW_SIZE);
assert!(mgr.drain_expired().is_none());
mgr.register_ack(8, 7, &[]);
assert_eq!(mgr.segments.len(), 8);
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE + 8);
tokio::time::sleep(INITIAL_RTO + Duration::from_millis(10)).await;
let pkt_nums = mgr
.drain_expired()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert!(pkt_nums.iter().all(|pkt_num| mgr.segments.get(&pkt_num).unwrap().num_sent == 2));
assert_eq!(pkt_nums.len(), 8);
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE);
tokio::time::sleep(2 * INITIAL_RTO + Duration::from_millis(10)).await;
let pkt_nums = mgr
.drain_expired()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert!(pkt_nums.iter().all(|pkt_num| mgr.segments.get(&pkt_num).unwrap().num_sent == 3));
assert_eq!(pkt_nums.len(), 8);
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE);
}
#[tokio::test]
async fn excess_packets_stay_in_pending() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 31 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 2 * MIN_WINDOW_SIZE);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), MIN_WINDOW_SIZE);
assert_eq!(mgr.segments.len(), MIN_WINDOW_SIZE);
assert_eq!(mgr.pending.len(), MIN_WINDOW_SIZE);
assert!(mgr.drain_expired().is_none());
assert!(!mgr.has_capacity());
mgr.register_ack(16, 15, &[]);
assert!(mgr.segments.is_empty());
assert_eq!(mgr.window_size, 2 * MIN_WINDOW_SIZE);
assert!(mgr.has_capacity());
let pkt_nums = mgr
.drain()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert_eq!(pkt_nums.len(), 16);
assert_eq!(mgr.segments.len(), 16);
assert!(mgr.pending.is_empty());
assert!(mgr.has_capacity()); }
#[tokio::test]
async fn pending_packets_partially_sent() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1472,
);
mgr.schedule(Message {
payload: vec![0u8; 1200 * 39 + 512],
..Default::default()
});
assert_eq!(mgr.pending.len(), 40);
assert_eq!(mgr.segments.len(), 0);
assert_eq!(mgr.drain().unwrap().len(), MIN_WINDOW_SIZE);
assert_eq!(mgr.segments.len(), MIN_WINDOW_SIZE);
assert_eq!(mgr.pending.len(), 40 - MIN_WINDOW_SIZE);
assert!(mgr.drain_expired().is_none());
assert!(!mgr.has_capacity());
mgr.register_ack(16, 5, &[]);
assert!(!mgr.segments.is_empty());
assert_eq!(mgr.window_size, MIN_WINDOW_SIZE + 6);
assert!(!mgr.pending.is_empty());
assert!(!mgr.has_capacity());
let pkt_nums = mgr
.drain()
.unwrap()
.into_iter()
.map(|(mut pkt, _)| {
let mut reader = HeaderReader::new(mgr.intro_key, &mut pkt).unwrap();
let _dst_id = reader.dst_id();
match reader.parse(mgr.send_key_ctx.k_header_2).unwrap() {
HeaderKind::Data { pkt_num, .. } => pkt_num,
_ => panic!("invalid pkt"),
}
})
.collect::<Vec<_>>();
assert_eq!(pkt_nums.len(), 12);
assert_eq!(mgr.segments.len(), MIN_WINDOW_SIZE + 6);
assert!(!mgr.pending.is_empty());
assert!(!mgr.has_capacity());
}
#[tokio::test]
async fn peer_test_with_router_info() {
let mut mgr = TransmissionManager::<MockRuntime>::new(
1337u64,
RouterId::random(),
[0xaa; 32],
KeyContext {
k_data: [0xbb; 32],
k_header_2: [0xcc; 32],
},
Arc::new(AtomicU32::new(1u32)),
MockRuntime::register_metrics(Vec::new(), None),
1200,
);
let block = PeerTestBlock::AliceRequest {
message: vec![0xaa; 20],
signature: vec![0xbb; 64],
};
let max_size = 1200 - SSU2_OVERHEAD - RI_BLOCK_OVERHEAD - block.serialized_len();
mgr.schedule((block, vec![0xaa; max_size]));
assert_eq!(mgr.pending.len(), 1);
assert_eq!(mgr.drain().unwrap().len(), 1);
let block = PeerTestBlock::AliceRequest {
message: vec![0xaa; 20],
signature: vec![0xbb; 64],
};
let max_size = 1200 - SSU2_OVERHEAD - RI_BLOCK_OVERHEAD - block.serialized_len();
mgr.schedule((block, vec![0xaa; max_size + 1]));
assert_eq!(mgr.pending.len(), 2);
assert_eq!(mgr.drain().unwrap().len(), 2);
}
}