#![allow(dead_code, unused_results, clippy::unwrap_used, clippy::expect_used)]
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use http_body_util::Full;
use hyper::body::Bytes;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
pub struct TestServer {
addr: SocketAddr,
shutdown: Option<oneshot::Sender<()>>,
}
impl TestServer {
pub async fn start<F>(handler: F) -> Self
where
F: Fn(Request<hyper::body::Incoming>) -> Response<Full<Bytes>>
+ Send
+ Sync
+ 'static
+ Clone,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handler = Arc::new(handler);
let (tx, mut rx) = oneshot::channel();
tokio::spawn(async move {
loop {
tokio::select! {
accept_result = listener.accept() => {
if let Ok((stream, _)) = accept_result {
let handler = handler.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
let _ = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
let handler = handler.clone();
async move {
Ok::<_, Infallible>(handler(req))
}
}),
)
.await;
});
}
}
_ = &mut rx => break,
}
}
});
Self { addr, shutdown: Some(tx) }
}
pub fn url(&self, path: &str) -> String {
format!("http://127.0.0.1:{}{path}", self.addr.port())
}
pub const fn addr(&self) -> SocketAddr {
self.addr
}
}
impl Drop for TestServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
}
}
pub struct TestCerts {
pub ca_cert_pem: Vec<u8>,
pub server_cert_pem: Vec<u8>,
pub server_key_pem: Vec<u8>,
pub client_cert_pem: Vec<u8>,
pub client_key_pem: Vec<u8>,
pub server_config: Arc<rustls::ServerConfig>,
pub mtls_server_config: Arc<rustls::ServerConfig>,
}
impl TestCerts {
pub fn generate() -> Self {
let mut ca_params = rcgen::CertificateParams::new(Vec::<String>::new()).unwrap();
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
ca_params.distinguished_name.push(rcgen::DnType::CommonName, "Test CA");
let ca_key = rcgen::KeyPair::generate().unwrap();
let ca_cert = ca_params.self_signed(&ca_key).unwrap();
let mut server_params =
rcgen::CertificateParams::new(vec!["localhost".to_string(), "127.0.0.1".to_string()])
.unwrap();
server_params.distinguished_name.push(rcgen::DnType::CommonName, "localhost");
let server_key = rcgen::KeyPair::generate().unwrap();
let server_cert = server_params.signed_by(&server_key, &ca_cert, &ca_key).unwrap();
let mut client_params = rcgen::CertificateParams::new(Vec::<String>::new()).unwrap();
client_params.distinguished_name.push(rcgen::DnType::CommonName, "Test Client");
let client_key = rcgen::KeyPair::generate().unwrap();
let client_cert = client_params.signed_by(&client_key, &ca_cert, &ca_key).unwrap();
let ca_cert_pem = ca_cert.pem().into_bytes();
let server_cert_pem = server_cert.pem().into_bytes();
let server_key_pem = server_key.serialize_pem().into_bytes();
let client_cert_pem = client_cert.pem().into_bytes();
let client_key_pem = client_key.serialize_pem().into_bytes();
let server_cert_der = rustls::pki_types::CertificateDer::from(server_cert.der().to_vec());
let server_key_der =
rustls::pki_types::PrivateKeyDer::try_from(server_key.serialize_der()).unwrap();
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![server_cert_der.clone()], server_key_der.clone_key())
.unwrap();
let mut ca_root_store = rustls::RootCertStore::empty();
let ca_cert_der = rustls::pki_types::CertificateDer::from(ca_cert.der().to_vec());
ca_root_store.add(ca_cert_der).unwrap();
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(ca_root_store)).build().unwrap();
let mtls_server_config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(vec![server_cert_der], server_key_der.clone_key())
.unwrap();
Self {
ca_cert_pem,
server_cert_pem,
server_key_pem,
client_cert_pem,
client_key_pem,
server_config: Arc::new(server_config),
mtls_server_config: Arc::new(mtls_server_config),
}
}
pub fn write_ca_cert(&self) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
std::io::Write::write_all(&mut f, &self.ca_cert_pem).unwrap();
f
}
pub fn write_client_cert(&self) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
std::io::Write::write_all(&mut f, &self.client_cert_pem).unwrap();
f
}
pub fn write_client_key(&self) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
std::io::Write::write_all(&mut f, &self.client_key_pem).unwrap();
f
}
}
pub struct HttpsTestServer {
addr: SocketAddr,
shutdown: Option<oneshot::Sender<()>>,
}
impl HttpsTestServer {
pub async fn start<F>(server_config: Arc<rustls::ServerConfig>, handler: F) -> Self
where
F: Fn(Request<hyper::body::Incoming>) -> Response<Full<Bytes>>
+ Send
+ Sync
+ 'static
+ Clone,
{
let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handler = Arc::new(handler);
let (tx, mut rx) = oneshot::channel();
tokio::spawn(async move {
loop {
tokio::select! {
accept_result = listener.accept() => {
if let Ok((stream, _)) = accept_result {
let handler = handler.clone();
let acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let Ok(tls_stream) = acceptor.accept(stream).await else {
return;
};
let io = TokioIo::new(tls_stream);
let _ = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
let handler = handler.clone();
async move {
Ok::<_, Infallible>(handler(req))
}
}),
)
.await;
});
}
}
_ = &mut rx => break,
}
}
});
Self { addr, shutdown: Some(tx) }
}
pub fn url(&self, path: &str) -> String {
format!("https://127.0.0.1:{}{path}", self.addr.port())
}
pub const fn addr(&self) -> SocketAddr {
self.addr
}
}
impl Drop for HttpsTestServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
}
}