async_std_utp/
packet.rs

1#![allow(dead_code)]
2
3use crate::bit_iterator::BitIterator;
4use crate::error::ParseError;
5use crate::time::{Delay, Timestamp};
6use std::fmt;
7
8pub const HEADER_SIZE: usize = 20;
9
10macro_rules! u8_to_unsigned_be {
11    ($src:ident, $start:expr, $end:expr, $t:ty) => ({
12        (0 .. $end - $start + 1).rev().fold(0, |acc, i| acc | $src[$start+i] as $t << (i * 8))
13    })
14}
15
16macro_rules! make_getter {
17    ($name:ident, $t:ty, $m:ident) => {
18        pub fn $name(&self) -> $t {
19            let header = unsafe { &*(self.0.as_ptr() as *const PacketHeader) };
20            $m::from_be(header.$name)
21        }
22    };
23}
24
25macro_rules! make_setter {
26    ($fn_name:ident, $field:ident, $t: ty) => {
27        pub fn $fn_name(&mut self, new: $t) {
28            let mut header = unsafe { &mut *(self.0.as_mut_ptr() as *mut PacketHeader) };
29            header.$field = new.to_be();
30        }
31    };
32}
33
34/// Attempt to construct `Self` through conversion.
35///
36/// Waiting for rust-lang/rust#33417 to become stable.
37pub trait TryFrom<T>: Sized {
38    type Err;
39    fn try_from(_: T) -> Result<Self, Self::Err>;
40}
41
42#[derive(PartialEq, Eq, Debug)]
43pub enum PacketType {
44    Data,  // packet carries a data payload
45    Fin,   // signals the end of a connection
46    State, // signals acknowledgment of a packet
47    Reset, // forcibly terminates a connection
48    Syn,   // initiates a new connection with a peer
49}
50
51impl TryFrom<u8> for PacketType {
52    type Err = ParseError;
53    fn try_from(original: u8) -> Result<Self, Self::Err> {
54        match original {
55            0 => Ok(PacketType::Data),
56            1 => Ok(PacketType::Fin),
57            2 => Ok(PacketType::State),
58            3 => Ok(PacketType::Reset),
59            4 => Ok(PacketType::Syn),
60            n => Err(ParseError::InvalidPacketType(n)),
61        }
62    }
63}
64
65impl From<PacketType> for u8 {
66    fn from(original: PacketType) -> u8 {
67        match original {
68            PacketType::Data => 0,
69            PacketType::Fin => 1,
70            PacketType::State => 2,
71            PacketType::Reset => 3,
72            PacketType::Syn => 4,
73        }
74    }
75}
76
77#[derive(PartialEq, Eq, Debug, Clone, Copy)]
78pub enum ExtensionType {
79    None,
80    SelectiveAck,
81    Unknown(u8),
82}
83
84impl From<u8> for ExtensionType {
85    fn from(original: u8) -> Self {
86        match original {
87            0 => ExtensionType::None,
88            1 => ExtensionType::SelectiveAck,
89            n => ExtensionType::Unknown(n),
90        }
91    }
92}
93
94impl From<ExtensionType> for u8 {
95    fn from(original: ExtensionType) -> u8 {
96        match original {
97            ExtensionType::None => 0,
98            ExtensionType::SelectiveAck => 1,
99            ExtensionType::Unknown(n) => n,
100        }
101    }
102}
103
104#[derive(Clone)]
105pub struct Extension<'a> {
106    ty: ExtensionType,
107    pub data: &'a [u8],
108}
109
110impl<'a> Extension<'a> {
111    pub fn len(&self) -> usize {
112        self.data.len()
113    }
114
115    pub fn get_type(&self) -> ExtensionType {
116        self.ty
117    }
118
119    pub fn iter(&self) -> BitIterator<'_> {
120        BitIterator::from_bytes(self.data)
121    }
122}
123
124#[repr(C)]
125struct PacketHeader {
126    type_ver: u8, // type: u4, ver: u4
127    extension: u8,
128    connection_id: u16,
129    // Both timestamps are in microseconds
130    timestamp: u32,
131    timestamp_difference: u32,
132    wnd_size: u32,
133    seq_nr: u16,
134    ack_nr: u16,
135}
136
137impl PacketHeader {
138    /// Sets the type of packet to the specified type.
139    pub fn set_type(&mut self, t: PacketType) {
140        let version = 0x0F & self.type_ver;
141        self.type_ver = u8::from(t) << 4 | version;
142    }
143
144    /// Returns the packet's type.
145    pub fn get_type(&self) -> PacketType {
146        PacketType::try_from(self.type_ver >> 4).unwrap()
147    }
148
149    /// Returns the packet's version.
150    pub fn get_version(&self) -> u8 {
151        self.type_ver & 0x0F
152    }
153
154    /// Returns the type of the first extension
155    pub fn get_extension_type(&self) -> ExtensionType {
156        self.extension.into()
157    }
158}
159
160impl AsRef<[u8]> for PacketHeader {
161    /// Returns the packet header as a slice of bytes.
162    fn as_ref(&self) -> &[u8] {
163        unsafe { &*(self as *const PacketHeader as *const [u8; HEADER_SIZE]) }
164    }
165}
166
167impl<'a> TryFrom<&'a [u8]> for PacketHeader {
168    type Err = ParseError;
169    /// Reads a byte buffer and returns the corresponding packet header.
170    /// It assumes the fields are in network (big-endian) byte order,
171    /// preserving it.
172    fn try_from(buf: &[u8]) -> Result<Self, Self::Err> {
173        // Check length
174        if buf.len() < HEADER_SIZE {
175            return Err(ParseError::InvalidPacketLength);
176        }
177
178        // Check version
179        if buf[0] & 0x0F != 1 {
180            return Err(ParseError::UnsupportedVersion);
181        }
182
183        // Check packet type
184        if let Err(e) = PacketType::try_from(buf[0] >> 4) {
185            return Err(e);
186        }
187
188        Ok(PacketHeader {
189            type_ver: buf[0],
190            extension: buf[1],
191            connection_id: u8_to_unsigned_be!(buf, 2, 3, u16),
192            timestamp: u8_to_unsigned_be!(buf, 4, 7, u32),
193            timestamp_difference: u8_to_unsigned_be!(buf, 8, 11, u32),
194            wnd_size: u8_to_unsigned_be!(buf, 12, 15, u32),
195            seq_nr: u8_to_unsigned_be!(buf, 16, 17, u16),
196            ack_nr: u8_to_unsigned_be!(buf, 18, 19, u16),
197        })
198    }
199}
200
201impl Default for PacketHeader {
202    fn default() -> PacketHeader {
203        PacketHeader {
204            type_ver: u8::from(PacketType::Data) << 4 | 1,
205            extension: 0,
206            connection_id: 0,
207            timestamp: 0,
208            timestamp_difference: 0,
209            wnd_size: 0,
210            seq_nr: 0,
211            ack_nr: 0,
212        }
213    }
214}
215
216pub struct Packet(Vec<u8>);
217
218impl AsRef<[u8]> for Packet {
219    fn as_ref(&self) -> &[u8] {
220        self.0.as_ref()
221    }
222}
223
224impl Packet {
225    /// Constructs a new, empty packet.
226    pub fn new() -> Packet {
227        Packet(PacketHeader::default().as_ref().to_owned())
228    }
229
230    /// Constructs a new data packet with the given payload.
231    pub fn with_payload(payload: &[u8]) -> Packet {
232        let mut inner = Vec::with_capacity(HEADER_SIZE + payload.len());
233        let mut header = PacketHeader::default();
234        header.set_type(PacketType::Data);
235        // inner.copy_from_slice(header.as_ref());
236        // inner.copy_from_slice(payload);
237        inner.extend_from_slice(header.as_ref());
238        inner.extend_from_slice(payload);
239
240        Packet(inner)
241    }
242
243    #[inline]
244    pub fn set_type(&mut self, t: PacketType) {
245        let header = unsafe { &mut *(self.0.as_mut_ptr() as *mut PacketHeader) };
246        header.set_type(t);
247    }
248
249    #[inline]
250    pub fn get_type(&self) -> PacketType {
251        let header = unsafe { &*(self.0.as_ptr() as *const PacketHeader) };
252        header.get_type()
253    }
254
255    pub fn get_version(&self) -> u8 {
256        let header = unsafe { &*(self.0.as_ptr() as *const PacketHeader) };
257        header.get_version()
258    }
259
260    pub fn get_extension_type(&self) -> ExtensionType {
261        let header = unsafe { &*(self.0.as_ptr() as *const PacketHeader) };
262        header.get_extension_type()
263    }
264
265    pub fn extensions(&self) -> ExtensionIterator<'_> {
266        ExtensionIterator::new(self)
267    }
268
269    pub fn payload(&self) -> &[u8] {
270        let mut index = HEADER_SIZE;
271        let mut extension_type = ExtensionType::from(self.0[1]);
272
273        // Consume known extensions and skip over unknown ones
274        while index < self.0.len() && extension_type != ExtensionType::None {
275            let len = self.0[index + 1] as usize;
276
277            // Assume extension is valid because the bytes come from a (valid) Packet
278            // ...
279
280            extension_type = ExtensionType::from(self.0[index]);
281            index += len + 2;
282        }
283
284        &self.0[index..]
285    }
286
287    pub fn timestamp(&self) -> Timestamp {
288        let header = unsafe { &*(self.0.as_ptr() as *const PacketHeader) };
289        u32::from_be(header.timestamp).into()
290    }
291
292    pub fn set_timestamp(&mut self, timestamp: Timestamp) {
293        let header = unsafe { &mut *(self.0.as_mut_ptr() as *mut PacketHeader) };
294        header.timestamp = u32::from(timestamp).to_be();
295    }
296
297    pub fn timestamp_difference(&self) -> Delay {
298        let header = unsafe { &*(self.0.as_ptr() as *const PacketHeader) };
299        u32::from_be(header.timestamp_difference).into()
300    }
301
302    pub fn set_timestamp_difference(&mut self, delay: Delay) {
303        let header = unsafe { &mut *(self.0.as_mut_ptr() as *mut PacketHeader) };
304        header.timestamp_difference = u32::from(delay).to_be();
305    }
306
307    make_getter!(seq_nr, u16, u16);
308    make_getter!(ack_nr, u16, u16);
309    make_getter!(connection_id, u16, u16);
310    make_getter!(wnd_size, u32, u32);
311
312    make_setter!(set_seq_nr, seq_nr, u16);
313    make_setter!(set_ack_nr, ack_nr, u16);
314    make_setter!(set_connection_id, connection_id, u16);
315    make_setter!(set_wnd_size, wnd_size, u32);
316
317    /// Sets Selective ACK field in packet header and adds appropriate data.
318    ///
319    /// The length of the SACK extension is expressed in bytes, which
320    /// must be a multiple of 4 and at least 4.
321    pub fn set_sack(&mut self, bv: Vec<u8>) {
322        // The length of the SACK extension is expressed in bytes, which
323        // must be a multiple of 4 and at least 4.
324        assert!(bv.len() >= 4);
325        assert_eq!(bv.len() % 4, 0);
326
327        let mut index = HEADER_SIZE;
328        let mut extension_type = ExtensionType::from(self.0[1]);
329
330        // Set extension type in header if none is used, otherwise find and update the
331        // "next extension type" marker in the last extension before payload
332        if extension_type == ExtensionType::None {
333            self.0[1] = ExtensionType::SelectiveAck.into();
334        } else {
335            // Skip over all extensions until last, then modify its "next extension type" field and
336            // add the new extension after it.
337
338            // Consume known extensions and skip over unknown ones
339            while index < self.0.len() && extension_type != ExtensionType::None {
340                let len = self.0[index + 1] as usize;
341                // No validity checks needed
342                // ...
343
344                extension_type = ExtensionType::from(self.0[index]);
345
346                // Arrived at last extension
347                if extension_type == ExtensionType::None {
348                    // Mark existence of an additional extension
349                    self.0[index] = ExtensionType::SelectiveAck.into();
350                }
351                index += len + 2;
352            }
353        }
354
355        // Insert the new extension into the packet's data.
356        // The way this is currently done is potentially slower than the alternative of resizing the
357        // underlying Vec, moving the payload forward and then writing the extension in the "new"
358        // place before the payload.
359
360        // Set the type of the following (non-existent) extension
361        self.0.insert(index, ExtensionType::None.into());
362        // Set this extension's length
363        self.0.insert(index + 1, bv.len() as u8);
364        // Write this extension's data
365        for (i, &value) in bv.iter().enumerate() {
366            self.0.insert(index + 2 + i, value);
367        }
368    }
369
370    pub fn len(&self) -> usize {
371        self.0.len()
372    }
373}
374
375impl<'a> TryFrom<&'a [u8]> for Packet {
376    type Err = ParseError;
377
378    /// Decodes a byte slice and construct the equivalent Packet.
379    ///
380    /// Note that this method makes no attempt to guess the payload size, saving
381    /// all except the initial 20 bytes corresponding to the header as payload.
382    /// It's the caller's responsibility to use an appropriately sized buffer.
383    fn try_from(buf: &[u8]) -> Result<Self, Self::Err> {
384        PacketHeader::try_from(buf)
385            .and(check_extensions(buf))
386            .and(Ok(Packet(buf.to_owned())))
387    }
388}
389
390impl Clone for Packet {
391    fn clone(&self) -> Packet {
392        Packet(self.0.clone())
393    }
394}
395
396impl fmt::Debug for Packet {
397    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398        f.debug_struct("Packet")
399            .field("type", &self.get_type())
400            .field("version", &self.get_version())
401            .field("extension", &self.get_extension_type())
402            .field("connection_id", &self.connection_id())
403            .field("timestamp", &self.timestamp())
404            .field("timestamp_difference", &self.timestamp_difference())
405            .field("wnd_size", &self.wnd_size())
406            .field("seq_nr", &self.seq_nr())
407            .field("ack_nr", &self.ack_nr())
408            .finish()
409    }
410}
411
412pub struct ExtensionIterator<'a> {
413    raw_bytes: &'a [u8],
414    next_extension: ExtensionType,
415    index: usize,
416}
417
418impl<'a> ExtensionIterator<'a> {
419    fn new(packet: &'a Packet) -> Self {
420        ExtensionIterator {
421            raw_bytes: packet.as_ref(),
422            next_extension: ExtensionType::from(packet.as_ref()[1]),
423            index: HEADER_SIZE,
424        }
425    }
426}
427
428impl<'a> Iterator for ExtensionIterator<'a> {
429    type Item = Extension<'a>;
430
431    fn next(&mut self) -> Option<Self::Item> {
432        if self.next_extension == ExtensionType::None {
433            None
434        } else if self.index < self.raw_bytes.len() {
435            let len = self.raw_bytes[self.index + 1] as usize;
436            let extension_start = self.index + 2;
437            let extension_end = extension_start + len;
438
439            // Assume extension is valid because the bytes come from a (valid) Packet
440            let extension = Extension {
441                ty: self.next_extension,
442                data: &self.raw_bytes[extension_start..extension_end],
443            };
444
445            self.next_extension = self.raw_bytes[self.index].into();
446            self.index += len + 2;
447
448            Some(extension)
449        } else {
450            None
451        }
452    }
453}
454
455/// Validate correctness of packet extensions, if any, in byte slice
456fn check_extensions(data: &[u8]) -> Result<(), ParseError> {
457    if data.len() < HEADER_SIZE {
458        return Err(ParseError::InvalidPacketLength);
459    }
460
461    let mut index = HEADER_SIZE;
462    let mut extension_type = ExtensionType::from(data[1]);
463
464    if data.len() == HEADER_SIZE && extension_type != ExtensionType::None {
465        return Err(ParseError::InvalidExtensionLength);
466    }
467
468    // Consume known extensions and skip over unknown ones
469    while index < data.len() && extension_type != ExtensionType::None {
470        if data.len() < index + 2 {
471            return Err(ParseError::InvalidPacketLength);
472        }
473        let len = data[index + 1] as usize;
474        let extension_start = index + 2;
475        let extension_end = extension_start + len;
476
477        // Check validity of extension length:
478        // - non-zero,
479        // - multiple of 4,
480        // - does not exceed packet length
481        if len == 0 || len % 4 != 0 || extension_end > data.len() {
482            return Err(ParseError::InvalidExtensionLength);
483        }
484
485        extension_type = ExtensionType::from(data[index]);
486        index += len + 2;
487    }
488    // Check for pending extensions (early exit of previous loop)
489    if extension_type != ExtensionType::None {
490        return Err(ParseError::InvalidPacketLength);
491    }
492
493    Ok(())
494}
495
496#[cfg(test)]
497mod tests {
498    use crate::packet::PacketType::{Data, State};
499    use crate::packet::*;
500    use crate::packet::{check_extensions, PacketHeader};
501    use crate::time::*;
502    use quickcheck::{QuickCheck, TestResult};
503
504    #[test]
505    fn test_packet_decode() {
506        let buf = [
507            0x21, 0x00, 0x41, 0xa8, 0x99, 0x2f, 0xd0, 0x2a, 0x9f, 0x4a, 0x26, 0x21, 0x00, 0x10,
508            0x00, 0x00, 0x3a, 0xf2, 0x6c, 0x79,
509        ];
510        let packet = Packet::try_from(&buf);
511        assert!(packet.is_ok());
512        let packet = packet.unwrap();
513        assert_eq!(packet.get_version(), 1);
514        assert_eq!(packet.get_extension_type(), ExtensionType::None);
515        assert_eq!(packet.get_type(), State);
516        assert_eq!(packet.connection_id(), 16808);
517        assert_eq!(packet.timestamp(), Timestamp(2570047530));
518        assert_eq!(packet.timestamp_difference(), Delay(2672436769));
519        assert_eq!(packet.wnd_size(), 2u32.pow(20));
520        assert_eq!(packet.seq_nr(), 15090);
521        assert_eq!(packet.ack_nr(), 27769);
522        assert_eq!(packet.len(), buf.len());
523        assert!(packet.payload().is_empty());
524    }
525
526    #[test]
527    fn test_decode_packet_with_extension() {
528        let buf = [
529            0x21, 0x01, 0x41, 0xa7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
530            0x05, 0xdc, 0xab, 0x53, 0x3a, 0xf5, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00,
531        ];
532        let packet = Packet::try_from(&buf);
533        assert!(packet.is_ok());
534        let packet = packet.unwrap();
535        assert_eq!(packet.get_version(), 1);
536        assert_eq!(packet.get_extension_type(), ExtensionType::SelectiveAck);
537        assert_eq!(packet.get_type(), State);
538        assert_eq!(packet.connection_id(), 16807);
539        assert_eq!(packet.timestamp(), Timestamp(0));
540        assert_eq!(packet.timestamp_difference(), Delay(0));
541        assert_eq!(packet.wnd_size(), 1500);
542        assert_eq!(packet.seq_nr(), 43859);
543        assert_eq!(packet.ack_nr(), 15093);
544        assert_eq!(packet.len(), buf.len());
545        assert!(packet.payload().is_empty());
546        let extensions: Vec<Extension<'_>> = packet.extensions().collect();
547        assert_eq!(extensions.len(), 1);
548        assert_eq!(extensions[0].ty, ExtensionType::SelectiveAck);
549        assert_eq!(extensions[0].data, &[0, 0, 0, 0]);
550        assert_eq!(extensions[0].len(), extensions[0].data.len());
551        assert_eq!(extensions[0].len(), 4);
552        // Reversible
553        assert_eq!(packet.as_ref(), &buf);
554    }
555
556    #[test]
557    fn test_packet_decode_with_missing_extension() {
558        let buf = [
559            0x21, 0x01, 0x41, 0xa8, 0x99, 0x2f, 0xd0, 0x2a, 0x9f, 0x4a, 0x26, 0x21, 0x00, 0x10,
560            0x00, 0x00, 0x3a, 0xf2, 0x6c, 0x79,
561        ];
562        let packet = Packet::try_from(&buf);
563        assert!(packet.is_err());
564    }
565
566    #[test]
567    fn test_packet_decode_with_malformed_extension() {
568        let buf = [
569            0x21, 0x01, 0x41, 0xa8, 0x99, 0x2f, 0xd0, 0x2a, 0x9f, 0x4a, 0x26, 0x21, 0x00, 0x10,
570            0x00, 0x00, 0x3a, 0xf2, 0x6c, 0x79, 0x00, 0x04, 0x00,
571        ];
572        let packet = Packet::try_from(&buf);
573        assert!(packet.is_err());
574    }
575
576    #[test]
577    fn test_decode_packet_with_unknown_extensions() {
578        let buf = [
579            0x21, 0x01, 0x41, 0xa7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
580            0x05, 0xdc, 0xab, 0x53, 0x3a, 0xf5, 0xff, 0x04, 0x00, 0x00, 0x00,
581            0x00, // Imaginary extension
582            0x00, 0x04, 0x00, 0x00, 0x00, 0x00,
583        ];
584        match Packet::try_from(&buf) {
585            Ok(packet) => {
586                assert_eq!(packet.get_version(), 1);
587                assert_eq!(packet.get_extension_type(), ExtensionType::SelectiveAck);
588                assert_eq!(packet.get_type(), State);
589                assert_eq!(packet.connection_id(), 16807);
590                assert_eq!(packet.timestamp(), Timestamp(0));
591                assert_eq!(packet.timestamp_difference(), Delay(0));
592                assert_eq!(packet.wnd_size(), 1500);
593                assert_eq!(packet.seq_nr(), 43859);
594                assert_eq!(packet.ack_nr(), 15093);
595                assert!(packet.payload().is_empty());
596                // The invalid extension is discarded
597                let extensions: Vec<Extension<'_>> = packet.extensions().collect();
598                assert_eq!(extensions.len(), 2);
599                assert_eq!(extensions[0].ty, ExtensionType::SelectiveAck);
600                assert_eq!(extensions[0].data, &[0, 0, 0, 0]);
601                assert_eq!(extensions[0].len(), extensions[0].data.len());
602                assert_eq!(extensions[0].len(), 4);
603            }
604            Err(ref e) => panic!("{}", e),
605        }
606    }
607
608    #[test]
609    fn test_packet_set_type() {
610        let mut packet = Packet::new();
611        packet.set_type(PacketType::Syn);
612        assert_eq!(packet.get_type(), PacketType::Syn);
613        packet.set_type(PacketType::State);
614        assert_eq!(packet.get_type(), PacketType::State);
615        packet.set_type(PacketType::Fin);
616        assert_eq!(packet.get_type(), PacketType::Fin);
617        packet.set_type(PacketType::Reset);
618        assert_eq!(packet.get_type(), PacketType::Reset);
619        packet.set_type(PacketType::Data);
620        assert_eq!(packet.get_type(), PacketType::Data);
621    }
622
623    #[test]
624    fn test_packet_set_selective_acknowledgment() {
625        let mut packet = Packet::new();
626        packet.set_sack(vec![1, 2, 3, 4]);
627
628        {
629            let extensions: Vec<Extension<'_>> = packet.extensions().collect();
630            assert_eq!(extensions.len(), 1);
631            assert_eq!(extensions[0].ty, ExtensionType::SelectiveAck);
632            assert_eq!(extensions[0].data, &[1, 2, 3, 4]);
633            assert_eq!(extensions[0].len(), extensions[0].data.len());
634            assert_eq!(extensions[0].len(), 4);
635        }
636
637        // Add a second sack
638        packet.set_sack(vec![5, 6, 7, 8, 9, 10, 11, 12]);
639
640        let extensions: Vec<Extension<'_>> = packet.extensions().collect();
641        assert_eq!(extensions.len(), 2);
642        assert_eq!(extensions[0].ty, ExtensionType::SelectiveAck);
643        assert_eq!(extensions[0].data, &[1, 2, 3, 4]);
644        assert_eq!(extensions[0].len(), extensions[0].data.len());
645        assert_eq!(extensions[0].len(), 4);
646        assert_eq!(extensions[1].ty, ExtensionType::SelectiveAck);
647        assert_eq!(extensions[1].data, &[5, 6, 7, 8, 9, 10, 11, 12]);
648        assert_eq!(extensions[1].len(), extensions[1].data.len());
649        assert_eq!(extensions[1].len(), 8);
650    }
651
652    #[test]
653    fn test_packet_encode() {
654        let payload = b"Hello\n".to_vec();
655        let timestamp = Timestamp(15270793);
656        let timestamp_diff = Delay(1707040186);
657        let (connection_id, seq_nr, ack_nr): (u16, u16, u16) = (16808, 15090, 17096);
658        let window_size: u32 = 1048576;
659        let mut packet = Packet::with_payload(&payload[..]);
660        packet.set_type(Data);
661        packet.set_timestamp(timestamp);
662        packet.set_timestamp_difference(timestamp_diff);
663        packet.set_connection_id(connection_id);
664        packet.set_seq_nr(seq_nr);
665        packet.set_ack_nr(ack_nr);
666        packet.set_wnd_size(window_size);
667        let buf = [
668            0x01, 0x00, 0x41, 0xa8, 0x00, 0xe9, 0x03, 0x89, 0x65, 0xbf, 0x5d, 0xba, 0x00, 0x10,
669            0x00, 0x00, 0x3a, 0xf2, 0x42, 0xc8, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x0a,
670        ];
671
672        assert_eq!(packet.len(), buf.len());
673        assert_eq!(packet.len(), HEADER_SIZE + payload.len());
674        assert_eq!(&packet.payload(), &payload.as_slice());
675        assert_eq!(packet.get_version(), 1);
676        assert_eq!(packet.get_extension_type(), ExtensionType::None);
677        assert_eq!(packet.get_type(), Data);
678        assert_eq!(packet.connection_id(), connection_id);
679        assert_eq!(packet.seq_nr(), seq_nr);
680        assert_eq!(packet.ack_nr(), ack_nr);
681        assert_eq!(packet.wnd_size(), window_size);
682        assert_eq!(packet.timestamp(), timestamp);
683        assert_eq!(packet.timestamp_difference(), timestamp_diff);
684        assert_eq!(packet.as_ref(), buf);
685    }
686
687    #[test]
688    fn test_packet_encode_with_payload() {
689        let payload = b"Hello\n".to_vec();
690        let timestamp = Timestamp(15270793);
691        let timestamp_diff = Delay(1707040186);
692        let (connection_id, seq_nr, ack_nr): (u16, u16, u16) = (16808, 15090, 17096);
693        let window_size: u32 = 1048576;
694        let mut packet = Packet::with_payload(&payload[..]);
695        packet.set_timestamp(timestamp);
696        packet.set_timestamp_difference(timestamp_diff);
697        packet.set_connection_id(connection_id);
698        packet.set_seq_nr(seq_nr);
699        packet.set_ack_nr(ack_nr);
700        packet.set_wnd_size(window_size);
701        let buf = [
702            0x01, 0x00, 0x41, 0xa8, 0x00, 0xe9, 0x03, 0x89, 0x65, 0xbf, 0x5d, 0xba, 0x00, 0x10,
703            0x00, 0x00, 0x3a, 0xf2, 0x42, 0xc8, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x0a,
704        ];
705
706        assert_eq!(packet.len(), buf.len());
707        assert_eq!(packet.len(), HEADER_SIZE + payload.len());
708        assert_eq!(&packet.payload(), &payload.as_slice());
709        assert_eq!(packet.get_version(), 1);
710        assert_eq!(packet.get_type(), Data);
711        assert_eq!(packet.get_extension_type(), ExtensionType::None);
712        assert_eq!(packet.connection_id(), connection_id);
713        assert_eq!(packet.seq_nr(), seq_nr);
714        assert_eq!(packet.ack_nr(), ack_nr);
715        assert_eq!(packet.wnd_size(), window_size);
716        assert_eq!(packet.timestamp(), timestamp);
717        assert_eq!(packet.timestamp_difference(), timestamp_diff);
718        assert_eq!(packet.as_ref(), buf);
719    }
720
721    #[test]
722    fn test_reversible() {
723        let buf = [
724            0x01, 0x00, 0x41, 0xa8, 0x00, 0xe9, 0x03, 0x89, 0x65, 0xbf, 0x5d, 0xba, 0x00, 0x10,
725            0x00, 0x00, 0x3a, 0xf2, 0x42, 0xc8, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x0a,
726        ];
727        assert_eq!(&Packet::try_from(&buf).unwrap().as_ref(), &buf);
728    }
729
730    #[test]
731    fn test_decode_evil_sequence() {
732        let buf = [
733            0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
734        ];
735        let packet = Packet::try_from(&buf);
736        assert!(packet.is_err());
737    }
738
739    #[test]
740    fn test_decode_empty_packet() {
741        let packet = Packet::try_from(&[]);
742        assert!(packet.is_err());
743    }
744
745    // Use quickcheck to simulate a malicious attacker sending malformed packets
746    #[test]
747    fn quicktest() {
748        fn run(x: Vec<u8>) -> TestResult {
749            let packet = Packet::try_from(&x);
750
751            if PacketHeader::try_from(&x)
752                .and(check_extensions(&x))
753                .is_err()
754            {
755                TestResult::from_bool(packet.is_err())
756            } else if let Ok(packet) = packet {
757                TestResult::from_bool(&packet.as_ref() == &x.as_slice())
758            } else {
759                TestResult::from_bool(false)
760            }
761        }
762        QuickCheck::new()
763            .tests(10000)
764            .quickcheck(run as fn(Vec<u8>) -> TestResult)
765    }
766
767    #[test]
768    fn extension_iterator() {
769        let buf = [
770            0x21, 0x00, 0x41, 0xa8, 0x99, 0x2f, 0xd0, 0x2a, 0x9f, 0x4a, 0x26, 0x21, 0x00, 0x10,
771            0x00, 0x00, 0x3a, 0xf2, 0x6c, 0x79,
772        ];
773        let packet = Packet::try_from(&buf).unwrap();
774        assert_eq!(packet.extensions().count(), 0);
775
776        let buf = [
777            0x21, 0x01, 0x41, 0xa7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
778            0x05, 0xdc, 0xab, 0x53, 0x3a, 0xf5, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00,
779        ];
780        let packet = Packet::try_from(&buf).unwrap();
781        let extensions: Vec<Extension<'_>> = packet.extensions().collect();
782        assert_eq!(extensions.len(), 1);
783        assert_eq!(extensions[0].ty, ExtensionType::SelectiveAck);
784        assert_eq!(extensions[0].data, &[0, 0, 0, 0]);
785        assert_eq!(extensions[0].len(), extensions[0].data.len());
786        assert_eq!(extensions[0].len(), 4);
787
788        let buf = [
789            0x21, 0x01, 0x41, 0xa7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
790            0x05, 0xdc, 0xab, 0x53, 0x3a, 0xf5, 0xff, 0x04, 0x01, 0x02, 0x03,
791            0x04, // Imaginary extension
792            0x00, 0x04, 0x05, 0x06, 0x07, 0x08,
793        ];
794
795        let packet = Packet::try_from(&buf).unwrap();
796        let extensions: Vec<Extension<'_>> = packet.extensions().collect();
797        assert_eq!(extensions.len(), 2);
798        assert_eq!(extensions[0].ty, ExtensionType::SelectiveAck);
799        assert_eq!(extensions[0].data, &[1, 2, 3, 4]);
800        assert_eq!(extensions[0].len(), extensions[0].data.len());
801        assert_eq!(extensions[0].len(), 4);
802        assert_eq!(extensions[1].ty, ExtensionType::Unknown(0xff));
803        assert_eq!(extensions[1].data, &[5, 6, 7, 8]);
804        assert_eq!(extensions[1].len(), extensions[1].data.len());
805        assert_eq!(extensions[1].len(), 4);
806    }
807}
808
809#[cfg(all(feature = "unstable", test))]
810mod bench {
811    extern crate test;
812
813    use self::test::Bencher;
814    use packet::{Packet, TryFrom};
815
816    #[bench]
817    fn bench_decode(b: &mut Bencher) {
818        let buf = [
819            0x21, 0x00, 0x41, 0xa8, 0x99, 0x2f, 0xd0, 0x2a, 0x9f, 0x4a, 0x26, 0x21, 0x00, 0x10,
820            0x00, 0x00, 0x3a, 0xf2, 0x6c, 0x79, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
821            0x09, 0x0a,
822        ];
823        b.iter(|| {
824            let _ = test::black_box(Packet::try_from(&buf));
825        });
826    }
827
828    #[bench]
829    fn bench_encode(b: &mut Bencher) {
830        let packet = Packet::with_payload(&[1, 2, 3, 4, 5, 6]);
831        b.iter(|| {
832            let _ = test::black_box(packet.as_ref());
833        });
834    }
835
836    #[bench]
837    fn bench_extract_payload(b: &mut Bencher) {
838        let buf = [
839            0x21, 0x01, 0x41, 0xa7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
840            0x05, 0xdc, 0xab, 0x53, 0x3a, 0xf5, 0xff, 0x04, 0x01, 0x02, 0x03,
841            0x04, // First extension
842            0x00, 0x04, 0x05, 0x06, 0x07, 0x08, // Second extension, followed by data
843            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
844            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
845            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
846            0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
847        ];
848        let packet = Packet::try_from(&buf).unwrap();
849        b.iter(|| {
850            let _ = test::black_box(packet.payload());
851        });
852    }
853
854    #[bench]
855    fn bench_extract_extensions(b: &mut Bencher) {
856        let buf = [
857            0x21, 0x01, 0x41, 0xa7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
858            0x05, 0xdc, 0xab, 0x53, 0x3a, 0xf5, 0xff, 0x04, 0x01, 0x02, 0x03,
859            0x04, // First extension
860            0x00, 0x04, 0x05, 0x06, 0x07, 0x08, // Second extension, followed by data
861            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
862            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
863            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
864            0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
865        ];
866        let packet = Packet::try_from(&buf).unwrap();
867        b.iter(|| {
868            let _ = test::black_box(packet.extensions().count());
869        });
870    }
871}