soth-mitm 0.3.0

Rust intercepting proxy crate with deterministic handler/event contracts for SOTH.
Documentation
use std::io;

pub(crate) const H2_MAX_CONCURRENT_STREAMS: u32 = 128;
pub(crate) const H2_INITIAL_WINDOW_SIZE: u32 = 1_048_576;
pub(crate) const H2_INITIAL_CONNECTION_WINDOW_SIZE: u32 = 4_194_304;
pub(crate) const H2_MAX_SEND_BUFFER_SIZE: usize = 128 * 1024;
pub(crate) const H2_FORWARD_CHUNK_LIMIT: usize = 128 * 1024;

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct GrpcRequestObservation {
    pub(crate) path: String,
    pub(crate) service: Option<String>,
    pub(crate) method: Option<String>,
    pub(crate) detection_mode: &'static str,
    pub(crate) content_type: Option<String>,
}

pub(crate) fn configure_h2_server(builder: &mut h2::server::Builder, max_header_list_size: u32) {
    builder.max_header_list_size(max_header_list_size);
    builder.max_concurrent_streams(H2_MAX_CONCURRENT_STREAMS);
    builder.initial_window_size(H2_INITIAL_WINDOW_SIZE);
    builder.initial_connection_window_size(H2_INITIAL_CONNECTION_WINDOW_SIZE);
    builder.max_send_buffer_size(H2_MAX_SEND_BUFFER_SIZE);
}

pub(crate) fn configure_h2_client(builder: &mut h2::client::Builder, max_header_list_size: u32) {
    builder.max_header_list_size(max_header_list_size);
    builder.max_concurrent_streams(H2_MAX_CONCURRENT_STREAMS);
    builder.initial_window_size(H2_INITIAL_WINDOW_SIZE);
    builder.initial_connection_window_size(H2_INITIAL_CONNECTION_WINDOW_SIZE);
    builder.max_send_buffer_size(H2_MAX_SEND_BUFFER_SIZE);
}

pub(crate) fn is_h2_transport_close_error(error: &h2::Error) -> bool {
    if let Some(io_error) = error.get_io() {
        matches!(
            io_error.kind(),
            io::ErrorKind::UnexpectedEof
                | io::ErrorKind::BrokenPipe
                | io::ErrorKind::ConnectionReset
                | io::ErrorKind::ConnectionAborted
        )
    } else {
        error.is_go_away() && error.is_remote() && error.reason() == Some(h2::Reason::NO_ERROR)
    }
}

pub(crate) fn is_h2_nonfatal_stream_error(error: &h2::Error) -> bool {
    if is_h2_transport_close_error(error) {
        return true;
    }
    if !error.is_remote() {
        return false;
    }
    if error.is_go_away() {
        return error.reason() == Some(h2::Reason::NO_ERROR);
    }
    if error.is_reset() {
        return matches!(
            error.reason(),
            Some(h2::Reason::NO_ERROR)
                | Some(h2::Reason::CANCEL)
                | Some(h2::Reason::REFUSED_STREAM)
                | Some(h2::Reason::STREAM_CLOSED)
        );
    }
    false
}

pub(crate) fn h2_reason_for_downstream_reset(error: &h2::Error) -> h2::Reason {
    error.reason().unwrap_or(h2::Reason::CANCEL)
}

pub(crate) fn is_benign_h2_stream_io_error(error: &io::Error) -> bool {
    if matches!(
        error.kind(),
        io::ErrorKind::BrokenPipe
            | io::ErrorKind::ConnectionReset
            | io::ErrorKind::ConnectionAborted
            | io::ErrorKind::UnexpectedEof
    ) {
        return true;
    }
    let text = error.to_string();
    text.contains("stream error received: CANCEL")
        || text.contains("stream error received: REFUSED_STREAM")
        || text.contains("stream error received: STREAM_CLOSED")
        || text.contains("stream error received: NO_ERROR")
        || text.contains("connection error received: NO_ERROR")
}

pub(crate) fn h2_error_to_io(context: &str, error: h2::Error) -> io::Error {
    io::Error::other(format!("{context}: {error}"))
}

pub(crate) fn enforce_h2_request_header_limit(
    parts: &http::request::Parts,
    max_header_list_size: u32,
) -> io::Result<()> {
    let mut header_list_size = estimate_header_map_size(&parts.headers);
    header_list_size += header_field_size(":method", parts.method.as_str());
    header_list_size += header_field_size(":scheme", parts.uri.scheme_str().unwrap_or("https"));
    if let Some(authority) = parts.uri.authority() {
        header_list_size += header_field_size(":authority", authority.as_str());
    }
    let path = parts
        .uri
        .path_and_query()
        .map(|value| value.as_str())
        .unwrap_or("/");
    header_list_size += header_field_size(":path", path);
    enforce_h2_header_limit("request", header_list_size, max_header_list_size)
}

