soth-mitm 0.3.3

Rust intercepting proxy crate with deterministic handler/event contracts for SOTH.
Documentation
use super::close_codes::CloseReasonCode;
use super::downstream_tls::accept_downstream_tls;
use super::event_emitters::{
    emit_stream_closed, emit_tls_event, emit_tls_event_with_cache,
    emit_tls_event_with_negotiated_alpn,
};
use super::flow_hooks::FlowHooks;
use super::flow_intercept_http1::relay_http1_mitm_loop;
use super::flow_intercept_tls_failure::fail_tls_and_close;
use super::http2_stream_relay::relay_http2_connection;
use super::http2_stream_relay_http1::relay_http2_downstream_to_http1_upstream;
use super::http_body_relay::write_proxy_response;
use super::io_timeouts::write_all_with_idle_timeout;
use super::route_planner_model::{RouteBinding, RouteConnectIntent, UpstreamRequestTargetMode};
use super::route_planner_transport::connect_via_route;
use super::runtime_governor;
use super::tls_diagnostics::TlsDiagnostics;
use super::tls_learning::TlsLearningGuardrails;
use super::tls_profile_mapping::map_upstream_sni_mode;
use super::BufferedConn;
use crate::engine::MitmEngine;
use crate::observe::{EventConsumer, EventType, FlowContext};
use crate::policy::PolicyEngine;
use crate::protocol::{protocol_from_negotiated_alpn, ApplicationProtocol};
use crate::tls::{
    resolve_upstream_server_name, MitmCertificateStore,
    UpstreamClientAuthMode as TlsUpstreamClientAuthMode, UpstreamTlsConfigCache,
};
use crate::types::ProcessInfo;
use std::io;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;

