use crate::config::AccessControlConfig;
use crate::headers::CompiledHeaderConfig;
use crate::shadow::ShadowMirrorConfig;
use ahash::RandomState;
use regex::Regex;
use std::collections::HashMap;
use tracing::{debug, warn};
use unicase::Ascii;
#[derive(Debug, Clone)]
pub struct SiteConfig {
pub hostname: String,
pub upstreams: Vec<String>,
pub tls_enabled: bool,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub waf_threshold: Option<u8>,
pub waf_enabled: bool,
pub access_control: Option<AccessControlConfig>,
pub headers: Option<CompiledHeaderConfig>,
pub shadow_mirror: Option<ShadowMirrorConfig>,
}
impl Default for SiteConfig {
fn default() -> Self {
Self {
hostname: String::new(),
upstreams: Vec::new(),
tls_enabled: false,
tls_cert: None,
tls_key: None,
waf_threshold: None,
waf_enabled: true,
access_control: None,
headers: None,
shadow_mirror: None,
}
}
}
impl From<crate::config::SiteYamlConfig> for SiteConfig {
fn from(yaml: crate::config::SiteYamlConfig) -> Self {
Self {
hostname: yaml.hostname,
upstreams: yaml
.upstreams
.iter()
.map(|u| format!("{}:{}", u.host, u.port))
.collect(),
tls_enabled: yaml.tls.is_some(),
tls_cert: yaml.tls.as_ref().map(|t| t.cert_path.clone()),
tls_key: yaml.tls.as_ref().map(|t| t.key_path.clone()),
waf_threshold: yaml.waf.as_ref().and_then(|w| w.threshold),
waf_enabled: yaml.waf.as_ref().map(|w| w.enabled).unwrap_or(true),
access_control: yaml.access_control,
headers: yaml.headers.as_ref().map(|headers| headers.compile()),
shadow_mirror: yaml.shadow_mirror,
}
}
}
#[derive(Debug)]
struct WildcardPattern {
pattern: String,
regex: Regex,
site_index: usize,
}
#[derive(Debug)]
pub struct VhostMatcher {
exact_matches: HashMap<Ascii<String>, usize, RandomState>,
wildcard_patterns: Vec<WildcardPattern>,
sites: Vec<SiteConfig>,
default_site: Option<usize>,
}
impl VhostMatcher {
const MAX_WILDCARDS: usize = 3;
const MAX_HOSTNAME_LEN: usize = 253;
pub fn new(sites: Vec<SiteConfig>) -> Result<Self, VhostError> {
let mut exact_matches = HashMap::with_capacity_and_hasher(sites.len(), RandomState::new());
let mut wildcard_patterns = Vec::with_capacity(sites.len() / 4); let mut default_site = None;
for (index, site) in sites.iter().enumerate() {
if site.hostname.len() > Self::MAX_HOSTNAME_LEN {
return Err(VhostError::HostnameTooLong {
hostname: site.hostname.clone(),
max_len: Self::MAX_HOSTNAME_LEN,
});
}
let normalized = site.hostname.to_lowercase();
if normalized.contains('*') {
let wildcard_count = normalized.matches('*').count();
if wildcard_count > Self::MAX_WILDCARDS {
return Err(VhostError::TooManyWildcards {
pattern: site.hostname.clone(),
count: wildcard_count,
max: Self::MAX_WILDCARDS,
});
}
let regex_pattern = Self::wildcard_to_regex(&normalized);
let regex = Regex::new(®ex_pattern).map_err(|e| VhostError::InvalidPattern {
pattern: site.hostname.clone(),
reason: e.to_string(),
})?;
wildcard_patterns.push(WildcardPattern {
pattern: normalized,
regex,
site_index: index,
});
} else if normalized == "_" || normalized == "default" {
default_site = Some(index);
} else {
exact_matches.insert(Ascii::new(normalized), index);
}
}
wildcard_patterns.sort_by(|a, b| {
let a_segments = a.pattern.matches('.').count();
let b_segments = b.pattern.matches('.').count();
b_segments.cmp(&a_segments)
});
Ok(Self {
exact_matches,
wildcard_patterns,
sites,
default_site,
})
}
pub fn empty() -> Self {
Self {
exact_matches: HashMap::with_hasher(RandomState::new()),
wildcard_patterns: Vec::new(),
sites: Vec::new(),
default_site: None,
}
}
fn wildcard_to_regex(pattern: &str) -> String {
let mut regex = String::from("^");
for ch in pattern.chars() {
match ch {
'*' => regex.push_str("[a-z0-9-]*"),
'.' => regex.push_str("\\."),
'-' => regex.push('-'),
c if c.is_ascii_alphanumeric() => regex.push(c),
_ => regex.push_str(®ex::escape(&ch.to_string())),
}
}
regex.push('$');
regex
}
pub fn sanitize_host(host: &str) -> Result<String, VhostError> {
if host.contains('\0') {
return Err(VhostError::InvalidHost {
host: host.to_string(),
reason: "contains null byte".to_string(),
});
}
if !host.chars().all(|c| c.is_ascii() && !c.is_control()) {
return Err(VhostError::InvalidHost {
host: host.to_string(),
reason: "contains invalid characters".to_string(),
});
}
let hostname = host.split(':').next().unwrap_or(host);
if !hostname.is_empty() && !Self::is_valid_hostname(hostname) {
return Err(VhostError::InvalidHost {
host: host.to_string(),
reason: "invalid hostname characters".to_string(),
});
}
Ok(hostname.to_lowercase())
}
fn is_valid_hostname(hostname: &str) -> bool {
hostname
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.')
&& !hostname.starts_with('-')
&& !hostname.ends_with('-')
}
#[inline]
pub fn match_host(&self, host: &str) -> Option<&SiteConfig> {
let hostname = match Self::sanitize_host(host) {
Ok(h) => h,
Err(e) => {
warn!("Invalid host header: {}", e);
return self.default_site.map(|i| &self.sites[i]);
}
};
if let Some(&index) = self.exact_matches.get(&Ascii::new(hostname.clone())) {
debug!("Exact match for host '{}' -> site {}", hostname, index);
return Some(&self.sites[index]);
}
for pattern in &self.wildcard_patterns {
if pattern.regex.is_match(&hostname) {
debug!(
"Wildcard match for host '{}' -> pattern '{}' -> site {}",
hostname, pattern.pattern, pattern.site_index
);
return Some(&self.sites[pattern.site_index]);
}
}
if let Some(index) = self.default_site {
debug!("Using default site for host '{}'", hostname);
return Some(&self.sites[index]);
}
debug!("No match found for host '{}'", hostname);
None
}
pub fn sites(&self) -> &[SiteConfig] {
&self.sites
}
pub fn site_count(&self) -> usize {
self.sites.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum VhostError {
#[error("hostname '{hostname}' exceeds maximum length of {max_len}")]
HostnameTooLong { hostname: String, max_len: usize },
#[error("pattern '{pattern}' has {count} wildcards, max is {max}")]
TooManyWildcards {
pattern: String,
count: usize,
max: usize,
},
#[error("invalid pattern '{pattern}': {reason}")]
InvalidPattern { pattern: String, reason: String },
#[error("invalid host header '{host}': {reason}")]
InvalidHost { host: String, reason: String },
}
#[cfg(test)]
mod tests {
use super::*;
fn make_site(hostname: &str) -> SiteConfig {
SiteConfig {
hostname: hostname.to_string(),
upstreams: vec!["127.0.0.1:8080".to_string()],
..Default::default()
}
}
#[test]
fn test_exact_match() {
let sites = vec![make_site("example.com"), make_site("api.example.com")];
let matcher = VhostMatcher::new(sites).unwrap();
assert!(matcher.match_host("example.com").is_some());
assert!(matcher.match_host("api.example.com").is_some());
assert!(matcher.match_host("other.com").is_none());
}
#[test]
fn test_case_insensitive() {
let sites = vec![make_site("Example.COM")];
let matcher = VhostMatcher::new(sites).unwrap();
assert!(matcher.match_host("example.com").is_some());
assert!(matcher.match_host("EXAMPLE.COM").is_some());
assert!(matcher.match_host("Example.Com").is_some());
}
#[test]
fn test_wildcard_match() {
let sites = vec![make_site("*.example.com"), make_site("example.com")];
let matcher = VhostMatcher::new(sites).unwrap();
assert!(matcher.match_host("example.com").is_some());
assert!(matcher.match_host("api.example.com").is_some());
assert!(matcher.match_host("www.example.com").is_some());
assert!(matcher.match_host("other.com").is_none());
}
#[test]
fn test_port_stripping() {
let sites = vec![make_site("example.com")];
let matcher = VhostMatcher::new(sites).unwrap();
assert!(matcher.match_host("example.com:8080").is_some());
assert!(matcher.match_host("example.com:443").is_some());
}
#[test]
fn test_default_site() {
let sites = vec![make_site("example.com"), make_site("_")];
let matcher = VhostMatcher::new(sites).unwrap();
assert!(matcher.match_host("example.com").is_some());
assert!(matcher.match_host("unknown.com").is_some()); }
#[test]
fn test_sanitize_null_byte() {
let result = VhostMatcher::sanitize_host("example\0.com");
assert!(result.is_err());
}
#[test]
fn test_sanitize_non_ascii() {
let result = VhostMatcher::sanitize_host("δΎγ.com");
assert!(result.is_err());
}
#[test]
fn test_too_many_wildcards() {
let sites = vec![make_site("*.*.*.*")];
let result = VhostMatcher::new(sites);
assert!(result.is_err());
}
#[test]
fn test_hostname_too_long() {
let long_hostname = "a".repeat(300);
let sites = vec![make_site(&long_hostname)];
let result = VhostMatcher::new(sites);
assert!(result.is_err());
}
#[test]
fn test_wildcard_specificity() {
let sites = vec![make_site("*.example.com"), make_site("*.api.example.com")];
let matcher = VhostMatcher::new(sites).unwrap();
let site = matcher.match_host("v1.api.example.com").unwrap();
assert_eq!(site.hostname, "*.api.example.com");
}
}