rs-netty 1.1.0

A Tokio-native typed TCP/UDP pipeline framework inspired by Netty.
Documentation
use bytes::{Buf, BufMut, Bytes, BytesMut};

use crate::{
    codec::{Decoder, Encoder},
    Error, Result,
};

/// Byte order used for numeric length fields.
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub enum ByteOrder {
    /// Most significant byte first.
    #[default]
    BigEndian,
    /// Least significant byte first.
    LittleEndian,
}

/// Decoder for frames that carry their payload length in the byte stream.
///
/// This follows the same shape as Netty's length-field based framing: the
/// length field can appear after an offset, can be adjusted before calculating
/// the frame end, and leading bytes can be stripped from the returned frame.
///
/// Supported length field widths are 1, 2, 3, 4, and 8 bytes.
pub struct LengthFieldBasedFrameDecoder {
    max_frame_length: usize,
    length_field_offset: usize,
    length_field_length: usize,
    length_adjustment: isize,
    initial_bytes_to_strip: usize,
    byte_order: ByteOrder,
}

impl LengthFieldBasedFrameDecoder {
    /// Creates a length-field decoder with no adjustment.
    ///
    /// The returned frame strips the bytes up to and including the length
    /// field, so the decoded item is the payload by default.
    pub fn new(
        max_frame_length: usize,
        length_field_offset: usize,
        length_field_length: usize,
    ) -> Self {
        Self {
            max_frame_length,
            length_field_offset,
            length_field_length,
            length_adjustment: 0,
            initial_bytes_to_strip: length_field_offset + length_field_length,
            byte_order: ByteOrder::BigEndian,
        }
    }

    /// Creates a length-field decoder with explicit adjustment and strip count.
    ///
    /// `length_adjustment` is added to the decoded length before calculating
    /// the frame end. `initial_bytes_to_strip` controls how many bytes are
    /// removed from the front of the complete frame before it is returned.
    pub fn with_adjustment(
        max_frame_length: usize,
        length_field_offset: usize,
        length_field_length: usize,
        length_adjustment: isize,
        initial_bytes_to_strip: usize,
    ) -> Self {
        Self {
            max_frame_length,
            length_field_offset,
            length_field_length,
            length_adjustment,
            initial_bytes_to_strip,
            byte_order: ByteOrder::BigEndian,
        }
    }

    /// Sets the byte order used to read and write the length field.
    pub fn byte_order(mut self, byte_order: ByteOrder) -> Self {
        self.byte_order = byte_order;
        self
    }

    fn read_frame_length(&self, src: &BytesMut) -> Result<usize> {
        let start = self.length_field_offset;
        let end = start + self.length_field_length;
        let bytes = &src[start..end];

        let len = match (self.byte_order, self.length_field_length) {
            (ByteOrder::BigEndian, 1) | (ByteOrder::LittleEndian, 1) => bytes[0] as u64,
            (ByteOrder::BigEndian, 2) => u16::from_be_bytes([bytes[0], bytes[1]]) as u64,
            (ByteOrder::LittleEndian, 2) => u16::from_le_bytes([bytes[0], bytes[1]]) as u64,
            (ByteOrder::BigEndian, 3) => {
                ((bytes[0] as u64) << 16) | ((bytes[1] as u64) << 8) | bytes[2] as u64
            }
            (ByteOrder::LittleEndian, 3) => {
                ((bytes[2] as u64) << 16) | ((bytes[1] as u64) << 8) | bytes[0] as u64
            }
            (ByteOrder::BigEndian, 4) => {
                u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as u64
            }
            (ByteOrder::LittleEndian, 4) => {
                u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as u64
            }
            (ByteOrder::BigEndian, 8) => u64::from_be_bytes([
                bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
            ]),
            (ByteOrder::LittleEndian, 8) => u64::from_le_bytes([
                bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
            ]),
            _ => {
                return Err(Error::Decode(format!(
                    "unsupported length field length: {}",
                    self.length_field_length
                )))
            }
        };

        usize::try_from(len).map_err(|err| Error::Decode(err.to_string()))
    }

    fn write_frame_length(&self, item_len: usize, dst: &mut BytesMut) -> Result<()> {
        if self.length_field_offset != 0 || self.length_adjustment != 0 {
            return Err(Error::Encode(
                "LengthFieldBasedFrameDecoder encoder supports zero offset and zero adjustment only".to_string(),
            ));
        }

        write_length(item_len, self.length_field_length, self.byte_order, dst)
    }
}

impl Default for LengthFieldBasedFrameDecoder {
    fn default() -> Self {
        Self::new(8 * 1024 * 1024, 0, 4)
    }
}

impl Decoder for LengthFieldBasedFrameDecoder {
    type Item = Bytes;

