raknet-rust 0.2.0

Asynchronous, high-performance RakNet transport library for Rust.
Documentation
use bytes::{Buf, BufMut};

use crate::error::{DecodeError, EncodeError};

use super::ack::AckNackPayload;
use super::codec::RaknetCodec;
use super::constants::{DatagramFlags, RAKNET_DATAGRAM_HEADER_SIZE};
use super::frame::Frame;
use super::sequence24::Sequence24;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DatagramKind {
    Data,
    Ack,
    Nack,
}

#[derive(Debug, Clone)]
pub enum DatagramPayload {
    Frames(Vec<Frame>),
    Ack(AckNackPayload),
    Nack(AckNackPayload),
}

impl DatagramPayload {
    fn kind(&self) -> DatagramKind {
        match self {
            Self::Frames(_) => DatagramKind::Data,
            Self::Ack(_) => DatagramKind::Ack,
            Self::Nack(_) => DatagramKind::Nack,
        }
    }
}

#[derive(Debug, Clone)]
pub struct DatagramHeader {
    pub flags: DatagramFlags,
    pub sequence: Sequence24,
}

impl RaknetCodec for DatagramHeader {
    fn encode_raknet(&self, dst: &mut impl BufMut) -> Result<(), EncodeError> {
        self.flags.bits().encode_raknet(dst)?;
        self.sequence.encode_raknet(dst)
    }

    fn decode_raknet(src: &mut impl Buf) -> Result<Self, DecodeError> {
        if src.remaining() < 4 {
            return Err(DecodeError::UnexpectedEof);
        }

        let raw_flags = u8::decode_raknet(src)?;
        let (flags, kind) = decode_datagram_flags(raw_flags)?;
        if kind != DatagramKind::Data {
            return Err(DecodeError::InvalidDatagramFlags(raw_flags));
        }
        let sequence = Sequence24::decode_raknet(src)?;

        Ok(Self { flags, sequence })
    }
}

#[derive(Debug, Clone)]
pub struct Datagram {
    pub header: DatagramHeader,
    pub payload: DatagramPayload,
}

impl Datagram {
    pub fn encoded_size(&self) -> usize {
        match &self.payload {
            DatagramPayload::Frames(frames) => {
                RAKNET_DATAGRAM_HEADER_SIZE + frames.iter().map(Frame::encoded_size).sum::<usize>()
            }
            DatagramPayload::Ack(payload) | DatagramPayload::Nack(payload) => {
                1 + payload.encoded_size()
            }
        }
    }

    pub fn encode(&self, dst: &mut impl BufMut) -> Result<(), EncodeError> {
        validate_flags_for_payload(self.header.flags, &self.payload)?;

        match &self.payload {
            DatagramPayload::Frames(frames) => {
                self.header.encode_raknet(dst)?;
                for frame in frames {
                    frame.encode_raknet(dst)?;
                }
            }
            DatagramPayload::Ack(payload) | DatagramPayload::Nack(payload) => {
                self.header.flags.bits().encode_raknet(dst)?;
                payload.encode_raknet(dst)?;
            }
        }
        Ok(())
    }

    pub fn decode(src: &mut impl Buf) -> Result<Self, DecodeError> {
        if !src.has_remaining() {
            return Err(DecodeError::UnexpectedEof);
        }

        let raw_flags = src.get_u8();
        let (flags, kind) = decode_datagram_flags(raw_flags)?;

        match kind {
            DatagramKind::Ack => Ok(Self {
                header: DatagramHeader {
                    flags,
                    sequence: Sequence24::new(0),
                },
                payload: DatagramPayload::Ack(AckNackPayload::decode_raknet(src)?),
            }),
            DatagramKind::Nack => Ok(Self {
                header: DatagramHeader {
                    flags,
                    sequence: Sequence24::new(0),
                },
                payload: DatagramPayload::Nack(AckNackPayload::decode_raknet(src)?),
            }),
            DatagramKind::Data => {
                let sequence = Sequence24::decode_raknet(src)?;
                let header = DatagramHeader { flags, sequence };

                let mut frames = Vec::new();
                while src.has_remaining() {
                    frames.push(Frame::decode_raknet(src)?);
                }

                Ok(Self {
                    header,
                    payload: DatagramPayload::Frames(frames),
                })
            }
        }
    }
}

