ymmp/
packet.rs

1use std::convert::{TryFrom, TryInto};
2use std::string::FromUtf8Error;
3
4use crate::header::{Header, HeaderOctets, HEADER_LENGTH};
5use crate::{Octet, Octets};
6
7#[derive(Debug, thiserror::Error)]
8pub enum Error {
9    #[error("Too short packet found.")]
10    PacketTooShort,
11}
12
13#[derive(Debug, PartialEq)]
14pub struct Packet {
15    header: Header,
16    message: Vec<Octet>,
17}
18
19impl Packet {
20    pub fn new(message: Vec<Octet>) -> Self {
21        Self {
22            header: Header::default(),
23            message,
24        }
25    }
26
27    pub fn from_raw_parts(header: &HeaderOctets, message: &Octets) -> Self {
28        Self {
29            header: Header::from(header),
30            message: message.to_vec(),
31        }
32    }
33
34    pub fn to_octets_vec(&self) -> Vec<Octet> {
35        let mut vec: Vec<u8> = self.header.to_octets_vec();
36
37        vec.extend_from_slice(&(self.message.len() as u64).to_le_bytes());
38        vec.extend_from_slice(self.message.as_slice());
39
40        vec
41    }
42
43    pub fn message(&self) -> &[Octet] {
44        &self.message
45    }
46}
47
48impl Default for Packet {
49    fn default() -> Self {
50        Self::new(Vec::new())
51    }
52}
53
54impl TryFrom<&Octets> for Packet {
55    type Error = Error;
56
57    fn try_from(value: &Octets) -> Result<Self, Self::Error> {
58        if value.len() < HEADER_LENGTH + 8 {
59            Err(Error::PacketTooShort)
60        } else {
61            let mut header = HeaderOctets::default();
62            header.copy_from_slice(&value[0..HEADER_LENGTH]);
63
64            let offset = HEADER_LENGTH + 8;
65            let mut length: [Octet; 8] = [0; 8];
66            length.copy_from_slice(&value[HEADER_LENGTH..offset]);
67
68            let length = u64::from_le_bytes(length);
69            let rest = &value[offset..];
70
71            // TODO: Support Fragmentation
72            assert_eq!(length as usize, rest.len());
73
74            Ok(Packet::from_raw_parts(&header, rest))
75        }
76    }
77}
78
79impl TryInto<String> for Packet {
80    type Error = FromUtf8Error;
81
82    fn try_into(self) -> Result<String, Self::Error> {
83        String::from_utf8(self.message)
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use crate::packet::Packet;
90    use crate::Octets;
91    use std::convert::TryFrom;
92
93    #[test]
94    fn default() {
95        let packet = Packet::default();
96        let header_bytes = packet.header.to_octets_vec();
97        let header_len = header_bytes.len();
98        let bytes = packet.to_octets_vec();
99
100        assert_eq!(header_len + 8, bytes.len());
101        assert_eq!(&header_bytes, &bytes[..header_len])
102    }
103
104    #[test]
105    fn with_message() {
106        let message = vec![b'f', b'o', b'o'];
107        let packet = Packet::new(message.clone());
108        let header_bytes = packet.header.to_octets_vec();
109        let header_len = header_bytes.len();
110        let bytes = packet.to_octets_vec();
111
112        assert_eq!(header_len + 8 + message.len(), bytes.len());
113        assert_eq!(&header_bytes, &bytes[..header_len]);
114        assert_eq!(&message, &bytes[header_len + 8..]);
115    }
116
117    #[test]
118    fn try_from_header_only() {
119        let octets = vec![
120            b'Y', b'M', b'M', b'P', b'v', b'0', b'.', b'1', 0, 0, 0, 0, 0, 0, 0, 0,
121        ];
122
123        let octets: &Octets = octets.as_slice();
124        let packet = Packet::try_from(octets).expect("failed");
125
126        assert_eq!(0, packet.message.len())
127    }
128}