    #[inline]
    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
        let field_end = self.length_field_offset + self.length_field_length;
        if src.len() < field_end {
            return Ok(None);
        }

        let frame_length = self.read_frame_length(src)?;
        let adjusted = frame_length
            .checked_add_signed(self.length_adjustment)
            .ok_or_else(|| Error::Decode("negative adjusted frame length".to_string()))?;
        let frame_end = field_end
            .checked_add(adjusted)
            .ok_or_else(|| Error::Decode("frame length overflow".to_string()))?;

        if frame_end > self.max_frame_length {
            return Err(Error::FrameTooLarge {
                current: frame_end,
                max: self.max_frame_length,
            });
        }

        if src.len() < frame_end {
            return Ok(None);
        }

        if self.initial_bytes_to_strip > frame_end {
            return Err(Error::Decode(format!(
                "initial_bytes_to_strip={} exceeds frame length={frame_end}",
                self.initial_bytes_to_strip
            )));
        }

        let mut frame = src.split_to(frame_end);
        frame.advance(self.initial_bytes_to_strip);
        Ok(Some(frame.freeze()))
    }
}

impl Encoder<Bytes> for LengthFieldBasedFrameDecoder {
    #[inline]
    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<()> {
        self.write_frame_length(item.len(), dst)?;
        dst.extend_from_slice(&item);
        Ok(())
    }
}

/// Outbound stage that prepends a binary length field.
///
/// Supported length field widths are 1, 2, 3, 4, and 8 bytes.
pub struct LengthFieldPrepender {
    length_field_length: usize,
    byte_order: ByteOrder,
}

impl LengthFieldPrepender {
    /// Creates a length prepender using big-endian byte order.
    pub fn new(length_field_length: usize) -> Self {
        Self {
            length_field_length,
            byte_order: ByteOrder::BigEndian,
        }
    }

    /// Sets the byte order used to write the length field.
    pub fn byte_order(mut self, byte_order: ByteOrder) -> Self {
        self.byte_order = byte_order;
        self
    }
}

impl Encoder<Bytes> for LengthFieldPrepender {
    #[inline]
    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<()> {
        write_length(item.len(), self.length_field_length, self.byte_order, dst)?;
        dst.extend_from_slice(&item);
        Ok(())
    }
}

fn write_length(
    len: usize,
    length_field_length: usize,
    byte_order: ByteOrder,
    dst: &mut BytesMut,
) -> Result<()> {
    match (byte_order, length_field_length) {
        (ByteOrder::BigEndian, 1) | (ByteOrder::LittleEndian, 1) => {
            let len = u8::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u8(len);
        }
        (ByteOrder::BigEndian, 2) => {
            let len = u16::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u16(len);
        }
        (ByteOrder::LittleEndian, 2) => {
            let len = u16::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u16_le(len);
        }
        (ByteOrder::BigEndian, 3) => {
            if len > 0x00ff_ffff {
                return Err(Error::Encode(format!("length {len} exceeds 24-bit field")));
            }
            dst.put_u8(((len >> 16) & 0xff) as u8);
            dst.put_u8(((len >> 8) & 0xff) as u8);
            dst.put_u8((len & 0xff) as u8);
        }
        (ByteOrder::LittleEndian, 3) => {
            if len > 0x00ff_ffff {
                return Err(Error::Encode(format!("length {len} exceeds 24-bit field")));
            }
            dst.put_u8((len & 0xff) as u8);
            dst.put_u8(((len >> 8) & 0xff) as u8);
            dst.put_u8(((len >> 16) & 0xff) as u8);
        }
        (ByteOrder::BigEndian, 4) => {
            let len = u32::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u32(len);
        }
        (ByteOrder::LittleEndian, 4) => {
            let len = u32::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u32_le(len);
        }
        (ByteOrder::BigEndian, 8) => {
            let len = u64::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u64(len);
        }
        (ByteOrder::LittleEndian, 8) => {
            let len = u64::try_from(len).map_err(|err| Error::Encode(err.to_string()))?;
            dst.put_u64_le(len);
        }
        _ => {
            return Err(Error::Encode(format!(
                "unsupported length field length: {length_field_length}"
            )))
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn decodes_and_encodes_length_prefixed_frame() {
        let mut codec = LengthFieldBasedFrameDecoder::default();
        let mut buf = BytesMut::new();

        codec
            .encode(Bytes::from_static(b"ping"), &mut buf)
            .expect("encode");
        assert_eq!(&buf[..], b"\0\0\0\x04ping");

        let frame = codec.decode(&mut buf).expect("decode").expect("frame");
        assert_eq!(frame, Bytes::from_static(b"ping"));
        assert!(buf.is_empty());
    }
}