use super::Response;
use super::request::RequestHead;
use super::router::{FrozenRouter, Router};
use super::{BufferConfig, Request};
impl std::fmt::Debug for HostRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostRouter")
.field("host_count", &self.hosts.len())
.field("has_default", &self.default.is_some())
.field("buffers", &self.buffers)
.finish()
}
}
#[derive(Default)]
pub struct HostRouter {
hosts: Vec<(Box<str>, Router)>,
default: Option<Router>,
buffers: BufferConfig,
}
impl HostRouter {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn max_request_body(mut self, bytes: usize) -> Self {
self.buffers = self.buffers.with_max_request_body(bytes);
self
}
#[must_use]
pub fn sse_buffer_size(mut self, size: usize) -> Self {
self.buffers = self.buffers.with_sse_buffer_size(size);
self
}
#[cfg(feature = "ws")]
#[must_use]
pub fn ws_buffer_size(mut self, size: usize) -> Self {
self.buffers = self.buffers.with_ws_buffer_size(size);
self
}
pub(super) fn buffer_config(&self) -> BufferConfig {
self.buffers
}
pub fn add(&mut self, host: &str, router: Router) -> &mut Self {
let normalized: Box<str> = host.to_ascii_lowercase().into_boxed_str();
self.hosts.push((normalized, router));
self
}
pub fn set_default(&mut self, router: Router) -> &mut Self {
self.default = Some(router);
self
}
pub(super) fn freeze(self) -> FrozenHostRouter {
let mut hosts: Vec<(Box<str>, FrozenRouter)> = self
.hosts
.into_iter()
.map(|(host, router)| (host, router.freeze()))
.collect();
hosts.sort_by(|(a, _), (b, _)| a.cmp(b));
let default = self.default.map(Router::freeze);
FrozenHostRouter {
hosts: hosts.into_boxed_slice(),
default,
}
}
}
pub(super) struct FrozenHostRouter {
hosts: Box<[(Box<str>, FrozenRouter)]>,
default: Option<FrozenRouter>,
}
fn is_valid_host(host: &str) -> bool {
!host.bytes().any(|b| matches!(b, b'/' | b'\\' | 0..=31))
}
fn strip_host_port(host: &str) -> &str {
match (host.starts_with('['), host.find(']')) {
(true, Some(end)) => &host[..=end],
(true, None) => host,
(false, _) => host.rsplit_once(':').map_or(host, |(h, _)| h),
}
}
fn lowercase_hostname(host: &str) -> std::borrow::Cow<'_, str> {
match host.bytes().any(|b| b.is_ascii_uppercase()) {
false => std::borrow::Cow::Borrowed(host),
true => std::borrow::Cow::Owned(host.to_ascii_lowercase()),
}
}
impl FrozenHostRouter {
fn resolve_host(&self, host_header: &str) -> Result<Option<&FrozenRouter>, Response> {
match is_valid_host(host_header) {
false => Err(Response::text_raw(400, "bad request")),
true => {
let hostname = strip_host_port(host_header);
let lookup = lowercase_hostname(hostname);
Ok(self
.hosts
.binary_search_by_key(&lookup.as_ref(), |(h, _)| h.as_ref())
.ok()
.map(|i| &self.hosts[i].1)
.or(self.default.as_ref()))
}
}
}
pub(super) fn resolve(&self, req: &Request) -> Result<Option<&FrozenRouter>, Response> {
self.resolve_host(req.header("host").unwrap_or(""))
}
pub(super) fn resolve_from_head(
&self,
head: &RequestHead<'_>,
) -> Result<Option<&FrozenRouter>, Response> {
self.resolve_host(head.header("host").unwrap_or(""))
}
}