use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::{Request, Response};
use hyper_util::rt::{TokioExecutor, TokioTimer};
use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder;
use hyper_util::service::TowerToHyperService;
use rustls::ServerConfig;
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use tower::{Layer, ServiceBuilder};
use tracing::{debug, info, trace, Level};
use postel::{load_certs, load_private_key, serve_http_with_shutdown};
lazy_static::lazy_static! {
static ref HELLO: Bytes = Bytes::from("Hello, World!");
}
async fn hello(_: Request<Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::new(Full::new(HELLO.clone()))) }
struct AddHeaderLayer;
impl<S> Layer<S> for AddHeaderLayer {
type Service = AddHeaderService<S>;
fn layer(&self, service: S) -> Self::Service {
AddHeaderService { inner: service }
}
}
#[derive(Clone)]
struct AddHeaderService<S> {
inner: S,
}
impl<S, B> tower::Service<Request<B>> for AddHeaderService<S>
where
S: tower::Service<Request<B>, Response = Response<Full<Bytes>>>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
trace!("Adding custom header to response");
let future = self.inner.call(req);
Box::pin(async move {
let mut resp = future.await?;
resp.headers_mut()
.insert("X-Custom-Header", "Hello from middleware!".parse().unwrap());
Ok(resp)
})
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
tracing_subscriber::fmt().with_max_level(Level::INFO).init();
let addr = SocketAddr::from(([127, 0, 0, 1], 8443));
let listener = TcpListener::bind(addr).await?;
info!("Listening on https://{}", addr);
let incoming = TcpListenerStream::new(listener);
let mut builder = HttpConnectionBuilder::new(TokioExecutor::new());
builder
.http1()
.half_close(true)
.keep_alive(true)
.max_buf_size(1024 * 1024)
.pipeline_flush(true)
.preserve_header_case(true)
.title_case_headers(false)
.http2()
.timer(TokioTimer::new())
.initial_stream_window_size(Some(4 * 1024 * 1024))
.initial_connection_window_size(Some(8 * 1024 * 1024))
.adaptive_window(true)
.max_frame_size(Some(1024 * 1024))
.max_concurrent_streams(Some(1024))
.max_send_buf_size(4 * 1024 * 1024)
.enable_connect_protocol()
.max_header_list_size(64 * 1024)
.keep_alive_interval(Some(Duration::from_secs(30)))
.keep_alive_timeout(Duration::from_secs(60));
let svc = tower::service_fn(hello);
let svc = ServiceBuilder::new()
.layer(AddHeaderLayer) .service(svc);
let svc = TowerToHyperService::new(svc);
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
let certs = load_certs("examples/sample.pem")?;
let key = load_private_key("examples/sample.rsa")?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
let tls_config = Arc::new(config);
let (shutdown_tx, _shutdown_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(1)).await;
let _ = shutdown_tx.send(());
debug!("Shutdown signal sent");
});
let (shutdown_tx, _shutdown_rx) = tokio::sync::oneshot::channel();
let server = tokio::spawn(async move {
info!("Starting HTTPS server...");
serve_http_with_shutdown(
svc,
incoming,
builder,
Some(tls_config),
Some(async {
_shutdown_rx.await.ok();
info!("Shutdown signal received");
}),
)
.await
.expect("Server failed unexpectedly");
});
tokio::signal::ctrl_c().await?;
info!("Initiating graceful shutdown");
let _ = shutdown_tx.send(());
server.await?;
info!("Server has shut down");
Ok(())
}