use super::close_codes::CloseReasonCode;
use super::event_emitters::emit_stream_closed;
use super::flow_hooks::FlowHooks;
use super::http2_relay_support::{
configure_h2_client, configure_h2_server, h2_error_to_io, is_benign_h2_stream_io_error,
is_h2_nonfatal_stream_error,
};
use super::http2_stream_relay_stream::relay_http2_stream;
use super::io_timeouts::{is_idle_watchdog_timeout, is_stream_stage_timeout};
use super::runtime_governor;
use crate::engine::MitmEngine;
use crate::observe::{EventConsumer, FlowContext};
use crate::policy::PolicyEngine;
use crate::protocol::ApplicationProtocol;
use crate::types::ProcessInfo;
use std::io;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::io::{AsyncRead, AsyncWrite};
static H2_RELAY_DEBUG_ENABLED: OnceLock<bool> = OnceLock::new();
fn h2_relay_debug_enabled() -> bool {
*H2_RELAY_DEBUG_ENABLED.get_or_init(|| {
std::env::var("SOTH_MITM_H2_RELAY_DEBUG")
.ok()
.map(|value| matches!(value.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
.unwrap_or(false)
})
}
pub(crate) fn h2_relay_debug(message: impl AsRef<str>) {
if h2_relay_debug_enabled() {
tracing::debug!("{}", message.as_ref());
}
}
#[derive(Clone)]
pub(crate) struct H2ByteCounters {
pub(crate) request_bytes: Arc<AtomicU64>,
pub(crate) response_bytes: Arc<AtomicU64>,
}
pub(crate) async fn relay_http2_connection<P, S, D, U>(
engine: Arc<MitmEngine<P, S>>,
runtime_governor: Arc<runtime_governor::RuntimeGovernor>,
flow_hooks: Arc<dyn FlowHooks>,
tunnel_context: FlowContext,
process_info: Option<ProcessInfo>,
downstream_tls: D,
upstream_tls: U,
max_header_list_size: u32,
) -> io::Result<()>
where
P: PolicyEngine + Send + Sync + 'static,
S: EventConsumer + Send + Sync + 'static,
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
U: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut downstream_builder = h2::server::Builder::new();
configure_h2_server(&mut downstream_builder, max_header_list_size);
let mut downstream_connection = match downstream_builder.handshake(downstream_tls).await {
Ok(connection) => connection,
Err(error) => {
emit_stream_closed(
&engine,
FlowContext {
protocol: ApplicationProtocol::Http2,
..tunnel_context
},
CloseReasonCode::MitmHttpError,
Some(format!("downstream HTTP/2 handshake failed: {error}")),
None,
None,
);
return Ok(());
}
};
let mut upstream_builder = h2::client::Builder::new();
configure_h2_client(&mut upstream_builder, max_header_list_size);
let (upstream_sender, upstream_connection) =
match upstream_builder.handshake(upstream_tls).await {
Ok(connection_parts) => connection_parts,
Err(error) => {
emit_stream_closed(
&engine,
FlowContext {
protocol: ApplicationProtocol::Http2,
..tunnel_context
},
CloseReasonCode::MitmHttpError,
Some(format!("upstream HTTP/2 handshake failed: {error}")),
None,
None,
);
return Ok(());
}
};
let upstream_connection_task = tokio::spawn(upstream_connection);
let http2_context = FlowContext {
protocol: ApplicationProtocol::Http2,
..tunnel_context.clone()
};
let byte_counters = H2ByteCounters {
request_bytes: Arc::new(AtomicU64::new(0)),
response_bytes: Arc::new(AtomicU64::new(0)),
};
let mut stream_tasks = tokio::task::JoinSet::new();
let mut first_error: Option<io::Error> = None;
while let Some(next_stream) = downstream_connection.accept().await {
match next_stream {
Ok((request, respond)) => {
let stream_engine = Arc::clone(&engine);
let stream_runtime_governor = Arc::clone(&runtime_governor);
let stream_context = FlowContext {
flow_id: stream_engine.allocate_flow_id(),
..http2_context.clone()
};
let stream_upstream_sender = upstream_sender.clone();
let stream_byte_counters = byte_counters.clone();
let stream_flow_hooks = Arc::clone(&flow_hooks);
let stream_process_info = process_info.clone();
stream_tasks.spawn(async move {
stream_flow_hooks
.on_connection_open(stream_context.clone(), stream_process_info)
.await;
let stream_end_context = stream_context.clone();
let result = relay_http2_stream(
stream_engine,
stream_runtime_governor,
Arc::clone(&stream_flow_hooks),
stream_context,
stream_upstream_sender,
request,
respond,
max_header_list_size,
stream_byte_counters,
)
.await;
if let Err(ref error) = result {
h2_relay_debug(format!(
"[h2-relay:stream] flow_id={} host={} error={} benign={}",
stream_end_context.flow_id,
stream_end_context.server_host,
error,
is_benign_h2_stream_io_error(error),
));
stream_flow_hooks.on_stream_end(stream_end_context).await;
}
result
});
}
Err(error) => {
if !is_h2_nonfatal_stream_error(&error) && first_error.is_none() {
first_error = Some(h2_error_to_io("downstream HTTP/2 accept failed", error));
}
break;
}
}
}
while let Some(task_result) = stream_tasks.join_next().await {
match task_result {
Ok(Ok(())) => {}
Ok(Err(error)) => {
if !is_benign_h2_stream_io_error(&error) && first_error.is_none() {
first_error = Some(error);
}
}
Err(join_error) => {
if first_error.is_none() {
first_error = Some(io::Error::other(format!(
"HTTP/2 stream task join failed: {join_error}"
)));
}
}
}
}
drop(upstream_sender);
match upstream_connection_task.await {
Ok(Ok(())) => {}
Ok(Err(error)) => {
if !is_h2_nonfatal_stream_error(&error) && first_error.is_none() {
first_error = Some(h2_error_to_io("upstream HTTP/2 driver failed", error));
}
}
Err(join_error) => {
if first_error.is_none() {
first_error = Some(io::Error::other(format!(
"HTTP/2 upstream task join failed: {join_error}"
)));
}
}
}
let bytes_from_client = byte_counters.request_bytes.load(Ordering::Relaxed);
let bytes_from_server = byte_counters.response_bytes.load(Ordering::Relaxed);
if let Some(error) = first_error {
let close_reason = if is_stream_stage_timeout(&error) {
CloseReasonCode::StreamStageTimeout
} else if is_idle_watchdog_timeout(&error) {
CloseReasonCode::IdleWatchdogTimeout
} else {
CloseReasonCode::MitmHttpError
};
emit_stream_closed(
&engine,
http2_context,
close_reason,
Some(error.to_string()),
Some(bytes_from_client),
Some(bytes_from_server),
);
} else {
emit_stream_closed(
&engine,
http2_context,
CloseReasonCode::MitmHttpCompleted,
None,
Some(bytes_from_client),
Some(bytes_from_server),
);
}
Ok(())
}