fn decode_datagram_flags(raw_flags: u8) -> Result<(DatagramFlags, DatagramKind), DecodeError> {
    let Some(flags) = DatagramFlags::from_bits(raw_flags) else {
        return Err(DecodeError::InvalidDatagramFlags(raw_flags));
    };

    if !flags.contains(DatagramFlags::VALID) {
        return Err(DecodeError::InvalidDatagramFlags(raw_flags));
    }

    let has_ack = flags.contains(DatagramFlags::ACK);
    let has_nack = flags.contains(DatagramFlags::NACK);
    if has_ack && has_nack {
        return Err(DecodeError::InvalidDatagramFlags(raw_flags));
    }

    if has_ack || has_nack {
        let control_extras = DatagramFlags::PACKET_PAIR
            | DatagramFlags::CONTINUOUS_SEND
            | DatagramFlags::HAS_B_AND_AS;
        if flags.intersects(control_extras) {
            return Err(DecodeError::InvalidDatagramFlags(raw_flags));
        }

        return Ok((
            flags,
            if has_ack {
                DatagramKind::Ack
            } else {
                DatagramKind::Nack
            },
        ));
    }

    Ok((flags, DatagramKind::Data))
}

fn validate_flags_for_payload(
    flags: DatagramFlags,
    payload: &DatagramPayload,
) -> Result<(), EncodeError> {
    let raw_flags = flags.bits();
    let (_, decoded_kind) = decode_datagram_flags(raw_flags)
        .map_err(|_| EncodeError::InvalidDatagramFlags(raw_flags))?;

    if decoded_kind != payload.kind() {
        return Err(EncodeError::InvalidDatagramFlags(raw_flags));
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::protocol::ack::SequenceRange;

    fn sample_ack_payload() -> AckNackPayload {
        AckNackPayload {
            ranges: vec![SequenceRange {
                start: Sequence24::new(7),
                end: Sequence24::new(7),
            }],
        }
    }

    #[test]
    fn decode_rejects_unknown_datagram_bits() {
        let mut src = &b"\x83\0\0\0"[..];
        let err = Datagram::decode(&mut src).expect_err("unknown bit must be rejected");
        assert!(matches!(err, DecodeError::InvalidDatagramFlags(0x83)));
    }

    #[test]
    fn decode_rejects_ack_without_valid_flag() {
        let mut src = &b"\x40\0\0"[..];
        let err = Datagram::decode(&mut src).expect_err("ack without valid bit must be rejected");
        assert!(matches!(err, DecodeError::InvalidDatagramFlags(0x40)));
    }

    #[test]
    fn decode_rejects_control_with_data_only_bits() {
        let mut src = &b"\xC8\0\0"[..];
        let err =
            Datagram::decode(&mut src).expect_err("control datagram must not carry data-only bits");
        assert!(matches!(err, DecodeError::InvalidDatagramFlags(0xC8)));
    }

    #[test]
    fn decode_accepts_valid_ack_flags() {
        let payload = sample_ack_payload();
        let mut encoded = Vec::new();
        payload
            .encode_raknet(&mut encoded)
            .expect("ack payload encode");

        let mut src = vec![0xC0];
        src.extend(encoded);
        let decoded = Datagram::decode(&mut src.as_slice()).expect("ack datagram should decode");

        match decoded.payload {
            DatagramPayload::Ack(decoded_payload) => assert_eq!(decoded_payload, payload),
            other => panic!("unexpected payload: {other:?}"),
        }
    }

    #[test]
    fn encode_rejects_payload_and_flag_mismatch() {
        let datagram = Datagram {
            header: DatagramHeader {
                flags: DatagramFlags::VALID | DatagramFlags::ACK,
                sequence: Sequence24::new(12),
            },
            payload: DatagramPayload::Frames(Vec::new()),
        };

        let mut out = Vec::new();
        let err = datagram
            .encode(&mut out)
            .expect_err("invalid payload/flags mismatch must fail");
        assert!(matches!(err, EncodeError::InvalidDatagramFlags(0xC0)));
    }
}