openwire 0.1.1

OkHttp-inspired async HTTP client for Rust built on hyper and tower
Documentation
use bytes::{Buf, BufMut, BytesMut};

use openwire_core::websocket::WebSocketEngineError;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub(crate) enum Opcode {
    Continuation = 0x0,
    Text = 0x1,
    Binary = 0x2,
    Close = 0x8,
    Ping = 0x9,
    Pong = 0xA,
}

impl Opcode {
    pub(crate) fn from_byte(b: u8) -> Option<Self> {
        match b & 0x0F {
            0x0 => Some(Self::Continuation),
            0x1 => Some(Self::Text),
            0x2 => Some(Self::Binary),
            0x8 => Some(Self::Close),
            0x9 => Some(Self::Ping),
            0xA => Some(Self::Pong),
            _ => None,
        }
    }

    pub(crate) fn is_control(self) -> bool {
        matches!(self, Self::Close | Self::Ping | Self::Pong)
    }
}

#[derive(Clone, Copy, Debug)]
pub(crate) struct FrameHeader {
    pub fin: bool,
    pub opcode: Opcode,
}

#[derive(Debug)]
pub(crate) struct DecodedFrame {
    pub fin: bool,
    pub opcode: Opcode,
    pub payload: bytes::Bytes,
}

pub(crate) fn encode_frame(
    out: &mut BytesMut,
    header: FrameHeader,
    payload: &[u8],
    mask: Option<[u8; 4]>,
) {
    let b0 = if header.fin { 0x80 } else { 0x00 } | (header.opcode as u8);
    out.put_u8(b0);

    let mask_bit = if mask.is_some() { 0x80 } else { 0x00 };
    let len = payload.len();
    if len <= 125 {
        out.put_u8(mask_bit | len as u8);
    } else if len <= u16::MAX as usize {
        out.put_u8(mask_bit | 126);
        out.put_u16(len as u16);
    } else {
        out.put_u8(mask_bit | 127);
        out.put_u64(len as u64);
    }

    if let Some(key) = mask {
        out.put_slice(&key);
        let start = out.len();
        out.put_slice(payload);
        super::mask::mask_in_place(&mut out[start..], key);
    } else {
        out.put_slice(payload);
    }
}

pub(crate) fn decode_frame(
    buf: &mut BytesMut,
    max_frame_size: usize,
) -> Result<Option<DecodedFrame>, WebSocketEngineError> {
    if buf.len() < 2 {
        return Ok(None);
    }

    let b0 = buf[0];
    let b1 = buf[1];
    let fin = (b0 & 0x80) != 0;
    let rsv = b0 & 0x70;
    if rsv != 0 {
        return Err(WebSocketEngineError::InvalidFrame(
            "reserved bits set".into(),
        ));
    }
    let opcode = Opcode::from_byte(b0)
        .ok_or_else(|| WebSocketEngineError::InvalidFrame("unknown opcode".into()))?;

    let mask_bit = (b1 & 0x80) != 0;
    if mask_bit {
        return Err(WebSocketEngineError::InvalidFrame(
            "server frames must not be masked".into(),
        ));
    }

    if opcode.is_control() && !fin {
        return Err(WebSocketEngineError::InvalidFrame(
            "fragmented control frame".into(),
        ));
    }

    let len_field = b1 & 0x7F;
    let (header_len, payload_len) = match len_field {
        0..=125 => (2usize, len_field as usize),
        126 => {
            if buf.len() < 4 {
                return Ok(None);
            }
            let payload_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
            if payload_len <= 125 {
                return Err(WebSocketEngineError::InvalidFrame(
                    "non-minimal payload length encoding".into(),
                ));
            }
            (4, payload_len)
        }
        127 => {
            if buf.len() < 10 {
                return Ok(None);
            }
            let mut v = [0u8; 8];
            v.copy_from_slice(&buf[2..10]);
            let payload_len = u64::from_be_bytes(v);
            if payload_len & (1 << 63) != 0 {
                return Err(WebSocketEngineError::InvalidFrame(
                    "invalid 64-bit payload length".into(),
                ));
            }
            if payload_len <= u16::MAX as u64 {
                return Err(WebSocketEngineError::InvalidFrame(
                    "non-minimal payload length encoding".into(),
                ));
            }
            let payload_len = usize::try_from(payload_len).map_err(|_| {
                WebSocketEngineError::PayloadTooLarge {
                    limit: max_frame_size,
                    received: usize::MAX,
                }
            })?;
            (10, payload_len)
        }
        _ => unreachable!("len_field is at most 7 bits"),
    };

    if opcode.is_control() && payload_len > 125 {
        return Err(WebSocketEngineError::InvalidFrame(
            "control frame > 125 bytes".into(),
        ));
    }
    if payload_len > max_frame_size {
        return Err(WebSocketEngineError::PayloadTooLarge {
            limit: max_frame_size,
            received: payload_len,
        });
    }

    if buf.len() < header_len + payload_len {
        return Ok(None);
    }
    buf.advance(header_len);
    let payload = buf.split_to(payload_len).freeze();

    Ok(Some(DecodedFrame {
        fin,
        opcode,
        payload,
    }))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn opcode_round_trip() {
        for op in [
            Opcode::Continuation,
            Opcode::Text,
            Opcode::Binary,
            Opcode::Close,
            Opcode::Ping,
            Opcode::Pong,
        ] {
            assert_eq!(Opcode::from_byte(op as u8), Some(op));
        }
    }
}

#[cfg(test)]
mod encode_tests {
    use super::*;
    use bytes::BytesMut;

