use crate::{App, app::AppEnv, error::Error, headers::HeaderValue};
use hyper_util::{rt::TokioIo, server::graceful::GracefulShutdown};
use std::{
fmt,
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
time::sleep,
};
use tokio_rustls::{
TlsAcceptor,
rustls::{
RootCertStore, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
server::WebPkiClientVerifier,
},
};
use crate::tls::https_redirect::HttpsRedirectionMiddleware;
#[cfg(any(
all(feature = "http1", feature = "http2"),
all(feature = "http2", not(feature = "http1"))
))]
use hyper::server::conn::http2;
#[cfg(any(
all(feature = "http1", feature = "http2"),
all(feature = "http2", not(feature = "http1"))
))]
use hyper_util::rt::TokioExecutor;
#[cfg(all(feature = "http1", not(feature = "http2")))]
use hyper::server::conn::http1;
#[derive(Debug)]
#[cfg(feature = "dev-cert")]
pub enum DevCertMode {
Auto,
Ask,
}
pub(super) mod https_redirect;
const CERT_FILE_NAME: &str = "cert.pem";
const KEY_FILE_NAME: &str = "key.pem";
const DEFAULT_PORT: u16 = 7879;
const DEFAULT_MAX_AGE: u64 = 30 * 24 * 60 * 60;
#[derive(Debug)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
#[cfg_attr(feature = "config", serde(default))]
pub struct TlsConfig {
pub cert: PathBuf,
pub key: PathBuf,
pub https_redirection_config: RedirectionConfig,
pub(super) hsts_config: HstsConfig,
client_auth: ClientAuth,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
#[cfg_attr(feature = "config", serde(default))]
pub struct RedirectionConfig {
pub enabled: bool,
pub http_port: u16,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
#[cfg_attr(feature = "config", serde(default))]
pub struct HstsConfig {
preload: bool,
include_sub_domains: bool,
#[cfg_attr(feature = "config", serde(deserialize_with = "deser_duration_secs"))]
max_age: Duration,
exclude_hosts: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct HstsHeader {
pub(super) inner: HeaderValue,
pub(super) exclude_hosts: Vec<String>,
}
#[derive(Debug, PartialEq)]
enum ClientAuth {
None,
Optional(PathBuf),
Required(PathBuf),
}
#[cfg(feature = "config")]
impl<'de> serde::Deserialize<'de> for ClientAuth {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
#[derive(serde::Deserialize)]
struct Helper {
r#type: String,
path: Option<PathBuf>,
}
let h = Helper::deserialize(d)?;
match h.r#type.as_str() {
"None" => Ok(ClientAuth::None),
"Optional" => h
.path
.map(ClientAuth::Optional)
.ok_or_else(|| serde::de::Error::missing_field("path")),
"Required" => h
.path
.map(ClientAuth::Required)
.ok_or_else(|| serde::de::Error::missing_field("path")),
other => Err(serde::de::Error::unknown_variant(
other,
&["None", "Optional", "Required"],
)),
}
}
}
#[cfg(feature = "config")]
fn deser_duration_secs<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
use serde::Deserialize;
let secs = u64::deserialize(d)?;
Ok(Duration::from_secs(secs))
}
impl Default for RedirectionConfig {
fn default() -> Self {
Self {
enabled: false,
http_port: DEFAULT_PORT,
}
}
}
impl Default for HstsConfig {
fn default() -> Self {
Self {
preload: false,
include_sub_domains: false,
max_age: Duration::from_secs(DEFAULT_MAX_AGE), exclude_hosts: Vec::new(),
}
}
}
impl Default for TlsConfig {
fn default() -> Self {
let path = std::env::current_dir().unwrap_or_default();
let cert = path.join(CERT_FILE_NAME);
let key = path.join(KEY_FILE_NAME);
Self {
https_redirection_config: RedirectionConfig::default(),
client_auth: ClientAuth::None,
hsts_config: HstsConfig::default(),
key,
cert,
}
}
}
impl fmt::Display for HstsConfig {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "max-age={}", self.max_age.as_secs())?;
if self.include_sub_domains {
f.write_str("; includeSubDomains")?;
}
if self.preload {
f.write_str("; preload")?;
}
Ok(())
}
}
impl HstsConfig {
pub fn with_preload(mut self) -> Self {
self.preload = true;
self
}
pub fn without_preload(mut self) -> Self {
self.preload = false;
self
}
pub fn with_sub_domains(mut self) -> Self {
self.include_sub_domains = true;
self
}
pub fn without_sub_domains(mut self) -> Self {
self.include_sub_domains = false;
self
}
pub fn with_max_age(mut self, max_age: Duration) -> Self {
self.max_age = max_age;
self
}
pub fn with_exclude_hosts<I, S>(mut self, hosts: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.exclude_hosts = hosts
.into_iter()
.map(|h| normalize_host(h.as_ref()))
.collect();
self
}
}
impl HstsHeader {
#[inline]
pub(super) fn new(config: HstsConfig) -> Self {
let mut value = String::with_capacity(64);
use std::fmt::Write;
write!(&mut value, "max-age={}", config.max_age.as_secs()).expect("valid HSTS header");
if config.include_sub_domains {
value.push_str("; includeSubDomains");
}
if config.preload {
value.push_str("; preload");
}
let header_value = HeaderValue::from_str(&value).expect("valid HSTS header");
Self {
exclude_hosts: config.exclude_hosts,
inner: header_value,
}
}
#[inline]
pub(super) fn value(&self) -> HeaderValue {
self.inner.clone()
}
}
impl TlsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn from_pem(path: impl AsRef<Path>) -> Self {
let path = path.as_ref();
let cert = path.join(CERT_FILE_NAME);
let key = path.join(KEY_FILE_NAME);
Self {
https_redirection_config: RedirectionConfig::default(),
client_auth: ClientAuth::None,
hsts_config: HstsConfig::default(),
key,
cert,
}
}
pub fn from_pem_files(cert_file_path: &str, key_file_path: &str) -> Self {
Self {
key: key_file_path.into(),
cert: cert_file_path.into(),
client_auth: ClientAuth::None,
https_redirection_config: RedirectionConfig::default(),
hsts_config: HstsConfig::default(),
}
}
pub fn set_pem(mut self, path: impl AsRef<Path>) -> Self {
let path = path.as_ref();
self.key = path.join(KEY_FILE_NAME);
self.cert = path.join(CERT_FILE_NAME);
self
}
pub fn with_cert_path(mut self, path: impl AsRef<Path>) -> Self {
self.cert = path.as_ref().into();
self
}
pub fn with_key_path(mut self, path: impl AsRef<Path>) -> Self {
self.key = path.as_ref().into();
self
}
pub fn with_optional_client_auth(mut self, path: impl AsRef<Path>) -> Self {
self.client_auth = ClientAuth::Optional(path.as_ref().into());
self
}
pub fn with_required_client_auth(mut self, path: impl AsRef<Path>) -> Self {
self.client_auth = ClientAuth::Required(path.as_ref().into());
self
}
pub fn with_https_redirection(mut self) -> Self {
self.https_redirection_config.enabled = true;
self
}
pub fn with_http_port(mut self, port: u16) -> Self {
self.https_redirection_config.http_port = port;
self
}
pub fn with_hsts<T>(mut self, config: T) -> Self
where
T: FnOnce(HstsConfig) -> HstsConfig,
{
self.hsts_config = config(self.hsts_config);
self
}
pub fn set_hsts(mut self, hsts_config: HstsConfig) -> Self {
self.hsts_config = hsts_config;
self
}
#[cfg(feature = "dev-cert")]
pub fn with_dev_cert(self, _mode: DevCertMode) -> Self {
#[cfg(not(debug_assertions))]
{
return self;
}
#[cfg(debug_assertions)]
{
use volga_dev_cert::{DEV_CERT_NAMES, ask_generate, dev_cert_exists, generate};
if dev_cert_exists() {
return self.use_dev_cert();
}
#[inline]
fn generate_impl(tls: TlsConfig) -> TlsConfig {
if let Err(_err) = generate(
DEV_CERT_NAMES
.iter()
.map(|n| n.to_string())
.collect::<Vec<_>>(),
) {
#[cfg(feature = "tracing")]
tracing::error!("Failed to generate self-signed TLS certificates: {_err:#}");
return tls;
}
tls.use_dev_cert()
}
match _mode {
DevCertMode::Auto => generate_impl(self),
DevCertMode::Ask => match ask_generate() {
Ok(true) => generate_impl(self),
Ok(false) => self,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!("Failed to ask for certificate generation: {_err:#}");
self
}
},
}
}
}
#[inline]
#[cfg(feature = "dev-cert")]
pub(super) fn use_dev_cert(self) -> Self {
self.with_cert_path(volga_dev_cert::get_cert_path())
.with_key_path(volga_dev_cert::get_signing_key_path())
}
pub(super) fn build(self) -> Result<ServerConfig, Error> {
let certs = Self::load_cert_file(&self.cert)?;
let key = Self::load_key_file(&self.key)?;
let builder = match self.client_auth {
ClientAuth::None => ServerConfig::builder().with_no_client_auth(),
ClientAuth::Optional(trust_anchor) => {
let verifier =
WebPkiClientVerifier::builder(Self::read_trust_anchor(trust_anchor)?.into())
.allow_unauthenticated()
.build()
.map_err(Error::from)?;
ServerConfig::builder().with_client_cert_verifier(verifier)
}
ClientAuth::Required(trust_anchor) => {
let verifier =
WebPkiClientVerifier::builder(Self::read_trust_anchor(trust_anchor)?.into())
.build()
.map_err(Error::from)?;
ServerConfig::builder().with_client_cert_verifier(verifier)
}
};
let mut config = builder.with_single_cert(certs, key).map_err(Error::from)?;
config.alpn_protocols = vec![
#[cfg(feature = "http2")]
b"h2".into(),
b"http/1.1".into(),
];
Ok(config)
}
#[inline]
fn load_cert_file<'a>(path: impl AsRef<Path>) -> Result<Vec<CertificateDer<'a>>, Error> {
CertificateDer::pem_file_iter(path)
.map_err(Error::from)?
.collect::<Result<Vec<_>, _>>()
.map_err(Error::from)
}
#[inline]
fn load_key_file<'a>(path: impl AsRef<Path>) -> Result<PrivateKeyDer<'a>, Error> {
PrivateKeyDer::from_pem_file(path).map_err(Error::from)
}
fn read_trust_anchor(path: impl AsRef<Path>) -> Result<RootCertStore, Error> {
let trust_anchors = Self::load_cert_file(path)?;
let mut store = RootCertStore::empty();
let (added, _skipped) = store.add_parsable_certificates(trust_anchors);
if added == 0 {
return Err(Error::server_error(
"TLS config error: certificate parse error",
));
}
Ok(store)
}
}
impl From<tokio_rustls::rustls::Error> for Error {
#[inline]
fn from(err: tokio_rustls::rustls::Error) -> Self {
Self::server_error(format!("TLS config error: {err}"))
}
}
impl From<tokio_rustls::rustls::pki_types::pem::Error> for Error {
fn from(err: tokio_rustls::rustls::pki_types::pem::Error) -> Self {
Self::server_error(format!("TLS config error: {err}"))
}
}
impl From<tokio_rustls::rustls::server::VerifierBuilderError> for Error {
#[inline]
fn from(err: tokio_rustls::rustls::server::VerifierBuilderError) -> Self {
Self::server_error(format!("TLS config error: {err}"))
}
}
impl AppEnv {
#[inline]
pub(super) fn acceptor(&self) -> Option<TlsAcceptor> {
self.acceptor.clone()
}
}
impl App {
pub fn with_tls<T>(mut self, config: T) -> Self
where
T: FnOnce(TlsConfig) -> TlsConfig,
{
let tls = self.tls_config.unwrap_or_default();
self.tls_config = Some(config(tls));
self
}
pub fn set_tls(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub fn with_hsts<T>(mut self, config: T) -> Self
where
T: FnOnce(HstsConfig) -> HstsConfig,
{
self.tls_config = self
.tls_config
.map(|tls| tls.set_hsts(config(HstsConfig::default())));
self
}
pub fn set_hsts(mut self, hsts_config: HstsConfig) -> Self {
self.tls_config = self.tls_config.map(|config| config.set_hsts(hsts_config));
self
}
pub(super) fn run_https_redirection_middleware(
socket: SocketAddr,
http_port: u16,
shutdown_tx: Arc<watch::Sender<()>>,
) {
tokio::spawn(async move {
let https_port = socket.port();
let socket = SocketAddr::new(socket.ip(), http_port);
#[cfg(feature = "tracing")]
tracing::info!("listening on: http://{socket}");
if let Ok(tcp_listener) = TcpListener::bind(socket).await {
let graceful_shutdown = GracefulShutdown::new();
loop {
let (stream, _) = tokio::select! {
_ = shutdown_tx.closed() => break,
Ok(connection) = tcp_listener.accept() => connection
};
Self::serve_http_redirection(https_port, stream, &graceful_shutdown);
}
tokio::select! {
_ = sleep(Duration::from_secs(super::app::GRACEFUL_SHUTDOWN_TIMEOUT)) => (),
_ = graceful_shutdown.shutdown() => {
#[cfg(feature = "tracing")]
tracing::info!("shutting down HTTPS redirection...");
},
}
} else {
#[cfg(feature = "tracing")]
tracing::error!("unable to start HTTPS redirection listener");
}
});
}
#[inline]
fn serve_http_redirection(
https_port: u16,
stream: TcpStream,
graceful_shutdown: &GracefulShutdown,
) {
let io = TokioIo::new(stream);
#[cfg(all(feature = "http1", not(feature = "http2")))]
let connection_builder = http1::Builder::new();
#[cfg(any(
all(feature = "http1", feature = "http2"),
all(feature = "http2", not(feature = "http1"))
))]
let connection_builder = http2::Builder::new(TokioExecutor::new());
let connection =
connection_builder.serve_connection(io, HttpsRedirectionMiddleware::new(https_port));
let connection = graceful_shutdown.watch(connection);
tokio::spawn(async move {
if let Err(_err) = connection.await {
#[cfg(feature = "tracing")]
tracing::error!("error serving connection: {_err:#}");
}
});
}
}
#[inline]
fn normalize_host(host: &str) -> String {
let host = match host.trim().rsplit_once(':') {
Some((h, "443")) => h,
_ => host,
};
host.trim_end_matches('.').to_ascii_lowercase()
}
#[cfg(test)]
mod tests {
use super::{
CERT_FILE_NAME, ClientAuth, DEFAULT_MAX_AGE, DEFAULT_PORT, HstsConfig, KEY_FILE_NAME,
RedirectionConfig, TlsConfig,
};
use crate::{App, tls::HstsHeader};
use std::path::PathBuf;
use std::time::Duration;
#[test]
fn it_creates_new_tls_config() {
let tls_config = TlsConfig::new();
let path = std::env::current_dir().unwrap_or_default();
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 0);
assert_eq!(
tls_config.hsts_config.max_age,
Duration::from_secs(DEFAULT_MAX_AGE)
);
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_default_tls_config() {
let tls_config = TlsConfig::default();
let path = std::env::current_dir().unwrap_or_default();
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 0);
assert_eq!(
tls_config.hsts_config.max_age,
Duration::from_secs(DEFAULT_MAX_AGE)
);
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_from_pem() {
let tls_config = TlsConfig::from_pem("tls");
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 0);
assert_eq!(
tls_config.hsts_config.max_age,
Duration::from_secs(DEFAULT_MAX_AGE)
);
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_with_set_hsts() {
let tls_config = TlsConfig::from_pem("tls").set_hsts(HstsConfig {
max_age: Duration::from_secs(1),
preload: false,
include_sub_domains: false,
exclude_hosts: vec!["example.com".to_string()],
});
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 1);
assert_eq!(tls_config.hsts_config.max_age, Duration::from_secs(1));
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_with_hsts() {
let tls_config = TlsConfig::from_pem("tls").with_hsts(|hsts| {
hsts.with_exclude_hosts(["example.com"])
.without_preload()
.without_sub_domains()
.with_max_age(Duration::from_secs(1))
});
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 1);
assert_eq!(tls_config.hsts_config.max_age, Duration::from_secs(1));
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_with_hsts_preload() {
let tls_config = TlsConfig::from_pem("tls").with_hsts(|h| h.with_preload());
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 0);
assert_eq!(
tls_config.hsts_config.max_age,
Duration::from_secs(DEFAULT_MAX_AGE)
);
assert!(tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_with_hsts_sub_domains() {
let tls_config = TlsConfig::from_pem("tls").with_hsts(|h| h.with_sub_domains());
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 0);
assert_eq!(
tls_config.hsts_config.max_age,
Duration::from_secs(DEFAULT_MAX_AGE)
);
assert!(!tls_config.hsts_config.preload);
assert!(tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_with_hsts_max_age() {
let tls_config =
TlsConfig::from_pem("tls").with_hsts(|h| h.with_max_age(Duration::from_secs(5)));
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 0);
assert_eq!(tls_config.hsts_config.max_age, Duration::from_secs(5));
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_tls_config_with_hsts_exclude_hosts() {
let tls_config =
TlsConfig::from_pem("tls").with_hsts(|h| h.with_exclude_hosts(["example.com"]));
let path = PathBuf::from("tls");
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
assert_eq!(tls_config.client_auth, ClientAuth::None);
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 1);
assert_eq!(
tls_config.hsts_config.max_age,
Duration::from_secs(DEFAULT_MAX_AGE)
);
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(!tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_default_hsts_config() {
let hsts_config = HstsConfig::default();
assert_eq!(hsts_config.exclude_hosts.len(), 0);
assert_eq!(hsts_config.max_age, Duration::from_secs(DEFAULT_MAX_AGE));
assert!(!hsts_config.preload);
assert!(!hsts_config.include_sub_domains);
}
#[test]
fn it_creates_default_redirect_config() {
let https_redirection_config = RedirectionConfig::default();
assert!(!https_redirection_config.enabled);
assert_eq!(https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_displays_hsts_config() {
let hsts_config = HstsConfig::default();
let hsts_string = hsts_config.to_string();
assert_eq!(hsts_string, "max-age=2592000");
}
#[test]
fn it_creates_app_with_tls_config_and_hsts_custom_config() {
let app = App::new()
.with_tls(|tls| tls.with_https_redirection())
.with_hsts(|hsts| {
hsts.with_max_age(Duration::from_secs(1))
.without_preload()
.without_sub_domains()
.with_exclude_hosts(["example.com"])
});
let tls_config = app.tls_config.unwrap();
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 1);
assert_eq!(tls_config.hsts_config.max_age, Duration::from_secs(1));
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_creates_app_with_tls_config_and_sets_hsts_custom_config() {
let hsts = HstsConfig::default()
.with_max_age(Duration::from_secs(1))
.without_preload()
.without_sub_domains()
.with_exclude_hosts(["example.com"]);
let app = App::new()
.with_tls(|tls| tls.with_https_redirection())
.set_hsts(hsts);
let tls_config = app.tls_config.unwrap();
assert_eq!(tls_config.hsts_config.exclude_hosts.len(), 1);
assert_eq!(tls_config.hsts_config.max_age, Duration::from_secs(1));
assert!(!tls_config.hsts_config.preload);
assert!(!tls_config.hsts_config.include_sub_domains);
assert!(tls_config.https_redirection_config.enabled);
assert_eq!(tls_config.https_redirection_config.http_port, DEFAULT_PORT);
}
#[test]
fn it_sets_tls_key_and_cert() {
let app = App::new().with_tls(|tls| {
tls.with_key_path(KEY_FILE_NAME)
.with_cert_path(CERT_FILE_NAME)
});
let tls_config = app.tls_config.unwrap();
assert_eq!(tls_config.key, PathBuf::from(KEY_FILE_NAME));
assert_eq!(tls_config.cert, PathBuf::from(CERT_FILE_NAME));
}
#[test]
fn it_sets_tls_pem() {
let app = App::new().with_tls(|tls| tls.set_pem("tls"));
let path = PathBuf::from("tls");
let tls_config = app.tls_config.unwrap();
assert_eq!(tls_config.key, path.join(KEY_FILE_NAME));
assert_eq!(tls_config.cert, path.join(CERT_FILE_NAME));
}
#[test]
fn it_creates_hsts_header() {
let hsts_config = HstsConfig::default().with_exclude_hosts(["www.example.com"]);
let hsts_header = HstsHeader::new(hsts_config);
assert_eq!(hsts_header.exclude_hosts, &["www.example.com"]);
assert_eq!(hsts_header.inner, "max-age=2592000");
}
#[test]
fn it_normalizes_excluded_hosts() {
let hsts_config = HstsConfig::default().with_exclude_hosts([
"www.ExAmplE.com.",
"www.ExAmplE.net:80",
"www.ExAmplE.org:443",
]);
assert_eq!(
hsts_config.exclude_hosts,
&["www.example.com", "www.example.net:80", "www.example.org"]
);
}
}