use std::sync::Arc;
use bytes::Bytes;
use http_body_util::{BodyExt, Empty, Full};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use rcgen::{CertificateParams, KeyPair};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::TcpStream;
use crate::rewrite::{
bad_gateway, bad_request, box_response, build_request, collect_and_rewrite_body,
http1_handshake, replace_tokens_in_bytes, rewrite_headers, try_or_bad_gateway, BoxBodyType,
CONNECT_TIMEOUT,
};
pub struct MitmConfig {
root_issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
pub(crate) optional_extra_ca: Option<std::path::PathBuf>,
}
impl MitmConfig {
pub fn new(
optional_extra_ca: Option<std::path::PathBuf>,
) -> Result<(Self, String), Box<dyn std::error::Error + Send + Sync>> {
let _ = rustls::crypto::ring::default_provider().install_default();
let mut params = CertificateParams::default();
params.distinguished_name = rcgen::DistinguishedName::new();
params.distinguished_name.push(
rcgen::DnType::CommonName,
rcgen::DnValue::Utf8String("keyzero-mitm-ca".to_string()),
);
params.key_usages = vec![
rcgen::KeyUsagePurpose::KeyCertSign,
rcgen::KeyUsagePurpose::CrlSign,
];
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let signing_key = KeyPair::generate()?;
let cert = params.self_signed(&signing_key)?;
let ca_pem = cert.pem();
let root_issuer = rcgen::Issuer::new(params, signing_key);
Ok((
Self {
root_issuer,
optional_extra_ca,
},
ca_pem,
))
}
fn server_config(
&self,
host: &str,
) -> Result<rustls::ServerConfig, Box<dyn std::error::Error + Send + Sync>> {
let mut params = CertificateParams::new(vec![host.to_string()])?;
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, host);
let key_pair = KeyPair::generate()?;
let cert = params.signed_by(&key_pair, &self.root_issuer)?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::Pkcs8(rustls::pki_types::PrivatePkcs8KeyDer::from(
key_pair.serialize_der(),
));
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)?;
Ok(config)
}
}
pub async fn handle_connect_mitm(
req: hyper::Request<hyper::body::Incoming>,
token_map: Arc<Vec<(String, String)>>,
mitm_config: Arc<MitmConfig>,
) -> Result<Response<BoxBodyType>, hyper::Error> {
let authority = match req.uri().authority() {
Some(a) => a.clone(),
None => {
return Ok(bad_request(
"CONNECT must include authority (host:port)",
));
}
};
let authority_str = authority.to_string();
let host = authority.host().to_string();
tokio::spawn(async move {
let upgraded = match hyper::upgrade::on(req).await {
Ok(u) => u,
Err(e) => {
tracing::error!("CONNECT MITM upgrade error: {}", e);
return;
}
};
let server_config = match mitm_config.server_config(&host) {
Ok(c) => c,
Err(e) => {
tracing::error!("CONNECT MITM cert gen for {}: {}", host, e);
return;
}
};
let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config));
let tls_stream = match acceptor.accept(TokioIo::new(upgraded)).await {
Ok(s) => s,
Err(e) => {
tracing::error!("CONNECT MITM TLS accept for {}: {}", host, e);
return;
}
};
let token_map = Arc::clone(&token_map);
let authority_str = authority_str.clone();
let mitm_config = Arc::clone(&mitm_config);
let service = service_fn(move |req: Request<hyper::body::Incoming>| {
let token_map = Arc::clone(&token_map);
let authority = authority_str.clone();
let mitm_config = Arc::clone(&mitm_config);
async move {
mitm_forward(
req,
token_map,
&authority,
&mitm_config,
)
.await
}
});
if let Err(e) = http1::Builder::new()
.serve_connection(TokioIo::new(tls_stream), service)
.await
{
tracing::error!("CONNECT MITM serve for {}: {}", host, e);
}
});
Ok(Response::new(
Empty::<Bytes>::new()
.map_err(|never: std::convert::Infallible| match never {})
.boxed(),
))
}
async fn mitm_forward(
req: Request<hyper::body::Incoming>,
token_map: Arc<Vec<(String, String)>>,
authority: &str,
config: &MitmConfig,
) -> Result<Response<BoxBodyType>, hyper::Error> {
let (parts, body) = req.into_parts();
let body_bytes = body.collect().await?.to_bytes();
let modified_body = match collect_and_rewrite_body(&body_bytes, &token_map) {
Ok(b) => b,
Err(resp) => return Ok(resp),
};
let mut new_headers = rewrite_headers(&parts.headers, &token_map);
new_headers.insert(
http::header::CONTENT_LENGTH,
http::HeaderValue::from_str(&modified_body.len().to_string()).unwrap(),
);
let path = parts.uri.path();
let query = parts
.uri
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default();
let uri_str = format!("https://{}{}{}", authority, path, query);
let modified_uri_bytes = replace_tokens_in_bytes(uri_str.as_bytes(), &token_map);
let modified_uri_str = match String::from_utf8(modified_uri_bytes) {
Ok(s) => s,
Err(_) => uri_str.clone(),
};
let uri = match modified_uri_str
.parse::<http::Uri>()
.or_else(|_| uri_str.parse::<http::Uri>())
{
Ok(u) => u,
Err(_) => return Ok(bad_gateway("bad upstream URI")),
};
let new_req = match build_request(
parts.method,
&uri,
new_headers,
Full::new(Bytes::from(modified_body)),
) {
Ok(r) => r,
Err(resp) => return Ok(resp),
};
let (host, port) = authority
.split_once(':')
.map(|(h, p)| (h, p.parse().unwrap_or(443)))
.unwrap_or((authority, 443));
let tcp = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect((host, port))).await {
Ok(r) => try_or_bad_gateway!(r, "upstream connection failed"),
Err(_) => return Ok(bad_gateway("upstream connect timeout")),
};
let mut root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(path) = &config.optional_extra_ca {
let pem = std::fs::read(path).map_err(|e| {
tracing::error!("MITM extra CA read error: {}", e);
});
if let Ok(pem) = pem {
for cert in rustls_pemfile::certs(&mut std::io::Cursor::new(pem)).flatten() {
let _ = root_store.add(cert);
}
}
}
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
let domain = match host.to_string().try_into() {
Ok(d) => d,
Err(_) => return Ok(bad_gateway("invalid upstream host")),
};
let tls = try_or_bad_gateway!(connector.connect(domain, tcp).await, "upstream TLS failed");
let io = TokioIo::new(tls);
let mut sender = match http1_handshake(io).await {
Ok(s) => s,
Err(resp) => return Ok(resp),
};
let resp = try_or_bad_gateway!(sender.send_request(new_req).await, "upstream request failed");
Ok(box_response(resp))
}