rustls_session/
rustls_session.rs

1//! Run with `cargo run --all-features --example rustls_session` command.
2//!
3//! To connect through browser, navigate to "https://localhost:3000" url.
4
5use axum::{middleware::AddExtension, routing::get, Extension, Router};
6use futures_util::future::BoxFuture;
7use hyper_server::{
8    accept::Accept,
9    tls_rustls::{RustlsAcceptor, RustlsConfig},
10};
11use std::{io, net::SocketAddr, sync::Arc};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::server::TlsStream;
14use tower::Layer;
15
16#[tokio::main]
17async fn main() {
18    let app = Router::new().route("/", get(handler));
19
20    let config = RustlsConfig::from_pem_file(
21        "examples/self-signed-certs/cert.pem",
22        "examples/self-signed-certs/key.pem",
23    )
24    .await
25    .unwrap();
26
27    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
28
29    println!("listening on {}", addr);
30
31    let acceptor = CustomAcceptor::new(RustlsAcceptor::new(config));
32    let server = hyper_server::bind(addr).acceptor(acceptor);
33
34    server.serve(app.into_make_service()).await.unwrap();
35}
36
37async fn handler(tls_data: Extension<TlsData>) -> String {
38    format!("{:?}", tls_data)
39}
40
41#[derive(Debug, Clone)]
42struct TlsData {
43    _hostname: Option<Arc<str>>,
44}
45
46#[derive(Debug, Clone)]
47struct CustomAcceptor {
48    inner: RustlsAcceptor,
49}
50
51impl CustomAcceptor {
52    fn new(inner: RustlsAcceptor) -> Self {
53        Self { inner }
54    }
55}
56
57impl<I, S> Accept<I, S> for CustomAcceptor
58where
59    I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60    S: Send + 'static,
61{
62    type Stream = TlsStream<I>;
63    type Service = AddExtension<S, TlsData>;
64    type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
65
66    fn accept(&self, stream: I, service: S) -> Self::Future {
67        let acceptor = self.inner.clone();
68
69        Box::pin(async move {
70            let (stream, service) = acceptor.accept(stream, service).await?;
71            let server_conn = stream.get_ref().1;
72            let sni_hostname = TlsData {
73                _hostname: server_conn.server_name().map(From::from),
74            };
75            let service = Extension(sni_hostname).layer(service);
76
77            Ok((stream, service))
78        })
79    }
80}