soth-mitm 0.3.3

Rust intercepting proxy crate with deterministic handler/event contracts for SOTH.
Documentation
use std::sync::Arc;
use std::time::{Duration, Instant};

use dashmap::{DashMap, DashSet};
use lru::LruCache;
use tokio::sync::Mutex;

use crate::handler::InterceptHandler;
use crate::metrics::ProxyMetricsStore;
use crate::process::{PlatformProcessAttributor, ProcessLookupService};
use crate::runtime::connection_id::connection_id_for_flow_id;
use crate::runtime::flow_dispatch::FlowDispatchers;
use crate::runtime::handler_guard::HandlerCallbackGuard;
use crate::types::{ConnectionMeta, FlowId};

#[derive(Debug)]
pub(super) struct FlowStateContext<H: InterceptHandler> {
    pub(super) metrics_store: Arc<ProxyMetricsStore>,
    pub(super) closed_flow_ids: Arc<Mutex<LruCache<FlowId, ()>>>,
    pub(super) closed_flow_live: Arc<DashSet<FlowId>>,
    pub(super) flow_dispatchers: Arc<FlowDispatchers<H>>,
    pub(super) stream_sequences: Arc<DashMap<FlowId, u64>>,
    pub(super) connection_meta_by_flow: Arc<DashMap<FlowId, Arc<ConnectionMeta>>>,
    pub(super) response_activity_flows: Arc<DashSet<FlowId>>,
    pub(super) flow_last_touched: Arc<DashMap<FlowId, Instant>>,
    pub(super) tls_intercepted_flow_ids: Arc<DashMap<FlowId, ()>>,
    pub(super) process_lookup: Option<Arc<ProcessLookupService<PlatformProcessAttributor>>>,
    pub(super) handler: Arc<H>,
    pub(super) callback_guard: Arc<HandlerCallbackGuard>,
    /// Abort handles for spawned connection tasks. When the reaper detects a
    /// stale flow, it aborts the task — which drops the `FlowRuntimeGuard` and
    /// releases the semaphore permit back to the pool.
    pub(super) task_abort_handles: Arc<DashMap<FlowId, tokio::task::AbortHandle>>,
}

pub(super) async fn schedule_stale_flow_reap<H: InterceptHandler>(
    flow_state: Arc<FlowStateContext<H>>,
    stale_flow_ttl: Duration,
    stale_reap_interval: Duration,
    stale_reap_max_batch: usize,
    last_stale_reap_at: Arc<Mutex<Instant>>,
) {
    let now = Instant::now();
    let should_reap = {
        let mut last = last_stale_reap_at.lock().await;
        if now.duration_since(*last) < stale_reap_interval {
            false
        } else {
            *last = now;
            true
        }
    };
    if !should_reap {
        return;
    }

    tokio::spawn(async move {
        reap_stale_flows(flow_state, stale_flow_ttl, stale_reap_max_batch).await;
    });
}

async fn reap_stale_flows<H: InterceptHandler>(
    flow_state: Arc<FlowStateContext<H>>,
    stale_flow_ttl: Duration,
    stale_reap_max_batch: usize,
) {
    let now = Instant::now();
    let stale_flow_ids: Vec<FlowId> = flow_state
        .flow_last_touched
        .iter()
        .filter_map(|entry| {
            if now.saturating_duration_since(*entry.value()) >= stale_flow_ttl {
                Some(*entry.key())
            } else {
                None
            }
        })
        .take(stale_reap_max_batch.max(1))
        .collect();

    for flow_id in stale_flow_ids {
        if let Some(last_touched) = flow_state.flow_last_touched.get(&flow_id) {
            if now.saturating_duration_since(*last_touched) < stale_flow_ttl {
                continue;
            }
        }
        tracing::warn!(
            flow_id = flow_id.as_u64(),
            "reaping stale flow state without explicit stream_end"
        );
        flow_state.metrics_store.record_stale_flow_reap();
        finalize_flow(flow_id, Arc::clone(&flow_state)).await;
    }
}

pub(super) async fn finalize_flow<H: InterceptHandler>(
    flow_id: FlowId,
    flow_state: Arc<FlowStateContext<H>>,
) {
    let should_finalize = {
        let mut closed = flow_state.closed_flow_ids.lock().await;
        if closed.get(&flow_id).is_some() {
            false
        } else {
            if let Some((evicted_flow_id, _)) = closed.push(flow_id, ()) {
                flow_state.closed_flow_live.remove(&evicted_flow_id);
                flow_state.metrics_store.record_closed_flow_id_eviction();
                tracing::debug!(
                    flow_id = flow_id.as_u64(),
                    evicted_flow_id = evicted_flow_id.as_u64(),
                    "closed-flow LRU evicted tombstone entry"
                );
            }
            flow_state.closed_flow_live.insert(flow_id);
            flow_state.stream_sequences.remove(&flow_id);
            flow_state.connection_meta_by_flow.remove(&flow_id);
            flow_state.response_activity_flows.remove(&flow_id);
            flow_state.flow_last_touched.remove(&flow_id);
            flow_state.tls_intercepted_flow_ids.remove(&flow_id);
            true
        }
    };
    if !should_finalize {
        return;
    }

    // Abort the stuck connection task. This drops the FlowRuntimeGuard
    // inside it, which releases the semaphore permit back to the pool.
    if let Some((_, abort_handle)) = flow_state.task_abort_handles.remove(&flow_id) {
        abort_handle.abort();
        tracing::debug!(
            flow_id = flow_id.as_u64(),
            "aborted stuck connection task to release flow permit"
        );
    }

    flow_state.flow_dispatchers.close_and_drain(flow_id).await;

    let connection_id = connection_id_for_flow_id(flow_id);
    if let Some(lookup) = flow_state.process_lookup.as_ref() {
        lookup.remove_connection(connection_id).await;
    }

    let handler_for_end = Arc::clone(&flow_state.handler);
    flow_state
        .callback_guard
        .run_response((), async move {
            handler_for_end.on_stream_end(connection_id).await
        })
        .await;

    flow_state.handler.on_connection_close(connection_id);
}