use super::*;
use crate::{
packet::number::{PacketNumber, PacketNumberSpace},
varint::VarInt,
};
use bolero::{check, generator::*};
use s2n_codec::{testing::encode, DecoderBuffer};
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(5))]
fn round_trip() {
check!()
.with_generator(
gen_packet_number_space()
.and_then(|space| (gen_packet_number(space), gen_packet_number(space))),
)
.cloned()
.for_each(|(packet_number, largest_acked_packet_number)| {
if let Some((mask, bytes)) =
encode_packet_number(packet_number, largest_acked_packet_number)
{
let actual_packet_number =
decode_packet_number(mask, bytes, largest_acked_packet_number).unwrap();
assert_eq!(actual_packet_number, packet_number);
}
});
}
fn gen_packet_number_space() -> impl ValueGenerator<Output = PacketNumberSpace> {
(0u8..=2).map_gen(|id| match id {
0 => PacketNumberSpace::Initial,
1 => PacketNumberSpace::Handshake,
2 => PacketNumberSpace::ApplicationData,
_ => unreachable!("invalid space id {:?}", id),
})
}
fn gen_packet_number(space: PacketNumberSpace) -> impl ValueGenerator<Output = PacketNumber> {
gen().map(move |packet_number| {
space.new_packet_number(match VarInt::new(packet_number) {
Ok(packet_number) => packet_number,
Err(_) => VarInt::from_u32(packet_number as u32),
})
})
}
fn encode_packet_number(
packet_number: PacketNumber,
largest_acked_packet_number: PacketNumber,
) -> Option<(u8, Vec<u8>)> {
let truncated_packet_number = packet_number.truncate(largest_acked_packet_number)?;
let bytes = encode(&truncated_packet_number).unwrap();
let mask = truncated_packet_number.len().into_packet_tag_mask();
Some((mask, bytes))
}
fn decode_packet_number(
packet_tag: u8,
packet_bytes: Vec<u8>,
largest_acked_packet_number: PacketNumber,
) -> Result<PacketNumber, String> {
let packet_number_len = largest_acked_packet_number
.space()
.new_packet_number_len(packet_tag);
assert_eq!(packet_number_len.into_packet_tag_mask(), packet_tag);
assert_eq!(packet_number_len.bytesize(), packet_bytes.len());
let (truncated_packet_number, _) = packet_number_len
.decode_truncated_packet_number(DecoderBuffer::new(&packet_bytes))
.map_err(|err| err.to_string())?;
assert_eq!(truncated_packet_number.len(), packet_number_len);
assert_eq!(packet_bytes, encode(&truncated_packet_number).unwrap());
let packet_number = truncated_packet_number.expand(largest_acked_packet_number);
let actual_truncated_packet_number = packet_number
.truncate(largest_acked_packet_number)
.ok_or_else(|| "Could not truncate packet number".to_string())?;
assert_eq!(actual_truncated_packet_number, truncated_packet_number);
Ok(packet_number)
}
fn new(value: VarInt) -> PacketNumber {
PacketNumberSpace::Initial.new_packet_number(value)
}
#[allow(clippy::blocks_in_if_conditions)]
fn rfc_decoder(largest_pn: u64, truncated_pn: u64, pn_nbits: usize) -> u64 {
macro_rules! catch {
($expr:expr) => {
(|| Some($expr))().unwrap_or(false)
};
}
let expected_pn = largest_pn + 1;
let pn_win = 1 << pn_nbits;
let pn_hwin = pn_win / 2;
let pn_mask = pn_win - 1;
let candidate_pn = (expected_pn & !pn_mask) | truncated_pn;
if catch!(
candidate_pn <= expected_pn.checked_sub(pn_hwin)?
&& candidate_pn < (1u64 << 62).checked_sub(pn_win)?
) {
return candidate_pn + pn_win;
}
if catch!(candidate_pn > expected_pn.checked_add(pn_hwin)? && candidate_pn >= pn_win) {
return candidate_pn - pn_win;
}
candidate_pn
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(5))]
fn truncate_expand_test() {
check!()
.with_type()
.cloned()
.for_each(|(largest_pn, expected_pn)| {
let largest_pn = new(largest_pn);
let expected_pn = new(expected_pn);
if let Some(truncated_pn) = expected_pn.truncate(largest_pn) {
assert_eq!(expected_pn, truncated_pn.expand(largest_pn));
}
});
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(5))]
fn rfc_differential_test() {
check!()
.with_type()
.cloned()
.for_each(|(largest_pn, truncated_pn)| {
let largest_pn = new(largest_pn);
let space = largest_pn.space();
let truncated_pn = TruncatedPacketNumber {
space,
value: truncated_pn,
};
let rfc_value = rfc_decoder(
largest_pn.as_u64(),
truncated_pn.into_u64(),
truncated_pn.bitsize(),
)
.min(VarInt::MAX.as_u64());
let actual_value = truncated_pn.expand(largest_pn).as_u64();
assert_eq!(
actual_value,
rfc_value,
"diff: {:?}",
actual_value
.checked_sub(rfc_value)
.or_else(|| rfc_value.checked_sub(actual_value))
);
});
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(5))]
fn example_test() {
macro_rules! example {
($largest:expr, $truncated:expr, $expected:expr) => {{
let largest = new(VarInt::from_u32($largest));
let truncated = TruncatedPacketNumber::new($truncated, PacketNumberSpace::Initial);
let expected = new(VarInt::from_u32($expected));
assert_eq!(truncated.expand(largest), expected);
}};
}
example!(0xa82e1b31, 0x9b32u16, 0xa82e9b32);
}
#[test]
#[cfg_attr(miri, ignore)] fn size_of_snapshots() {
use core::mem::size_of;
use insta::assert_debug_snapshot;
assert_debug_snapshot!("PacketNumber", size_of::<PacketNumber>());
assert_debug_snapshot!("PacketNumberLen", size_of::<PacketNumberLen>());
assert_debug_snapshot!("PacketNumberSpace", size_of::<PacketNumberSpace>());
assert_debug_snapshot!("ProtectedPacketNumber", size_of::<ProtectedPacketNumber>());
assert_debug_snapshot!("TruncatedPacketNumber", size_of::<TruncatedPacketNumber>());
}