#![cfg_attr(docsrs, feature(doc_cfg))]
use hyper::server::conn::AddrIncoming;
use hyper::server::{Builder, Server};
#[cfg(feature = "tls-openssl")]
use openssl::pkey::PKey;
#[cfg(feature = "tls-openssl")]
use openssl::ssl::{SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef};
#[cfg(feature = "tls-openssl")]
use openssl::x509::X509;
#[cfg(feature = "tls-rustls")]
use rustls::ServerConfig;
#[cfg(feature = "tls-rustls")]
use rustls_pemfile;
use std::net::SocketAddr;
use std::path::Path;
#[cfg(any(feature = "tls-rustls", feature = "tls-openssl"))]
use tls_listener::hyper::WrappedAccept;
#[cfg(feature = "tls-rustls")]
use tokio_rustls::rustls::{Certificate, PrivateKey};
#[cfg(feature = "tls-rustls")]
use tokio_rustls::TlsAcceptor;
pub use hyper;
#[cfg(feature = "tls-openssl")]
pub use openssl;
#[cfg(feature = "tls-rustls")]
pub use rustls;
#[cfg(any(feature = "tls-rustls", feature = "tls-openssl"))]
pub use tls_listener;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocols {
ALL,
#[cfg(feature = "hyper-h1")]
#[cfg_attr(docsrs, doc(cfg(feature = "hyper-h1")))]
HTTP1,
#[cfg(feature = "hyper-h2")]
#[cfg_attr(docsrs, doc(cfg(feature = "hyper-h2")))]
HTTP2,
}
pub type Error = Box<dyn std::error::Error>;
#[cfg(feature = "tls-rustls")]
fn rustls_server_config_from_readers<R: std::io::Read>(
cert: R,
key: R,
protocols: Protocols,
) -> Result<ServerConfig, Error> {
use std::io::{self, BufReader};
let certs = rustls_pemfile::certs(&mut BufReader::new(cert))
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
let mut keys: Vec<PrivateKey> = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(key))
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())?;
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, keys.remove(0))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
config.alpn_protocols = match protocols {
#[cfg(all(feature = "hyper-h1", feature = "hyper-h2"))]
Protocols::ALL => vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()],
#[cfg(all(feature = "hyper-h1", not(feature = "hyper-h2")))]
Protocols::ALL => vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()],
#[cfg(all(not(feature = "hyper-h1"), feature = "hyper-h2"))]
Protocols::ALL => vec![b"h2".to_vec()],
#[cfg(feature = "hyper-h1")]
Protocols::HTTP1 => vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()],
#[cfg(feature = "hyper-h2")]
Protocols::HTTP2 => vec![b"h2".to_vec()],
};
Ok(config)
}
#[cfg(feature = "tls-rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
pub fn rustls_server_config_from_pem_files<P: AsRef<Path>, Q: AsRef<Path>>(
cert_file: P,
key_file: Q,
protocols: Protocols,
) -> Result<ServerConfig, Error> {
use std::fs::File;
rustls_server_config_from_readers(File::open(cert_file)?, File::open(key_file)?, protocols)
}
#[cfg(feature = "tls-rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
pub fn rustls_server_config_from_pem_data<'a>(
cert: &'a [u8],
key: &'a [u8],
protocols: Protocols,
) -> Result<ServerConfig, Error> {
rustls_server_config_from_readers(cert, key, protocols)
}
#[cfg(feature = "tls-openssl")]
fn ssl_context_set_alpns(
builder: &mut SslContextBuilder,
protocols: Protocols,
) -> Result<(), Error> {
let protos = match protocols {
#[cfg(all(feature = "hyper-h1", feature = "hyper-h2"))]
Protocols::ALL => &b"\x02h2\x08http/1.1\x08http/1.0"[..],
#[cfg(all(feature = "hyper-h1", not(feature = "hyper-h2")))]
Protocols::ALL => &b"\x08http/1.1\x08http/1.0"[..],
#[cfg(all(not(feature = "hyper-h1"), feature = "hyper-h2"))]
Protocols::ALL => &b"\x02h2"[..],
#[cfg(feature = "hyper-h1")]
Protocols::HTTP1 => &b"\x08http/1.1\x08http/1.0"[..],
#[cfg(feature = "hyper-h2")]
Protocols::HTTP2 => &b"\x02h2"[..],
};
builder.set_alpn_protos(protos)?;
builder.set_alpn_select_callback(move |_: &mut SslRef, list: &[u8]| {
openssl::ssl::select_next_proto(protos, list).ok_or(openssl::ssl::AlpnError::NOACK)
});
Ok(())
}
#[cfg(feature = "tls-openssl")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-openssl")))]
pub fn ssl_context_builder_from_pem_files<P: AsRef<Path>, Q: AsRef<Path>>(
cert_file: P,
key_file: Q,
protocols: Protocols,
) -> Result<SslContextBuilder, Error> {
let mut builder = SslContext::builder(SslMethod::tls_server()).unwrap();
builder.set_certificate_chain_file(cert_file)?;
builder.set_private_key_file(key_file, SslFiletype::PEM)?;
ssl_context_set_alpns(&mut builder, protocols)?;
Ok(builder)
}
#[cfg(feature = "tls-openssl")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-openssl")))]
pub fn ssl_context_builder_from_pem_data<'a>(
cert: &'a [u8],
key: &'a [u8],
protocols: Protocols,
) -> Result<SslContextBuilder, Error> {
let mut builder = SslContext::builder(SslMethod::tls_server()).unwrap();
let mut certs = X509::stack_from_pem(cert)?;
let mut certs = certs.drain(..);
builder.set_certificate(certs.next().ok_or("no leaf certificate")?.as_ref())?;
certs.try_for_each(|cert| builder.add_extra_chain_cert(cert))?;
builder.set_private_key(PKey::private_key_from_pem(key)?.as_ref())?;
ssl_context_set_alpns(&mut builder, protocols)?;
Ok(builder)
}
#[cfg(feature = "tls-rustls")]
pub type TlsListener = tls_listener::TlsListener<WrappedAccept<AddrIncoming>, TlsAcceptor>;
#[cfg(all(not(docsrs), feature = "tls-openssl"))]
pub type TlsListener = tls_listener::TlsListener<WrappedAccept<AddrIncoming>, SslContext>;
pub fn listener_from_pem_files<P: AsRef<Path>, Q: AsRef<Path>>(
cert_file: P,
key_file: Q,
protocols: Protocols,
addr: &SocketAddr,
) -> Result<TlsListener, Error> {
#[cfg(feature = "tls-rustls")]
let acceptor = {
use std::sync::Arc;
let config = rustls_server_config_from_pem_files(cert_file, key_file, protocols)?;
TlsAcceptor::from(Arc::new(config))
};
#[cfg(feature = "tls-openssl")]
let acceptor = {
let builder = ssl_context_builder_from_pem_files(cert_file, key_file, protocols)?;
builder.build()
};
Ok(TlsListener::new_hyper(acceptor, AddrIncoming::bind(addr)?))
}
pub fn listener_from_pem_data<'a>(
cert: &'a [u8],
key: &'a [u8],
protocols: Protocols,
addr: &SocketAddr,
) -> Result<TlsListener, Error> {
#[cfg(feature = "tls-rustls")]
let acceptor = {
use std::sync::Arc;
let config = rustls_server_config_from_pem_data(cert, key, protocols)?;
TlsAcceptor::from(Arc::new(config))
};
#[cfg(feature = "tls-openssl")]
let acceptor = {
let builder = ssl_context_builder_from_pem_data(cert, key, protocols)?;
builder.build()
};
Ok(TlsListener::new_hyper(acceptor, AddrIncoming::bind(&addr)?))
}
pub fn hyper_from_pem_files<P: AsRef<Path>, Q: AsRef<Path>>(
cert_file: P,
key_file: Q,
protocols: Protocols,
addr: &SocketAddr,
) -> Result<Builder<TlsListener>, Error> {
let listener = listener_from_pem_files(cert_file, key_file, protocols, addr)?;
let builder = Server::builder(listener);
Ok(match protocols {
Protocols::ALL => builder,
#[cfg(feature = "hyper-h1")]
Protocols::HTTP1 => builder.http1_only(true),
#[cfg(feature = "hyper-h2")]
Protocols::HTTP2 => builder.http2_only(true),
})
}
pub fn hyper_from_pem_data<'a>(
cert: &'a [u8],
key: &'a [u8],
protocols: Protocols,
addr: &SocketAddr,
) -> Result<Builder<TlsListener>, Error> {
let listener = listener_from_pem_data(cert, key, protocols, addr)?;
let builder = Server::builder(listener);
Ok(match protocols {
Protocols::ALL => builder,
#[cfg(feature = "hyper-h1")]
Protocols::HTTP1 => builder.http1_only(true),
#[cfg(feature = "hyper-h2")]
Protocols::HTTP2 => builder.http2_only(true),
})
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(feature = "tls-rustls")]
fn test_rustls_server_config_from_readers() {
use super::*;
const CERT: &[u8] = include_bytes!("../data/cert.pem");
const KEY: &[u8] = include_bytes!("../data/key.pem");
let config = rustls_server_config_from_readers(CERT, KEY, Protocols::ALL).unwrap();
#[cfg(all(feature = "hyper-h1", feature = "hyper-h2"))]
assert_eq!(
config.alpn_protocols,
vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]
);
#[cfg(all(not(feature = "hyper-h1"), feature = "hyper-h2"))]
assert_eq!(config.alpn_protocols, vec![b"h2".to_vec()]);
#[cfg(all(feature = "hyper-h1", not(feature = "hyper-h2")))]
assert_eq!(
config.alpn_protocols,
vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()]
);
#[cfg(feature = "hyper-h1")]
{
let config = rustls_server_config_from_readers(CERT, KEY, Protocols::HTTP1).unwrap();
assert_eq!(
config.alpn_protocols,
vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()]
);
}
#[cfg(feature = "hyper-h2")]
{
let config = rustls_server_config_from_readers(CERT, KEY, Protocols::HTTP2).unwrap();
assert_eq!(config.alpn_protocols, vec![b"h2".to_vec()]);
}
}
}