openwire 0.1.0

OkHttp-inspired async HTTP client for Rust built on hyper and tower
Documentation
use bytes::{Bytes, BytesMut};

use openwire_core::websocket::{close_code_is_valid, EngineFrame, WebSocketEngineError};

use super::codec::{DecodedFrame, Opcode};

pub(crate) struct ReassemblyState {
    max_message_size: usize,
    in_progress: Option<InProgress>,
}

struct InProgress {
    opcode: Opcode,
    buffer: BytesMut,
}

impl ReassemblyState {
    pub(crate) fn new(max_message_size: usize) -> Self {
        Self {
            max_message_size,
            in_progress: None,
        }
    }

    pub(crate) fn feed(
        &mut self,
        frame: DecodedFrame,
    ) -> Result<Option<EngineFrame>, WebSocketEngineError> {
        if frame.opcode.is_control() {
            return Ok(Some(self.parse_control(frame)?));
        }

        match (frame.opcode, self.in_progress.is_some()) {
            (Opcode::Continuation, false) => Err(WebSocketEngineError::InvalidFrame(
                "orphan continuation frame".into(),
            )),
            (Opcode::Continuation, true) => self.append_continuation(frame),
            (Opcode::Text | Opcode::Binary, false) => {
                if frame.fin {
                    self.deliver_single(frame.opcode, frame.payload)
                } else {
                    Self::check_message_size(self.max_message_size, frame.payload.len())?;
                    self.in_progress = Some(InProgress {
                        opcode: frame.opcode,
                        buffer: BytesMut::from(&frame.payload[..]),
                    });
                    Ok(None)
                }
            }
            (Opcode::Text | Opcode::Binary, true) => Err(WebSocketEngineError::InvalidFrame(
                "nested data frame while reassembling".into(),
            )),
            (Opcode::Close | Opcode::Ping | Opcode::Pong, _) => {
                unreachable!("control frames handled above")
            }
        }
    }

    pub(crate) fn end_of_stream(&self) -> Result<(), WebSocketEngineError> {
        if self.in_progress.is_some() {
            Err(WebSocketEngineError::InvalidFrame(
                "incomplete fragmented message".into(),
            ))
        } else {
            Ok(())
        }
    }

    fn append_continuation(
        &mut self,
        frame: DecodedFrame,
    ) -> Result<Option<EngineFrame>, WebSocketEngineError> {
        let max_message_size = self.max_message_size;
        let state = self
            .in_progress
            .as_mut()
            .expect("checked by caller before calling append");
        let Some(new_len) = state.buffer.len().checked_add(frame.payload.len()) else {
            return Err(WebSocketEngineError::PayloadTooLarge {
                limit: self.max_message_size,
                received: usize::MAX,
            });
        };
        Self::check_message_size(max_message_size, new_len)?;
        state.buffer.extend_from_slice(&frame.payload);

        if frame.fin {
            let opcode = state.opcode;
            let buffer = std::mem::take(&mut state.buffer).freeze();
            self.in_progress = None;
            self.deliver_single(opcode, buffer)
        } else {
            Ok(None)
        }
    }

    fn check_message_size(
        max_message_size: usize,
        received: usize,
    ) -> Result<(), WebSocketEngineError> {
        if received > max_message_size {
            Err(WebSocketEngineError::PayloadTooLarge {
                limit: max_message_size,
                received,
            })
        } else {
            Ok(())
        }
    }

    fn deliver_single(
        &self,
        opcode: Opcode,
        payload: Bytes,
    ) -> Result<Option<EngineFrame>, WebSocketEngineError> {
        Self::check_message_size(self.max_message_size, payload.len())?;
        match opcode {
            Opcode::Text => {
                let text = std::str::from_utf8(&payload)
                    .map_err(|_| WebSocketEngineError::InvalidUtf8)?
                    .to_string();
                Ok(Some(EngineFrame::Text(text)))
            }
            Opcode::Binary => Ok(Some(EngineFrame::Binary(payload))),
            _ => unreachable!("deliver_single only handles Text/Binary"),
        }
    }

