Skip to main content

raknet_rust/protocol/
datagram.rs

1use bytes::{Buf, BufMut};
2
3use crate::error::{DecodeError, EncodeError};
4
5use super::ack::AckNackPayload;
6use super::codec::RaknetCodec;
7use super::constants::{DatagramFlags, RAKNET_DATAGRAM_HEADER_SIZE};
8use super::frame::Frame;
9use super::sequence24::Sequence24;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12enum DatagramKind {
13    Data,
14    Ack,
15    Nack,
16}
17
18#[derive(Debug, Clone)]
19pub enum DatagramPayload {
20    Frames(Vec<Frame>),
21    Ack(AckNackPayload),
22    Nack(AckNackPayload),
23}
24
25impl DatagramPayload {
26    fn kind(&self) -> DatagramKind {
27        match self {
28            Self::Frames(_) => DatagramKind::Data,
29            Self::Ack(_) => DatagramKind::Ack,
30            Self::Nack(_) => DatagramKind::Nack,
31        }
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct DatagramHeader {
37    pub flags: DatagramFlags,
38    pub sequence: Sequence24,
39}
40
41impl RaknetCodec for DatagramHeader {
42    fn encode_raknet(&self, dst: &mut impl BufMut) -> Result<(), EncodeError> {
43        self.flags.bits().encode_raknet(dst)?;
44        self.sequence.encode_raknet(dst)
45    }
46
47    fn decode_raknet(src: &mut impl Buf) -> Result<Self, DecodeError> {
48        if src.remaining() < 4 {
49            return Err(DecodeError::UnexpectedEof);
50        }
51
52        let raw_flags = u8::decode_raknet(src)?;
53        let (flags, kind) = decode_datagram_flags(raw_flags)?;
54        if kind != DatagramKind::Data {
55            return Err(DecodeError::InvalidDatagramFlags(raw_flags));
56        }
57        let sequence = Sequence24::decode_raknet(src)?;
58
59        Ok(Self { flags, sequence })
60    }
61}
62
63#[derive(Debug, Clone)]
64pub struct Datagram {
65    pub header: DatagramHeader,
66    pub payload: DatagramPayload,
67}
68
69impl Datagram {
70    pub fn encoded_size(&self) -> usize {
71        match &self.payload {
72            DatagramPayload::Frames(frames) => {
73                RAKNET_DATAGRAM_HEADER_SIZE + frames.iter().map(Frame::encoded_size).sum::<usize>()
74            }
75            DatagramPayload::Ack(payload) | DatagramPayload::Nack(payload) => {
76                1 + payload.encoded_size()
77            }
78        }
79    }
80
81    pub fn encode(&self, dst: &mut impl BufMut) -> Result<(), EncodeError> {
82        validate_flags_for_payload(self.header.flags, &self.payload)?;
83
84        match &self.payload {
85            DatagramPayload::Frames(frames) => {
86                self.header.encode_raknet(dst)?;
87                for frame in frames {
88                    frame.encode_raknet(dst)?;
89                }
90            }
91            DatagramPayload::Ack(payload) | DatagramPayload::Nack(payload) => {
92                self.header.flags.bits().encode_raknet(dst)?;
93                payload.encode_raknet(dst)?;
94            }
95        }
96        Ok(())
97    }
98
99    pub fn decode(src: &mut impl Buf) -> Result<Self, DecodeError> {
100        if !src.has_remaining() {
101            return Err(DecodeError::UnexpectedEof);
102        }
103
104        let raw_flags = src.get_u8();
105        let (flags, kind) = decode_datagram_flags(raw_flags)?;
106
107        match kind {
108            DatagramKind::Ack => Ok(Self {
109                header: DatagramHeader {
110                    flags,
111                    sequence: Sequence24::new(0),
112                },
113                payload: DatagramPayload::Ack(AckNackPayload::decode_raknet(src)?),
114            }),
115            DatagramKind::Nack => Ok(Self {
116                header: DatagramHeader {
117                    flags,
118                    sequence: Sequence24::new(0),
119                },
120                payload: DatagramPayload::Nack(AckNackPayload::decode_raknet(src)?),
121            }),
122            DatagramKind::Data => {
123                let sequence = Sequence24::decode_raknet(src)?;
124                let header = DatagramHeader { flags, sequence };
125
126                let mut frames = Vec::new();
127                while src.has_remaining() {
128                    frames.push(Frame::decode_raknet(src)?);
129                }
130
131                Ok(Self {
132                    header,
133                    payload: DatagramPayload::Frames(frames),
134                })
135            }
136        }
137    }
138}
139
140fn decode_datagram_flags(raw_flags: u8) -> Result<(DatagramFlags, DatagramKind), DecodeError> {
141    let Some(flags) = DatagramFlags::from_bits(raw_flags) else {
142        return Err(DecodeError::InvalidDatagramFlags(raw_flags));
143    };
144
145    if !flags.contains(DatagramFlags::VALID) {
146        return Err(DecodeError::InvalidDatagramFlags(raw_flags));
147    }
148
149    let has_ack = flags.contains(DatagramFlags::ACK);
150    let has_nack = flags.contains(DatagramFlags::NACK);
151    if has_ack && has_nack {
152        return Err(DecodeError::InvalidDatagramFlags(raw_flags));
153    }
154
155    if has_ack || has_nack {
156        let control_extras = DatagramFlags::PACKET_PAIR
157            | DatagramFlags::CONTINUOUS_SEND
158            | DatagramFlags::HAS_B_AND_AS;
159        if flags.intersects(control_extras) {
160            return Err(DecodeError::InvalidDatagramFlags(raw_flags));
161        }
162
163        return Ok((
164            flags,
165            if has_ack {
166                DatagramKind::Ack
167            } else {
168                DatagramKind::Nack
169            },
170        ));
171    }
172
173    Ok((flags, DatagramKind::Data))
174}
175
176fn validate_flags_for_payload(
177    flags: DatagramFlags,
178    payload: &DatagramPayload,
179) -> Result<(), EncodeError> {
180    let raw_flags = flags.bits();
181    let (_, decoded_kind) = decode_datagram_flags(raw_flags)
182        .map_err(|_| EncodeError::InvalidDatagramFlags(raw_flags))?;
183
184    if decoded_kind != payload.kind() {
185        return Err(EncodeError::InvalidDatagramFlags(raw_flags));
186    }
187
188    Ok(())
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::protocol::ack::SequenceRange;
195
196    fn sample_ack_payload() -> AckNackPayload {
197        AckNackPayload {
198            ranges: vec![SequenceRange {
199                start: Sequence24::new(7),
200                end: Sequence24::new(7),
201            }],
202        }
203    }
204
205    #[test]
206    fn decode_rejects_unknown_datagram_bits() {
207        let mut src = &b"\x83\0\0\0"[..];
208        let err = Datagram::decode(&mut src).expect_err("unknown bit must be rejected");
209        assert!(matches!(err, DecodeError::InvalidDatagramFlags(0x83)));
210    }
211
212    #[test]
213    fn decode_rejects_ack_without_valid_flag() {
214        let mut src = &b"\x40\0\0"[..];
215        let err = Datagram::decode(&mut src).expect_err("ack without valid bit must be rejected");
216        assert!(matches!(err, DecodeError::InvalidDatagramFlags(0x40)));
217    }
218
219    #[test]
220    fn decode_rejects_control_with_data_only_bits() {
221        let mut src = &b"\xC8\0\0"[..];
222        let err =
223            Datagram::decode(&mut src).expect_err("control datagram must not carry data-only bits");
224        assert!(matches!(err, DecodeError::InvalidDatagramFlags(0xC8)));
225    }
226
227    #[test]
228    fn decode_accepts_valid_ack_flags() {
229        let payload = sample_ack_payload();
230        let mut encoded = Vec::new();
231        payload
232            .encode_raknet(&mut encoded)
233            .expect("ack payload encode");
234
235        let mut src = vec![0xC0];
236        src.extend(encoded);
237        let decoded = Datagram::decode(&mut src.as_slice()).expect("ack datagram should decode");
238
239        match decoded.payload {
240            DatagramPayload::Ack(decoded_payload) => assert_eq!(decoded_payload, payload),
241            other => panic!("unexpected payload: {other:?}"),
242        }
243    }
244
245    #[test]
246    fn encode_rejects_payload_and_flag_mismatch() {
247        let datagram = Datagram {
248            header: DatagramHeader {
249                flags: DatagramFlags::VALID | DatagramFlags::ACK,
250                sequence: Sequence24::new(12),
251            },
252            payload: DatagramPayload::Frames(Vec::new()),
253        };
254
255        let mut out = Vec::new();
256        let err = datagram
257            .encode(&mut out)
258            .expect_err("invalid payload/flags mismatch must fail");
259        assert!(matches!(err, EncodeError::InvalidDatagramFlags(0xC0)));
260    }
261}