#[allow(clippy::too_many_arguments)]
pub(crate) async fn intercept_http_connection<P, S>(
    engine: Arc<MitmEngine<P, S>>,
    cert_store: Arc<MitmCertificateStore>,
    runtime_governor: Arc<runtime_governor::RuntimeGovernor>,
    tls_diagnostics: Arc<TlsDiagnostics>,
    tls_learning: Arc<TlsLearningGuardrails>,
    flow_hooks: Arc<dyn FlowHooks>,
    upstream_tls_cache: Arc<UpstreamTlsConfigCache>,
    tunnel_context: FlowContext,
    process_info: Option<ProcessInfo>,
    route: RouteBinding,
    policy_override_state: crate::policy::PolicyOverrideState,
    mut downstream: TcpStream,
    max_http_head_bytes: usize,
) -> io::Result<()>
where
    P: PolicyEngine + Send + Sync + 'static,
    S: EventConsumer + Send + Sync + 'static,
{
    let http2_enabled_for_flow = engine.config.http2_enabled && !policy_override_state.disable_h2;
    let skip_upstream_verify_for_flow = engine.config.upstream_tls_insecure_skip_verify
        || policy_override_state.skip_upstream_verify;
    let upstream_tcp = match connect_via_route(&route, RouteConnectIntent::TargetTunnel).await {
        Ok(stream) => {
            super::socket_hardening::apply_upstream_socket_hardening(&stream);
            stream
        }
        Err(error) => {
            let detail = format!(
                "upstream_connect_failed[{}]: {error}",
                route.route_mode_label()
            );
            write_proxy_response(&mut downstream, "502 Bad Gateway", &detail).await?;
            emit_stream_closed(
                &engine,
                tunnel_context,
                CloseReasonCode::UpstreamConnectFailed,
                Some(error.to_string()),
                None,
                None,
            );
            return Ok(());
        }
    };

    write_all_with_idle_timeout(
        &mut downstream,
        b"HTTP/1.1 200 Connection Established\r\n\r\n",
        "mitm_connect_established_write",
    )
    .await?;

    let handshake_context = FlowContext {
        protocol: ApplicationProtocol::Http1,
        ..tunnel_context.clone()
    };
    let upstream_sni_mode = map_upstream_sni_mode(engine.config.upstream_sni_mode);
    let server_name =
        match resolve_upstream_server_name(&handshake_context.server_host, upstream_sni_mode) {
            Ok(value) => value,
            Err(error) => {
                let detail = format!("invalid server name for upstream TLS: {error}");
                return fail_tls_and_close(
                    &engine,
                    &flow_hooks,
                    &tls_diagnostics,
                    &tls_learning,
                    handshake_context.clone(),
                    tunnel_context,
                    "upstream",
                    detail,
                )
                .await;
            }
        };
    let issued_server_config = match cert_store
        .server_config_for_host_with_http2(&handshake_context.server_host, http2_enabled_for_flow)
    {
        Ok(config) => config,
        Err(error) => {
            return fail_tls_and_close(
                &engine,
                &flow_hooks,
                &tls_diagnostics,
                &tls_learning,
                handshake_context.clone(),
                tunnel_context,
                "downstream",
                format!("downstream leaf issuance error: {error}"),
            )
            .await;
        }
    };
    emit_tls_event_with_cache(
        &engine,
        EventType::TlsHandshakeStarted,
        handshake_context.clone(),
        "downstream",
        issued_server_config.cache_status.as_str(),
    );
    let (downstream_tls, client_fingerprint) = match accept_downstream_tls(
        engine.config.downstream_tls_backend,
        downstream,
        &issued_server_config,
        http2_enabled_for_flow,
    )
    .await
    {
        Ok(result) => result,
        Err(error) => {
            return fail_tls_and_close(
                &engine,
                &flow_hooks,
                &tls_diagnostics,
                &tls_learning,
                handshake_context.clone(),
                tunnel_context,
                "downstream",
                format!("downstream handshake failed: {error}"),
            )
            .await;
        }
    };

    // Store the JA4 fingerprint in the connection metadata if available.
    if let Some(ref fp) = client_fingerprint {
        flow_hooks
            .update_connection_fingerprint(tunnel_context.flow_id, fp)
            .await;
    }

    let downstream_alpn = downstream_tls.negotiated_alpn();
    let downstream_protocol =
        protocol_from_negotiated_alpn(downstream_alpn.as_deref(), http2_enabled_for_flow);
    let downstream_context = FlowContext {
        protocol: downstream_protocol,
        ..tunnel_context.clone()
    };
    emit_tls_event_with_negotiated_alpn(
        &engine,
        EventType::TlsHandshakeSucceeded,
        downstream_context.clone(),
        "downstream",
        downstream_alpn.as_deref(),
    );

    let should_offer_http2_upstream =
        http2_enabled_for_flow && downstream_protocol == ApplicationProtocol::Http2;
    let client_config = match upstream_tls_cache.get_or_build(
        skip_upstream_verify_for_flow,
        should_offer_http2_upstream,
        &handshake_context.server_host,
    ) {
        Ok(value) => value,
        Err(error) => {
            let detail = format!("upstream TLS config build failed: {error}");
            return fail_tls_and_close(
                &engine,
                &flow_hooks,
                &tls_diagnostics,
                &tls_learning,
                downstream_context.clone(),
                tunnel_context,
                "upstream",
                detail,
            )
            .await;
        }
    };
    let connector = TlsConnector::from(client_config);
    let upstream_start_context = FlowContext {
        protocol: if should_offer_http2_upstream {
            ApplicationProtocol::Http2
        } else {
            ApplicationProtocol::Http1
        },
        ..tunnel_context.clone()
    };
    emit_tls_event(
        &engine,
        EventType::TlsHandshakeStarted,
        upstream_start_context,
        "upstream",
    );
    let upstream_tls = match connector.connect(server_name.clone(), upstream_tcp).await {
        Ok(stream) => stream,
        Err(error) => {
            return fail_tls_and_close(
                &engine,
                &flow_hooks,
                &tls_diagnostics,
                &tls_learning,
                downstream_context.clone(),
                tunnel_context,
                "upstream",
                format!("upstream handshake failed: {error}"),
            )
            .await;
        }
    };
    let upstream_alpn = upstream_tls
        .get_ref()
        .1
        .alpn_protocol()
        .map(ToOwned::to_owned);
    let upstream_protocol =
        protocol_from_negotiated_alpn(upstream_alpn.as_deref(), should_offer_http2_upstream);
    let upstream_context = FlowContext {
        protocol: upstream_protocol,
        ..tunnel_context.clone()
    };
    emit_tls_event_with_negotiated_alpn(
        &engine,
        EventType::TlsHandshakeSucceeded,
        upstream_context.clone(),
        "upstream",
        upstream_alpn.as_deref(),
    );

    if downstream_protocol == ApplicationProtocol::Http2 {
        let max_header_list_size = engine.config.http2_max_header_list_size;
        return if upstream_protocol == ApplicationProtocol::Http2 {
            relay_http2_connection(
                engine,
                Arc::clone(&runtime_governor),
                flow_hooks,
                tunnel_context,
                process_info,
                downstream_tls,
                upstream_tls,
                max_header_list_size,
            )
            .await
        } else {
            relay_http2_downstream_to_http1_upstream(
                engine,
                Arc::clone(&runtime_governor),
                flow_hooks,
                tunnel_context,
                process_info,
                downstream_tls,
                upstream_tls,
                route,
                connector,
                server_name,
                max_http_head_bytes,
                max_header_list_size,
                policy_override_state.strict_header_mode,
            )
            .await
        };
    }

    let downstream_conn = BufferedConn::new(downstream_tls);
    let upstream_conn = BufferedConn::new(upstream_tls);
    relay_http1_mitm_loop(
        engine,
        runtime_governor,
        flow_hooks,
        tunnel_context,
        UpstreamRequestTargetMode::OriginForm,
        downstream_conn,
        upstream_conn,
        max_http_head_bytes,
        policy_override_state.strict_header_mode,
    )
    .await
}

/// Reads raw PEM bytes for upstream client auth at startup, returning them
/// for deferred parsing inside the TLS config cache. Fails eagerly when the
/// auth mode is `Required` but the files cannot be read.
pub(crate) fn load_upstream_client_auth_pem(
    cert_path: &Option<String>,
    key_path: &Option<String>,
    auth_mode: TlsUpstreamClientAuthMode,
) -> Result<Option<(Vec<u8>, Vec<u8>)>, String> {
    let (Some(cert_path), Some(key_path)) = (cert_path.as_ref(), key_path.as_ref()) else {
        return Ok(None);
    };

    let cert_pem = std::fs::read(cert_path)
        .map_err(|error| format!("read cert path {cert_path} failed: {error}"));
    let key_pem = std::fs::read(key_path)
        .map_err(|error| format!("read key path {key_path} failed: {error}"));

    match (cert_pem, key_pem) {
        (Ok(c), Ok(k)) => Ok(Some((c, k))),
        (Err(e), _) | (_, Err(e))
            if auth_mode == TlsUpstreamClientAuthMode::IfRequested
                || auth_mode == TlsUpstreamClientAuthMode::Never =>
        {
            tracing::debug!("upstream client auth PEM not loaded (non-required): {e}");
            Ok(None)
        }
        (Err(e), _) | (_, Err(e)) => Err(e),
    }
}