1use crate::config::AccessControlConfig;
11use crate::headers::CompiledHeaderConfig;
12use crate::shadow::ShadowMirrorConfig;
13use ahash::RandomState;
14use regex::Regex;
15use std::collections::HashMap;
16use tracing::{debug, warn};
17use unicase::Ascii;
18
19#[derive(Debug, Clone)]
21pub struct SiteConfig {
22 pub hostname: String,
24 pub upstreams: Vec<String>,
26 pub tls_enabled: bool,
28 pub tls_cert: Option<String>,
30 pub tls_key: Option<String>,
32 pub waf_threshold: Option<u8>,
34 pub waf_enabled: bool,
36 pub access_control: Option<AccessControlConfig>,
38 pub headers: Option<CompiledHeaderConfig>,
40 pub shadow_mirror: Option<ShadowMirrorConfig>,
42}
43
44impl Default for SiteConfig {
45 fn default() -> Self {
46 Self {
47 hostname: String::new(),
48 upstreams: Vec::new(),
49 tls_enabled: false,
50 tls_cert: None,
51 tls_key: None,
52 waf_threshold: None,
53 waf_enabled: true,
54 access_control: None,
55 headers: None,
56 shadow_mirror: None,
57 }
58 }
59}
60
61impl From<crate::config::SiteYamlConfig> for SiteConfig {
62 fn from(yaml: crate::config::SiteYamlConfig) -> Self {
63 Self {
64 hostname: yaml.hostname,
65 upstreams: yaml
66 .upstreams
67 .iter()
68 .map(|u| format!("{}:{}", u.host, u.port))
69 .collect(),
70 tls_enabled: yaml.tls.is_some(),
71 tls_cert: yaml.tls.as_ref().map(|t| t.cert_path.clone()),
72 tls_key: yaml.tls.as_ref().map(|t| t.key_path.clone()),
73 waf_threshold: yaml.waf.as_ref().and_then(|w| w.threshold),
74 waf_enabled: yaml.waf.as_ref().map(|w| w.enabled).unwrap_or(true),
75 access_control: yaml.access_control,
76 headers: yaml.headers.as_ref().map(|headers| headers.compile()),
77 shadow_mirror: yaml.shadow_mirror,
78 }
79 }
80}
81
82#[derive(Debug)]
84struct WildcardPattern {
85 pattern: String,
87 regex: Regex,
89 site_index: usize,
91}
92
93#[derive(Debug)]
104pub struct VhostMatcher {
105 exact_matches: HashMap<Ascii<String>, usize, RandomState>,
107 wildcard_patterns: Vec<WildcardPattern>,
109 sites: Vec<SiteConfig>,
111 default_site: Option<usize>,
113}
114
115impl VhostMatcher {
116 const MAX_WILDCARDS: usize = 3;
118 const MAX_HOSTNAME_LEN: usize = 253;
120
121 pub fn new(sites: Vec<SiteConfig>) -> Result<Self, VhostError> {
129 let mut exact_matches = HashMap::with_capacity_and_hasher(sites.len(), RandomState::new());
131 let mut wildcard_patterns = Vec::with_capacity(sites.len() / 4); let mut default_site = None;
133
134 for (index, site) in sites.iter().enumerate() {
135 if site.hostname.len() > Self::MAX_HOSTNAME_LEN {
137 return Err(VhostError::HostnameTooLong {
138 hostname: site.hostname.clone(),
139 max_len: Self::MAX_HOSTNAME_LEN,
140 });
141 }
142
143 let normalized = site.hostname.to_lowercase();
145
146 if normalized.contains('*') {
148 let wildcard_count = normalized.matches('*').count();
150 if wildcard_count > Self::MAX_WILDCARDS {
151 return Err(VhostError::TooManyWildcards {
152 pattern: site.hostname.clone(),
153 count: wildcard_count,
154 max: Self::MAX_WILDCARDS,
155 });
156 }
157
158 let regex_pattern = Self::wildcard_to_regex(&normalized);
160 let regex = Regex::new(®ex_pattern).map_err(|e| VhostError::InvalidPattern {
161 pattern: site.hostname.clone(),
162 reason: e.to_string(),
163 })?;
164
165 wildcard_patterns.push(WildcardPattern {
166 pattern: normalized,
167 regex,
168 site_index: index,
169 });
170 } else if normalized == "_" || normalized == "default" {
171 default_site = Some(index);
173 } else {
174 exact_matches.insert(Ascii::new(normalized), index);
176 }
177 }
178
179 wildcard_patterns.sort_by(|a, b| {
181 let a_segments = a.pattern.matches('.').count();
183 let b_segments = b.pattern.matches('.').count();
184 b_segments.cmp(&a_segments)
185 });
186
187 Ok(Self {
188 exact_matches,
189 wildcard_patterns,
190 sites,
191 default_site,
192 })
193 }
194
195 pub fn empty() -> Self {
197 Self {
198 exact_matches: HashMap::with_hasher(RandomState::new()),
199 wildcard_patterns: Vec::new(),
200 sites: Vec::new(),
201 default_site: None,
202 }
203 }
204
205 fn wildcard_to_regex(pattern: &str) -> String {
207 let mut regex = String::from("^");
208 for ch in pattern.chars() {
209 match ch {
210 '*' => regex.push_str("[a-z0-9-]*"),
211 '.' => regex.push_str("\\."),
212 '-' => regex.push('-'),
213 c if c.is_ascii_alphanumeric() => regex.push(c),
214 _ => regex.push_str(®ex::escape(&ch.to_string())),
215 }
216 }
217 regex.push('$');
218 regex
219 }
220
221 pub fn sanitize_host(host: &str) -> Result<String, VhostError> {
229 if host.contains('\0') {
231 return Err(VhostError::InvalidHost {
232 host: host.to_string(),
233 reason: "contains null byte".to_string(),
234 });
235 }
236
237 if !host.chars().all(|c| c.is_ascii() && !c.is_control()) {
239 return Err(VhostError::InvalidHost {
240 host: host.to_string(),
241 reason: "contains invalid characters".to_string(),
242 });
243 }
244
245 let hostname = host.split(':').next().unwrap_or(host);
247
248 if !hostname.is_empty() && !Self::is_valid_hostname(hostname) {
250 return Err(VhostError::InvalidHost {
251 host: host.to_string(),
252 reason: "invalid hostname characters".to_string(),
253 });
254 }
255
256 Ok(hostname.to_lowercase())
257 }
258
259 fn is_valid_hostname(hostname: &str) -> bool {
261 hostname
262 .chars()
263 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.')
264 && !hostname.starts_with('-')
265 && !hostname.ends_with('-')
266 }
267
268 #[inline]
279 pub fn match_host(&self, host: &str) -> Option<&SiteConfig> {
280 let hostname = match Self::sanitize_host(host) {
282 Ok(h) => h,
283 Err(e) => {
284 warn!("Invalid host header: {}", e);
285 return self.default_site.map(|i| &self.sites[i]);
286 }
287 };
288
289 if let Some(&index) = self.exact_matches.get(&Ascii::new(hostname.clone())) {
291 debug!("Exact match for host '{}' -> site {}", hostname, index);
292 return Some(&self.sites[index]);
293 }
294
295 for pattern in &self.wildcard_patterns {
297 if pattern.regex.is_match(&hostname) {
298 debug!(
299 "Wildcard match for host '{}' -> pattern '{}' -> site {}",
300 hostname, pattern.pattern, pattern.site_index
301 );
302 return Some(&self.sites[pattern.site_index]);
303 }
304 }
305
306 if let Some(index) = self.default_site {
308 debug!("Using default site for host '{}'", hostname);
309 return Some(&self.sites[index]);
310 }
311
312 debug!("No match found for host '{}'", hostname);
313 None
314 }
315
316 pub fn sites(&self) -> &[SiteConfig] {
318 &self.sites
319 }
320
321 pub fn site_count(&self) -> usize {
323 self.sites.len()
324 }
325}
326
327#[derive(Debug, thiserror::Error)]
329pub enum VhostError {
330 #[error("hostname '{hostname}' exceeds maximum length of {max_len}")]
331 HostnameTooLong { hostname: String, max_len: usize },
332
333 #[error("pattern '{pattern}' has {count} wildcards, max is {max}")]
334 TooManyWildcards {
335 pattern: String,
336 count: usize,
337 max: usize,
338 },
339
340 #[error("invalid pattern '{pattern}': {reason}")]
341 InvalidPattern { pattern: String, reason: String },
342
343 #[error("invalid host header '{host}': {reason}")]
344 InvalidHost { host: String, reason: String },
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 fn make_site(hostname: &str) -> SiteConfig {
352 SiteConfig {
353 hostname: hostname.to_string(),
354 upstreams: vec!["127.0.0.1:8080".to_string()],
355 ..Default::default()
356 }
357 }
358
359 #[test]
360 fn test_exact_match() {
361 let sites = vec![make_site("example.com"), make_site("api.example.com")];
362 let matcher = VhostMatcher::new(sites).unwrap();
363
364 assert!(matcher.match_host("example.com").is_some());
365 assert!(matcher.match_host("api.example.com").is_some());
366 assert!(matcher.match_host("other.com").is_none());
367 }
368
369 #[test]
370 fn test_case_insensitive() {
371 let sites = vec![make_site("Example.COM")];
372 let matcher = VhostMatcher::new(sites).unwrap();
373
374 assert!(matcher.match_host("example.com").is_some());
375 assert!(matcher.match_host("EXAMPLE.COM").is_some());
376 assert!(matcher.match_host("Example.Com").is_some());
377 }
378
379 #[test]
380 fn test_wildcard_match() {
381 let sites = vec![make_site("*.example.com"), make_site("example.com")];
382 let matcher = VhostMatcher::new(sites).unwrap();
383
384 assert!(matcher.match_host("example.com").is_some());
385 assert!(matcher.match_host("api.example.com").is_some());
386 assert!(matcher.match_host("www.example.com").is_some());
387 assert!(matcher.match_host("other.com").is_none());
388 }
389
390 #[test]
391 fn test_port_stripping() {
392 let sites = vec![make_site("example.com")];
393 let matcher = VhostMatcher::new(sites).unwrap();
394
395 assert!(matcher.match_host("example.com:8080").is_some());
396 assert!(matcher.match_host("example.com:443").is_some());
397 }
398
399 #[test]
400 fn test_default_site() {
401 let sites = vec![make_site("example.com"), make_site("_")];
402 let matcher = VhostMatcher::new(sites).unwrap();
403
404 assert!(matcher.match_host("example.com").is_some());
405 assert!(matcher.match_host("unknown.com").is_some()); }
407
408 #[test]
409 fn test_sanitize_null_byte() {
410 let result = VhostMatcher::sanitize_host("example\0.com");
411 assert!(result.is_err());
412 }
413
414 #[test]
415 fn test_sanitize_non_ascii() {
416 let result = VhostMatcher::sanitize_host("δΎγ.com");
417 assert!(result.is_err());
418 }
419
420 #[test]
421 fn test_too_many_wildcards() {
422 let sites = vec![make_site("*.*.*.*")];
423 let result = VhostMatcher::new(sites);
424 assert!(result.is_err());
425 }
426
427 #[test]
428 fn test_hostname_too_long() {
429 let long_hostname = "a".repeat(300);
430 let sites = vec![make_site(&long_hostname)];
431 let result = VhostMatcher::new(sites);
432 assert!(result.is_err());
433 }
434
435 #[test]
436 fn test_wildcard_specificity() {
437 let sites = vec![make_site("*.example.com"), make_site("*.api.example.com")];
438 let matcher = VhostMatcher::new(sites).unwrap();
439
440 let site = matcher.match_host("v1.api.example.com").unwrap();
442 assert_eq!(site.hostname, "*.api.example.com");
443 }
444}