pub mod pattern_trie;
use std::{
fmt::{self, Debug, Write},
rc::Rc,
str::from_utf8,
time::Instant,
};
use regex::bytes::Regex;
use sozu_command::{
logging::CachedTags,
proto::command::{
HeaderPosition, HstsConfig, PathRule as CommandPathRule, PathRuleKind, RedirectPolicy,
RedirectScheme, RulePosition,
},
response::HttpFrontend,
state::ClusterId,
};
use crate::metrics::names;
use crate::{
protocol::{http::editor::HeaderEditMode, http::parser::Method},
router::pattern_trie::{TrieMatches, TrieNode, TrieSubMatch},
sozu_command::logging::ansi_palette,
};
macro_rules! log_module_context {
() => {{
let (open, reset, _, _, _) = ansi_palette();
format!("{open}ROUTER{reset}\t >>>", open = open, reset = reset)
}};
}
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum RouterError {
#[error("Could not parse rule from frontend path {0:?}")]
InvalidPathRule(String),
#[error("parsing hostname {hostname} failed")]
InvalidDomain { hostname: String },
#[error("Could not parse host rewrite {0:?}")]
InvalidHostRewrite(String),
#[error("Could not parse path rewrite {0:?}")]
InvalidPathRewrite(String),
#[error("Could not add route {0}")]
AddRoute(String),
#[error("Could not remove route {0}")]
RemoveRoute(String),
#[error("no route for {method} {host} {path}")]
RouteNotFound {
host: String,
path: String,
method: Method,
},
}
pub struct Router {
pre: Vec<(DomainRule, PathRule, MethodRule, Route)>,
pub tree: TrieNode<Vec<(PathRule, MethodRule, Route)>>,
post: Vec<(DomainRule, PathRule, MethodRule, Route)>,
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Router {
pub fn new() -> Router {
Router {
pre: Vec::new(),
tree: TrieNode::root(),
post: Vec::new(),
}
}
pub fn lookup(
&self,
hostname: &str,
path: &str,
method: &Method,
) -> Result<RouteResult, RouterError> {
let hostname_b = hostname.as_bytes();
let path_b = path.as_bytes();
for (domain_rule, path_rule, method_rule, route) in &self.pre {
if domain_rule.matches(hostname_b)
&& path_rule.matches(path_b) != PathRuleResult::None
&& method_rule.matches(method) != MethodRuleResult::None
{
return Ok(RouteResult::new_no_trie(
hostname_b,
domain_rule,
path_b,
path_rule,
route,
));
}
}
let trie_path: TrieMatches<'_, '_> = Vec::with_capacity(16);
if let Some(((_, path_rules), trie_matches)) =
self.tree.lookup_with_path(hostname_b, true, trie_path)
{
let mut prefix_length = 0;
let mut matched: Option<(&PathRule, &Route)> = None;
for (rule, method_rule, route) in path_rules {
match rule.matches(path_b) {
PathRuleResult::Regex | PathRuleResult::Equals => {
match method_rule.matches(method) {
MethodRuleResult::Equals => {
return Ok(RouteResult::new_with_trie(
hostname_b,
trie_matches,
path_b,
rule,
route,
));
}
MethodRuleResult::All => {
prefix_length = path_b.len();
matched = Some((rule, route));
}
MethodRuleResult::None => {}
}
}
PathRuleResult::Prefix(size) => {
if size >= prefix_length {
match method_rule.matches(method) {
MethodRuleResult::Equals => {
debug_assert!(
size >= prefix_length,
"longest-prefix selection must never shrink the match length",
);
prefix_length = size;
matched = Some((rule, route));
}
MethodRuleResult::All => {
debug_assert!(
size >= prefix_length,
"longest-prefix selection must never shrink the match length",
);
prefix_length = size;
matched = Some((rule, route));
}
MethodRuleResult::None => {}
}
}
}
PathRuleResult::None => {}
}
}
if let Some((path_rule, route)) = matched {
return Ok(RouteResult::new_with_trie(
hostname_b,
trie_matches,
path_b,
path_rule,
route,
));
}
}
for (domain_rule, path_rule, method_rule, route) in self.post.iter() {
if domain_rule.matches(hostname_b)
&& path_rule.matches(path_b) != PathRuleResult::None
&& method_rule.matches(method) != MethodRuleResult::None
{
return Ok(RouteResult::new_no_trie(
hostname_b,
domain_rule,
path_b,
path_rule,
route,
));
}
}
Err(RouterError::RouteNotFound {
host: hostname.to_owned(),
path: path.to_owned(),
method: method.to_owned(),
})
}
pub fn add_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> {
self.add_http_front_with_hsts_origin(front, HstsOrigin::Explicit)
}
pub fn add_http_front_with_hsts_origin(
&mut self,
front: &HttpFrontend,
hsts_origin: HstsOrigin,
) -> Result<(), RouterError> {
let path_rule = PathRule::from_config(front.path.clone())
.ok_or(RouterError::InvalidPathRule(front.path.to_string()))?;
let method_rule = MethodRule::new(front.method.clone());
let has_policy = front.redirect.is_some()
|| front.redirect_scheme.is_some()
|| front.redirect_template.is_some()
|| front.rewrite_host.is_some()
|| front.rewrite_path.is_some()
|| front.rewrite_port.is_some()
|| front.required_auth.unwrap_or(false)
|| !front.headers.is_empty()
|| front.hsts.is_some();
let domain =
front
.hostname
.parse::<DomainRule>()
.map_err(|_| RouterError::InvalidDomain {
hostname: front.hostname.clone(),
})?;
let route = if has_policy {
let redirect = front
.redirect
.and_then(|r| RedirectPolicy::try_from(r).ok())
.unwrap_or(RedirectPolicy::Forward);
let redirect_scheme = front
.redirect_scheme
.and_then(|s| RedirectScheme::try_from(s).ok())
.unwrap_or(RedirectScheme::UseSame);
let frontend = Frontend::new(
&domain,
&path_rule,
front,
redirect,
redirect_scheme,
front.redirect_template.clone(),
front.rewrite_host.clone(),
front.rewrite_path.clone(),
front.rewrite_port.and_then(|p| u16::try_from(p).ok()),
&front.headers,
front.required_auth.unwrap_or(false),
hsts_origin,
)?;
Route::Frontend(Rc::new(frontend))
} else {
match &front.cluster_id {
Some(cluster_id) => Route::ClusterId(cluster_id.clone()),
None => Route::Deny,
}
};
let success = match front.position {
RulePosition::Pre => self.add_pre_rule(&domain, &path_rule, &method_rule, &route),
RulePosition::Post => self.add_post_rule(&domain, &path_rule, &method_rule, &route),
RulePosition::Tree => {
self.add_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule, &route)
}
};
if !success {
return Err(RouterError::AddRoute(format!("{front:?}")));
}
Ok(())
}
pub fn remove_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> {
let path_rule = PathRule::from_config(front.path.clone())
.ok_or(RouterError::InvalidPathRule(front.path.to_string()))?;
let method_rule = MethodRule::new(front.method.clone());
let remove_success = match front.position {
RulePosition::Pre => {
let domain = front.hostname.parse::<DomainRule>().map_err(|_| {
RouterError::InvalidDomain {
hostname: front.hostname.clone(),
}
})?;
self.remove_pre_rule(&domain, &path_rule, &method_rule)
}
RulePosition::Post => {
let domain = front.hostname.parse::<DomainRule>().map_err(|_| {
RouterError::InvalidDomain {
hostname: front.hostname.clone(),
}
})?;
self.remove_post_rule(&domain, &path_rule, &method_rule)
}
RulePosition::Tree => {
self.remove_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule)
}
};
if !remove_success {
return Err(RouterError::RemoveRoute(format!("{front:?}")));
}
Ok(())
}
pub fn add_tree_rule(
&mut self,
hostname: &[u8],
path: &PathRule,
method: &MethodRule,
cluster: &Route,
) -> bool {
let hostname = match from_utf8(hostname) {
Err(_) => return false,
Ok(h) => h,
};
match ::idna::domain_to_ascii(hostname) {
Ok(hostname) => {
let mut empty = true;
if let Some((_, paths)) = self.tree.domain_lookup_mut(hostname.as_bytes(), false) {
empty = false;
let before = paths.len();
if !paths.iter().any(|(p, m, _)| p == path && m == method) {
paths.push((path.to_owned(), method.to_owned(), cluster.to_owned()));
debug_assert_eq!(
paths.len(),
before + 1,
"appending a tree rule must grow the leaf's rule list by exactly one",
);
debug_assert!(
paths.iter().any(|(p, m, _)| p == path && m == method),
"the freshly appended (path, method) rule must be present after insert",
);
return true;
}
}
if empty {
let inserted_host = hostname.clone().into_bytes();
self.tree.domain_insert(
hostname.into_bytes(),
vec![(path.to_owned(), method.to_owned(), cluster.to_owned())],
);
debug_assert!(
self.tree
.domain_lookup_mut(&inserted_host, false)
.is_some_and(|(_, paths)| paths
.iter()
.any(|(p, m, _)| p == path && m == method)),
"a freshly inserted tree domain must resolve to its inserted rule",
);
return true;
}
false
}
Err(_) => false,
}
}
pub fn remove_tree_rule(
&mut self,
hostname: &[u8],
path: &PathRule,
method: &MethodRule,
) -> bool {
let hostname = match from_utf8(hostname) {
Err(_) => return false,
Ok(h) => h,
};
match ::idna::domain_to_ascii(hostname) {
Ok(hostname) => {
let should_delete = {
let paths_opt = self.tree.domain_lookup_mut(hostname.as_bytes(), false);
if let Some((_, paths)) = paths_opt {
paths.retain(|(p, m, _)| p != path || m != method);
debug_assert!(
!paths.iter().any(|(p, m, _)| p == path && m == method),
"remove must evict every matching (path, method) rule from the leaf",
);
}
paths_opt
.as_ref()
.map(|(_, paths)| paths.is_empty())
.unwrap_or(false)
};
if should_delete {
let removed_host = hostname.clone().into_bytes();
self.tree.domain_remove(&hostname.into_bytes());
debug_assert!(
self.tree.domain_lookup_mut(&removed_host, false).is_none(),
"a domain whose last rule was removed must be unreachable",
);
}
true
}
Err(_) => false,
}
}
pub fn refresh_inheriting_hsts(&mut self, new_hsts: Option<&HstsConfig>) -> usize {
let mut refreshed = 0usize;
let new_edit = build_listener_hsts_edit(new_hsts);
let new_edit_ref = new_edit.as_ref();
let promote_lightweight = new_edit_ref.is_some();
let mut visit = |route: &mut Route| match route {
Route::Frontend(rc) => {
if rc.inherits_listener_hsts {
let new_frontend = rebuild_with_listener_hsts(rc, new_edit_ref);
*rc = Rc::new(new_frontend);
refreshed += 1;
}
}
Route::ClusterId(id) => {
if promote_lightweight {
let promoted = rebuild_with_listener_hsts(
&Frontend::minimal_forward(id.clone()),
new_edit_ref,
);
*route = Route::Frontend(Rc::new(promoted));
refreshed += 1;
}
}
Route::Deny => {
if promote_lightweight {
let promoted =
rebuild_with_listener_hsts(&Frontend::minimal_deny(), new_edit_ref);
*route = Route::Frontend(Rc::new(promoted));
refreshed += 1;
}
}
};
for (_, _, _, route) in self.pre.iter_mut() {
visit(route);
}
self.tree.for_each_value_mut(&mut |paths| {
for (_, _, route) in paths.iter_mut() {
visit(route);
}
});
for (_, _, _, route) in self.post.iter_mut() {
visit(route);
}
refreshed
}
pub fn add_pre_rule(
&mut self,
domain: &DomainRule,
path: &PathRule,
method: &MethodRule,
cluster_id: &Route,
) -> bool {
let before = self.pre.len();
if !self
.pre
.iter()
.any(|(d, p, m, _)| d == domain && p == path && m == method)
{
self.pre.push((
domain.to_owned(),
path.to_owned(),
method.to_owned(),
cluster_id.to_owned(),
));
debug_assert_eq!(
self.pre.len(),
before + 1,
"adding a unique pre-rule must push exactly one entry",
);
debug_assert!(
self.pre
.iter()
.any(|(d, p, m, _)| d == domain && p == path && m == method),
"the freshly added pre-rule must be present",
);
true
} else {
debug_assert_eq!(
self.pre.len(),
before,
"a duplicate pre-rule must not change the list length",
);
false
}
}
pub fn add_post_rule(
&mut self,
domain: &DomainRule,
path: &PathRule,
method: &MethodRule,
cluster_id: &Route,
) -> bool {
let before = self.post.len();
if !self
.post
.iter()
.any(|(d, p, m, _)| d == domain && p == path && m == method)
{
self.post.push((
domain.to_owned(),
path.to_owned(),
method.to_owned(),
cluster_id.to_owned(),
));
debug_assert_eq!(
self.post.len(),
before + 1,
"adding a unique post-rule must push exactly one entry",
);
debug_assert!(
self.post
.iter()
.any(|(d, p, m, _)| d == domain && p == path && m == method),
"the freshly added post-rule must be present",
);
true
} else {
debug_assert_eq!(
self.post.len(),
before,
"a duplicate post-rule must not change the list length",
);
false
}
}
pub fn remove_pre_rule(
&mut self,
domain: &DomainRule,
path: &PathRule,
method: &MethodRule,
) -> bool {
let before = self.pre.len();
match self
.pre
.iter()
.position(|(d, p, m, _)| d == domain && p == path && m == method)
{
None => {
debug_assert_eq!(
self.pre.len(),
before,
"a no-op pre-rule removal must not change the list length",
);
false
}
Some(index) => {
debug_assert!(index < self.pre.len(), "found index must be in bounds");
self.pre.remove(index);
debug_assert_eq!(
self.pre.len() + 1,
before,
"removing a pre-rule must drop exactly one entry",
);
debug_assert!(
!self
.pre
.iter()
.any(|(d, p, m, _)| d == domain && p == path && m == method),
"the removed pre-rule must no longer be present",
);
true
}
}
}
pub fn remove_post_rule(
&mut self,
domain: &DomainRule,
path: &PathRule,
method: &MethodRule,
) -> bool {
let before = self.post.len();
match self
.post
.iter()
.position(|(d, p, m, _)| d == domain && p == path && m == method)
{
None => {
debug_assert_eq!(
self.post.len(),
before,
"a no-op post-rule removal must not change the list length",
);
false
}
Some(index) => {
debug_assert!(index < self.post.len(), "found index must be in bounds");
self.post.remove(index);
debug_assert_eq!(
self.post.len() + 1,
before,
"removing a post-rule must drop exactly one entry",
);
debug_assert!(
!self
.post
.iter()
.any(|(d, p, m, _)| d == domain && p == path && m == method),
"the removed post-rule must no longer be present",
);
true
}
}
}
pub fn has_hostname(&self, hostname: &str) -> bool {
let hostname_b = hostname.as_bytes();
for (domain_rule, _, _, _) in &self.pre {
if domain_rule.matches(hostname_b) {
return true;
}
}
if let Ok(ascii_hostname) = ::idna::domain_to_ascii(hostname) {
if self
.tree
.domain_lookup(ascii_hostname.as_bytes(), false)
.is_some()
{
return true;
}
}
for (domain_rule, _, _, _) in &self.post {
if domain_rule.matches(hostname_b) {
return true;
}
}
false
}
}
#[derive(Clone, Debug)]
pub enum DomainRule {
Any,
Exact(String),
Wildcard(String),
Regex(Regex),
}
fn convert_regex_domain_rule(hostname: &str) -> Option<String> {
let mut result = String::from("\\A");
let s = hostname.as_bytes();
let mut index = 0;
loop {
if s[index] == b'/' {
let mut found = false;
for i in index + 1..s.len() {
if s[i] == b'/' {
match std::str::from_utf8(&s[index + 1..i]) {
Ok(r) => result.push_str(r),
Err(_) => return None,
}
index = i + 1;
found = true;
break;
}
}
if !found {
return None;
}
} else {
let start = index;
for i in start..s.len() + 1 {
index = i;
if i < s.len() && s[i] == b'.' {
match std::str::from_utf8(&s[start..i]) {
Ok(r) => result.push_str(r),
Err(_) => return None,
}
break;
}
}
if index == s.len() {
match std::str::from_utf8(&s[start..]) {
Ok(r) => result.push_str(r),
Err(_) => return None,
}
}
}
if index == s.len() {
result.push_str("\\z");
return Some(result);
} else if s[index] == b'.' {
result.push_str("\\.");
index += 1;
} else {
return None;
}
}
}
impl DomainRule {
pub fn matches(&self, hostname: &[u8]) -> bool {
match self {
DomainRule::Any => true,
DomainRule::Wildcard(s) => {
debug_assert_eq!(
s.as_bytes().first(),
Some(&b'*'),
"a Wildcard rule must retain its leading '*'",
);
let suffix = &s.as_bytes()[1..];
let matched = hostname
.strip_suffix(suffix)
.is_some_and(|prefix| !prefix.is_empty() && !prefix.contains(&b'.'));
debug_assert!(
!matched || hostname.len() > suffix.len(),
"a wildcard match requires a non-empty leftmost label before the suffix",
);
matched
}
DomainRule::Exact(s) => s.as_bytes() == hostname,
DomainRule::Regex(r) => {
let start = Instant::now();
let is_a_match = r.is_match(hostname);
let now = Instant::now();
time!(
names::event_loop::REGEX_MATCHING_TIME,
(now - start).as_millis()
);
is_a_match
}
}
}
}
impl std::cmp::PartialEq for DomainRule {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(DomainRule::Any, DomainRule::Any) => true,
(DomainRule::Wildcard(s1), DomainRule::Wildcard(s2)) => s1 == s2,
(DomainRule::Exact(s1), DomainRule::Exact(s2)) => s1 == s2,
(DomainRule::Regex(r1), DomainRule::Regex(r2)) => r1.as_str() == r2.as_str(),
_ => false,
}
}
}
impl std::str::FromStr for DomainRule {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if s == "*" {
DomainRule::Any
} else if s.contains('/') {
match convert_regex_domain_rule(s) {
Some(s) => match regex::bytes::Regex::new(&s) {
Ok(r) => DomainRule::Regex(r),
Err(_) => return Err(()),
},
None => return Err(()),
}
} else if s.contains('*') {
if s.starts_with('*') {
match ::idna::domain_to_ascii(s) {
Ok(r) => DomainRule::Wildcard(r),
Err(_) => return Err(()),
}
} else {
return Err(());
}
} else {
match ::idna::domain_to_ascii(s) {
Ok(r) => DomainRule::Exact(r),
Err(_) => return Err(()),
}
})
}
}
#[derive(Clone, Debug)]
pub enum PathRule {
Prefix(String),
Regex(Regex),
Equals(String),
}
#[derive(PartialEq, Eq)]
pub enum PathRuleResult {
Regex,
Prefix(usize),
Equals,
None,
}
impl PathRule {
pub fn matches(&self, path: &[u8]) -> PathRuleResult {
match self {
PathRule::Prefix(prefix) => {
if path.starts_with(prefix.as_bytes()) {
debug_assert!(
prefix.len() <= path.len(),
"a matching prefix cannot be longer than the path it matched",
);
PathRuleResult::Prefix(prefix.len())
} else {
PathRuleResult::None
}
}
PathRule::Regex(regex) => {
let start = Instant::now();
let is_a_match = regex.is_match(path);
let now = Instant::now();
time!(
names::event_loop::REGEX_MATCHING_TIME,
(now - start).as_millis()
);
if is_a_match {
PathRuleResult::Regex
} else {
PathRuleResult::None
}
}
PathRule::Equals(pattern) => {
if path == pattern.as_bytes() {
PathRuleResult::Equals
} else {
PathRuleResult::None
}
}
}
}
pub fn from_config(rule: CommandPathRule) -> Option<Self> {
match PathRuleKind::try_from(rule.kind) {
Ok(PathRuleKind::Prefix) => Some(PathRule::Prefix(rule.value)),
Ok(PathRuleKind::Regex) => Regex::new(&rule.value).ok().map(PathRule::Regex),
Ok(PathRuleKind::Equals) => Some(PathRule::Equals(rule.value)),
Err(_) => None,
}
}
}
impl std::cmp::PartialEq for PathRule {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(PathRule::Prefix(s1), PathRule::Prefix(s2)) => s1 == s2,
(PathRule::Regex(r1), PathRule::Regex(r2)) => r1.as_str() == r2.as_str(),
_ => false,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MethodRule {
pub inner: Option<Method>,
}
#[derive(PartialEq, Eq)]
pub enum MethodRuleResult {
All,
Equals,
None,
}
impl MethodRule {
pub fn new(method: Option<String>) -> Self {
MethodRule {
inner: method.map(|s| Method::new(s.as_bytes())),
}
}
pub fn matches(&self, method: &Method) -> MethodRuleResult {
match self.inner {
None => MethodRuleResult::All,
Some(ref m) => {
if method == m {
MethodRuleResult::Equals
} else {
MethodRuleResult::None
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum Route {
Deny,
ClusterId(ClusterId),
Frontend(Rc<Frontend>),
}
fn build_listener_hsts_edit(new_hsts: Option<&HstsConfig>) -> Option<HeaderEdit> {
let cfg = new_hsts?;
if !matches!(cfg.enabled, Some(true)) {
return None;
}
let rendered = render_hsts(cfg)?;
let mode = if matches!(cfg.force_replace_backend, Some(true)) {
HeaderEditMode::Set
} else {
HeaderEditMode::SetIfAbsent
};
Some(HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: rendered.into_bytes().into(),
mode,
})
}
fn rebuild_with_listener_hsts(frontend: &Frontend, new_edit: Option<&HeaderEdit>) -> Frontend {
let mut headers_response: Vec<HeaderEdit> = frontend
.headers_response
.iter()
.filter(|edit| !edit.key.eq_ignore_ascii_case(b"strict-transport-security"))
.cloned()
.collect();
if let Some(edit) = new_edit {
headers_response.push(edit.clone());
}
Frontend {
headers_response: headers_response.into(),
..frontend.clone()
}
}
pub fn render_hsts(cfg: &HstsConfig) -> Option<String> {
let max_age = cfg.max_age?;
let mut s = format!("max-age={max_age}");
if matches!(cfg.include_subdomains, Some(true)) {
s.push_str("; includeSubDomains");
}
if matches!(cfg.preload, Some(true)) {
s.push_str("; preload");
}
Some(s)
}
#[derive(Clone, PartialEq, Eq)]
pub struct HeaderEdit {
pub key: Rc<[u8]>,
pub val: Rc<[u8]>,
pub mode: HeaderEditMode,
}
impl Debug for HeaderEdit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"({:?}, {:?}, {:?})",
String::from_utf8_lossy(&self.key),
String::from_utf8_lossy(&self.val),
self.mode,
))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum RewritePart {
String(String),
Host(usize),
Path(usize),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RewriteParts(Vec<RewritePart>);
impl RewriteParts {
pub fn parse(
template: &str,
host_cap_cap: usize,
path_cap_cap: usize,
used_index_host: &mut usize,
used_index_path: &mut usize,
) -> Option<Self> {
let mut result = Vec::new();
let mut i = 0;
let pattern = template.as_bytes();
while i < pattern.len() {
if pattern[i] == b'$' {
let is_host = if pattern[i..].starts_with(b"$HOST[") {
i += 6;
true
} else if pattern[i..].starts_with(b"$PATH[") {
i += 6;
false
} else {
return None;
};
let mut index = 0usize;
let digits_start = i;
while i < pattern.len() && pattern[i].is_ascii_digit() {
index = index
.checked_mul(10)?
.checked_add((pattern[i] - b'0') as usize)?;
i += 1;
}
if i == digits_start {
return None;
}
if i >= pattern.len() || pattern[i] != b']' {
return None;
}
if is_host {
if index >= host_cap_cap {
return None;
}
if index >= *used_index_host {
*used_index_host = index + 1;
}
result.push(RewritePart::Host(index));
} else {
if index >= path_cap_cap {
return None;
}
if index >= *used_index_path {
*used_index_path = index + 1;
}
result.push(RewritePart::Path(index));
}
i += 1; } else {
let start = i;
while i < pattern.len() && pattern[i] != b'$' {
i += 1;
}
result.push(RewritePart::String(template[start..i].to_owned()));
}
}
debug_assert!(
result.iter().all(|part| match part {
RewritePart::Host(idx) => *idx < host_cap_cap,
RewritePart::Path(idx) => *idx < path_cap_cap,
RewritePart::String(_) => true,
}),
"a parsed rewrite template must only reference captures within the rule's caps",
);
debug_assert!(
*used_index_host <= host_cap_cap && *used_index_path <= path_cap_cap,
"the highest referenced capture index cannot exceed the cap",
);
Some(Self(result))
}
pub fn run(&self, host_captures: &[&str], path_captures: &[&str]) -> String {
let mut cap = 0usize;
for part in &self.0 {
cap += match part {
RewritePart::String(s) => s.len(),
RewritePart::Host(i) => host_captures.get(*i).map(|s| s.len()).unwrap_or(0),
RewritePart::Path(i) => path_captures.get(*i).map(|s| s.len()).unwrap_or(0),
};
}
let mut result = String::with_capacity(cap);
for part in &self.0 {
let _ = match part {
RewritePart::String(s) => result.write_str(s),
RewritePart::Host(i) => result.write_str(host_captures.get(*i).unwrap_or(&"")),
RewritePart::Path(i) => result.write_str(path_captures.get(*i).unwrap_or(&"")),
};
}
debug_assert_eq!(
result.len(),
cap,
"rewrite output length must equal the pre-computed one-pass capacity",
);
result
}
}
#[derive(Debug, Clone)]
pub struct Frontend {
pub cluster_id: Option<ClusterId>,
pub redirect: RedirectPolicy,
pub redirect_scheme: RedirectScheme,
pub redirect_template: Option<String>,
pub capture_cap_host: usize,
pub capture_cap_path: usize,
pub rewrite_host: Option<RewriteParts>,
pub rewrite_path: Option<RewriteParts>,
pub rewrite_port: Option<u16>,
pub headers_request: Rc<[HeaderEdit]>,
pub headers_response: Rc<[HeaderEdit]>,
pub required_auth: bool,
pub tags: Option<Rc<CachedTags>>,
pub inherits_listener_hsts: bool,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum HstsOrigin {
Explicit,
InheritedFromListenerDefault,
}
impl PartialEq for Frontend {
fn eq(&self, other: &Self) -> bool {
self.cluster_id == other.cluster_id
&& self.redirect == other.redirect
&& self.redirect_scheme == other.redirect_scheme
&& self.redirect_template == other.redirect_template
&& self.rewrite_host == other.rewrite_host
&& self.rewrite_path == other.rewrite_path
&& self.rewrite_port == other.rewrite_port
&& self.headers_request == other.headers_request
&& self.headers_response == other.headers_response
&& self.required_auth == other.required_auth
}
}
impl Eq for Frontend {}
impl std::hash::Hash for Frontend {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.cluster_id.hash(state);
(self.redirect as i32).hash(state);
(self.redirect_scheme as i32).hash(state);
self.redirect_template.hash(state);
self.required_auth.hash(state);
}
}
impl PartialOrd for Frontend {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Frontend {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.cluster_id
.cmp(&other.cluster_id)
.then_with(|| (self.redirect as i32).cmp(&(other.redirect as i32)))
.then_with(|| (self.redirect_scheme as i32).cmp(&(other.redirect_scheme as i32)))
.then_with(|| self.redirect_template.cmp(&other.redirect_template))
.then_with(|| self.required_auth.cmp(&other.required_auth))
}
}
impl Frontend {
#[allow(clippy::too_many_arguments)]
pub fn new(
domain_rule: &DomainRule,
path_rule: &PathRule,
front: &HttpFrontend,
redirect: RedirectPolicy,
redirect_scheme: RedirectScheme,
redirect_template: Option<String>,
rewrite_host: Option<String>,
rewrite_path: Option<String>,
rewrite_port: Option<u16>,
headers: &[sozu_command::proto::command::Header],
required_auth: bool,
hsts_origin: HstsOrigin,
) -> Result<Self, RouterError> {
let hsts = front.hsts.as_ref();
let inherits_listener_hsts =
matches!(hsts_origin, HstsOrigin::InheritedFromListenerDefault) && hsts.is_some();
let cluster_id = front.cluster_id.clone();
let tags = front
.tags
.clone()
.map(|tags| Rc::new(CachedTags::new(tags)));
let redirect_template = redirect_template.filter(|s| !s.is_empty());
let rewrite_host = rewrite_host.filter(|s| !s.is_empty());
let rewrite_path = rewrite_path.filter(|s| !s.is_empty());
let deny = match (&cluster_id, redirect) {
(_, RedirectPolicy::Unauthorized) => true,
(None, RedirectPolicy::Forward) => {
warn!(
"{} Frontend[domain: {:?}, path: {:?}]: forward on clusterless frontends are unauthorized",
log_module_context!(),
domain_rule,
path_rule,
);
true
}
_ => false,
};
if deny {
let mut deny_headers_response: Vec<HeaderEdit> = Vec::new();
if let Some(cfg) = hsts
&& matches!(cfg.enabled, Some(true))
&& let Some(rendered) = render_hsts(cfg)
{
let mode = if matches!(cfg.force_replace_backend, Some(true)) {
HeaderEditMode::Set
} else {
HeaderEditMode::SetIfAbsent
};
deny_headers_response.push(HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: rendered.into_bytes().into(),
mode,
});
crate::incr!(names::http::HSTS_FRONTEND_ADDED);
}
return Ok(Self {
cluster_id,
redirect: RedirectPolicy::Unauthorized,
redirect_scheme,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: deny_headers_response.into(),
required_auth,
tags,
inherits_listener_hsts,
});
}
let mut capture_cap_host = match domain_rule {
DomainRule::Any => 1,
DomainRule::Exact(_) => 1,
DomainRule::Wildcard(_) => 2,
DomainRule::Regex(regex) => regex.captures_len(),
};
let mut capture_cap_path = match path_rule {
PathRule::Equals(_) => 1,
PathRule::Prefix(_) => 2,
PathRule::Regex(regex) => regex.captures_len(),
};
let mut used_capture_host = 0usize;
let mut used_capture_path = 0usize;
let rewrite_host_parts = if let Some(p) = rewrite_host {
Some(
RewriteParts::parse(
&p,
capture_cap_host,
capture_cap_path,
&mut used_capture_host,
&mut used_capture_path,
)
.ok_or(RouterError::InvalidHostRewrite(p))?,
)
} else {
None
};
let rewrite_path_parts = if let Some(p) = rewrite_path {
Some(
RewriteParts::parse(
&p,
capture_cap_host,
capture_cap_path,
&mut used_capture_host,
&mut used_capture_path,
)
.ok_or(RouterError::InvalidPathRewrite(p))?,
)
} else {
None
};
if used_capture_host == 0 {
capture_cap_host = 0;
}
if used_capture_path == 0 {
capture_cap_path = 0;
}
let mut headers_request = Vec::new();
let mut headers_response = Vec::new();
for header in headers {
let edit = HeaderEdit {
key: header.key.as_bytes().into(),
val: header.val.as_bytes().into(),
mode: HeaderEditMode::Append,
};
match header.position() {
HeaderPosition::Request => headers_request.push(edit),
HeaderPosition::Response => headers_response.push(edit),
HeaderPosition::Both => {
headers_request.push(edit.clone());
headers_response.push(edit);
}
HeaderPosition::Unspecified => {
warn!(
"{} dropping Header {{ key: {:?}, val: {:?} }} with HEADER_POSITION_UNSPECIFIED",
log_module_context!(),
header.key,
header.val,
);
}
}
}
if let Some(cfg) = hsts
&& matches!(cfg.enabled, Some(true))
{
if let Some(rendered) = render_hsts(cfg) {
let mode = if matches!(cfg.force_replace_backend, Some(true)) {
HeaderEditMode::Set
} else {
HeaderEditMode::SetIfAbsent
};
headers_response.push(HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: rendered.into_bytes().into(),
mode,
});
crate::incr!(names::http::HSTS_FRONTEND_ADDED);
} else {
warn!(
"{} HSTS enabled = true on frontend {:?} but render_hsts \
returned None (max_age missing). Frontend will not emit \
Strict-Transport-Security; the config layer that built \
this HstsConfig must substitute DEFAULT_HSTS_MAX_AGE.",
log_module_context!(),
cluster_id,
);
crate::incr!(names::http::HSTS_UNRENDERED);
}
}
Ok(Frontend {
cluster_id,
redirect,
redirect_scheme,
redirect_template,
capture_cap_host,
capture_cap_path,
rewrite_host: rewrite_host_parts,
rewrite_path: rewrite_path_parts,
rewrite_port,
headers_request: headers_request.into(),
headers_response: headers_response.into(),
required_auth,
tags,
inherits_listener_hsts,
})
}
pub(crate) fn minimal_forward(cluster_id: ClusterId) -> Self {
Self {
cluster_id: Some(cluster_id),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::new([]),
required_auth: false,
tags: None,
inherits_listener_hsts: true,
}
}
pub(crate) fn minimal_deny() -> Self {
Self {
cluster_id: None,
redirect: RedirectPolicy::Unauthorized,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::new([]),
required_auth: false,
tags: None,
inherits_listener_hsts: true,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RouteResult {
pub cluster_id: Option<ClusterId>,
pub redirect: RedirectPolicy,
pub redirect_scheme: RedirectScheme,
pub redirect_template: Option<String>,
pub rewritten_host: Option<String>,
pub rewritten_path: Option<String>,
pub rewritten_port: Option<u16>,
pub headers_request: Rc<[HeaderEdit]>,
pub headers_response: Rc<[HeaderEdit]>,
pub required_auth: bool,
pub tags: Option<Rc<CachedTags>>,
}
impl RouteResult {
pub fn deny(cluster_id: Option<ClusterId>) -> Self {
Self {
cluster_id,
redirect: RedirectPolicy::Unauthorized,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
rewritten_host: None,
rewritten_path: None,
rewritten_port: None,
headers_request: Rc::new([]),
headers_response: Rc::new([]),
required_auth: false,
tags: None,
}
}
pub fn forward(cluster_id: ClusterId) -> Self {
Self {
cluster_id: Some(cluster_id),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
rewritten_host: None,
rewritten_path: None,
rewritten_port: None,
headers_request: Rc::new([]),
headers_response: Rc::new([]),
required_auth: false,
tags: None,
}
}
fn from_frontend(
frontend: &Frontend,
captures_host: Vec<&str>,
path: &[u8],
path_rule: &PathRule,
) -> Self {
if frontend.redirect == RedirectPolicy::Unauthorized {
return Self {
cluster_id: frontend.cluster_id.clone(),
redirect: RedirectPolicy::Unauthorized,
redirect_scheme: frontend.redirect_scheme,
redirect_template: frontend.redirect_template.clone(),
rewritten_host: None,
rewritten_path: None,
rewritten_port: None,
headers_request: Rc::new([]),
headers_response: frontend.headers_response.clone(),
required_auth: frontend.required_auth,
tags: frontend.tags.clone(),
};
}
let mut captures_path: Vec<&str> = Vec::with_capacity(frontend.capture_cap_path);
if frontend.capture_cap_path > 0 {
captures_path.push(from_utf8(path).unwrap_or_default());
match path_rule {
PathRule::Prefix(prefix) => {
let tail_start = prefix.len().min(path.len());
captures_path.push(from_utf8(&path[tail_start..]).unwrap_or_default());
}
PathRule::Regex(regex) => {
if let Some(caps) = regex.captures(path) {
captures_path.extend(caps.iter().skip(1).map(|c| {
c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
.unwrap_or("")
}));
}
}
PathRule::Equals(_) => {}
}
}
Self {
cluster_id: frontend.cluster_id.clone(),
redirect: frontend.redirect,
redirect_scheme: frontend.redirect_scheme,
redirect_template: frontend.redirect_template.clone(),
rewritten_host: frontend
.rewrite_host
.as_ref()
.map(|rewrite| rewrite.run(&captures_host, &captures_path)),
rewritten_path: frontend
.rewrite_path
.as_ref()
.map(|rewrite| rewrite.run(&captures_host, &captures_path)),
rewritten_port: frontend.rewrite_port,
headers_request: frontend.headers_request.clone(),
headers_response: frontend.headers_response.clone(),
required_auth: frontend.required_auth,
tags: frontend.tags.clone(),
}
}
fn new_no_trie<'a>(
domain: &'a [u8],
domain_rule: &DomainRule,
path: &'a [u8],
path_rule: &PathRule,
route: &Route,
) -> Self {
let frontend = match route {
Route::Frontend(f) => f.clone(),
Route::ClusterId(id) => return Self::forward(id.clone()),
Route::Deny => return Self::deny(None),
};
let mut captures_host: Vec<&str> = Vec::with_capacity(frontend.capture_cap_host);
if frontend.capture_cap_host > 0 {
captures_host.push(from_utf8(domain).unwrap_or_default());
match domain_rule {
DomainRule::Wildcard(suffix) => {
let head_end = domain.len().saturating_sub(suffix.len().saturating_sub(1));
captures_host.push(from_utf8(&domain[..head_end]).unwrap_or_default());
}
DomainRule::Regex(regex) => {
if let Some(caps) = regex.captures(domain) {
captures_host.extend(caps.iter().skip(1).map(|c| {
c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
.unwrap_or("")
}));
}
}
DomainRule::Any | DomainRule::Exact(_) => {}
}
}
Self::from_frontend(&frontend, captures_host, path, path_rule)
}
fn new_with_trie<'a, 'b>(
domain: &'a [u8],
domain_submatches: TrieMatches<'a, 'b>,
path: &'a [u8],
path_rule: &PathRule,
route: &Route,
) -> Self {
let frontend = match route {
Route::Frontend(f) => f.clone(),
Route::ClusterId(id) => return Self::forward(id.clone()),
Route::Deny => return Self::deny(None),
};
let mut captures_host: Vec<&str> = Vec::with_capacity(frontend.capture_cap_host);
if frontend.capture_cap_host > 0 {
captures_host.push(from_utf8(domain).unwrap_or_default());
for submatch in &domain_submatches {
match submatch {
TrieSubMatch::Wildcard(part) => {
captures_host.push(from_utf8(part).unwrap_or_default());
}
TrieSubMatch::Regexp(part, regex) => {
if let Some(caps) = regex.captures(part) {
captures_host.extend(caps.iter().skip(1).map(|c| {
c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
.unwrap_or("")
}));
}
}
}
}
}
Self::from_frontend(&frontend, captures_host, path, path_rule)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn render_hsts_max_age_only() {
let cfg = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
assert_eq!(render_hsts(&cfg), Some("max-age=31536000".to_owned()));
}
#[test]
fn render_hsts_with_include_subdomains() {
let cfg = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
assert_eq!(
render_hsts(&cfg),
Some("max-age=31536000; includeSubDomains".to_owned())
);
}
#[test]
fn render_hsts_with_preload_only() {
let cfg = HstsConfig {
enabled: Some(true),
max_age: Some(63_072_000),
include_subdomains: None,
preload: Some(true),
force_replace_backend: None,
};
assert_eq!(
render_hsts(&cfg),
Some("max-age=63072000; preload".to_owned())
);
}
#[test]
fn render_hsts_full() {
let cfg = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: Some(true),
preload: Some(true),
force_replace_backend: None,
};
assert_eq!(
render_hsts(&cfg),
Some("max-age=31536000; includeSubDomains; preload".to_owned())
);
}
#[test]
fn render_hsts_kill_switch_max_age_zero() {
let cfg = HstsConfig {
enabled: Some(true),
max_age: Some(0),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
assert_eq!(
render_hsts(&cfg),
Some("max-age=0; includeSubDomains".to_owned())
);
}
#[test]
fn render_hsts_omitted_when_max_age_missing() {
let cfg = HstsConfig {
enabled: Some(true),
max_age: None,
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
assert_eq!(render_hsts(&cfg), None);
}
#[test]
fn rebuild_with_listener_hsts_replaces_existing_entry() {
let frontend = Frontend {
cluster_id: Some("api".to_owned()),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::from(vec![
HeaderEdit {
key: Rc::from(&b"x-cache"[..]),
val: Rc::from(&b"hit"[..]),
mode: HeaderEditMode::Append,
},
HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: Rc::from(&b"max-age=31536000"[..]),
mode: HeaderEditMode::SetIfAbsent,
},
]),
required_auth: false,
tags: None,
inherits_listener_hsts: true,
};
let new_hsts = HstsConfig {
enabled: Some(true),
max_age: Some(63_072_000),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
let new_edit = build_listener_hsts_edit(Some(&new_hsts));
let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
let response: Vec<_> = rebuilt.headers_response.iter().collect();
assert_eq!(response.len(), 2, "x-cache + new STS, no leftover STS");
assert_eq!(&*response[0].key, b"x-cache");
assert_eq!(&*response[1].key, b"strict-transport-security");
assert_eq!(
&*response[1].val,
b"max-age=63072000; includeSubDomains".as_slice()
);
assert!(rebuilt.inherits_listener_hsts);
}
#[test]
fn rebuild_with_listener_hsts_strips_when_none() {
let frontend = Frontend {
cluster_id: Some("api".to_owned()),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::from(vec![
HeaderEdit {
key: Rc::from(&b"x-cache"[..]),
val: Rc::from(&b"hit"[..]),
mode: HeaderEditMode::Append,
},
HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: Rc::from(&b"max-age=31536000"[..]),
mode: HeaderEditMode::SetIfAbsent,
},
]),
required_auth: false,
tags: None,
inherits_listener_hsts: true,
};
let new_edit = build_listener_hsts_edit(None);
let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
let response: Vec<_> = rebuilt.headers_response.iter().collect();
assert_eq!(response.len(), 1);
assert_eq!(&*response[0].key, b"x-cache");
}
#[test]
fn rebuild_with_listener_hsts_disabled_strips() {
let frontend = Frontend {
cluster_id: Some("api".to_owned()),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::from(vec![HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: Rc::from(&b"max-age=31536000"[..]),
mode: HeaderEditMode::SetIfAbsent,
}]),
required_auth: false,
tags: None,
inherits_listener_hsts: true,
};
let new_hsts = HstsConfig {
enabled: Some(false),
max_age: None,
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
let new_edit = build_listener_hsts_edit(Some(&new_hsts));
let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
assert_eq!(rebuilt.headers_response.len(), 0);
}
#[test]
fn refresh_inheriting_hsts_skips_explicit_overrides() {
use crate::router::pattern_trie::TrieNode;
let mut router = Router {
pre: Vec::new(),
tree: TrieNode::root(),
post: Vec::new(),
};
let inheriting = Frontend {
cluster_id: Some("api".to_owned()),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::from(vec![HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: Rc::from(&b"max-age=31536000"[..]),
mode: HeaderEditMode::SetIfAbsent,
}]),
required_auth: false,
tags: None,
inherits_listener_hsts: true,
};
let explicit = Frontend {
cluster_id: Some("legacy".to_owned()),
redirect: RedirectPolicy::Forward,
redirect_scheme: RedirectScheme::UseSame,
redirect_template: None,
capture_cap_host: 0,
capture_cap_path: 0,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
headers_request: Rc::new([]),
headers_response: Rc::from(vec![HeaderEdit {
key: Rc::from(&b"strict-transport-security"[..]),
val: Rc::from(&b"max-age=300"[..]),
mode: HeaderEditMode::SetIfAbsent,
}]),
required_auth: false,
tags: None,
inherits_listener_hsts: false,
};
router.pre.push((
DomainRule::Any,
PathRule::Prefix("/api".to_owned()),
MethodRule::new(None),
Route::Frontend(Rc::new(inheriting)),
));
router.post.push((
DomainRule::Any,
PathRule::Prefix("/legacy".to_owned()),
MethodRule::new(None),
Route::Frontend(Rc::new(explicit)),
));
let new_hsts = HstsConfig {
enabled: Some(true),
max_age: Some(63_072_000),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
let count = router.refresh_inheriting_hsts(Some(&new_hsts));
assert_eq!(count, 1, "only the inheriting frontend should refresh");
if let Route::Frontend(rc) = &router.pre[0].3 {
let response: Vec<_> = rc.headers_response.iter().collect();
assert_eq!(
&*response.last().unwrap().val,
b"max-age=63072000; includeSubDomains".as_slice(),
"inheriting frontend's STS must reflect the new listener default"
);
} else {
panic!("pre[0] should be Route::Frontend");
}
if let Route::Frontend(rc) = &router.post[0].3 {
let response: Vec<_> = rc.headers_response.iter().collect();
assert_eq!(
&*response.last().unwrap().val,
b"max-age=300".as_slice(),
"explicit override must be preserved unchanged"
);
} else {
panic!("post[0] should be Route::Frontend");
}
}
#[test]
fn refresh_inheriting_hsts_promotes_clusterid_on_enable() {
use crate::router::pattern_trie::TrieNode;
let mut router = Router {
pre: Vec::new(),
tree: TrieNode::root(),
post: vec![(
DomainRule::Any,
PathRule::Prefix("/".to_owned()),
MethodRule::new(None),
Route::ClusterId("api".to_owned()),
)],
};
let new_hsts = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
let count = router.refresh_inheriting_hsts(Some(&new_hsts));
assert_eq!(count, 1, "the ClusterId entry must be promoted + counted");
let Route::Frontend(rc) = &router.post[0].3 else {
panic!("post[0] should now be Route::Frontend, not the original Route::ClusterId");
};
assert_eq!(rc.cluster_id.as_deref(), Some("api"));
assert_eq!(
rc.redirect,
RedirectPolicy::Forward,
"promoted entry must keep Forward semantics so lookup yields the same backend"
);
assert!(
rc.inherits_listener_hsts,
"promoted entry must mark itself inheriting so the next patch refreshes it"
);
let response: Vec<_> = rc.headers_response.iter().collect();
assert_eq!(
response.len(),
1,
"promoted entry carries exactly one STS edit, no operator headers"
);
assert_eq!(&*response[0].key, b"strict-transport-security");
assert_eq!(
&*response[0].val,
b"max-age=31536000; includeSubDomains".as_slice()
);
}
#[test]
fn refresh_inheriting_hsts_promotes_deny_on_enable() {
use crate::router::pattern_trie::TrieNode;
let mut router = Router {
pre: Vec::new(),
tree: TrieNode::root(),
post: vec![(
DomainRule::Any,
PathRule::Prefix("/forbidden".to_owned()),
MethodRule::new(None),
Route::Deny,
)],
};
let new_hsts = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
let count = router.refresh_inheriting_hsts(Some(&new_hsts));
assert_eq!(count, 1);
let Route::Frontend(rc) = &router.post[0].3 else {
panic!("post[0] should now be Route::Frontend, not the original Route::Deny");
};
assert_eq!(rc.cluster_id, None, "promoted Deny stays clusterless");
assert_eq!(
rc.redirect,
RedirectPolicy::Unauthorized,
"promoted Deny must keep Unauthorized so lookup yields a 401"
);
assert!(rc.inherits_listener_hsts);
let response: Vec<_> = rc.headers_response.iter().collect();
assert_eq!(response.len(), 1);
assert_eq!(&*response[0].key, b"strict-transport-security");
assert_eq!(&*response[0].val, b"max-age=31536000".as_slice());
}
#[test]
fn refresh_inheriting_hsts_skips_lightweight_on_disable() {
use crate::router::pattern_trie::TrieNode;
let make_router = || Router {
pre: vec![(
DomainRule::Any,
PathRule::Prefix("/".to_owned()),
MethodRule::new(None),
Route::ClusterId("api".to_owned()),
)],
tree: TrieNode::root(),
post: vec![(
DomainRule::Any,
PathRule::Prefix("/forbidden".to_owned()),
MethodRule::new(None),
Route::Deny,
)],
};
for (label, hsts) in [
("none", None),
(
"disabled",
Some(HstsConfig {
enabled: Some(false),
max_age: None,
include_subdomains: None,
preload: None,
force_replace_backend: None,
}),
),
(
"enabled-without-max-age",
Some(HstsConfig {
enabled: Some(true),
max_age: None,
include_subdomains: None,
preload: None,
force_replace_backend: None,
}),
),
] {
let mut router = make_router();
let count = router.refresh_inheriting_hsts(hsts.as_ref());
assert_eq!(count, 0, "no promotion expected for {label}");
assert!(
matches!(router.pre[0].3, Route::ClusterId(_)),
"{label}: ClusterId must stay lightweight"
);
assert!(
matches!(router.post[0].3, Route::Deny),
"{label}: Deny must stay lightweight"
);
}
}
#[test]
fn refresh_inheriting_hsts_promoted_entry_refreshes_on_subsequent_patches() {
use crate::router::pattern_trie::TrieNode;
let mut router = Router {
pre: Vec::new(),
tree: TrieNode::root(),
post: vec![(
DomainRule::Any,
PathRule::Prefix("/".to_owned()),
MethodRule::new(None),
Route::ClusterId("api".to_owned()),
)],
};
let first_patch = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
assert_eq!(router.refresh_inheriting_hsts(Some(&first_patch)), 1);
let second_patch = HstsConfig {
enabled: Some(true),
max_age: Some(63_072_000),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
assert_eq!(
router.refresh_inheriting_hsts(Some(&second_patch)),
1,
"the previously promoted entry must be re-counted via the path-1 branch"
);
let Route::Frontend(rc) = &router.post[0].3 else {
panic!("post[0] should still be Route::Frontend after the second patch");
};
let response: Vec<_> = rc.headers_response.iter().collect();
assert_eq!(
response.len(),
1,
"second patch must REPLACE the existing STS edit, not append a duplicate"
);
assert_eq!(
&*response[0].val,
b"max-age=63072000; includeSubDomains".as_slice()
);
}
#[test]
fn refresh_inheriting_hsts_promoted_entry_loses_hsts_on_disable_patch() {
use crate::router::pattern_trie::TrieNode;
let mut router = Router {
pre: vec![(
DomainRule::Any,
PathRule::Prefix("/".to_owned()),
MethodRule::new(None),
Route::ClusterId("api".to_owned()),
)],
tree: TrieNode::root(),
post: Vec::new(),
};
let enable = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
assert_eq!(router.refresh_inheriting_hsts(Some(&enable)), 1);
let disable = HstsConfig {
enabled: Some(false),
max_age: None,
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
assert_eq!(
router.refresh_inheriting_hsts(Some(&disable)),
1,
"the promoted entry must still be touched on disable to strip its STS edit"
);
let Route::Frontend(rc) = &router.pre[0].3 else {
panic!("pre[0] should still be Route::Frontend (no demotion)");
};
assert_eq!(rc.cluster_id.as_deref(), Some("api"));
assert_eq!(
rc.headers_response.len(),
0,
"disable patch must strip the STS edit from the promoted entry"
);
}
#[test]
fn refresh_inheriting_hsts_promotes_clusterid_in_trie_on_enable() {
use crate::router::pattern_trie::TrieNode;
let mut router = Router {
pre: Vec::new(),
tree: TrieNode::root(),
post: Vec::new(),
};
let path_rule = PathRule::Prefix("/".to_owned());
let method_rule = MethodRule::new(None);
assert!(router.add_tree_rule(
b"example.com",
&path_rule,
&method_rule,
&Route::ClusterId("api".to_owned()),
));
let new_hsts = HstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: Some(true),
preload: None,
force_replace_backend: None,
};
let count = router.refresh_inheriting_hsts(Some(&new_hsts));
assert_eq!(
count, 1,
"trie-resident ClusterId must be promoted + counted"
);
let (_, paths) = router
.tree
.domain_lookup_mut(b"example.com", false)
.expect("trie leaf still present after refresh");
assert_eq!(paths.len(), 1);
let Route::Frontend(rc) = &paths[0].2 else {
panic!("trie leaf should now be Route::Frontend, not Route::ClusterId");
};
assert_eq!(rc.cluster_id.as_deref(), Some("api"));
assert_eq!(rc.redirect, RedirectPolicy::Forward);
assert!(rc.inherits_listener_hsts);
let response: Vec<_> = rc.headers_response.iter().collect();
assert_eq!(response.len(), 1);
assert_eq!(&*response[0].key, b"strict-transport-security");
assert_eq!(
&*response[0].val,
b"max-age=31536000; includeSubDomains".as_slice()
);
}
#[test]
fn convert_regex() {
assert_eq!(
convert_regex_domain_rule("www.example.com")
.unwrap()
.as_str(),
"\\Awww\\.example\\.com\\z"
);
assert_eq!(
convert_regex_domain_rule("*.example.com").unwrap().as_str(),
"\\A*\\.example\\.com\\z"
);
assert_eq!(
convert_regex_domain_rule("test.*.example.com")
.unwrap()
.as_str(),
"\\Atest\\.*\\.example\\.com\\z"
);
assert_eq!(
convert_regex_domain_rule("css./cdn[a-z0-9]+/.example.com")
.unwrap()
.as_str(),
"\\Acss\\.cdn[a-z0-9]+\\.example\\.com\\z"
);
assert_eq!(
convert_regex_domain_rule("css./cdn[a-z0-9]+.example.com"),
None
);
assert_eq!(
convert_regex_domain_rule("css./cdn[a-z0-9]+/a.example.com"),
None
);
}
#[test]
fn regex_domain_rule_rejects_suffix_and_prefix() {
let rule: DomainRule = "/example\\.com/".parse().unwrap();
assert!(rule.matches(b"example.com"));
assert!(!rule.matches(b"attacker.example.com"));
assert!(!rule.matches(b"example.com.evil.org"));
assert!(!rule.matches(b"prefixexample.com"));
assert!(!rule.matches(b"example.commercial"));
}
#[test]
fn regex_domain_rule_multi_segment_segments_are_isolated() {
let pattern = convert_regex_domain_rule("/seg1/.foo./seg2/.com")
.expect("multi-segment regex hostname must compile");
assert_eq!(pattern.as_str(), "\\Aseg1\\.foo\\.seg2\\.com\\z");
}
#[test]
fn parse_domain_rule() {
assert_eq!("*".parse::<DomainRule>().unwrap(), DomainRule::Any);
assert_eq!(
"www.example.com".parse::<DomainRule>().unwrap(),
DomainRule::Exact("www.example.com".to_string())
);
assert_eq!(
"*.example.com".parse::<DomainRule>().unwrap(),
DomainRule::Wildcard("*.example.com".to_string())
);
assert_eq!("test.*.example.com".parse::<DomainRule>(), Err(()));
assert_eq!(
"/cdn[0-9]+/.example.com".parse::<DomainRule>().unwrap(),
DomainRule::Regex(Regex::new("\\Acdn[0-9]+\\.example\\.com\\z").unwrap())
);
}
#[test]
fn match_domain_rule() {
assert!(DomainRule::Any.matches("www.example.com".as_bytes()));
assert!(
DomainRule::Exact("www.example.com".to_string()).matches("www.example.com".as_bytes())
);
assert!(
DomainRule::Wildcard("*.example.com".to_string()).matches("www.example.com".as_bytes())
);
assert!(
!DomainRule::Wildcard("*.example.com".to_string())
.matches("test.www.example.com".as_bytes())
);
assert!(
"/cdn[0-9]+/.example.com"
.parse::<DomainRule>()
.unwrap()
.matches("cdn1.example.com".as_bytes())
);
assert!(
!"/cdn[0-9]+/.example.com"
.parse::<DomainRule>()
.unwrap()
.matches("www.example.com".as_bytes())
);
assert!(
!"/cdn[0-9]+/.example.com"
.parse::<DomainRule>()
.unwrap()
.matches("cdn10.exampleAcom".as_bytes())
);
}
#[test]
fn match_domain_rule_wildcard_short_hostname_does_not_panic() {
let rule = DomainRule::Wildcard("*.foo.example.com".to_string());
assert!(!rule.matches(b""));
assert!(!rule.matches(b"a.b"));
assert!(!rule.matches(b"x"));
assert!(!rule.matches(b".foo.example.com"));
assert!(!rule.matches(b"y.x.foo.example.com"));
assert!(rule.matches(b"x.foo.example.com"));
}
#[test]
fn router_lookup_wildcard_pre_rule_short_hostname_does_not_panic() {
let mut router = Router::new();
assert!(router.add_pre_rule(
&"*.foo.example.com".parse::<DomainRule>().unwrap(),
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("wildcard".to_string()),
));
let method = Method::new(&b"GET"[..]);
assert!(router.lookup("", "/", &method).is_err());
assert!(router.lookup("x", "/", &method).is_err());
assert!(router.lookup("a.b", "/", &method).is_err());
assert!(router.lookup(".foo.example.com", "/", &method).is_err());
assert_eq!(
router.lookup("x.foo.example.com", "/", &method),
Ok(RouteResult::forward("wildcard".to_string()))
);
}
#[test]
fn match_path_rule() {
assert!(PathRule::Prefix("".to_string()).matches("/".as_bytes()) != PathRuleResult::None);
assert!(
PathRule::Prefix("".to_string()).matches("/hello".as_bytes()) != PathRuleResult::None
);
assert!(
PathRule::Prefix("/hello".to_string()).matches("/hello".as_bytes())
!= PathRuleResult::None
);
assert!(
PathRule::Prefix("/hello".to_string()).matches("/hello/world".as_bytes())
!= PathRuleResult::None
);
assert!(
PathRule::Prefix("/hello".to_string()).matches("/".as_bytes()) == PathRuleResult::None
);
}
#[test]
fn multiple_children_on_a_wildcard() {
let mut router = Router::new();
assert!(router.add_tree_rule(
b"*.sozu.io",
&PathRule::Prefix("".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("base".to_string())
));
println!("{:#?}", router.tree);
assert_eq!(
router.lookup("www.sozu.io", "/api", &Method::Get),
Ok(RouteResult::forward("base".to_string()))
);
assert!(router.add_tree_rule(
b"*.sozu.io",
&PathRule::Prefix("/api".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("api".to_string())
));
println!("{:#?}", router.tree);
assert_eq!(
router.lookup("www.sozu.io", "/ap", &Method::Get),
Ok(RouteResult::forward("base".to_string()))
);
assert_eq!(
router.lookup("www.sozu.io", "/api", &Method::Get),
Ok(RouteResult::forward("api".to_string()))
);
}
#[test]
fn multiple_children_including_one_with_wildcard() {
let mut router = Router::new();
assert!(router.add_tree_rule(
b"*.sozu.io",
&PathRule::Prefix("".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("base".to_string())
));
println!("{:#?}", router.tree);
assert_eq!(
router.lookup("www.sozu.io", "/api", &Method::Get),
Ok(RouteResult::forward("base".to_string()))
);
assert!(router.add_tree_rule(
b"api.sozu.io",
&PathRule::Prefix("".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("api".to_string())
));
println!("{:#?}", router.tree);
assert_eq!(
router.lookup("www.sozu.io", "/api", &Method::Get),
Ok(RouteResult::forward("base".to_string()))
);
assert_eq!(
router.lookup("api.sozu.io", "/api", &Method::Get),
Ok(RouteResult::forward("api".to_string()))
);
}
#[test]
fn router_insert_remove_through_regex() {
let mut router = Router::new();
assert!(router.add_tree_rule(
b"www./.*/.io",
&PathRule::Prefix("".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("base".to_string())
));
println!("{:#?}", router.tree);
assert!(router.add_tree_rule(
b"www.doc./.*/.io",
&PathRule::Prefix("".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("doc".to_string())
));
println!("{:#?}", router.tree);
assert_eq!(
router.lookup("www.sozu.io", "/", &Method::Get),
Ok(RouteResult::forward("base".to_string()))
);
assert_eq!(
router.lookup("www.doc.sozu.io", "/", &Method::Get),
Ok(RouteResult::forward("doc".to_string()))
);
assert!(router.remove_tree_rule(
b"www./.*/.io",
&PathRule::Prefix("".to_string()),
&MethodRule::new(Some("GET".to_string()))
));
println!("{:#?}", router.tree);
assert!(router.lookup("www.sozu.io", "/", &Method::Get).is_err());
assert_eq!(
router.lookup("www.doc.sozu.io", "/", &Method::Get),
Ok(RouteResult::forward("doc".to_string()))
);
}
#[test]
fn match_router() {
let mut router = Router::new();
assert!(router.add_pre_rule(
&"*".parse::<DomainRule>().unwrap(),
&PathRule::Prefix("/.well-known/acme-challenge".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("acme".to_string())
));
assert!(router.add_tree_rule(
"www.example.com".as_bytes(),
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("example".to_string())
));
assert!(router.add_tree_rule(
"*.test.example.com".as_bytes(),
&PathRule::Regex(Regex::new("/hello[A-Z]+/").unwrap()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("examplewildcard".to_string())
));
assert!(router.add_tree_rule(
"/test[0-9]/.example.com".as_bytes(),
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("exampleregex".to_string())
));
assert_eq!(
router.lookup("www.example.com", "/helloA", &Method::new(&b"GET"[..])),
Ok(RouteResult::forward("example".to_string()))
);
assert_eq!(
router.lookup(
"www.example.com",
"/.well-known/acme-challenge",
&Method::new(&b"GET"[..])
),
Ok(RouteResult::forward("acme".to_string()))
);
assert!(
router
.lookup("www.test.example.com", "/", &Method::new(&b"GET"[..]))
.is_err()
);
assert_eq!(
router.lookup(
"www.test.example.com",
"/helloAB/",
&Method::new(&b"GET"[..])
),
Ok(RouteResult::forward("examplewildcard".to_string()))
);
assert_eq!(
router.lookup("test1.example.com", "/helloAB/", &Method::new(&b"GET"[..])),
Ok(RouteResult::forward("exampleregex".to_string()))
);
}
#[test]
fn has_hostname_checks_tree_pre_and_post() {
let mut router = Router::new();
assert!(!router.has_hostname("www.example.com"));
assert!(router.add_tree_rule(
b"www.example.com",
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("cluster1".to_string())
));
assert!(router.has_hostname("www.example.com"));
assert!(!router.has_hostname("api.example.com"));
assert!(router.remove_tree_rule(
b"www.example.com",
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string()))
));
assert!(!router.has_hostname("www.example.com"));
assert!(router.add_pre_rule(
&DomainRule::Exact("api.example.com".to_string()),
&PathRule::Prefix("/".to_string()),
&MethodRule::new(None),
&Route::ClusterId("cluster2".to_string())
));
assert!(router.has_hostname("api.example.com"));
assert!(!router.has_hostname("www.example.com"));
assert!(router.add_post_rule(
&DomainRule::Exact("cdn.example.com".to_string()),
&PathRule::Prefix("/".to_string()),
&MethodRule::new(None),
&Route::ClusterId("cluster3".to_string())
));
assert!(router.has_hostname("cdn.example.com"));
assert!(router.remove_pre_rule(
&DomainRule::Exact("api.example.com".to_string()),
&PathRule::Prefix("/".to_string()),
&MethodRule::new(None),
));
assert!(!router.has_hostname("api.example.com"));
assert!(router.has_hostname("cdn.example.com"));
}
#[test]
fn has_hostname_false_after_last_route_removed() {
let mut router = Router::new();
assert!(router.add_tree_rule(
b"www.example.com",
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("cluster1".to_string())
));
assert!(router.add_tree_rule(
b"www.example.com",
&PathRule::Prefix("/api".to_string()),
&MethodRule::new(Some("GET".to_string())),
&Route::ClusterId("cluster2".to_string())
));
assert!(router.has_hostname("www.example.com"));
assert!(router.remove_tree_rule(
b"www.example.com",
&PathRule::Prefix("/".to_string()),
&MethodRule::new(Some("GET".to_string()))
));
assert!(router.has_hostname("www.example.com"));
assert!(router.remove_tree_rule(
b"www.example.com",
&PathRule::Prefix("/api".to_string()),
&MethodRule::new(Some("GET".to_string()))
));
assert!(!router.has_hostname("www.example.com"));
}
}