use std::net::SocketAddr;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::certificate::{Certificate, CertificateVerifier};
use crate::config::{LookupFileFn, LookupHashDirFn, SslConfig, TlsConfigBuilder};
use crate::stream::{CloneableStream, TlsStream};
use crate::Result;
use futures_util::{Future, TryFuture};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto;
use hyper_util::server::graceful::GracefulShutdown;
use hyper_util::service::TowerToHyperService;
use tokio::net::TcpListener;
use warp::{Filter, Reply};
pub fn serve<F>(filter: F) -> OpensslServer<F> {
OpensslServer {
filter,
tls: TlsConfigBuilder::new(),
}
}
#[derive(Debug, Clone)]
pub enum TlsLevel {
MozillaModern,
MozillaModernV5,
MozillaIntermediate,
MozillaIntermediateV5,
}
#[derive(Debug)]
pub struct OpensslServer<F> {
filter: F,
tls: TlsConfigBuilder,
}
impl<F> OpensslServer<F>
where
F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: Reply,
{
pub fn key(self, key: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.key(key.as_ref()))
}
pub fn tls_level(self, tls_level: TlsLevel) -> Self {
self.with_tls(|tls| tls.tls_level(tls_level))
}
pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.cert(cert.as_ref()))
}
pub fn add_file_lookup(self, lookup: LookupFileFn) -> Self {
self.with_tls(|tls| tls.add_file_lookup(lookup))
}
pub fn add_hash_dir_lookup(self, lookup: LookupHashDirFn) -> Self {
self.with_tls(|tls| tls.add_hash_dir_lookup(lookup))
}
pub fn client_auth_optional(
self,
trust_anchor: impl AsRef<[u8]>,
certificate_verifier: Arc<dyn CertificateVerifier>,
) -> Self {
self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref(), certificate_verifier))
}
pub fn client_auth_required(
self,
trust_anchor: impl AsRef<[u8]>,
certificate_verifier: Arc<dyn CertificateVerifier>,
) -> Self {
self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref(), certificate_verifier))
}
pub fn disable_partial_chain_verification(self) -> Self {
self.with_tls(|tls| tls.disable_partial_chain_verification())
}
fn with_tls<Func>(self, func: Func) -> Self
where
Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
{
let OpensslServer { filter, tls } = self;
let tls = func(tls);
OpensslServer { filter, tls }
}
fn build_server(
self,
addr: impl Into<SocketAddr>,
) -> Result<(SocketAddr, TcpListener, SslConfig, F)> {
let ssl_config = self.tls.build()?;
let addr = addr.into();
let std_listener = std::net::TcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;
let listener = TcpListener::from_std(std_listener)?;
let local_addr = listener.local_addr()?;
Ok((local_addr, listener, ssl_config, self.filter))
}
pub fn bind(
self,
addr: impl Into<SocketAddr>,
) -> Result<(SocketAddr, impl Future<Output = ()> + 'static)> {
let (addr, listener, ssl_config, filter) = self.build_server(addr)?;
let ssl_config = Arc::new(ssl_config);
let srv = async move {
let builder = auto::Builder::new(TokioExecutor::new());
loop {
let (tcp_stream, remote_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!("accept error: {}", e);
continue;
}
};
if let Err(e) = tcp_stream.set_nodelay(true) {
tracing::warn!("set_nodelay failed for {}: {}", remote_addr, e);
}
let ssl_config = ssl_config.clone();
let filter = filter.clone();
let builder = builder.clone();
tokio::spawn(async move {
if let Err(e) =
serve_connection(tcp_stream, &ssl_config, filter, &builder).await
{
tracing::error!("connection error: {}", e);
}
});
}
};
Ok((addr, srv))
}
pub fn bind_with_graceful_shutdown(
self,
addr: impl Into<SocketAddr>,
signal: impl Future<Output = ()> + Send + 'static,
) -> Result<(SocketAddr, impl Future<Output = ()> + 'static)> {
let (addr, listener, ssl_config, filter) = self.build_server(addr)?;
let ssl_config = Arc::new(ssl_config);
let srv = async move {
let builder = auto::Builder::new(TokioExecutor::new());
let graceful = GracefulShutdown::new();
let mut signal = std::pin::pin!(signal);
loop {
tokio::select! {
result = listener.accept() => {
let (tcp_stream, remote_addr) = match result {
Ok(conn) => conn,
Err(e) => {
tracing::error!("accept error: {}", e);
continue;
}
};
if let Err(e) = tcp_stream.set_nodelay(true) {
tracing::warn!("set_nodelay failed for {}: {}", remote_addr, e);
}
let ssl_config = ssl_config.clone();
let filter = filter.clone();
let builder = builder.clone();
let watcher = graceful.watcher();
tokio::spawn(async move {
let tls_stream = match TlsStream::new(tcp_stream, &ssl_config) {
Ok(s) => s,
Err(e) => {
tracing::error!("TLS stream creation error: {}", e);
return;
}
};
let stream_ref = tls_stream.stream();
let svc = CertInjectorService {
inner: warp::service(filter),
stream: stream_ref,
};
let conn = builder.serve_connection(
TokioIo::new(tls_stream),
TowerToHyperService::new(svc),
);
let conn = watcher.watch(conn.into_owned());
if let Err(e) = conn.await {
tracing::error!("connection error: {}", e);
}
});
}
_ = &mut signal => {
break;
}
}
}
graceful.shutdown().await;
};
Ok((addr, srv))
}
}
async fn serve_connection<F>(
tcp_stream: tokio::net::TcpStream,
ssl_config: &SslConfig,
filter: F,
builder: &auto::Builder<TokioExecutor>,
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: Reply,
{
let tls_stream = TlsStream::new(tcp_stream, ssl_config)?;
let stream_ref = tls_stream.stream();
let svc = CertInjectorService {
inner: warp::service(filter),
stream: stream_ref,
};
builder
.serve_connection(TokioIo::new(tls_stream), TowerToHyperService::new(svc))
.await?;
Ok(())
}
#[derive(Clone)]
struct CertInjectorService<S> {
inner: S,
stream: CloneableStream,
}
impl<S, B> tower_service::Service<http::Request<B>> for CertInjectorService<S>
where
S: tower_service::Service<http::Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
let certificate: Option<Certificate> = self
.stream
.lock()
.ok()
.and_then(|stream| stream.ssl().peer_certificate())
.and_then(|peer_certificate| peer_certificate.try_into().ok());
if let Some(certificate) = certificate {
req.extensions_mut().insert(certificate);
}
self.inner.call(req)
}
}
impl<S> std::fmt::Debug for CertInjectorService<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CertInjectorService").finish()
}
}