use crate::prelude::{internals::*, *};
#[cfg(feature = "https")]
use rustls::{
server::{ClientHello, ResolvesServerCert},
sign, ServerConfig,
};
#[must_use]
pub struct Host {
pub name: CompactString,
pub alternative_names: Vec<CompactString>,
#[cfg(feature = "https")]
pub certificate: std::sync::RwLock<Option<Arc<sign::CertifiedKey>>>,
pub path: CompactString,
pub extensions: Extensions,
pub file_cache: Option<FileCache>,
pub response_cache: Option<ResponseCache>,
pub limiter: LimitManager,
pub vary: Vary,
pub options: Options,
pub compression_options_oneshot: comprash::CompressionOptions,
pub compression_options_cached: comprash::CompressionOptions,
}
impl Host {
#[cfg(feature = "https")]
pub fn try_read_fs(
host_name: impl AsRef<str>,
cert_path: impl AsRef<str>,
private_key_path: impl AsRef<str>,
path: impl AsRef<str>,
extensions: Extensions,
options: Options,
) -> Result<Self, (CertificateError, Self)> {
let cert = get_certified_key(cert_path, private_key_path);
match cert {
Ok(key) => Ok(Self::new(host_name, key, path, extensions, options)),
Err(err) => Err((err, Self::unsecure(host_name, path, extensions, options))),
}
}
#[cfg(feature = "https")]
pub fn read_fs_name_from_cert(
cert_path: impl AsRef<str>,
private_key_path: impl AsRef<str>,
path: impl AsRef<str>,
extensions: Extensions,
options: Options,
) -> Result<Self, CertificateError> {
let cert = get_certified_key(cert_path, private_key_path);
match cert {
Ok(key) => Ok(Self::new_name_from_cert(key, path, extensions, options)),
Err(err) => Err(err),
}
}
#[cfg(feature = "https")]
pub fn new(
name: impl AsRef<str>,
key: sign::CertifiedKey,
path: impl AsRef<str>,
extensions: Extensions,
options: Options,
) -> Self {
Self {
name: name.as_ref().to_compact_string(),
alternative_names: Vec::new(),
certificate: std::sync::RwLock::new(Some(Arc::new(key))),
path: path.as_ref().to_compact_string(),
extensions,
file_cache: Some(MokaCache::default()),
response_cache: Some(MokaCache::default()),
options,
limiter: LimitManager::default(),
vary: Vary::default(),
compression_options_oneshot: comprash::CompressionOptions::oneshot(),
compression_options_cached: comprash::CompressionOptions::cached(),
}
}
#[cfg(feature = "https")]
pub fn new_name_from_cert(
key: sign::CertifiedKey,
path: impl AsRef<str>,
extensions: Extensions,
options: Options,
) -> Self {
let parsed = webpki::EndEntityCert::try_from(&key.cert[0])
.expect("internal certificate has invalid format?");
let names: Vec<_> = parsed
.valid_dns_names()
.map(|s| s.to_compact_string())
.collect();
let mut me = Self::new(
names.first().expect("cert has no name?").clone(),
key,
path,
extensions,
options,
);
me.alternative_names = names;
me
}
pub fn unsecure(
host_name: impl AsRef<str>,
path: impl AsRef<str>,
extensions: Extensions,
options: Options,
) -> Self {
Self {
name: host_name.as_ref().to_compact_string(),
alternative_names: Vec::new(),
#[cfg(feature = "https")]
certificate: std::sync::RwLock::new(None),
path: path.as_ref().to_compact_string(),
extensions,
file_cache: Some(MokaCache::default()),
response_cache: Some(MokaCache::default()),
options,
limiter: LimitManager::default(),
vary: Vary::default(),
compression_options_oneshot: comprash::CompressionOptions::oneshot(),
compression_options_cached: comprash::CompressionOptions::cached(),
}
}
#[cfg(feature = "https")]
pub fn http_redirect_or_unsecure(
host_name: impl AsRef<str>,
cert_path: impl AsRef<str>,
private_key_path: impl AsRef<str>,
path: impl AsRef<str>,
extensions: Extensions,
options: Options,
) -> Self {
match Host::try_read_fs(
host_name,
cert_path,
private_key_path,
path,
extensions,
options,
) {
Ok(mut host) => {
host.with_http_to_https_redirect();
host
}
Err((err, host_without_cert)) => {
error!(
"Failed to get certificate! Not running host on HTTPS. {:?}",
err
);
host_without_cert
}
}
}
#[cfg(feature = "https")]
pub fn with_http_to_https_redirect(&mut self) -> &mut Self {
self.extensions.with_http_to_https_redirect();
self
}
#[cfg(feature = "https")]
pub fn with_hsts(&mut self) -> &mut Self {
struct Ext;
impl extensions::PackageCall for Ext {
fn call<'a>(
&'a self,
response: &'a mut Response<()>,
request: &'a FatRequest,
_: &'a Host,
_: SocketAddr,
) -> RetFut<'a, ()> {
if request.uri().scheme_str() == Some("https") {
response
.headers_mut()
.entry("strict-transport-security")
.or_insert(HeaderValue::from_static(
"max-age=63072000; includeSubDomains; preload",
));
}
ready(())
}
}
self.extensions
.add_package(Box::new(Ext), extensions::Id::new(8, "Adding HSTS header"));
self
}
pub fn add_alternative_name(&mut self, name: impl AsRef<str>) -> &mut Self {
self.alternative_names
.push(name.as_ref().to_compact_string());
self
}
pub fn disable_client_cache(&mut self) -> &mut Self {
self.options.disable_client_cache();
self
}
pub fn disable_fs_cache(&mut self) -> &mut Self {
self.file_cache = None;
self
}
pub fn disable_response_cache(&mut self) -> &mut Self {
self.response_cache = None;
self
}
pub fn disable_server_cache(&mut self) -> &mut Self {
self.disable_fs_cache().disable_response_cache()
}
#[cfg(feature = "https")]
#[inline]
pub fn is_secure(&self) -> bool {
self.certificate.read().unwrap().is_some()
}
#[cfg(not(feature = "https"))]
#[inline]
pub(crate) fn is_secure(&self) -> bool {
false
}
#[cfg(feature = "br")]
pub fn set_brotli_level(&mut self, level: u32) -> &mut Self {
self.compression_options_cached.brotli_level = level;
self
}
#[cfg(feature = "gzip")]
pub fn set_gzip_level(&mut self, level: u32) -> &mut Self {
self.compression_options_cached.gzip_level = level;
self
}
#[cfg(feature = "zstd")]
pub fn set_zstd_level(&mut self, level: i32) -> &mut Self {
self.compression_options_cached.zstd_level = level;
self
}
#[cfg(feature = "br")]
pub fn set_brotli_level_oneshot(&mut self, level: u32) -> &mut Self {
self.compression_options_oneshot.brotli_level = level;
self
}
#[cfg(feature = "gzip")]
pub fn set_gzip_level_oneshot(&mut self, level: u32) -> &mut Self {
self.compression_options_oneshot.gzip_level = level;
self
}
#[cfg(feature = "zstd")]
pub fn set_zstd_level_oneshot(&mut self, level: i32) -> &mut Self {
self.compression_options_oneshot.zstd_level = level;
self
}
#[cfg(feature = "https")]
pub fn live_set_certificate(&self, key: sign::CertifiedKey) {
let mut guard = self.certificate.write().unwrap();
*guard = Some(Arc::new(key));
}
}
impl Debug for Host {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut s = f.debug_struct(utils::ident_str!(Host));
utils::fmt_fields!(
s,
(self.name),
(self.alternative_names),
#[cfg(feature = "https")]
(self.certificate, &"[internal certificate]".as_clean()),
(self.path),
(self.extensions, &"[internal extension data]".as_clean()),
(self.file_cache, &"[internal cache]".as_clean()),
(self.response_cache, &"[internal cache]".as_clean()),
(self.limiter),
(self.vary),
(self.options),
(self.compression_options_oneshot),
(self.compression_options_cached),
);
s.finish()
}
}
impl Host {
pub fn clone_without_extensions(&self) -> Self {
Self {
name: self.name.clone(),
alternative_names: self.alternative_names.clone(),
#[cfg(feature = "https")]
certificate: std::sync::RwLock::new(self.certificate.read().unwrap().clone()),
path: self.path.clone(),
extensions: Extensions::empty(),
file_cache: Some(MokaCache::default()),
response_cache: Some(MokaCache::default()),
limiter: self.limiter.clone(),
vary: Vary::default(),
options: self.options.clone(),
compression_options_cached: self.compression_options_cached.clone(),
compression_options_oneshot: self.compression_options_oneshot.clone(),
}
}
}
#[derive(Debug, Clone)]
#[must_use]
pub struct Options {
pub folder_default: Option<CompactString>,
pub extension_default: Option<CompactString>,
pub public_data_dir: Option<CompactString>,
pub errors_dir: Option<CompactString>,
pub disable_client_cache: bool,
pub disable_if_modified_since: bool,
pub status_code_cache_filter: fn(StatusCode) -> CacheAction,
pub disable_fs: bool,
}
impl Options {
pub fn new() -> Self {
Self {
folder_default: None,
extension_default: None,
public_data_dir: None,
errors_dir: None,
disable_client_cache: false,
disable_if_modified_since: false,
status_code_cache_filter: default_status_code_cache_filter,
disable_fs: false,
}
}
pub fn disable_client_cache(&mut self) -> &mut Self {
self.disable_client_cache = true;
self
}
pub fn disable_fs(&mut self) -> &mut Self {
self.disable_fs = true;
self
}
pub fn set_public_data_dir(&mut self, path: impl AsRef<str>) -> &mut Self {
self.public_data_dir = Some(path.as_ref().to_compact_string());
self
}
pub fn set_errors_dir(&mut self, path: impl AsRef<str>) -> &mut Self {
self.errors_dir = Some(path.as_ref().to_compact_string());
self
}
#[must_use]
pub fn get_folder_default(&self) -> &str {
self.folder_default.as_deref().unwrap_or("index.html")
}
#[must_use]
pub fn get_extension_default(&self) -> &str {
self.extension_default.as_deref().unwrap_or("html")
}
#[must_use]
pub fn get_public_data_dir(&self) -> &str {
self.public_data_dir.as_deref().unwrap_or("public")
}
#[must_use]
pub fn get_errors_dir(&self) -> &str {
self.public_data_dir.as_deref().unwrap_or("errors")
}
}
impl Default for Options {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
#[must_use]
pub struct CollectionBuilder(Collection);
impl CollectionBuilder {
#[inline]
pub fn insert(mut self, host: Host) -> Self {
self.check_secure(&host);
if self.0.first.is_none() {
self.0.first = Some(host.name.clone());
self.0.pre_host_limiter = host.limiter.clone();
}
for alt_name in &host.alternative_names {
self.0
.by_name
.insert(alt_name.clone(), HostValue::Ref(host.name.clone()));
}
self.0
.by_name
.insert(host.name.clone(), HostValue::Host(host));
self
}
#[inline]
pub fn default(mut self, host: Host) -> Self {
assert!(
self.0.default.is_none(),
"Can not set default host multiple times."
);
info!("Set default host {:?}", host.name);
self.0.default = Some(host.name.clone());
self.insert(host)
}
fn check_secure(&mut self, host: &Host) {
if host.is_secure() {
self.0.has_secure = true;
}
}
#[inline]
pub fn set_pre_host_limiter(mut self, limiter: LimitManager) -> Self {
self.0.pre_host_limiter = limiter;
self
}
#[inline]
#[must_use]
pub fn build(self) -> Arc<Collection> {
trace!("Build host collection: {:#?}", self.0);
Arc::new(self.into_inner())
}
#[inline]
pub fn into_inner(self) -> Collection {
self.0
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] enum HostValue {
Host(Host),
Ref(CompactString),
}
impl HostValue {
fn as_host(&self) -> Option<&Host> {
match self {
Self::Host(h) => Some(h),
Self::Ref(_) => None,
}
}
}
#[derive(Debug)]
#[must_use]
pub struct Collection {
default: Option<CompactString>,
by_name: HashMap<CompactString, HostValue>,
first: Option<CompactString>,
has_secure: bool,
pre_host_limiter: LimitManager,
}
impl Collection {
#[inline]
pub fn builder() -> CollectionBuilder {
CollectionBuilder(Self {
default: None,
by_name: HashMap::new(),
first: None,
has_secure: false,
pre_host_limiter: LimitManager::default(),
})
}
#[inline]
pub fn simple_non_secure(default_host_name: impl AsRef<str>, extensions: Extensions) -> Self {
Self::builder()
.default(Host::unsecure(
default_host_name,
".",
extensions,
Options::default(),
))
.into_inner()
}
#[inline]
#[must_use]
pub fn get_default(&self) -> Option<&Host> {
trace!("Getting default {:?}", self.default);
self.default
.as_ref()
.and_then(|default| self.get_host(default))
}
#[inline]
#[must_use]
pub fn get_host(&self, name: &str) -> Option<&Host> {
match self.by_name.get(name) {
Some(v) => match v {
HostValue::Host(h) => Some(h),
HostValue::Ref(r) => Some(
self.by_name
.get(r)
.and_then(HostValue::as_host)
.expect("internal error when resolving host: Ref pointed to Ref"),
),
},
None => None,
}
}
#[inline]
#[must_use]
pub fn get_or_default(&self, name: &str) -> Option<&Host> {
self.get_host(name)
.or_else(|| name.strip_suffix('.').and_then(|name| self.get_host(name)))
.or_else(|| self.get_default())
.or_else(|| {
let base_host = name.split(':').next();
if base_host == Some("localhost")
|| base_host == Some("127.0.0.1")
|| base_host == Some("::1")
|| base_host == Some("[::1]")
{
self.first.as_ref().and_then(|host| self.get_host(host))
} else {
None
}
})
}
#[inline]
#[must_use]
pub fn get_option_or_default(&self, name: Option<&str>) -> Option<&Host> {
match name {
Some(host) => self.get_or_default(host),
None => self.get_default(),
}
}
#[inline]
pub fn get_from_request<'a>(
&'a self,
request: &Request<Body>,
sni_hostname: Option<&str>,
) -> Option<&'a Host> {
fn get_header(headers: &HeaderMap) -> Option<&str> {
headers
.get(header::HOST)
.map(HeaderValue::to_str)
.and_then(Result::ok)
}
let host = sni_hostname.or_else(|| get_header(request.headers()));
self.get_option_or_default(host)
}
#[inline]
#[must_use]
pub fn has_secure(&self) -> bool {
self.has_secure
}
#[cfg(feature = "https")]
#[inline]
#[must_use]
pub fn make_config(self: &Arc<Self>) -> ServerConfig {
encryption::attach_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(self.clone());
config.alpn_protocols = alpn();
config
}
pub(crate) fn limiter(&self) -> &LimitManager {
&self.pre_host_limiter
}
#[inline]
#[allow(clippy::unused_async)] pub async fn clear_response_caches(&self, host_filter: Option<&str>) {
for host in self.by_name.values().filter_map(HostValue::as_host) {
if host_filter.map_or(false, |h| h != host.name) {
continue;
}
if let Some(cache) = &host.response_cache {
cache.cache.invalidate_all();
}
}
}
pub fn clear_page(&self, host: &str, uri: &Uri) -> (bool, bool) {
let key = UriKey::path_and_query(uri);
let mut found = false;
let mut cleared = false;
if host.is_empty() || host == "default" {
if let Some(cache) = self
.get_default()
.as_ref()
.and_then(|h| h.response_cache.as_ref())
{
found = true;
cleared ^= cache.cache.contains_key(&key);
cache.cache.invalidate(&key);
if let UriKey::PathQuery(path_query) = key {
let key = UriKey::Path(path_query.into_path());
cleared |= cache.cache.contains_key(&key);
cache.cache.invalidate(&key);
}
}
} else if let Some(host) = self.get_host(host) {
found = true;
if let Some(cache) = &host.response_cache {
cleared ^= cache.cache.contains_key(&key);
cache.cache.invalidate(&key);
if let UriKey::PathQuery(path_query) = key {
let key = UriKey::Path(path_query.into_path());
cleared |= cache.cache.contains_key(&key);
cache.cache.invalidate(&key);
}
}
}
(found, cleared)
}
#[inline]
#[allow(clippy::unused_async)] pub async fn clear_file_caches(&self, host_filter: Option<&str>) {
for host in self.by_name.values().filter_map(HostValue::as_host) {
if host_filter.map_or(false, |h| h != host.name) {
continue;
}
if let Some(cache) = &host.file_cache {
cache.cache.invalidate_all();
}
}
}
pub fn clear_file(&self, host: &str, path: impl AsRef<str>) -> (bool, bool) {
let path = path.as_ref();
let mut found = false;
let mut cleared = false;
if host.is_empty() || host == "default" {
if let Some(cache) = self
.get_default()
.as_ref()
.and_then(|h| h.file_cache.as_ref())
{
found = true;
cleared |= cache.cache.contains_key(path);
cache.cache.invalidate(path);
}
} else if let Some(host) = self.get_host(host) {
found = true;
if let Some(cache) = &host.file_cache {
cleared |= cache.cache.contains_key(path);
cache.cache.invalidate(path);
}
}
(found, cleared)
}
}
unsafe impl Send for Collection {}
unsafe impl Sync for Collection {}
#[cfg(feature = "https")]
impl ResolvesServerCert for Collection {
#[inline]
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
self.get_option_or_default(client_hello.server_name())
.and_then(|host| host.certificate.read().unwrap().as_ref().map(Arc::clone))
}
}
#[must_use]
#[allow(unused_mut)]
pub fn alpn() -> Vec<Vec<u8>> {
let mut vec = vec![
#[cfg(feature = "http2")]
b"h2".to_vec(),
b"http/1.1".to_vec(),
];
#[cfg(feature = "http3")]
{
vec.insert(0, b"h3-29".to_vec());
vec.insert(0, b"h3-30".to_vec());
vec.insert(0, b"h3-31".to_vec());
vec.insert(0, b"h3-31".to_vec());
vec.insert(0, b"h3-32".to_vec());
vec.insert(0, b"h3".to_vec());
}
vec
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[must_use]
pub enum CacheAction {
Cache,
Drop,
}
impl CacheAction {
pub fn from_cache(cache: bool) -> Self {
if cache {
Self::Cache
} else {
Self::Drop
}
}
pub fn from_drop(drop: bool) -> Self {
Self::from_cache(!drop)
}
#[must_use]
pub fn into_cache(self) -> bool {
matches!(self, Self::Cache)
}
#[must_use]
pub fn into_drop(self) -> bool {
matches!(self, Self::Drop)
}
}
pub fn default_status_code_cache_filter(code: StatusCode) -> CacheAction {
CacheAction::from_drop(matches!(code.as_u16(), 400..=403 | 405..=409 | 411..=499|100..=199|304))
}
#[cfg(feature = "https")]
#[derive(Debug)]
pub enum CertificateError {
Io(io::Error),
ImproperPrivateKeyFormat,
ImproperCertificateFormat,
NoKey,
InvalidPrivateKey,
}
#[cfg(feature = "https")]
impl From<io::Error> for CertificateError {
#[inline]
fn from(error: io::Error) -> Self {
Self::Io(error)
}
}
#[cfg(feature = "https")]
pub fn get_certified_key(
cert_path: impl AsRef<str>,
private_key_path: impl AsRef<str>,
) -> Result<sign::CertifiedKey, CertificateError> {
let mut chain = io::BufReader::new(std::fs::File::open(cert_path.as_ref())?);
let mut private_key = io::BufReader::new(std::fs::File::open(private_key_path.as_ref())?);
let private_key = match rustls_pemfile::private_key(&mut private_key) {
Ok(Some(key)) => key,
Ok(None) => return Err(CertificateError::NoKey),
Err(err) => {
error!("Invalid private key read, ignoring: {err}");
return Err(CertificateError::InvalidPrivateKey);
}
};
let key = rustls::crypto::ring::sign::any_supported_type(&private_key)
.map_err(|_| CertificateError::InvalidPrivateKey)?;
let mut certs = Vec::with_capacity(4);
for cert in rustls_pemfile::certs(&mut chain) {
match cert {
Ok(c) => certs.push(c),
Err(_) => return Err(CertificateError::ImproperCertificateFormat),
}
}
Ok(sign::CertifiedKey::new(certs, key))
}