soth-mitm 0.3.2

Rust intercepting proxy crate with deterministic handler/event contracts for SOTH.
Documentation
use super::io_timeouts::{read_with_websocket_idle_timeout, write_all_with_websocket_idle_timeout};
use super::runtime_governor;
use super::websocket_codec::{
    decode_websocket_header_soketto, websocket_payload_len_within_limit,
    WebSocketHeaderDecodeResult, WebSocketHeaderView,
};
use super::websocket_relay::WS_FRAME_COPY_CHUNK_SIZE;
use std::io;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};

pub(crate) struct PrefixedReader<R> {
    prefix: Vec<u8>,
    prefix_offset: usize,
    source: R,
}

impl<R> PrefixedReader<R> {
    pub(crate) fn new(prefix: Vec<u8>, source: R) -> Self {
        Self {
            prefix,
            prefix_offset: 0,
            source,
        }
    }
}

impl<R> PrefixedReader<R>
where
    R: AsyncRead + Unpin,
{
    async fn read_exact_or_eof(&mut self, out: &mut [u8]) -> io::Result<bool> {
        let mut written = 0_usize;
        while written < out.len() {
            if self.prefix_offset < self.prefix.len() {
                let available = self.prefix.len() - self.prefix_offset;
                let take = available.min(out.len() - written);
                out[written..written + take]
                    .copy_from_slice(&self.prefix[self.prefix_offset..self.prefix_offset + take]);
                self.prefix_offset += take;
                written += take;
                continue;
            }

            let read = read_with_websocket_idle_timeout(
                &mut self.source,
                &mut out[written..],
                "websocket_frame_read",
            )
            .await?;
            if read == 0 {
                if written == 0 {
                    return Ok(false);
                }
                return Err(io::Error::new(
                    io::ErrorKind::UnexpectedEof,
                    "websocket frame ended before expected bytes were read",
                ));
            }
            written += read;
        }
        Ok(true)
    }

    async fn read_exact_required(&mut self, out: &mut [u8], label: &str) -> io::Result<()> {
        if self.read_exact_or_eof(out).await? {
            return Ok(());
        }

        Err(io::Error::new(
            io::ErrorKind::UnexpectedEof,
            format!("connection closed while reading websocket {label}"),
        ))
    }
}

pub(crate) async fn read_websocket_frame_header<R>(
    source: &mut PrefixedReader<R>,
    codec: &soketto::base::Codec,
    max_frame_payload_bytes: usize,
) -> io::Result<Option<(Vec<u8>, WebSocketHeaderView)>>
where
    R: AsyncRead + Unpin,
{
    const WS_MAX_FRAME_HEADER_BYTES: usize = 14;
    let mut frame_header = Vec::with_capacity(WS_MAX_FRAME_HEADER_BYTES);

    loop {
        match decode_websocket_header_soketto(codec, &frame_header)? {
            WebSocketHeaderDecodeResult::NeedMore(_) => {
                if frame_header.len() >= WS_MAX_FRAME_HEADER_BYTES {
                    return Err(io::Error::new(
                        io::ErrorKind::InvalidData,
                        format!("websocket frame header exceeds {WS_MAX_FRAME_HEADER_BYTES} bytes"),
                    ));
                }
                let mut next_byte = [0_u8; 1];
                if !source.read_exact_or_eof(&mut next_byte).await? {
                    if frame_header.is_empty() {
                        return Ok(None);
                    }
                    return Err(io::Error::new(
                        io::ErrorKind::UnexpectedEof,
                        "connection closed while reading websocket frame header",
                    ));
                }
                frame_header.push(next_byte[0]);
            }
            WebSocketHeaderDecodeResult::Complete(header_view) => {
                websocket_payload_len_within_limit(
                    header_view.payload_len,
                    max_frame_payload_bytes,
                )?;
                if header_view.header_len != frame_header.len() {
                    return Err(io::Error::new(
                        io::ErrorKind::InvalidData,
                        format!(
                            "websocket frame header offset mismatch: decoded={} buffered={}",
                            header_view.header_len,
                            frame_header.len()
                        ),
                    ));
                }
                return Ok(Some((frame_header, header_view)));
            }
        }
    }
}

pub(crate) async fn relay_websocket_payload<R, W>(
    source: &mut PrefixedReader<R>,
    sink: &mut W,
    runtime_governor: &Arc<runtime_governor::RuntimeGovernor>,
    mut payload_len: u64,
    masking_key: Option<[u8; 4]>,
    max_payload_capture_bytes: usize,
) -> io::Result<bytes::Bytes>
where
    R: AsyncRead + Unpin,
    W: AsyncWrite + Unpin,
{
    if payload_len == 0 {
        return Ok(bytes::Bytes::new());
    }

    let mut chunk = [0_u8; WS_FRAME_COPY_CHUNK_SIZE];
    let mut captured = Vec::new();
    let mut mask_offset = 0_usize;
    while payload_len > 0 {
        let read_len = (chunk.len() as u64).min(payload_len) as usize;
        let _in_flight_lease =
            runtime_governor.reserve_in_flight_or_error(read_len, "websocket_payload_write")?;
        source
            .read_exact_required(&mut chunk[..read_len], "payload")
            .await?;
        write_all_with_websocket_idle_timeout(
            &mut *sink,
            &chunk[..read_len],
            "websocket_payload_write",
        )
        .await?;
        if captured.len() < max_payload_capture_bytes {
            let take = (max_payload_capture_bytes - captured.len()).min(read_len);
            if let Some(mask) = masking_key {
                for (index, byte) in chunk[..take].iter().enumerate() {
                    captured.push(*byte ^ mask[(mask_offset + index) % 4]);
                }
            } else {
                captured.extend_from_slice(&chunk[..take]);
            }
        }
        if masking_key.is_some() {
            mask_offset = (mask_offset + read_len) % 4;
        }
        payload_len -= read_len as u64;
    }
    Ok(bytes::Bytes::from(captured))
}