sim-lib-server 0.1.0-rc.1

SIM workspace package for sim lib server.
Documentation
use std::io::{ErrorKind, Read, Write};

use sim_kernel::{Error, Result};

use crate::transport::MAX_TRANSPORT_FRAME_BYTES;

use super::core::WsMessage;

pub(crate) fn read_ws_message<R: Read>(reader: &mut R) -> Result<Option<WsMessage>> {
    let mut first = [0u8; 2];
    match read_exact_or_eof(reader, &mut first)? {
        HeadRead::Eof => return Ok(None),
        HeadRead::Filled => {}
    }
    let fin = (first[0] & 0x80) != 0;
    let opcode = first[0] & 0x0f;
    if !fin {
        return Err(Error::HostError(
            "fragmented websocket frames are not supported".to_owned(),
        ));
    }
    let masked = (first[1] & 0x80) != 0;
    let mut payload_len = u64::from(first[1] & 0x7f);
    if payload_len == 126 {
        let mut buf = [0u8; 2];
        reader.read_exact(&mut buf).map_err(io_to_host)?;
        payload_len = u64::from(u16::from_be_bytes(buf));
    } else if payload_len == 127 {
        let mut buf = [0u8; 8];
        reader.read_exact(&mut buf).map_err(io_to_host)?;
        payload_len = u64::from_be_bytes(buf);
    }
    let payload_len = usize::try_from(payload_len)
        .map_err(|_| Error::HostError("websocket payload length overflow".to_owned()))?;
    if payload_len > MAX_TRANSPORT_FRAME_BYTES {
        return Err(Error::HostError(
            "websocket payload exceeds size limit".to_owned(),
        ));
    }
    let mut mask = [0u8; 4];
    if masked {
        reader.read_exact(&mut mask).map_err(io_to_host)?;
    }
    let mut payload = vec![0u8; payload_len];
    reader.read_exact(&mut payload).map_err(io_to_host)?;
    if masked {
        for (index, byte) in payload.iter_mut().enumerate() {
            *byte ^= mask[index % 4];
        }
    }
    match opcode {
        0x2 => Ok(Some(WsMessage::Binary(payload))),
        0x8 => Ok(Some(WsMessage::Close)),
        other => Err(Error::HostError(format!(
            "unsupported websocket opcode {other}"
        ))),
    }
}

pub(crate) fn write_ws_binary<W: Write>(
    writer: &mut W,
    payload: &[u8],
    masked: bool,
) -> Result<()> {
    write_ws_frame(writer, 0x2, payload, masked)
}

pub(crate) fn write_ws_close<W: Write>(writer: &mut W, masked: bool) -> Result<()> {
    write_ws_frame(writer, 0x8, &[], masked)
}

fn write_ws_frame<W: Write>(
    writer: &mut W,
    opcode: u8,
    payload: &[u8],
    masked: bool,
) -> Result<()> {
    if payload.len() > MAX_TRANSPORT_FRAME_BYTES {
        return Err(Error::HostError(
            "websocket payload exceeds size limit".to_owned(),
        ));
    }
    let mut header = vec![0x80 | opcode];
    let mask_bit = if masked { 0x80 } else { 0x00 };
    match payload.len() {
        len @ 0..=125 => header.push(mask_bit | len as u8),
        len @ 126..=65535 => {
            header.push(mask_bit | 126);
            header.extend_from_slice(&(len as u16).to_be_bytes());
        }
        len => {
            header.push(mask_bit | 127);
            header.extend_from_slice(&(len as u64).to_be_bytes());
        }
    }
    writer.write_all(&header).map_err(io_to_host)?;
    if masked {
        let mask = [0x13, 0x37, 0x42, 0x99];
        writer.write_all(&mask).map_err(io_to_host)?;
        let masked_payload = payload
            .iter()
            .enumerate()
            .map(|(index, byte)| byte ^ mask[index % 4])
            .collect::<Vec<_>>();
        writer.write_all(&masked_payload).map_err(io_to_host)?;
    } else {
        writer.write_all(payload).map_err(io_to_host)?;
    }
    writer.flush().map_err(io_to_host)
}

enum HeadRead {
    Eof,
    Filled,
}

fn read_exact_or_eof<R: Read>(reader: &mut R, mut buffer: &mut [u8]) -> Result<HeadRead> {
    let mut read_any = false;
    while !buffer.is_empty() {
        match reader.read(buffer) {
            Ok(0) if !read_any => return Ok(HeadRead::Eof),
            Ok(0) => return Err(Error::HostError("truncated websocket frame".to_owned())),
            Ok(read) => {
                read_any = true;
                let (_, rest) = buffer.split_at_mut(read);
                buffer = rest;
            }
            Err(error) if error.kind() == ErrorKind::Interrupted => {}
            Err(error) => return Err(io_to_host(error)),
        }
    }
    Ok(HeadRead::Filled)
}

fn io_to_host(error: std::io::Error) -> Error {
    Error::HostError(format!("io {:?}: {}", error.kind(), error))
}