mod common;
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use common::*;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, StatusCode};
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};
use palisade::{Config, TlsConfig, UpstreamConfig, handle_request};
use tokio::net::TcpListener;
#[tokio::test]
async fn tls_origination_forwards_to_https_upstream() {
init_tracing();
let (cert_pem, key_pem) = generate_test_cert();
let (addr, _shutdown) = start_tls_backend(
&cert_pem,
&key_pem,
StatusCode::OK,
"text/plain",
"tls-hello",
)
.await;
let config = Arc::new(
Config {
upstreams: vec![UpstreamConfig {
address: format!("https://localhost:{}", addr.port()),
weight: 1,
}],
..Default::default()
}
.into_runtime()
.expect("test config"),
);
let client = test_https_client(&cert_pem);
let req = Request::builder()
.method(Method::GET)
.uri(format!("https://localhost:{}/", addr.port()))
.body(http_body_util::Empty::<Bytes>::new())
.unwrap();
let balancer = test_balancer(&config);
let resp = handle_request(req, client, config, balancer, test_addr(), None)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = collect_body(resp.into_body()).await;
assert_eq!(body, Bytes::from("tls-hello"));
}
#[tokio::test]
async fn tls_termination_acceptor_loads_valid_certs() {
let (cert_pem, key_pem) = generate_test_cert();
let cert_path = write_temp_file("cert", &cert_pem);
let key_path = write_temp_file("key", &key_pem);
let tls_config = TlsConfig {
cert_path: cert_path.to_str().unwrap().into(),
key_path: key_path.to_str().unwrap().into(),
};
let result = palisade::tls::build_tls_acceptor(&tls_config);
assert!(result.is_ok(), "should build TLS acceptor from valid certs");
std::fs::remove_file(&cert_path).ok();
std::fs::remove_file(&key_path).ok();
}
#[tokio::test]
async fn tls_termination_rejects_missing_cert_file() {
let tls_config = TlsConfig {
cert_path: "/nonexistent/cert.pem".into(),
key_path: "/nonexistent/key.pem".into(),
};
let result = palisade::tls::build_tls_acceptor(&tls_config);
assert!(result.is_err(), "should fail with missing cert file");
}
#[tokio::test]
async fn tls_termination_serves_https_connection() {
init_tracing();
let (cert_pem, key_pem) = generate_test_cert();
let cert_path = write_temp_file("e2e-cert", &cert_pem);
let key_path = write_temp_file("e2e-key", &key_pem);
let tls_config = TlsConfig {
cert_path: cert_path.to_str().unwrap().into(),
key_path: key_path.to_str().unwrap().into(),
};
let tls_acceptor = palisade::tls::build_tls_acceptor(&tls_config).unwrap();
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap();
let proxy_addr = listener.local_addr().unwrap();
let (backend_addr, _backend_shutdown) =
start_backend(StatusCode::OK, "text/plain", "tls-termination-ok").await;
let config = Arc::new(
Config {
upstreams: vec![UpstreamConfig {
address: format!("http://{backend_addr}"),
weight: 1,
}],
..Default::default()
}
.into_runtime()
.expect("test config"),
);
let client = test_client();
let balancer = test_balancer(&config);
tokio::spawn(async move {
let (stream, client_addr) = listener.accept().await.unwrap();
let tls_stream = tls_acceptor.accept(stream).await.unwrap();
let config = Arc::clone(&config);
let service = service_fn(move |req: Request<Incoming>| {
let client = client.clone();
let config = Arc::clone(&config);
let balancer = balancer.clone();
async move {
let resp = handle_request(req, client, config, balancer, client_addr, None)
.await
.unwrap_or_else(|e| {
e.into_response().map(|b| {
b.map_err(|never| -> Box<dyn std::error::Error + Send + Sync> {
match never {}
})
.boxed()
})
});
Ok::<_, std::convert::Infallible>(resp)
}
});
let _ = http1::Builder::new()
.serve_connection(TokioIo::new(tls_stream), service)
.await;
});
use rustls::pki_types::CertificateDer;
use rustls::pki_types::pem::PemObject;
let cert_der: Vec<CertificateDer<'static>> =
CertificateDer::pem_slice_iter(cert_pem.as_bytes())
.collect::<std::result::Result<Vec<_>, _>>()
.unwrap();
let mut root_store = rustls::RootCertStore::empty();
for cert in &cert_der {
root_store.add(cert.clone()).unwrap();
}
let client_tls = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(client_tls)
.https_or_http()
.enable_http1()
.build();
let https_client: hyper_util::client::legacy::Client<_, http_body_util::Empty<Bytes>> =
Client::builder(TokioExecutor::new()).build(connector);
let resp = https_client
.get(
format!("https://localhost:{}/", proxy_addr.port())
.parse()
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.collect().await.unwrap().to_bytes();
assert_eq!(body, Bytes::from("tls-termination-ok"));
std::fs::remove_file(&cert_path).ok();
std::fs::remove_file(&key_path).ok();
}