    fn encode(opcode: Opcode, payload: &[u8], fin: bool, mask_key: [u8; 4]) -> Vec<u8> {
        let mut out = BytesMut::new();
        encode_frame(
            &mut out,
            FrameHeader { fin, opcode },
            payload,
            Some(mask_key),
        );
        out.to_vec()
    }

    #[test]
    fn encodes_short_text_frame() {
        let bytes = encode(Opcode::Text, b"hello", true, [1, 2, 3, 4]);
        assert_eq!(bytes[0], 0x81);
        assert_eq!(bytes[1], 0x80 | 5);
        assert_eq!(&bytes[2..6], &[1, 2, 3, 4]);

        let mut expected = b"hello".to_vec();
        super::super::mask::mask_in_place(&mut expected, [1, 2, 3, 4]);
        assert_eq!(&bytes[6..], &expected[..]);
    }

    #[test]
    fn uses_16bit_length_for_payload_126() {
        let payload = vec![0u8; 126];
        let bytes = encode(Opcode::Binary, &payload, true, [0; 4]);
        assert_eq!(bytes[1], 0x80 | 126);
        assert_eq!(&bytes[2..4], &[0x00, 0x7E]);
    }

    #[test]
    fn uses_64bit_length_for_payload_65536() {
        let payload = vec![0u8; 65536];
        let bytes = encode(Opcode::Binary, &payload, true, [0; 4]);
        assert_eq!(bytes[1], 0x80 | 127);
        assert_eq!(&bytes[2..10], &[0, 0, 0, 0, 0, 1, 0, 0]);
    }

    #[test]
    fn unmasked_encode_omits_mask_key() {
        let mut out = BytesMut::new();
        encode_frame(
            &mut out,
            FrameHeader {
                fin: true,
                opcode: Opcode::Text,
            },
            b"hi",
            None,
        );
        assert_eq!(&out[..], &[0x81, 0x02, b'h', b'i']);
    }
}

#[cfg(test)]
mod decode_tests {
    use super::*;
    use bytes::BytesMut;

    #[test]
    fn decodes_short_text_frame() {
        let mut buf = BytesMut::from(&[0x81u8, 0x05, b'h', b'e', b'l', b'l', b'o'][..]);
        let frame = decode_frame(&mut buf, 1024).unwrap().expect("frame ready");
        assert!(frame.fin);
        assert_eq!(frame.opcode, Opcode::Text);
        assert_eq!(frame.payload.as_ref(), b"hello");
        assert!(buf.is_empty());
    }

    #[test]
    fn returns_none_when_partial() {
        let mut buf = BytesMut::from(&[0x81u8][..]);
        assert!(decode_frame(&mut buf, 1024).unwrap().is_none());
    }

    #[test]
    fn rejects_masked_server_frame() {
        let mut buf = BytesMut::from(&[0x81u8, 0x85, 1, 2, 3, 4, b'h', b'e', b'l', b'l', b'o'][..]);
        let err = decode_frame(&mut buf, 1024).unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn enforces_max_frame_size() {
        let mut buf = BytesMut::from(&[0x82u8, 126, 0, 200][..]);
        buf.extend_from_slice(&[0u8; 200]);
        let err = decode_frame(&mut buf, 100).unwrap_err();
        assert!(matches!(err, WebSocketEngineError::PayloadTooLarge { .. }));
    }

    #[test]
    fn decodes_16bit_length_for_payload_126() {
        let mut buf = BytesMut::from(&[0x82u8, 126, 0, 126][..]);
        buf.extend_from_slice(&[0u8; 126]);

        let frame = decode_frame(&mut buf, 1024).unwrap().expect("frame ready");

        assert_eq!(frame.opcode, Opcode::Binary);
        assert_eq!(frame.payload.len(), 126);
        assert!(buf.is_empty());
    }

    #[test]
    fn decodes_64bit_length_for_payload_65536() {
        let mut buf = BytesMut::from(&[0x82u8, 127, 0, 0, 0, 0, 0, 1, 0, 0][..]);
        buf.extend_from_slice(&[0u8; 65_536]);

        let frame = decode_frame(&mut buf, 65_536)
            .unwrap()
            .expect("frame ready");

        assert_eq!(frame.opcode, Opcode::Binary);
        assert_eq!(frame.payload.len(), 65_536);
        assert!(buf.is_empty());
    }

    #[test]
    fn rejects_16bit_length_for_short_payloads() {
        let mut buf = BytesMut::from(&[0x81u8, 126, 0, 125][..]);
        buf.extend_from_slice(&[0u8; 125]);

        let err = decode_frame(&mut buf, 1024).unwrap_err();

        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_64bit_length_for_16bit_payloads() {
        let mut buf = BytesMut::from(&[0x82u8, 127, 0, 0, 0, 0, 0, 0, 0, 126][..]);
        buf.extend_from_slice(&[0u8; 126]);

        let err = decode_frame(&mut buf, 1024).unwrap_err();

        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_64bit_length_with_high_bit_set() {
        let mut buf = BytesMut::from(&[0x82u8, 127, 0x80, 0, 0, 0, 0, 0, 0, 0][..]);

        let err = decode_frame(&mut buf, usize::MAX).unwrap_err();

        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_fragmented_control_frame() {
        let mut buf = BytesMut::from(&[0x09u8, 0x00][..]);
        let err = decode_frame(&mut buf, 1024).unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_oversized_control_frame() {
        let mut buf = BytesMut::from(&[0x89u8, 126, 0, 200][..]);
        buf.extend_from_slice(&[0u8; 200]);
        let err = decode_frame(&mut buf, 1024).unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_reserved_bits() {
        let mut buf = BytesMut::from(&[0x91u8, 0x00][..]);
        let err = decode_frame(&mut buf, 1024).unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }
}