use super::close_codes::CloseReasonCode;
use super::downstream_tls::accept_downstream_tls;
use super::event_emitters::{
emit_stream_closed, emit_tls_event, emit_tls_event_with_cache,
emit_tls_event_with_negotiated_alpn,
};
use super::flow_hooks::FlowHooks;
use super::flow_intercept_http1::relay_http1_mitm_loop;
use super::flow_intercept_tls_failure::fail_tls_and_close;
use super::http2_stream_relay::relay_http2_connection;
use super::http2_stream_relay_http1::relay_http2_downstream_to_http1_upstream;
use super::http_body_relay::write_proxy_response;
use super::io_timeouts::write_all_with_idle_timeout;
use super::route_planner_model::{RouteBinding, RouteConnectIntent, UpstreamRequestTargetMode};
use super::route_planner_transport::connect_via_route;
use super::runtime_governor;
use super::tls_diagnostics::TlsDiagnostics;
use super::tls_learning::TlsLearningGuardrails;
use super::tls_profile_mapping::map_upstream_sni_mode;
use super::BufferedConn;
use crate::engine::MitmEngine;
use crate::observe::{EventConsumer, EventType, FlowContext};
use crate::policy::PolicyEngine;
use crate::protocol::{protocol_from_negotiated_alpn, ApplicationProtocol};
use crate::tls::{
resolve_upstream_server_name, MitmCertificateStore,
UpstreamClientAuthMode as TlsUpstreamClientAuthMode, UpstreamTlsConfigCache,
};
use crate::types::ProcessInfo;
use std::io;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
#[allow(clippy::too_many_arguments)]
pub(crate) async fn intercept_http_connection<P, S>(
engine: Arc<MitmEngine<P, S>>,
cert_store: Arc<MitmCertificateStore>,
runtime_governor: Arc<runtime_governor::RuntimeGovernor>,
tls_diagnostics: Arc<TlsDiagnostics>,
tls_learning: Arc<TlsLearningGuardrails>,
flow_hooks: Arc<dyn FlowHooks>,
upstream_tls_cache: Arc<UpstreamTlsConfigCache>,
tunnel_context: FlowContext,
process_info: Option<ProcessInfo>,
route: RouteBinding,
policy_override_state: crate::policy::PolicyOverrideState,
mut downstream: TcpStream,
max_http_head_bytes: usize,
) -> io::Result<()>
where
P: PolicyEngine + Send + Sync + 'static,
S: EventConsumer + Send + Sync + 'static,
{
let http2_enabled_for_flow = engine.config.http2_enabled && !policy_override_state.disable_h2;
let skip_upstream_verify_for_flow = engine.config.upstream_tls_insecure_skip_verify
|| policy_override_state.skip_upstream_verify;
let upstream_tcp = match connect_via_route(&route, RouteConnectIntent::TargetTunnel).await {
Ok(stream) => stream,
Err(error) => {
let detail = format!(
"upstream_connect_failed[{}]: {error}",
route.route_mode_label()
);
write_proxy_response(&mut downstream, "502 Bad Gateway", &detail).await?;
emit_stream_closed(
&engine,
tunnel_context,
CloseReasonCode::UpstreamConnectFailed,
Some(error.to_string()),
None,
None,
);
return Ok(());
}
};
write_all_with_idle_timeout(
&mut downstream,
b"HTTP/1.1 200 Connection Established\r\n\r\n",
"mitm_connect_established_write",
)
.await?;
let handshake_context = FlowContext {
protocol: ApplicationProtocol::Http1,
..tunnel_context.clone()
};
let upstream_sni_mode = map_upstream_sni_mode(engine.config.upstream_sni_mode);
let server_name =
match resolve_upstream_server_name(&handshake_context.server_host, upstream_sni_mode) {
Ok(value) => value,
Err(error) => {
let detail = format!("invalid server name for upstream TLS: {error}");
return fail_tls_and_close(
&engine,
&flow_hooks,
&tls_diagnostics,
&tls_learning,
handshake_context.clone(),
tunnel_context,
"upstream",
detail,
)
.await;
}
};
let issued_server_config = match cert_store
.server_config_for_host_with_http2(&handshake_context.server_host, http2_enabled_for_flow)
{
Ok(config) => config,
Err(error) => {
return fail_tls_and_close(
&engine,
&flow_hooks,
&tls_diagnostics,
&tls_learning,
handshake_context.clone(),
tunnel_context,
"downstream",
format!("downstream leaf issuance error: {error}"),
)
.await;
}
};
emit_tls_event_with_cache(
&engine,
EventType::TlsHandshakeStarted,
handshake_context.clone(),
"downstream",
issued_server_config.cache_status.as_str(),
);
let downstream_tls = match accept_downstream_tls(
engine.config.downstream_tls_backend,
downstream,
&issued_server_config,
http2_enabled_for_flow,
)
.await
{
Ok(stream) => stream,
Err(error) => {
return fail_tls_and_close(
&engine,
&flow_hooks,
&tls_diagnostics,
&tls_learning,
handshake_context.clone(),
tunnel_context,
"downstream",
format!("downstream handshake failed: {error}"),
)
.await;
}
};
let downstream_alpn = downstream_tls.negotiated_alpn();
let downstream_protocol =
protocol_from_negotiated_alpn(downstream_alpn.as_deref(), http2_enabled_for_flow);
let downstream_context = FlowContext {
protocol: downstream_protocol,
..tunnel_context.clone()
};
emit_tls_event_with_negotiated_alpn(
&engine,
EventType::TlsHandshakeSucceeded,
downstream_context.clone(),
"downstream",
downstream_alpn.as_deref(),
);
let should_offer_http2_upstream =
http2_enabled_for_flow && downstream_protocol == ApplicationProtocol::Http2;
let client_config = match upstream_tls_cache.get_or_build(
skip_upstream_verify_for_flow,
should_offer_http2_upstream,
&handshake_context.server_host,
) {
Ok(value) => value,
Err(error) => {
let detail = format!("upstream TLS config build failed: {error}");
return fail_tls_and_close(
&engine,
&flow_hooks,
&tls_diagnostics,
&tls_learning,
downstream_context.clone(),
tunnel_context,
"upstream",
detail,
)
.await;
}
};
let connector = TlsConnector::from(client_config);
let upstream_start_context = FlowContext {
protocol: if should_offer_http2_upstream {
ApplicationProtocol::Http2
} else {
ApplicationProtocol::Http1
},
..tunnel_context.clone()
};
emit_tls_event(
&engine,
EventType::TlsHandshakeStarted,
upstream_start_context,
"upstream",
);
let upstream_tls = match connector.connect(server_name.clone(), upstream_tcp).await {
Ok(stream) => stream,
Err(error) => {
return fail_tls_and_close(
&engine,
&flow_hooks,
&tls_diagnostics,
&tls_learning,
downstream_context.clone(),
tunnel_context,
"upstream",
format!("upstream handshake failed: {error}"),
)
.await;
}
};
let upstream_alpn = upstream_tls
.get_ref()
.1
.alpn_protocol()
.map(ToOwned::to_owned);
let upstream_protocol =
protocol_from_negotiated_alpn(upstream_alpn.as_deref(), should_offer_http2_upstream);
let upstream_context = FlowContext {
protocol: upstream_protocol,
..tunnel_context.clone()
};
emit_tls_event_with_negotiated_alpn(
&engine,
EventType::TlsHandshakeSucceeded,
upstream_context.clone(),
"upstream",
upstream_alpn.as_deref(),
);
if downstream_protocol == ApplicationProtocol::Http2 {
let max_header_list_size = engine.config.http2_max_header_list_size;
return if upstream_protocol == ApplicationProtocol::Http2 {
relay_http2_connection(
engine,
Arc::clone(&runtime_governor),
flow_hooks,
tunnel_context,
process_info,
downstream_tls,
upstream_tls,
max_header_list_size,
)
.await
} else {
relay_http2_downstream_to_http1_upstream(
engine,
Arc::clone(&runtime_governor),
flow_hooks,
tunnel_context,
process_info,
downstream_tls,
upstream_tls,
route,
connector,
server_name,
max_http_head_bytes,
max_header_list_size,
policy_override_state.strict_header_mode,
)
.await
};
}
let downstream_conn = BufferedConn::new(downstream_tls);
let upstream_conn = BufferedConn::new(upstream_tls);
relay_http1_mitm_loop(
engine,
runtime_governor,
flow_hooks,
tunnel_context,
UpstreamRequestTargetMode::OriginForm,
downstream_conn,
upstream_conn,
max_http_head_bytes,
policy_override_state.strict_header_mode,
)
.await
}
pub(crate) fn load_upstream_client_auth_pem(
cert_path: &Option<String>,
key_path: &Option<String>,
auth_mode: TlsUpstreamClientAuthMode,
) -> Result<Option<(Vec<u8>, Vec<u8>)>, String> {
let (Some(cert_path), Some(key_path)) = (cert_path.as_ref(), key_path.as_ref()) else {
return Ok(None);
};
let cert_pem = std::fs::read(cert_path)
.map_err(|error| format!("read cert path {cert_path} failed: {error}"));
let key_pem = std::fs::read(key_path)
.map_err(|error| format!("read key path {key_path} failed: {error}"));
match (cert_pem, key_pem) {
(Ok(c), Ok(k)) => Ok(Some((c, k))),
(Err(e), _) | (_, Err(e))
if auth_mode == TlsUpstreamClientAuthMode::IfRequested
|| auth_mode == TlsUpstreamClientAuthMode::Never =>
{
tracing::debug!("upstream client auth PEM not loaded (non-required): {e}");
Ok(None)
}
(Err(e), _) | (_, Err(e)) => Err(e),
}
}