    fn parse_control(&self, frame: DecodedFrame) -> Result<EngineFrame, WebSocketEngineError> {
        match frame.opcode {
            Opcode::Ping => Ok(EngineFrame::Ping(frame.payload)),
            Opcode::Pong => Ok(EngineFrame::Pong(frame.payload)),
            Opcode::Close => {
                if frame.payload.is_empty() {
                    Ok(EngineFrame::Close {
                        code: 1005,
                        reason: String::new(),
                    })
                } else if frame.payload.len() == 1 {
                    Err(WebSocketEngineError::InvalidFrame(
                        "close payload of length 1".into(),
                    ))
                } else {
                    let code = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
                    if !close_code_is_valid(code) {
                        return Err(WebSocketEngineError::InvalidCloseCode(code));
                    }
                    let reason = std::str::from_utf8(&frame.payload[2..])
                        .map_err(|_| WebSocketEngineError::InvalidUtf8)?
                        .to_string();
                    Ok(EngineFrame::Close { code, reason })
                }
            }
            _ => unreachable!("parse_control only handles control opcodes"),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::Bytes;

    fn frame(fin: bool, opcode: Opcode, payload: &[u8]) -> DecodedFrame {
        DecodedFrame {
            fin,
            opcode,
            payload: Bytes::copy_from_slice(payload),
        }
    }

    #[test]
    fn assembles_fragmented_text() {
        let mut session = ReassemblyState::new(1024);
        let first = session.feed(frame(false, Opcode::Text, b"He")).unwrap();
        assert!(first.is_none());
        let second = session
            .feed(frame(true, Opcode::Continuation, b"llo"))
            .unwrap()
            .unwrap();
        match second {
            EngineFrame::Text(text) => assert_eq!(text, "Hello"),
            other => panic!("expected text, got {other:?}"),
        }
    }

    #[test]
    fn rejects_invalid_utf8() {
        let mut session = ReassemblyState::new(1024);
        let err = session
            .feed(frame(true, Opcode::Text, &[0xff, 0xfe]))
            .unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidUtf8));
    }

    #[test]
    fn rejects_message_over_limit_in_single_frame() {
        let mut session = ReassemblyState::new(2);
        let err = session
            .feed(frame(true, Opcode::Binary, &[1, 2, 3]))
            .unwrap_err();
        assert!(matches!(err, WebSocketEngineError::PayloadTooLarge { .. }));
    }

    #[test]
    fn rejects_message_over_limit_during_reassembly() {
        let mut session = ReassemblyState::new(4);
        session
            .feed(frame(false, Opcode::Binary, &[1, 2, 3]))
            .unwrap();
        let err = session
            .feed(frame(true, Opcode::Continuation, &[4, 5]))
            .unwrap_err();
        assert!(matches!(err, WebSocketEngineError::PayloadTooLarge { .. }));
    }

    #[test]
    fn rejects_message_over_limit_on_initial_fragment() {
        let mut session = ReassemblyState::new(2);
        let err = session
            .feed(frame(false, Opcode::Binary, &[1, 2, 3]))
            .unwrap_err();

        assert!(matches!(
            err,
            WebSocketEngineError::PayloadTooLarge {
                limit: 2,
                received: 3
            }
        ));
        session.end_of_stream().unwrap();
    }

    #[test]
    fn parses_close_payload() {
        let mut session = ReassemblyState::new(1024);
        let mut payload = vec![0x03, 0xe8];
        payload.extend_from_slice(b"bye");
        let result = session
            .feed(frame(true, Opcode::Close, &payload))
            .unwrap()
            .unwrap();
        match result {
            EngineFrame::Close { code, reason } => {
                assert_eq!(code, 1000);
                assert_eq!(reason, "bye");
            }
            other => panic!("expected close, got {other:?}"),
        }
    }

    #[test]
    fn empty_close_payload_yields_1005() {
        let mut session = ReassemblyState::new(1024);
        let result = session
            .feed(frame(true, Opcode::Close, &[]))
            .unwrap()
            .unwrap();
        match result {
            EngineFrame::Close { code, .. } => assert_eq!(code, 1005),
            other => panic!("unexpected {other:?}"),
        }
    }

    #[test]
    fn rejects_orphan_continuation() {
        let mut session = ReassemblyState::new(1024);
        let err = session
            .feed(frame(true, Opcode::Continuation, b"x"))
            .unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_nested_data_frame() {
        let mut session = ReassemblyState::new(1024);
        session.feed(frame(false, Opcode::Text, b"He")).unwrap();
        let err = session.feed(frame(true, Opcode::Binary, b"!")).unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn rejects_end_of_stream_during_fragmented_message() {
        let mut session = ReassemblyState::new(1024);
        assert!(session
            .feed(frame(false, Opcode::Text, b"He"))
            .unwrap()
            .is_none());

        let err = session.end_of_stream().unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidFrame(_)));
    }

    #[test]
    fn accepts_end_of_stream_without_pending_fragment() {
        let session = ReassemblyState::new(1024);
        session.end_of_stream().unwrap();
    }

    #[test]
    fn rejects_invalid_close_code() {
        let mut session = ReassemblyState::new(1024);
        // 1004 is reserved
        let err = session
            .feed(frame(true, Opcode::Close, &[0x03, 0xec]))
            .unwrap_err();
        assert!(matches!(err, WebSocketEngineError::InvalidCloseCode(_)));
    }

    #[test]
    fn accepts_iana_registered_close_codes() {
        for code in [1012u16, 1013, 1014] {
            let mut session = ReassemblyState::new(1024);
            let result = session
                .feed(frame(true, Opcode::Close, &code.to_be_bytes()))
                .unwrap()
                .unwrap();

            match result {
                EngineFrame::Close {
                    code: actual,
                    reason,
                } => {
                    assert_eq!(actual, code);
                    assert!(reason.is_empty());
                }
                other => panic!("expected close, got {other:?}"),
            }
        }
    }
}