use crate::error::{ProxyError, Result};
use crate::tls_intercept::cert_cache::CertCache;
use rustls::server::ServerConfig;
use std::sync::Arc;
pub fn build_server_config(cert_cache: Arc<CertCache>) -> Result<Arc<ServerConfig>> {
let mut config =
ServerConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
.with_safe_default_protocol_versions()
.map_err(|e| ProxyError::Config(format!("tls_intercept TLS config error: {}", e)))?
.with_no_client_auth()
.with_cert_resolver(cert_cache);
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::tls_intercept::ca::EphemeralCa;
use rustls::pki_types::{CertificateDer, ServerName};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
#[test]
fn alpn_is_h1_only() {
let ca = Arc::new(EphemeralCa::generate().unwrap());
let cache = Arc::new(CertCache::new(ca));
let config = build_server_config(cache).unwrap();
assert_eq!(config.alpn_protocols, vec![b"http/1.1".to_vec()]);
}
fn client_config_trusting(ca_pem: &str) -> Arc<rustls::ClientConfig> {
use rustls::pki_types::pem::PemObject;
let mut roots = rustls::RootCertStore::empty();
let cert = CertificateDer::from_pem_slice(ca_pem.as_bytes()).unwrap();
roots.add(cert).unwrap();
Arc::new(
rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(roots)
.with_no_client_auth(),
)
}
fn client_config_empty_trust() -> Arc<rustls::ClientConfig> {
Arc::new(
rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
)
}
#[tokio::test]
async fn handshake_succeeds_when_client_trusts_ephemeral_ca() {
let ca = Arc::new(EphemeralCa::generate().unwrap());
let cache = Arc::new(CertCache::new(Arc::clone(&ca)));
let server_config = build_server_config(Arc::clone(&cache)).unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
let mut tls = acceptor.accept(stream).await.unwrap();
let mut buf = [0u8; 64];
let n = tls.read(&mut buf).await.unwrap();
tls.write_all(&buf[..n]).await.unwrap();
tls.flush().await.unwrap();
});
let client_config = client_config_trusting(ca.cert_pem());
let connector = tokio_rustls::TlsConnector::from(client_config);
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let server_name = ServerName::try_from("api.example.com").unwrap();
let mut tls = connector.connect(server_name, tcp).await.unwrap();
tls.write_all(b"hello").await.unwrap();
tls.flush().await.unwrap();
let mut buf = [0u8; 5];
tls.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
server_task.await.unwrap();
}
#[tokio::test]
async fn handshake_fails_when_client_pins_other_cert() {
let ca = Arc::new(EphemeralCa::generate().unwrap());
let cache = Arc::new(CertCache::new(Arc::clone(&ca)));
let server_config = build_server_config(Arc::clone(&cache)).unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
assert!(acceptor.accept(stream).await.is_err());
});
let client_config = client_config_empty_trust();
let connector = tokio_rustls::TlsConnector::from(client_config);
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let server_name = ServerName::try_from("api.example.com").unwrap();
assert!(connector.connect(server_name, tcp).await.is_err());
server_task.await.unwrap();
}
}