1use crate::{frame::Tag, inet::ExplicitCongestionNotification, number::CheckedSub, varint::VarInt};
5use core::{
6 convert::TryInto,
7 ops::{RangeInclusive, SubAssign},
8};
9use s2n_codec::{
10 decoder_parameterized_value, decoder_value, DecoderBuffer, DecoderError, Encoder, EncoderValue,
11};
12
13macro_rules! ack_tag {
22 () => {
23 0x02u8..=0x03u8
24 };
25}
26const ACK_TAG: u8 = 0x02;
27const ACK_W_ECN_TAG: u8 = 0x03;
28
29#[derive(Clone, PartialEq, Eq)]
74pub struct Ack<AckRanges> {
75 pub ack_delay: VarInt,
80
81 pub ack_ranges: AckRanges,
84
85 pub ecn_counts: Option<EcnCounts>,
87}
88
89impl<AckRanges> Ack<AckRanges> {
90 #[inline]
91 pub fn tag(&self) -> u8 {
92 if self.ecn_counts.is_some() {
93 ACK_W_ECN_TAG
94 } else {
95 ACK_TAG
96 }
97 }
98}
99
100impl<A: AckRanges> Ack<A> {
101 #[inline]
102 pub fn ack_delay(&self) -> core::time::Duration {
103 core::time::Duration::from_micros(self.ack_delay.as_u64())
104 }
105
106 #[inline]
107 pub fn ack_ranges(&self) -> A::Iter {
108 self.ack_ranges.ack_ranges()
109 }
110
111 #[inline]
112 pub fn largest_acknowledged(&self) -> VarInt {
113 self.ack_ranges.largest_acknowledged()
114 }
115}
116
117impl<A: core::fmt::Debug> core::fmt::Debug for Ack<A> {
118 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
119 f.debug_struct("Ack")
120 .field("ack_delay", &self.ack_delay)
121 .field("ack_ranges", &self.ack_ranges)
122 .field("ecn_counts", &self.ecn_counts)
123 .finish()
124 }
125}
126
127decoder_parameterized_value!(
128 impl<'a> Ack<AckRangesDecoder<'a>> {
129 fn decode(tag: Tag, buffer: Buffer) -> Result<Self> {
130 let (largest_acknowledged, buffer) = buffer.decode()?;
131 let (ack_delay, buffer) = buffer.decode()?;
132 let (ack_ranges, buffer) = buffer.decode_parameterized(largest_acknowledged)?;
133
134 let (ecn_counts, buffer) = if tag == ACK_W_ECN_TAG {
135 let (ecn_counts, buffer) = buffer.decode()?;
136 (Some(ecn_counts), buffer)
137 } else {
138 (None, buffer)
139 };
140
141 let frame = Ack {
142 ack_delay,
143 ack_ranges,
144 ecn_counts,
145 };
146
147 Ok((frame, buffer))
148 }
149 }
150);
151
152impl<A: AckRanges> EncoderValue for Ack<A> {
153 #[inline]
154 fn encode<E: Encoder>(&self, buffer: &mut E) {
155 buffer.encode(&self.tag());
156
157 let mut iter = self.ack_ranges.ack_ranges();
158
159 let first_ack_range = iter.next().expect("at least one ack range is required");
160 let (mut smallest, largest_acknowledged) = first_ack_range.into_inner();
161 let first_ack_range = largest_acknowledged - smallest;
162
163 let ack_range_count: VarInt = iter
164 .len()
165 .try_into()
166 .expect("ack range count cannot exceed VarInt::MAX");
167
168 buffer.encode(&largest_acknowledged);
169 buffer.encode(&self.ack_delay);
170 buffer.encode(&ack_range_count);
171 buffer.encode(&first_ack_range);
172
173 for range in iter {
174 smallest = encode_ack_range(range, smallest, buffer);
175 }
176
177 if let Some(ecn_counts) = self.ecn_counts.as_ref() {
178 buffer.encode(ecn_counts);
179 }
180 }
181}
182
183pub trait AckRanges {
191 type Iter: Iterator<Item = RangeInclusive<VarInt>> + ExactSizeIterator;
192
193 fn ack_ranges(&self) -> Self::Iter;
194
195 #[inline]
196 fn largest_acknowledged(&self) -> VarInt {
197 *self
198 .ack_ranges()
199 .next()
200 .expect("at least one ack range is required")
201 .end()
202 }
203}
204
205#[derive(Clone, Copy)]
206pub struct AckRangesDecoder<'a> {
207 largest_acknowledged: VarInt,
208 ack_range_count: VarInt,
209 range_buffer: DecoderBuffer<'a>,
210}
211
212impl<'a> AckRanges for AckRangesDecoder<'a> {
213 type Iter = AckRangesIter<'a>;
214
215 #[inline]
216 fn ack_ranges(&self) -> Self::Iter {
217 AckRangesIter {
218 largest_acknowledged: self.largest_acknowledged,
219 ack_range_count: self.ack_range_count,
220 range_buffer: self.range_buffer,
221 }
222 }
223
224 #[inline]
225 fn largest_acknowledged(&self) -> VarInt {
226 self.largest_acknowledged
227 }
228}
229
230impl PartialEq for AckRangesDecoder<'_> {
231 #[inline]
232 fn eq(&self, other: &Self) -> bool {
233 self.ack_ranges().eq(other.ack_ranges())
234 }
235}
236
237impl core::fmt::Debug for AckRangesDecoder<'_> {
238 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
239 core::fmt::Debug::fmt(&self.ack_ranges(), f)
240 }
241}
242
243decoder_parameterized_value!(
284 impl<'a> AckRangesDecoder<'a> {
285 fn decode(largest_acknowledged: VarInt, buffer: Buffer) -> Result<AckRangesDecoder<'a>> {
286 let (mut ack_range_count, buffer) = buffer.decode::<VarInt>()?;
287
288 ack_range_count = ack_range_count
290 .checked_add(VarInt::from_u8(1))
291 .ok_or(ACK_RANGE_DECODING_ERROR)?;
292
293 let mut iter = AckRangesIter {
294 ack_range_count,
295 range_buffer: buffer.peek(),
296 largest_acknowledged,
297 };
298
299 for _ in 0..*ack_range_count {
301 iter.next().ok_or(ACK_RANGE_DECODING_ERROR)?;
302 }
303
304 let peek_len = iter.range_buffer.len();
305 let buffer_len = buffer.len();
306 debug_assert!(
307 buffer_len >= peek_len,
308 "peeked buffer should never consume more than actual buffer"
309 );
310 let (range_buffer, remaining) = buffer.decode_slice(buffer_len - peek_len)?;
311
312 #[allow(clippy::useless_conversion)]
313 let range_buffer = range_buffer.into();
314
315 let ack_ranges = AckRangesDecoder {
316 largest_acknowledged,
317 ack_range_count,
318 range_buffer,
319 };
320
321 Ok((ack_ranges, remaining))
322 }
323 }
324);
325
326#[inline]
343fn encode_ack_range<E: Encoder>(
344 range: RangeInclusive<VarInt>,
345 smallest: VarInt,
346 buffer: &mut E,
347) -> VarInt {
348 let (start, end) = range.into_inner();
349 let gap = smallest - end - 2;
350 let ack_range = end - start;
351
352 buffer.encode(&gap);
353 buffer.encode(&ack_range);
354
355 start
356}
357
358#[derive(Clone, Copy)]
359pub struct AckRangesIter<'a> {
360 largest_acknowledged: VarInt,
361 ack_range_count: VarInt,
362 range_buffer: DecoderBuffer<'a>,
363}
364
365impl Iterator for AckRangesIter<'_> {
366 type Item = RangeInclusive<VarInt>;
367
368 #[inline]
369 fn next(&mut self) -> Option<Self::Item> {
370 self.ack_range_count = self.ack_range_count.checked_sub(VarInt::from_u8(1))?;
371
372 let largest_acknowledged = self.largest_acknowledged;
373 let (ack_range, buffer) = self.range_buffer.decode::<VarInt>().ok()?;
374
375 let start = largest_acknowledged.checked_sub(ack_range)?;
376 let end = largest_acknowledged;
377
378 self.range_buffer = if self.ack_range_count != VarInt::from_u8(0) {
380 let (gap, buffer) = buffer.decode::<VarInt>().ok()?;
381 self.largest_acknowledged = largest_acknowledged
382 .checked_sub(ack_range)?
383 .checked_sub(gap)?
384 .checked_sub(VarInt::from_u8(2))?;
385 buffer
386 } else {
387 buffer
388 };
389
390 Some(start..=end)
391 }
392
393 #[inline]
394 fn size_hint(&self) -> (usize, Option<usize>) {
395 let ack_range_count = *self.ack_range_count as usize;
396 (ack_range_count, Some(ack_range_count))
397 }
398}
399
400impl ExactSizeIterator for AckRangesIter<'_> {}
401
402impl core::fmt::Debug for AckRangesIter<'_> {
403 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
404 f.debug_list().entries(*self).finish()
405 }
406}
407
408const ACK_RANGE_DECODING_ERROR: DecoderError =
413 DecoderError::InvariantViolation("invalid ACK ranges");
414
415#[cfg(any(test, feature = "generator"))]
446use bolero_generator::prelude::*;
447
448#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
449#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))]
450pub struct EcnCounts {
451 pub ect_0_count: VarInt,
454
455 pub ect_1_count: VarInt,
458
459 pub ce_count: VarInt,
462}
463
464impl EcnCounts {
465 #[inline]
467 pub fn increment(&mut self, ecn: ExplicitCongestionNotification) {
468 match ecn {
469 ExplicitCongestionNotification::Ect0 => {
470 self.ect_0_count = self.ect_0_count.saturating_add(VarInt::from_u8(1))
471 }
472 ExplicitCongestionNotification::Ect1 => {
473 self.ect_1_count = self.ect_1_count.saturating_add(VarInt::from_u8(1))
474 }
475 ExplicitCongestionNotification::Ce => {
476 self.ce_count = self.ce_count.saturating_add(VarInt::from_u8(1))
477 }
478 ExplicitCongestionNotification::NotEct => {}
479 }
480 }
481
482 #[inline]
485 pub fn as_option(&self) -> Option<EcnCounts> {
486 if *self == Default::default() {
487 return None;
488 }
489
490 Some(*self)
491 }
492
493 #[must_use]
495 #[inline]
496 pub fn max(self, other: Self) -> Self {
497 EcnCounts {
498 ect_0_count: self.ect_0_count.max(other.ect_0_count),
499 ect_1_count: self.ect_1_count.max(other.ect_1_count),
500 ce_count: self.ce_count.max(other.ce_count),
501 }
502 }
503}
504
505impl SubAssign for EcnCounts {
506 #[inline]
507 fn sub_assign(&mut self, rhs: Self) {
508 self.ect_0_count = self.ect_0_count.saturating_sub(rhs.ect_0_count);
509 self.ect_1_count = self.ect_1_count.saturating_sub(rhs.ect_1_count);
510 self.ce_count = self.ce_count.saturating_sub(rhs.ce_count);
511 }
512}
513
514impl CheckedSub for EcnCounts {
515 type Output = EcnCounts;
516
517 #[inline]
518 fn checked_sub(self, rhs: Self) -> Option<Self::Output> {
519 let ect_0_count = self.ect_0_count.checked_sub(rhs.ect_0_count)?;
520 let ect_1_count = self.ect_1_count.checked_sub(rhs.ect_1_count)?;
521 let ce_count = self.ce_count.checked_sub(rhs.ce_count)?;
522
523 Some(EcnCounts {
524 ect_0_count,
525 ect_1_count,
526 ce_count,
527 })
528 }
529}
530
531decoder_value!(
532 impl<'a> EcnCounts {
533 fn decode(buffer: Buffer) -> Result<Self> {
534 let (ect_0_count, buffer) = buffer.decode()?;
535 let (ect_1_count, buffer) = buffer.decode()?;
536 let (ce_count, buffer) = buffer.decode()?;
537
538 let ecn_counts = Self {
539 ect_0_count,
540 ect_1_count,
541 ce_count,
542 };
543
544 Ok((ecn_counts, buffer))
545 }
546 }
547);
548
549impl EncoderValue for EcnCounts {
550 #[inline]
551 fn encode<E: Encoder>(&self, buffer: &mut E) {
552 buffer.encode(&self.ect_0_count);
553 buffer.encode(&self.ect_1_count);
554 buffer.encode(&self.ce_count);
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use crate::{frame::ack::EcnCounts, inet::ExplicitCongestionNotification};
561
562 #[test]
563 fn as_option() {
564 let mut ecn_counts = EcnCounts::default();
565
566 assert_eq!(None, ecn_counts.as_option());
567
568 ecn_counts.increment(ExplicitCongestionNotification::Ect0);
569 assert!(ecn_counts.as_option().is_some());
570
571 let mut ecn_counts = EcnCounts::default();
572 ecn_counts.increment(ExplicitCongestionNotification::Ect1);
573 assert!(ecn_counts.as_option().is_some());
574
575 let mut ecn_counts = EcnCounts::default();
576 ecn_counts.increment(ExplicitCongestionNotification::Ce);
577 assert!(ecn_counts.as_option().is_some());
578 }
579}