use crate::error::ProxyError;
use crate::upstream::UpstreamStream;
use base64::Engine as _;
use rcgen::{CertificateParams, KeyPair};
use rustls::ClientConfig;
use rustls::RootCertStore;
use rustls::ServerConfig;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName};
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector};
pub struct CaConfig {
pub cert_der: Vec<u8>,
key: KeyPair,
signing_cert: rcgen::Certificate,
}
impl CaConfig {
pub fn from_pem(cert_pem: &str, key_pem: &str) -> Result<Self, ProxyError> {
let cert_der = pem_to_der(cert_pem, "CERTIFICATE")?;
let key = KeyPair::from_pem(key_pem)
.map_err(|e| ProxyError::Mitm(format!("CA key load: {e}")))?;
let params = CertificateParams::from_ca_cert_pem(cert_pem)
.map_err(|e| ProxyError::Mitm(format!("CA cert parse: {e}")))?;
let signing_cert = params
.self_signed(&key)
.map_err(|e| ProxyError::Mitm(format!("CA signing cert init: {e}")))?;
Ok(Self {
cert_der,
key,
signing_cert,
})
}
}
fn pem_to_der(pem: &str, label: &str) -> Result<Vec<u8>, ProxyError> {
let header = format!("-----BEGIN {label}-----");
let footer = format!("-----END {label}-----");
let b64: String = pem
.lines()
.skip_while(|l| !l.trim().starts_with(&header))
.skip(1)
.take_while(|l| !l.trim().starts_with(&footer))
.collect();
base64::engine::general_purpose::STANDARD
.decode(b64.trim())
.map_err(|e| ProxyError::Mitm(format!("PEM base64 decode ({label}): {e}")))
}
fn server_config_for(host: &str, ca: &CaConfig) -> Result<Arc<ServerConfig>, ProxyError> {
let domain_key =
KeyPair::generate().map_err(|e| ProxyError::Mitm(format!("domain key gen: {e}")))?;
let params = CertificateParams::new(vec![host.to_string()])
.map_err(|e| ProxyError::Mitm(format!("domain params: {e}")))?;
let domain_cert = params
.signed_by(&domain_key, &ca.signing_cert, &ca.key)
.map_err(|e| ProxyError::Mitm(format!("domain cert sign: {e}")))?;
let chain = vec![
CertificateDer::from(domain_cert.der().to_vec()),
CertificateDer::from(ca.cert_der.clone()),
];
let key_der = PrivatePkcs8KeyDer::from(domain_key.serialize_der());
let cfg = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(chain, key_der.into())
.map_err(|e| ProxyError::Mitm(format!("ServerConfig: {e}")))?;
Ok(Arc::new(cfg))
}
fn client_config() -> Result<Arc<ClientConfig>, ProxyError> {
let mut roots = RootCertStore::empty();
let cert_result = rustls_native_certs::load_native_certs();
for error in cert_result.errors {
log::warn!("proxy mitm: failed to load one native root certificate: {error}");
}
for cert in cert_result.certs {
roots
.add(cert)
.map_err(|e| ProxyError::Mitm(format!("native root load: {e}")))?;
}
if roots.is_empty() {
return Err(ProxyError::Mitm(
"no native root certificates available for upstream TLS validation".to_string(),
));
}
Ok(Arc::new(
ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth(),
))
}
pub type ClientSide = tokio_rustls::server::TlsStream<TcpStream>;
pub type ServerSide = tokio_rustls::client::TlsStream<UpstreamStream>;
pub async fn intercept(
client: TcpStream,
host: &str,
upstream: UpstreamStream,
ca: &CaConfig,
) -> Result<(ClientSide, ServerSide), ProxyError> {
let acceptor = TlsAcceptor::from(server_config_for(host, ca)?);
let connector = TlsConnector::from(client_config()?);
let server_name: ServerName<'static> = host
.to_owned()
.try_into()
.map_err(|_| ProxyError::Mitm(format!("invalid server name: {host}")))?;
let (client_tls, server_tls) = tokio::try_join!(
acceptor.accept(client),
connector.connect(server_name, upstream),
)
.map_err(|e| ProxyError::Mitm(e.to_string()))?;
Ok((client_tls, server_tls))
}