wslink-rs 0.0.2

A wslink-compatible WebSocket RPC server runtime for Rust using MessagePack transport.
Documentation
use std::collections::HashMap;

const UINT32_LEN: usize = 4;
// there are 3 uint32 value in header, 4 bytes each
const HEADER_LENGTH: usize = UINT32_LEN * 3;

fn decode_header(header: &[u8]) -> anyhow::Result<(u32, u32, u32)> {
    if header.len() != UINT32_LEN * 3 {
        return Err(anyhow::anyhow!(
            "invalid chunk header length: expected {}, got {}",
            HEADER_LENGTH,
            header.len()
        ));
    }

    let mut values = [0; 3];

    header.chunks_exact(4).enumerate().for_each(|(idx, chunk)| {
        // bytes in little byteorder
        values[idx] = u32::from_le_bytes(chunk.try_into().unwrap());
    });

    Ok((values[0], values[1], values[2]))
}

fn encode_header(id: u32, offset: u32, total_size: u32) -> [u8; HEADER_LENGTH] {
    let mut bytes = [0; HEADER_LENGTH];

    let values = [id, offset, total_size];

    bytes
        .chunks_exact_mut(4)
        .enumerate()
        .for_each(|(i, chunk)| chunk.copy_from_slice(&values[i].to_le_bytes()));

    bytes
}

#[derive(Default)]
struct PendingMessage {
    received_size: usize,
    content: Vec<u8>,
    received_ranges: Vec<(usize, usize)>,
}

impl PendingMessage {
    fn add_received_range(&mut self, start: usize, end: usize) {
        if start >= end {
            return;
        }

        let mut added = end - start;
        let mut merged_start = start;
        let mut merged_end = end;

        let mut idx = 0;
        while idx < self.received_ranges.len() && self.received_ranges[idx].1 < merged_start {
            idx += 1;
        }

        let merge_start = idx;
        while idx < self.received_ranges.len() && self.received_ranges[idx].0 <= merged_end {
            let (existing_start, existing_end) = self.received_ranges[idx];

            let overlap_start = start.max(existing_start);
            let overlap_end = end.min(existing_end);
            if overlap_start < overlap_end {
                added -= overlap_end - overlap_start;
            }

            merged_start = merged_start.min(existing_start);
            merged_end = merged_end.max(existing_end);
            idx += 1;
        }

        self.received_ranges
            .splice(merge_start..idx, std::iter::once((merged_start, merged_end)));
        self.received_size += added;
    }
}

pub(crate) struct UnChunker {
    max_message_size: usize,
    pending_message: HashMap<u32, PendingMessage>,
}

impl UnChunker {
    pub(crate) fn new() -> Self {
        Self {
            max_message_size: 512,
            pending_message: HashMap::default(),
        }
    }

    pub(crate) fn set_max_message_size(&mut self, size: usize) {
        self.max_message_size = size;
    }

    pub(crate) fn process_chunk(&mut self, chunk: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
        if chunk.len() < HEADER_LENGTH {
            return Err(anyhow::anyhow!(
                "chunk shorter than header: {} < {}",
                chunk.len(),
                HEADER_LENGTH
            ));
        }

        let (header, chunk_content) = chunk.split_at(HEADER_LENGTH);

        let (id, offset, total_size) = decode_header(header)?;

        let offset = offset as usize;
        let total_size = total_size as usize;

        if total_size > self.max_message_size {
            return Err(anyhow::anyhow!(format!(
                r"Total size for message {} exceeds the allocation limit allowed.
Maximum size = {},
Received size = {}.",
                id, self.max_message_size, total_size
            )));
        }

        if offset > total_size {
            return Err(anyhow::anyhow!(
                "chunk offset out of bounds for message {}: offset={}, total_size={}",
                id,
                offset,
                total_size
            ));
        }

        let end = offset
            .checked_add(chunk_content.len())
            .ok_or_else(|| anyhow::anyhow!("chunk size overflow for message {}", id))?;

        if end > total_size {
            return Err(anyhow::anyhow!(
                "chunk exceeds message bounds for message {}: offset={}, chunk_size={}, total_size={}",
                id,
                offset,
                chunk_content.len(),
                total_size
            ));
        }

        let mut pending_message = self.pending_message.get_mut(&id);

        if pending_message.is_none() {
            let mut content = Vec::new();
            content.resize(total_size, 0);

            self.pending_message.insert(
                id,
                PendingMessage {
                    received_size: 0,
                    content,
                    received_ranges: Vec::new(),
                },
            );
            pending_message = self.pending_message.get_mut(&id);
        }

        if let Some(pending_message) = pending_message {
            if total_size != pending_message.content.len() {
                self.pending_message.remove(&id);

                return Err(anyhow::anyhow!(
                    "Total size in chunk header for message {} does not match total size declared by previous chunk.",
                    id
                ));
            }

            pending_message.content[offset..end].copy_from_slice(chunk_content);
            pending_message.add_received_range(offset, end);

            let value = if pending_message.received_size == total_size {
                self.pending_message.remove(&id).map(|msg| msg.content)
            } else {
                None
            };

            return Ok(value);
        }

        Err(anyhow::anyhow!("panic"))
    }
}

pub(crate) fn generate_chunks(message: &[u8], max_size: usize) -> Vec<Vec<u8>> {
    use rand;

    let total_len = message.len();
    let total_len_u32 = u32::try_from(total_len).expect("message length does not fit in chunk header");

    let max_content_size = max_size.saturating_sub(HEADER_LENGTH).max(1);

    let id = rand::random();

    if total_len == 0 {
        return vec![encode_header(id, 0, 0).to_vec()];
    }

    let mut chunks = Vec::new();

    let mut offset = 0;
    while offset < total_len {
        let end = (offset + max_content_size).min(total_len);
        let header = encode_header(id, offset as u32, total_len_u32);

        let mut bytes = Vec::with_capacity(HEADER_LENGTH + (end - offset));

        bytes.extend_from_slice(&header);
        bytes.extend_from_slice(&message[offset..end]);

        offset = end;

        chunks.push(bytes);
    }

    chunks
}