pub(crate) fn enforce_h2_response_header_limit(
    parts: &http::response::Parts,
    max_header_list_size: u32,
) -> io::Result<()> {
    let mut header_list_size = estimate_header_map_size(&parts.headers);
    header_list_size += header_field_size(":status", parts.status.as_str());
    enforce_h2_header_limit("response", header_list_size, max_header_list_size)
}

pub(crate) fn detect_grpc_request(parts: &http::request::Parts) -> Option<GrpcRequestObservation> {
    let path = parts
        .uri
        .path_and_query()
        .map(|value| value.as_str())
        .unwrap_or("/")
        .to_string();
    let content_type = parts
        .headers
        .get("content-type")
        .and_then(|value| value.to_str().ok())
        .map(ToOwned::to_owned);
    let has_grpc_content_type = content_type
        .as_deref()
        .map(is_grpc_content_type)
        .unwrap_or(false);
    let service_method = grpc_service_method_from_path(&path);
    let matches_grpc_path_pattern = service_method
        .as_ref()
        .map(|(service, method)| is_likely_grpc_path_pattern(service, method))
        .unwrap_or(false);

    let detection_mode = match (has_grpc_content_type, matches_grpc_path_pattern) {
        (true, true) => "content_type_and_path",
        (true, false) => "content_type",
        (false, true) => "path_pattern",
        (false, false) => return None,
    };
    let (service, method) = match service_method {
        Some((service, method)) => (Some(service), Some(method)),
        None => (None, None),
    };

    Some(GrpcRequestObservation {
        path,
        service,
        method,
        detection_mode,
        content_type,
    })
}

fn is_grpc_content_type(value: &str) -> bool {
    value
        .split(';')
        .next()
        .map(|head| {
            head.trim()
                .to_ascii_lowercase()
                .starts_with("application/grpc")
        })
        .unwrap_or(false)
}

fn is_likely_grpc_path_pattern(service: &str, method: &str) -> bool {
    !service.is_empty() && !method.is_empty() && service.contains('.')
}

fn grpc_service_method_from_path(path: &str) -> Option<(String, String)> {
    let path_only = path.split('?').next().unwrap_or(path);
    let trimmed = path_only.strip_prefix('/')?;
    let mut parts = trimmed.split('/');
    let service = parts.next()?;
    let method = parts.next()?;
    if service.is_empty() || method.is_empty() || parts.next().is_some() {
        return None;
    }
    Some((service.to_string(), method.to_string()))
}

fn enforce_h2_header_limit(
    direction: &str,
    observed_size: usize,
    max_header_list_size: u32,
) -> io::Result<()> {
    let limit = max_header_list_size as usize;
    if observed_size > limit {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            format!(
                "HTTP/2 {direction} header list size {observed_size} exceeded configured limit {limit}"
            ),
        ));
    }
    Ok(())
}

fn estimate_header_map_size(headers: &http::HeaderMap) -> usize {
    headers
        .iter()
        .map(|(name, value)| header_field_size(name.as_str(), value.as_bytes()))
        .sum()
}

fn header_field_size(name: &str, value: impl AsRef<[u8]>) -> usize {
    name.len() + value.as_ref().len() + 32
}

#[cfg(test)]
mod http2_relay_support_tests {
    use super::detect_grpc_request;

    #[test]
    fn detects_grpc_from_path_pattern_without_content_type() {
        let request = http::Request::builder()
            .method("POST")
            .uri("https://unit.test/greeter.Service/SayHello")
            .body(())
            .expect("request");
        let (parts, _) = request.into_parts();
        let observation = detect_grpc_request(&parts).expect("must detect grpc");
        assert_eq!(observation.detection_mode, "path_pattern");
        assert_eq!(observation.service.as_deref(), Some("greeter.Service"));
        assert_eq!(observation.method.as_deref(), Some("SayHello"));
    }

    #[test]
    fn does_not_detect_openai_rest_path_as_grpc_without_content_type() {
        let request = http::Request::builder()
            .method("GET")
            .uri("https://api.openai.com/v1/models")
            .body(())
            .expect("request");
        let (parts, _) = request.into_parts();
        assert!(
            detect_grpc_request(&parts).is_none(),
            "plain REST paths like /v1/models must not be tagged as grpc"
        );
    }
}