use crate::mitm::{LogStorage, LoggingPolicy, PiiRedactor, RequestMetadata};
use bytes::Bytes;
use h2::RecvStream;
use http::{Request, Response};
use std::sync::Arc;
use std::time::Instant;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct Http2Config {
pub initial_window_size: u32,
pub initial_connection_window_size: u32,
pub max_concurrent_streams: u32,
pub max_frame_size: u32,
pub enable_server_push: bool,
}
impl Default for Http2Config {
fn default() -> Self {
Self {
initial_window_size: 65535, initial_connection_window_size: 1_048_576, max_concurrent_streams: 100,
max_frame_size: 16384, enable_server_push: false, }
}
}
pub async fn handle_http2_mitm(
client_tls: ServerTlsStream<TcpStream>,
upstream_tls: ClientTlsStream<TcpStream>,
target_host: String,
target_port: u16,
logging_policy: Arc<LoggingPolicy>,
log_storage: Option<Arc<LogStorage>>,
config: Http2Config,
) -> Result<(), H2Error> {
debug!(
target_host = %target_host,
target_port = target_port,
"Starting HTTP/2 MITM connection"
);
let mut client_h2 = h2::server::Builder::new()
.initial_window_size(config.initial_window_size)
.initial_connection_window_size(config.initial_connection_window_size)
.max_concurrent_streams(config.max_concurrent_streams)
.max_frame_size(config.max_frame_size)
.handshake(client_tls)
.await
.map_err(H2Error::ClientHandshakeFailed)?;
debug!("HTTP/2 client handshake complete");
let (upstream_send_request, connection) = h2::client::Builder::new()
.initial_window_size(config.initial_window_size)
.initial_connection_window_size(config.initial_connection_window_size)
.max_frame_size(config.max_frame_size)
.handshake(upstream_tls)
.await
.map_err(H2Error::UpstreamHandshakeFailed)?;
debug!("HTTP/2 upstream handshake complete");
tokio::spawn(async move {
if let Err(e) = connection.await {
error!(error = ?e, "HTTP/2 upstream connection error");
}
});
let mut stream_count = 0u64;
while let Some(result) = client_h2.accept().await {
match result {
Ok((request, respond)) => {
stream_count += 1;
let stream_id = stream_count;
debug!(stream_id, "Accepted HTTP/2 stream");
let upstream_clone = upstream_send_request.clone();
let target_host_clone = target_host.clone();
let logging_policy_clone = Arc::clone(&logging_policy);
let log_storage_clone = log_storage.clone();
tokio::spawn(async move {
let start = Instant::now();
match proxy_h2_stream(
request,
respond,
upstream_clone,
target_host_clone,
target_port,
logging_policy_clone,
log_storage_clone,
)
.await
{
Ok(()) => {
debug!(
stream_id,
duration_ms = start.elapsed().as_millis(),
"HTTP/2 stream completed"
);
}
Err(e) => {
error!(stream_id, error = ?e, "HTTP/2 stream error");
}
}
});
}
Err(e) => {
if let Some(reason) = e.reason() {
if reason == h2::Reason::NO_ERROR {
info!("HTTP/2 graceful shutdown (NO_ERROR)");
break;
} else {
warn!(reason = ?reason, "HTTP/2 stream error, continuing");
continue;
}
} else {
error!(error = ?e, "HTTP/2 accept error without reason");
return Err(H2Error::AcceptFailed(e.to_string()));
}
}
}
}
info!(
target_host = %target_host,
stream_count,
"HTTP/2 MITM connection closed"
);
Ok(())
}
async fn proxy_h2_stream(
request: Request<RecvStream>,
mut client_respond: h2::server::SendResponse<Bytes>,
mut upstream: h2::client::SendRequest<Bytes>,
target_host: String,
target_port: u16,
logging_policy: Arc<LoggingPolicy>,
log_storage: Option<Arc<LogStorage>>,
) -> Result<(), H2Error> {
let request_start = Instant::now();
let (parts, mut recv_body) = request.into_parts();
let method = parts.method.to_string();
let path = parts
.uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
.to_string();
debug!(method = %method, path = %path, "Proxying HTTP/2 request");
let (response_future, mut send_body) = upstream
.send_request(Request::from_parts(parts, ()), false)
.map_err(|e| H2Error::UpstreamRequestFailed(e.to_string()))?;
tokio::spawn(async move {
let mut request_size = 0u64;
while let Some(chunk_result) = recv_body.data().await {
match chunk_result {
Ok(data) => {
request_size += data.len() as u64;
send_body.reserve_capacity(data.len());
if let Err(e) = send_body.send_data(data.clone(), false) {
error!(error = ?e, "Failed to send request body chunk");
break;
}
if let Err(e) = recv_body.flow_control().release_capacity(data.len()) {
error!(error = ?e, "Failed to release client flow control");
break;
}
}
Err(e) => {
error!(error = ?e, "Request body read error");
break;
}
}
}
let _ = send_body.send_data(Bytes::new(), true);
debug!(request_size, "Request body forwarded");
});
let response = response_future
.await
.map_err(|e| H2Error::UpstreamResponseFailed(e.to_string()))?;
let (parts, mut upstream_body) = response.into_parts();
let status = parts.status.as_u16();
debug!(status, "Received HTTP/2 response");
let mut client_body = client_respond
.send_response(Response::from_parts(parts, ()), false)
.map_err(|e| H2Error::ClientResponseFailed(e.to_string()))?;
let mut response_size = 0u64;
while let Some(chunk_result) = upstream_body.data().await {
match chunk_result {
Ok(data) => {
response_size += data.len() as u64;
client_body.reserve_capacity(data.len());
if let Err(e) = client_body.send_data(data.clone(), false) {
error!(error = ?e, "Failed to send response body chunk");
break;
}
if let Err(e) = upstream_body.flow_control().release_capacity(data.len()) {
error!(error = ?e, "Failed to release upstream flow control");
break;
}
}
Err(e) => {
error!(error = ?e, "Response body read error");
break;
}
}
}
let _ = client_body.send_data(Bytes::new(), true);
let duration_ms = request_start.elapsed().as_millis() as u64;
debug!(response_size, duration_ms, "Response body forwarded");
if let Some(storage) = log_storage {
let mut metadata = RequestMetadata {
timestamp: chrono::Utc::now().timestamp(),
host: target_host,
port: target_port,
method,
path: path.clone(),
http_version: "HTTP/2".to_string(),
status_code: Some(status),
request_size: 0, response_size: response_size as usize,
duration_ms,
tls_version: Some("TLS 1.3".to_string()),
mitm_applied: true,
bypass_reason: None,
};
if logging_policy.enable_pii_redaction {
metadata.path = PiiRedactor::redact(&metadata.path);
}
let storage_clone = Arc::clone(&storage);
tokio::spawn(async move {
if let Err(e) = storage_clone.log_request(&metadata).await {
warn!(error = %e, "Failed to log HTTP/2 request");
}
});
}
Ok(())
}
#[derive(Debug, Error)]
pub enum H2Error {
#[error("Client handshake failed: {0}")]
ClientHandshakeFailed(#[source] h2::Error),
#[error("Upstream handshake failed: {0}")]
UpstreamHandshakeFailed(#[source] h2::Error),
#[error("Accept stream failed: {0}")]
AcceptFailed(String),
#[error("Upstream request failed: {0}")]
UpstreamRequestFailed(String),
#[error("Upstream response failed: {0}")]
UpstreamResponseFailed(String),
#[error("Client response failed: {0}")]
ClientResponseFailed(String),
#[error("Stream reset: {0:?}")]
StreamReset(h2::Reason),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http2_config_defaults() {
let config = Http2Config::default();
assert_eq!(config.initial_window_size, 65535);
assert_eq!(config.initial_connection_window_size, 1_048_576);
assert_eq!(config.max_concurrent_streams, 100);
assert_eq!(config.max_frame_size, 16384);
assert_eq!(config.enable_server_push, false);
}
#[test]
fn test_http2_config_custom() {
let config = Http2Config {
initial_window_size: 131072, max_concurrent_streams: 200,
..Default::default()
};
assert_eq!(config.initial_window_size, 131072);
assert_eq!(config.max_concurrent_streams, 200);
}
}