use crate::{frame::Tag, inet::ExplicitCongestionNotification, number::CheckedSub, varint::VarInt};
use core::{
convert::TryInto,
ops::{RangeInclusive, SubAssign},
};
use s2n_codec::{
decoder_parameterized_value, decoder_value, DecoderBuffer, DecoderError, Encoder, EncoderValue,
};
macro_rules! ack_tag {
() => {
0x02u8..=0x03u8
};
}
const ACK_TAG: u8 = 0x02;
const ACK_W_ECN_TAG: u8 = 0x03;
#[derive(Clone, PartialEq, Eq)]
pub struct Ack<AckRanges> {
pub ack_delay: VarInt,
pub ack_ranges: AckRanges,
pub ecn_counts: Option<EcnCounts>,
}
impl<AckRanges> Ack<AckRanges> {
#[inline]
pub fn tag(&self) -> u8 {
if self.ecn_counts.is_some() {
ACK_W_ECN_TAG
} else {
ACK_TAG
}
}
}
impl<A: AckRanges> Ack<A> {
#[inline]
pub fn ack_delay(&self) -> core::time::Duration {
core::time::Duration::from_micros(self.ack_delay.as_u64())
}
#[inline]
pub fn ack_ranges(&self) -> A::Iter {
self.ack_ranges.ack_ranges()
}
#[inline]
pub fn largest_acknowledged(&self) -> VarInt {
self.ack_ranges.largest_acknowledged()
}
}
impl<A: core::fmt::Debug> core::fmt::Debug for Ack<A> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("Ack")
.field("ack_delay", &self.ack_delay)
.field("ack_ranges", &self.ack_ranges)
.field("ecn_counts", &self.ecn_counts)
.finish()
}
}
decoder_parameterized_value!(
impl<'a> Ack<AckRangesDecoder<'a>> {
fn decode(tag: Tag, buffer: Buffer) -> Result<Self> {
let (largest_acknowledged, buffer) = buffer.decode()?;
let (ack_delay, buffer) = buffer.decode()?;
let (ack_ranges, buffer) = buffer.decode_parameterized(largest_acknowledged)?;
let (ecn_counts, buffer) = if tag == ACK_W_ECN_TAG {
let (ecn_counts, buffer) = buffer.decode()?;
(Some(ecn_counts), buffer)
} else {
(None, buffer)
};
let frame = Ack {
ack_delay,
ack_ranges,
ecn_counts,
};
Ok((frame, buffer))
}
}
);
impl<A: AckRanges> EncoderValue for Ack<A> {
#[inline]
fn encode<E: Encoder>(&self, buffer: &mut E) {
buffer.encode(&self.tag());
let mut iter = self.ack_ranges.ack_ranges();
let first_ack_range = iter.next().expect("at least one ack range is required");
let (mut smallest, largest_acknowledged) = first_ack_range.into_inner();
let first_ack_range = largest_acknowledged - smallest;
let ack_range_count: VarInt = iter
.len()
.try_into()
.expect("ack range count cannot exceed VarInt::MAX");
buffer.encode(&largest_acknowledged);
buffer.encode(&self.ack_delay);
buffer.encode(&ack_range_count);
buffer.encode(&first_ack_range);
for range in iter {
smallest = encode_ack_range(range, smallest, buffer);
}
if let Some(ecn_counts) = self.ecn_counts.as_ref() {
buffer.encode(ecn_counts);
}
}
}
pub trait AckRanges {
type Iter: Iterator<Item = RangeInclusive<VarInt>> + ExactSizeIterator;
fn ack_ranges(&self) -> Self::Iter;
#[inline]
fn largest_acknowledged(&self) -> VarInt {
*self
.ack_ranges()
.next()
.expect("at least one ack range is required")
.end()
}
}
#[derive(Clone, Copy)]
pub struct AckRangesDecoder<'a> {
largest_acknowledged: VarInt,
ack_range_count: VarInt,
range_buffer: DecoderBuffer<'a>,
}
impl<'a> AckRanges for AckRangesDecoder<'a> {
type Iter = AckRangesIter<'a>;
#[inline]
fn ack_ranges(&self) -> Self::Iter {
AckRangesIter {
largest_acknowledged: self.largest_acknowledged,
ack_range_count: self.ack_range_count,
range_buffer: self.range_buffer,
}
}
#[inline]
fn largest_acknowledged(&self) -> VarInt {
self.largest_acknowledged
}
}
impl PartialEq for AckRangesDecoder<'_> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.ack_ranges().eq(other.ack_ranges())
}
}
impl core::fmt::Debug for AckRangesDecoder<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
core::fmt::Debug::fmt(&self.ack_ranges(), f)
}
}
decoder_parameterized_value!(
impl<'a> AckRangesDecoder<'a> {
fn decode(largest_acknowledged: VarInt, buffer: Buffer) -> Result<AckRangesDecoder<'a>> {
let (mut ack_range_count, buffer) = buffer.decode::<VarInt>()?;
ack_range_count = ack_range_count
.checked_add(VarInt::from_u8(1))
.ok_or(ACK_RANGE_DECODING_ERROR)?;
let mut iter = AckRangesIter {
ack_range_count,
range_buffer: buffer.peek(),
largest_acknowledged,
};
for _ in 0..*ack_range_count {
iter.next().ok_or(ACK_RANGE_DECODING_ERROR)?;
}
let peek_len = iter.range_buffer.len();
let buffer_len = buffer.len();
debug_assert!(
buffer_len >= peek_len,
"peeked buffer should never consume more than actual buffer"
);
let (range_buffer, remaining) = buffer.decode_slice(buffer_len - peek_len)?;
#[allow(clippy::useless_conversion)]
let range_buffer = range_buffer.into();
let ack_ranges = AckRangesDecoder {
largest_acknowledged,
ack_range_count,
range_buffer,
};
Ok((ack_ranges, remaining))
}
}
);
#[inline]
fn encode_ack_range<E: Encoder>(
range: RangeInclusive<VarInt>,
smallest: VarInt,
buffer: &mut E,
) -> VarInt {
let (start, end) = range.into_inner();
let gap = smallest - end - 2;
let ack_range = end - start;
buffer.encode(&gap);
buffer.encode(&ack_range);
start
}
#[derive(Clone, Copy)]
pub struct AckRangesIter<'a> {
largest_acknowledged: VarInt,
ack_range_count: VarInt,
range_buffer: DecoderBuffer<'a>,
}
impl Iterator for AckRangesIter<'_> {
type Item = RangeInclusive<VarInt>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.ack_range_count = self.ack_range_count.checked_sub(VarInt::from_u8(1))?;
let largest_acknowledged = self.largest_acknowledged;
let (ack_range, buffer) = self.range_buffer.decode::<VarInt>().ok()?;
let start = largest_acknowledged.checked_sub(ack_range)?;
let end = largest_acknowledged;
self.range_buffer = if self.ack_range_count != VarInt::from_u8(0) {
let (gap, buffer) = buffer.decode::<VarInt>().ok()?;
self.largest_acknowledged = largest_acknowledged
.checked_sub(ack_range)?
.checked_sub(gap)?
.checked_sub(VarInt::from_u8(2))?;
buffer
} else {
buffer
};
Some(start..=end)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let ack_range_count = *self.ack_range_count as usize;
(ack_range_count, Some(ack_range_count))
}
}
impl ExactSizeIterator for AckRangesIter<'_> {}
impl core::fmt::Debug for AckRangesIter<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_list().entries(*self).finish()
}
}
const ACK_RANGE_DECODING_ERROR: DecoderError =
DecoderError::InvariantViolation("invalid ACK ranges");
#[cfg(any(test, feature = "generator"))]
use bolero_generator::prelude::*;
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))]
pub struct EcnCounts {
pub ect_0_count: VarInt,
pub ect_1_count: VarInt,
pub ce_count: VarInt,
}
impl EcnCounts {
#[inline]
pub fn increment(&mut self, ecn: ExplicitCongestionNotification) {
match ecn {
ExplicitCongestionNotification::Ect0 => {
self.ect_0_count = self.ect_0_count.saturating_add(VarInt::from_u8(1))
}
ExplicitCongestionNotification::Ect1 => {
self.ect_1_count = self.ect_1_count.saturating_add(VarInt::from_u8(1))
}
ExplicitCongestionNotification::Ce => {
self.ce_count = self.ce_count.saturating_add(VarInt::from_u8(1))
}
ExplicitCongestionNotification::NotEct => {}
}
}
#[inline]
pub fn as_option(&self) -> Option<EcnCounts> {
if *self == Default::default() {
return None;
}
Some(*self)
}
#[must_use]
#[inline]
pub fn max(self, other: Self) -> Self {
EcnCounts {
ect_0_count: self.ect_0_count.max(other.ect_0_count),
ect_1_count: self.ect_1_count.max(other.ect_1_count),
ce_count: self.ce_count.max(other.ce_count),
}
}
}
impl SubAssign for EcnCounts {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
self.ect_0_count = self.ect_0_count.saturating_sub(rhs.ect_0_count);
self.ect_1_count = self.ect_1_count.saturating_sub(rhs.ect_1_count);
self.ce_count = self.ce_count.saturating_sub(rhs.ce_count);
}
}
impl CheckedSub for EcnCounts {
type Output = EcnCounts;
#[inline]
fn checked_sub(self, rhs: Self) -> Option<Self::Output> {
let ect_0_count = self.ect_0_count.checked_sub(rhs.ect_0_count)?;
let ect_1_count = self.ect_1_count.checked_sub(rhs.ect_1_count)?;
let ce_count = self.ce_count.checked_sub(rhs.ce_count)?;
Some(EcnCounts {
ect_0_count,
ect_1_count,
ce_count,
})
}
}
decoder_value!(
impl<'a> EcnCounts {
fn decode(buffer: Buffer) -> Result<Self> {
let (ect_0_count, buffer) = buffer.decode()?;
let (ect_1_count, buffer) = buffer.decode()?;
let (ce_count, buffer) = buffer.decode()?;
let ecn_counts = Self {
ect_0_count,
ect_1_count,
ce_count,
};
Ok((ecn_counts, buffer))
}
}
);
impl EncoderValue for EcnCounts {
#[inline]
fn encode<E: Encoder>(&self, buffer: &mut E) {
buffer.encode(&self.ect_0_count);
buffer.encode(&self.ect_1_count);
buffer.encode(&self.ce_count);
}
}
#[cfg(test)]
mod tests {
use crate::{frame::ack::EcnCounts, inet::ExplicitCongestionNotification};
#[test]
fn as_option() {
let mut ecn_counts = EcnCounts::default();
assert_eq!(None, ecn_counts.as_option());
ecn_counts.increment(ExplicitCongestionNotification::Ect0);
assert!(ecn_counts.as_option().is_some());
let mut ecn_counts = EcnCounts::default();
ecn_counts.increment(ExplicitCongestionNotification::Ect1);
assert!(ecn_counts.as_option().is_some());
let mut ecn_counts = EcnCounts::default();
ecn_counts.increment(ExplicitCongestionNotification::Ce);
assert!(ecn_counts.as_option().is_some());
}
}