Skip to main content

wire/
packet.rs

1//! Packet decoding and section framing.
2
3use crate::error::{DecodeError, EncodeError, LimitKind, SectionFramingError, WireResult};
4use crate::header::{PacketFlags, PacketHeader, HEADER_SIZE, MAGIC, VERSION};
5use crate::limits::Limits;
6
7/// Section tags for version 0.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[non_exhaustive]
10#[repr(u8)]
11pub enum SectionTag {
12    EntityCreate = 1,
13    EntityDestroy = 2,
14    EntityUpdate = 3,
15}
16
17impl SectionTag {
18    /// Parses a section tag from a raw byte.
19    pub fn parse(tag: u8) -> Result<Self, DecodeError> {
20        match tag {
21            1 => Ok(Self::EntityCreate),
22            2 => Ok(Self::EntityDestroy),
23            3 => Ok(Self::EntityUpdate),
24            _ => Err(DecodeError::UnknownSectionTag { tag }),
25        }
26    }
27}
28
29/// A section within a wire packet.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct WireSection<'a> {
32    pub tag: SectionTag,
33    pub body: &'a [u8],
34}
35
36/// A decoded wire packet.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct WirePacket<'a> {
39    pub header: PacketHeader,
40    pub sections: Vec<WireSection<'a>>,
41}
42
43/// Decodes a wire packet into header + section slices.
44pub fn decode_packet<'a>(buf: &'a [u8], limits: &Limits) -> WireResult<WirePacket<'a>> {
45    if buf.len() < HEADER_SIZE {
46        return Err(DecodeError::PacketTooSmall {
47            actual: buf.len(),
48            required: HEADER_SIZE,
49        });
50    }
51    if buf.len() > limits.max_packet_bytes {
52        return Err(DecodeError::LimitsExceeded {
53            kind: LimitKind::PacketBytes,
54            limit: limits.max_packet_bytes,
55            actual: buf.len(),
56        });
57    }
58
59    let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
60    if magic != MAGIC {
61        return Err(DecodeError::InvalidMagic { found: magic });
62    }
63
64    let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
65    if version != VERSION {
66        return Err(DecodeError::UnsupportedVersion { found: version });
67    }
68
69    let flags_raw = u16::from_le_bytes(buf[6..8].try_into().unwrap());
70    let flags = PacketFlags::from_raw(flags_raw);
71    if !flags.is_valid_v0() {
72        return Err(DecodeError::InvalidFlags { flags: flags_raw });
73    }
74
75    let schema_hash = u64::from_le_bytes(buf[8..16].try_into().unwrap());
76    let tick = u32::from_le_bytes(buf[16..20].try_into().unwrap());
77    let baseline_tick = u32::from_le_bytes(buf[20..24].try_into().unwrap());
78    let payload_len = u32::from_le_bytes(buf[24..28].try_into().unwrap());
79
80    if flags.is_full_snapshot() && baseline_tick != 0 {
81        return Err(DecodeError::InvalidBaselineTick {
82            baseline_tick,
83            flags: flags_raw,
84        });
85    }
86    if flags.is_delta_snapshot() && baseline_tick == 0 {
87        return Err(DecodeError::InvalidBaselineTick {
88            baseline_tick,
89            flags: flags_raw,
90        });
91    }
92
93    let actual_payload_len = buf.len() - HEADER_SIZE;
94    if payload_len as usize != actual_payload_len {
95        return Err(DecodeError::PayloadLengthMismatch {
96            header_len: payload_len,
97            actual_len: actual_payload_len,
98        });
99    }
100
101    let header = PacketHeader {
102        version,
103        flags,
104        schema_hash,
105        tick,
106        baseline_tick,
107        payload_len,
108    };
109
110    let payload = &buf[HEADER_SIZE..];
111    let mut offset = 0usize;
112    let mut sections = Vec::new();
113
114    while offset < payload.len() {
115        if sections.len() >= limits.max_sections {
116            return Err(DecodeError::LimitsExceeded {
117                kind: LimitKind::SectionCount,
118                limit: limits.max_sections,
119                actual: sections.len() + 1,
120            });
121        }
122
123        let tag = payload[offset];
124        offset += 1;
125        let (len, new_offset) = read_varu32(payload, offset)?;
126        offset = new_offset;
127        let len_usize = usize::try_from(len).unwrap();
128
129        if len_usize > limits.max_section_len {
130            return Err(DecodeError::LimitsExceeded {
131                kind: LimitKind::SectionLength,
132                limit: limits.max_section_len,
133                actual: len_usize,
134            });
135        }
136        if offset + len_usize > payload.len() {
137            return Err(DecodeError::SectionFraming(
138                SectionFramingError::Truncated {
139                    needed: offset + len_usize,
140                    available: payload.len(),
141                },
142            ));
143        }
144
145        let tag = SectionTag::parse(tag)?;
146        let body = &payload[offset..offset + len_usize];
147        sections.push(WireSection { tag, body });
148        offset += len_usize;
149    }
150
151    Ok(WirePacket { header, sections })
152}
153
154/// Encodes a packet header into the provided output buffer.
155pub fn encode_header(header: &PacketHeader, out: &mut [u8]) -> Result<usize, EncodeError> {
156    if out.len() < HEADER_SIZE {
157        return Err(EncodeError::BufferTooSmall {
158            needed: HEADER_SIZE,
159            available: out.len(),
160        });
161    }
162
163    out[0..4].copy_from_slice(&MAGIC.to_le_bytes());
164    out[4..6].copy_from_slice(&header.version.to_le_bytes());
165    out[6..8].copy_from_slice(&header.flags.raw().to_le_bytes());
166    out[8..16].copy_from_slice(&header.schema_hash.to_le_bytes());
167    out[16..20].copy_from_slice(&header.tick.to_le_bytes());
168    out[20..24].copy_from_slice(&header.baseline_tick.to_le_bytes());
169    out[24..28].copy_from_slice(&header.payload_len.to_le_bytes());
170
171    Ok(HEADER_SIZE)
172}
173
174/// Encodes a single section into the provided output buffer.
175pub fn encode_section(tag: SectionTag, body: &[u8], out: &mut [u8]) -> Result<usize, EncodeError> {
176    let len_u32 = u32::try_from(body.len())
177        .map_err(|_| EncodeError::LengthOverflow { length: body.len() })?;
178    let len_bytes = varu32_len(len_u32);
179    let needed = 1 + len_bytes + body.len();
180    if out.len() < needed {
181        return Err(EncodeError::BufferTooSmall {
182            needed,
183            available: out.len(),
184        });
185    }
186
187    out[0] = tag as u8;
188    let mut offset = 1;
189    offset += write_varu32(len_u32, &mut out[offset..]);
190    out[offset..offset + body.len()].copy_from_slice(body);
191    Ok(needed)
192}
193
194fn read_varu32(buf: &[u8], mut offset: usize) -> Result<(u32, usize), DecodeError> {
195    let mut value = 0u32;
196    let mut shift = 0u32;
197    for _ in 0..5 {
198        if offset >= buf.len() {
199            return Err(DecodeError::SectionFraming(
200                SectionFramingError::Truncated {
201                    needed: offset + 1,
202                    available: buf.len(),
203                },
204            ));
205        }
206        let byte = buf[offset];
207        offset += 1;
208        value |= u32::from(byte & 0x7F) << shift;
209        if byte & 0x80 == 0 {
210            return Ok((value, offset));
211        }
212        shift += 7;
213    }
214    Err(DecodeError::SectionFraming(
215        SectionFramingError::InvalidVarint,
216    ))
217}
218
219fn write_varu32(mut value: u32, out: &mut [u8]) -> usize {
220    let mut offset = 0;
221    loop {
222        let mut byte = (value & 0x7F) as u8;
223        value >>= 7;
224        if value != 0 {
225            byte |= 0x80;
226        }
227        out[offset] = byte;
228        offset += 1;
229        if value == 0 {
230            break;
231        }
232    }
233    offset
234}
235
236fn varu32_len(mut value: u32) -> usize {
237    let mut len = 1;
238    while value >= 0x80 {
239        value >>= 7;
240        len += 1;
241    }
242    len
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn encode_header_roundtrip_empty_payload() {
251        let header = PacketHeader::full_snapshot(0xABCD, 42, 0);
252        let mut buf = [0u8; HEADER_SIZE];
253        let written = encode_header(&header, &mut buf).unwrap();
254        assert_eq!(written, HEADER_SIZE);
255
256        let limits = Limits::for_testing();
257        let packet = decode_packet(&buf, &limits).unwrap();
258        assert_eq!(packet.header, header);
259        assert!(packet.sections.is_empty());
260    }
261
262    #[test]
263    fn decode_rejects_invalid_magic() {
264        let mut buf = [0u8; HEADER_SIZE];
265        buf[0..4].copy_from_slice(&0xDEAD_BEEFu32.to_le_bytes());
266        buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
267        buf[6..8].copy_from_slice(&PacketFlags::full_snapshot().raw().to_le_bytes());
268        let limits = Limits::for_testing();
269        let err = decode_packet(&buf, &limits).unwrap_err();
270        assert!(matches!(err, DecodeError::InvalidMagic { .. }));
271    }
272
273    #[test]
274    fn decode_payload_length_mismatch() {
275        let header = PacketHeader::full_snapshot(0, 1, 10);
276        let mut buf = [0u8; HEADER_SIZE];
277        encode_header(&header, &mut buf).unwrap();
278        let limits = Limits::for_testing();
279        let err = decode_packet(&buf, &limits).unwrap_err();
280        assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
281    }
282
283    #[test]
284    fn decode_payload_length_mismatch_with_extra_bytes() {
285        let header = PacketHeader::full_snapshot(0, 1, 0);
286        let mut buf = vec![0u8; HEADER_SIZE + 4];
287        encode_header(&header, &mut buf).unwrap();
288        let limits = Limits::for_testing();
289        let err = decode_packet(&buf, &limits).unwrap_err();
290        assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
291    }
292
293    #[test]
294    fn decode_rejects_invalid_baseline_full() {
295        let header = PacketHeader {
296            version: VERSION,
297            flags: PacketFlags::full_snapshot(),
298            schema_hash: 0,
299            tick: 1,
300            baseline_tick: 1,
301            payload_len: 0,
302        };
303        let mut buf = [0u8; HEADER_SIZE];
304        encode_header(&header, &mut buf).unwrap();
305        let limits = Limits::for_testing();
306        let err = decode_packet(&buf, &limits).unwrap_err();
307        assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
308    }
309
310    #[test]
311    fn decode_rejects_invalid_baseline_delta() {
312        let header = PacketHeader {
313            version: VERSION,
314            flags: PacketFlags::delta_snapshot(),
315            schema_hash: 0,
316            tick: 1,
317            baseline_tick: 0,
318            payload_len: 0,
319        };
320        let mut buf = [0u8; HEADER_SIZE];
321        encode_header(&header, &mut buf).unwrap();
322        let limits = Limits::for_testing();
323        let err = decode_packet(&buf, &limits).unwrap_err();
324        assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
325    }
326
327    #[test]
328    fn decode_rejects_invalid_flags_reserved_bits() {
329        let mut buf = [0u8; HEADER_SIZE];
330        buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
331        buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
332        let flags = PacketFlags::from_raw(0b101).raw(); // reserved bit set
333        buf[6..8].copy_from_slice(&flags.to_le_bytes());
334        let limits = Limits::for_testing();
335        let err = decode_packet(&buf, &limits).unwrap_err();
336        assert!(matches!(err, DecodeError::InvalidFlags { .. }));
337    }
338
339    #[test]
340    fn decode_rejects_invalid_varint_len() {
341        let header = PacketHeader::full_snapshot(0, 1, 6);
342        let mut buf = vec![0u8; HEADER_SIZE + 6];
343        encode_header(&header, &mut buf).unwrap();
344        let payload = &mut buf[HEADER_SIZE..];
345        payload[0] = SectionTag::EntityCreate as u8;
346        payload[1..6].copy_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF]);
347        let limits = Limits::for_testing();
348        let err = decode_packet(&buf, &limits).unwrap_err();
349        assert!(matches!(
350            err,
351            DecodeError::SectionFraming(SectionFramingError::InvalidVarint)
352        ));
353    }
354
355    #[test]
356    fn decode_sections() {
357        let mut payload = [0u8; 16];
358        let body = [1u8, 2, 3];
359        let section_len = encode_section(SectionTag::EntityUpdate, &body, &mut payload).unwrap();
360
361        let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
362        let mut buf = vec![0u8; HEADER_SIZE + section_len];
363        encode_header(&header, &mut buf).unwrap();
364        buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
365
366        let limits = Limits::for_testing();
367        let packet = decode_packet(&buf, &limits).unwrap();
368        assert_eq!(packet.sections.len(), 1);
369        assert_eq!(packet.sections[0].tag, SectionTag::EntityUpdate);
370        assert_eq!(packet.sections[0].body, &body);
371    }
372
373    #[test]
374    fn decode_enforces_section_limits() {
375        let mut payload = [0u8; 8];
376        let body = [0u8; 5];
377        let section_len = encode_section(SectionTag::EntityCreate, &body, &mut payload).unwrap();
378
379        let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
380        let mut buf = vec![0u8; HEADER_SIZE + section_len];
381        encode_header(&header, &mut buf).unwrap();
382        buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
383
384        let limits = Limits {
385            max_packet_bytes: 4096,
386            max_sections: 1,
387            max_section_len: 4,
388        };
389        let err = decode_packet(&buf, &limits).unwrap_err();
390        assert!(matches!(
391            err,
392            DecodeError::LimitsExceeded {
393                kind: LimitKind::SectionLength,
394                ..
395            }
396        ));
397    }
398}