use crate::access::{PolicyBlock, PolicyRule, Predicate};
use crate::config::{
BasicAuthConfig, Config, HeaderOpConfig, PolicyRuleDef, VHostConfig,
};
use crate::handler::Handler;
use crate::handler::status::{
LbPoolEntry, ServerSummary, SharedLbRegistry,
};
use crate::headers::{HeaderOp, HeaderRules, Template};
use crate::metrics::{HandlerKind, Metrics};
use anyhow::bail;
use hyper::Request;
use hyper::header::HeaderName;
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
pub struct Route {
pub handler: Arc<dyn Handler>,
pub matched_prefix: String,
pub vhost_name: Arc<str>,
pub handler_kind: HandlerKind,
pub policy: Option<Arc<PolicyBlock>>,
pub basic_auth: Option<Arc<BasicAuthConfig>>,
pub header_rules: Option<Arc<HeaderRules>>,
pub rate_limits: Vec<Arc<crate::rate_limit::RateLimitRule>>,
pub max_request_body: Option<u64>,
pub cache_policy: Option<Arc<crate::cache::CachePolicy>>,
}
struct VHost {
name: Arc<str>,
locations: Vec<Location>,
}
struct Location {
path: String,
handler: Arc<dyn Handler>,
handler_kind: HandlerKind,
policy: Option<Arc<PolicyBlock>>,
basic_auth: Option<Arc<BasicAuthConfig>>,
header_rules: Option<Arc<HeaderRules>>,
rate_limits: Vec<Arc<crate::rate_limit::RateLimitRule>>,
max_request_body: Option<u64>,
matcher: Option<Arc<crate::matcher::Matcher>>,
rewrite: Option<Arc<Rewrite>>,
cache_policy: Option<Arc<crate::cache::CachePolicy>>,
}
struct Rewrite {
from: Regex,
to: String,
}
const MAX_REWRITES: usize = 10;
#[derive(Default)]
struct VhostTable {
literals: HashMap<String, Arc<VHost>>,
patterns: Vec<(Regex, Arc<VHost>)>,
default: Option<Arc<VHost>>,
}
pub struct Router {
tables: HashMap<String, VhostTable>,
named_policies: HashMap<String, Vec<PolicyRule>>,
}
impl Router {
pub fn new(
config: &Config,
metrics: &Arc<Metrics>,
summary: &Arc<ServerSummary>,
cert_state: Option<&crate::cert::state::SharedCertState>,
) -> anyhow::Result<Self> {
let named_policies = resolve_named_policies(&config.server.policies)?;
let lb_registry: SharedLbRegistry =
Arc::new(arc_swap::ArcSwap::from_pointee(Vec::new()));
let mut lb_pools: Vec<LbPoolEntry> = Vec::new();
let mut built: Vec<Arc<VHost>> = Vec::with_capacity(config.vhosts.len());
let mut by_handle: HashMap<&str, usize> = HashMap::new();
for (i, vcfg) in config.vhosts.iter().enumerate() {
let vhost = Arc::new(build_vhost(
vcfg,
metrics,
summary,
cert_state,
&named_policies,
&lb_registry,
&mut lb_pools,
)?);
built.push(vhost);
by_handle.insert(vcfg.handle(), i);
}
lb_registry.store(Arc::new(lb_pools));
let mut tables: HashMap<String, VhostTable> = HashMap::new();
for l in config.listeners.iter().filter(|l| l.proxy.is_none()) {
let indices: Vec<usize> = if l.vhosts.is_empty() {
config
.vhosts
.iter()
.enumerate()
.filter(|(_, v)| !v.explicit_only)
.map(|(i, _)| i)
.collect()
} else {
l.vhosts
.iter()
.filter_map(|h| by_handle.get(h.as_str()).copied())
.collect()
};
let mut table = VhostTable::default();
for &i in &indices {
let vcfg = &config.vhosts[i];
let vhost = &built[i];
let all_names =
std::iter::once(&vcfg.name).chain(vcfg.aliases.iter());
for n in all_names {
if n.regex {
let re = Regex::new(&format!("^(?:{})$", n.value))
.expect("regex validated at config load");
table.patterns.push((re, vhost.clone()));
} else {
table.literals.insert(n.value.clone(), vhost.clone());
}
}
}
table.default = if l.reject_unknown_host {
None
} else {
indices.first().map(|&i| built[i].clone())
};
tables.insert(l.local_name(), table);
}
Ok(Self {
tables,
named_policies,
})
}
pub fn route<B>(
&self,
req: &mut Request<B>,
listener_bind: &str,
) -> Option<Route> {
let host = req
.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.map(|h| strip_port(h).to_owned());
let vhost = self.resolve_vhost(host.as_deref(), listener_bind)?;
for _ in 0..MAX_REWRITES {
let chosen = pick_location(&vhost, req);
let loc = chosen?;
if let Some(rw) = &loc.rewrite
&& apply_rewrite(req, rw)
{
continue;
}
return Some(Route {
handler: loc.handler.clone(),
matched_prefix: loc.path.clone(),
vhost_name: vhost.name.clone(),
handler_kind: loc.handler_kind,
policy: loc.policy.clone(),
basic_auth: loc.basic_auth.clone(),
header_rules: loc.header_rules.clone(),
rate_limits: loc.rate_limits.clone(),
max_request_body: loc.max_request_body,
cache_policy: loc.cache_policy.clone(),
});
}
tracing::warn!(
uri = %req.uri(),
"rewrite cycle: hit MAX_REWRITES={} without settling on a \
non-rewriting location; treating as 404",
MAX_REWRITES,
);
None
}
fn resolve_vhost(
&self,
host: Option<&str>,
listener_bind: &str,
) -> Option<Arc<VHost>> {
let table = self.tables.get(listener_bind).or_else(|| {
(self.tables.len() == 1)
.then(|| self.tables.values().next())
.flatten()
})?;
if let Some(host) = host {
if let Some(vhost) = table.literals.get(host) {
return Some(vhost.clone());
}
for (re, vhost) in &table.patterns {
if re.is_match(host) {
return Some(vhost.clone());
}
}
}
table.default.clone()
}
pub fn resolve_block(
&self,
defs: &[PolicyRuleDef],
tcp_only: bool,
) -> anyhow::Result<PolicyBlock> {
let rules = inline_rules(defs, &self.named_policies, tcp_only)?;
Ok(PolicyBlock::new(rules))
}
}
fn resolve_named_policies(
defs: &HashMap<String, Vec<PolicyRuleDef>>,
) -> anyhow::Result<HashMap<String, Vec<PolicyRule>>> {
let mut resolved: HashMap<String, Vec<PolicyRule>> = HashMap::new();
for name in defs.keys() {
let mut visiting = Vec::new();
resolve_one(name, defs, &mut resolved, &mut visiting)?;
}
Ok(resolved)
}
fn resolve_one(
name: &str,
defs: &HashMap<String, Vec<PolicyRuleDef>>,
resolved: &mut HashMap<String, Vec<PolicyRule>>,
visiting: &mut Vec<String>,
) -> anyhow::Result<Vec<PolicyRule>> {
if let Some(rules) = resolved.get(name) {
return Ok(rules.clone());
}
if visiting.iter().any(|v| v == name) {
bail!(
"circular reference in policy '{name}' (chain: {})",
visiting.join(" → ")
);
}
let rule_defs = defs
.get(name)
.ok_or_else(|| anyhow::anyhow!("undefined policy '{name}'"))?;
visiting.push(name.to_string());
let rules = resolve_rule_defs(rule_defs, defs, resolved, visiting)?;
visiting.pop();
resolved.insert(name.to_string(), rules.clone());
Ok(rules)
}
fn resolve_rule_defs(
rule_defs: &[PolicyRuleDef],
raw_defs: &HashMap<String, Vec<PolicyRuleDef>>,
resolved: &mut HashMap<String, Vec<PolicyRule>>,
visiting: &mut Vec<String>,
) -> anyhow::Result<Vec<PolicyRule>> {
let mut result = Vec::new();
for def in rule_defs {
match def {
PolicyRuleDef::Rule { predicate, action } => {
result.push(PolicyRule {
predicate: predicate.clone(),
action: action.clone(),
});
}
PolicyRuleDef::Apply { name } => {
let inlined = resolve_one(name, raw_defs, resolved, visiting)?;
result.extend(inlined);
}
}
}
Ok(result)
}
fn inline_rules(
defs: &[PolicyRuleDef],
named_policies: &HashMap<String, Vec<PolicyRule>>,
tcp_only: bool,
) -> anyhow::Result<Vec<PolicyRule>> {
let mut result = Vec::new();
for def in defs {
match def {
PolicyRuleDef::Rule { predicate, action } => {
check_tcp_predicate(predicate, tcp_only)?;
result.push(PolicyRule {
predicate: predicate.clone(),
action: action.clone(),
});
}
PolicyRuleDef::Apply { name } => {
let rules =
named_policies.get(name.as_str()).ok_or_else(|| {
anyhow::anyhow!("undefined policy '{name}'")
})?;
if tcp_only {
check_tcp_block_rules(rules, name)?;
}
result.extend_from_slice(rules);
}
}
}
Ok(result)
}
fn check_tcp_predicate(
predicate: &Option<Predicate>,
tcp_only: bool,
) -> anyhow::Result<()> {
if !tcp_only {
return Ok(());
}
if predicate.as_ref().is_some_and(|p| p.needs_auth()) {
bail!(
"policy used in a stream listener context contains \
identity predicates, which require HTTP authentication"
);
}
Ok(())
}
fn check_tcp_block_rules(
rules: &[PolicyRule],
name: &str,
) -> anyhow::Result<()> {
for rule in rules {
if rule.predicate.as_ref().is_some_and(|p| p.needs_auth()) {
bail!(
"policy '{name}' contains identity predicates and \
cannot be used in a stream listener policy block"
);
}
}
Ok(())
}
fn pick_location<'a, B>(
vhost: &'a VHost,
req: &Request<B>,
) -> Option<&'a Location> {
let path = req.uri().path();
let mut candidates: Vec<(usize, &Location)> = vhost
.locations
.iter()
.enumerate()
.filter(|(_, loc)| path.starts_with(loc.path.as_str()))
.collect();
candidates.sort_by(|a, b| {
b.1.path.len()
.cmp(&a.1.path.len())
.then(a.0.cmp(&b.0))
});
for (_, loc) in candidates {
if let Some(m) = &loc.matcher
&& !m.matches(req)
{
continue;
}
return Some(loc);
}
None
}
fn apply_rewrite<B>(req: &mut Request<B>, rw: &Rewrite) -> bool {
let path = req.uri().path();
if !rw.from.is_match(path) {
return false;
}
let replaced = rw.from.replace(path, rw.to.as_str()).into_owned();
let (new_path, new_query) = match replaced.split_once('?') {
Some((p, q)) => (p.to_owned(), Some(q.to_owned())),
None => (replaced, None),
};
let new_path = if new_path.is_empty() {
"/".to_owned()
} else {
new_path
};
let new_pq = match new_query {
Some(q) => format!("{new_path}?{q}"),
None => new_path,
};
let mut parts = req.uri().clone().into_parts();
parts.path_and_query = match new_pq.parse() {
Ok(pq) => Some(pq),
Err(e) => {
tracing::warn!(
error = %e,
rewrite_to = %new_pq,
"rewrite produced an invalid URI; ignoring this rewrite",
);
return false;
}
};
match hyper::Uri::from_parts(parts) {
Ok(uri) => {
*req.uri_mut() = uri;
true
}
Err(e) => {
tracing::warn!(
error = %e,
"rewrite produced an unassemblable URI; ignoring",
);
false
}
}
}
fn strip_port(host: &str) -> &str {
if host.starts_with('[')
&& let Some(end) = host.find(']')
{
return &host[..=end];
}
host.split(':').next().unwrap_or(host)
}
#[allow(clippy::too_many_arguments)]
fn build_vhost(
vcfg: &VHostConfig,
metrics: &Arc<Metrics>,
summary: &Arc<ServerSummary>,
cert_state: Option<&crate::cert::state::SharedCertState>,
named_policies: &HashMap<String, Vec<PolicyRule>>,
lb_registry: &SharedLbRegistry,
lb_pools: &mut Vec<LbPoolEntry>,
) -> anyhow::Result<VHost> {
let mut locations = Vec::with_capacity(vcfg.locations.len());
for loc in &vcfg.locations {
let (handler, pool) = crate::handler::build_handler(
&loc.handler,
metrics,
summary,
cert_state,
lb_registry,
)?;
if let Some(pool) = pool {
lb_pools.push(LbPoolEntry {
label: format!("{} {}", vcfg.name.value, loc.path),
pool,
});
}
let header_rules = if loc.request_headers.is_empty()
&& loc.response_headers.is_empty()
{
None
} else {
let req = loc
.request_headers
.iter()
.map(op_from_config)
.collect::<anyhow::Result<Vec<_>>>()?;
let resp = loc
.response_headers
.iter()
.map(op_from_config)
.collect::<anyhow::Result<Vec<_>>>()?;
Some(Arc::new(HeaderRules::new(req, resp)))
};
let policy = if let Some(defs) = &loc.policy {
let rules = inline_rules(defs, named_policies, false)?;
Some(Arc::new(PolicyBlock::new(rules)))
} else {
None
};
let rate_limits = loc
.rate_limits
.iter()
.map(rate_limit_rule_from_config)
.collect::<anyhow::Result<Vec<_>>>()?;
let matcher = loc
.matcher
.as_ref()
.map(matcher_from_config)
.transpose()?
.map(Arc::new);
let rewrite = loc
.rewrite
.as_ref()
.map(rewrite_from_config)
.transpose()?
.map(Arc::new);
let cache_policy = loc.cache.as_ref().map(|c| {
Arc::new(crate::cache::CachePolicy::compile(c))
});
locations.push(Location {
path: loc.path.clone(),
handler,
handler_kind: handler_kind(&loc.handler),
policy,
basic_auth: loc.auth.as_ref().map(|a| Arc::new(a.clone())),
header_rules,
rate_limits,
max_request_body: loc.max_request_body,
matcher,
rewrite,
cache_policy,
});
}
Ok(VHost {
name: Arc::from(vcfg.name.value.as_str()),
locations,
})
}
fn handler_kind(h: &crate::config::HandlerConfig) -> HandlerKind {
use crate::config::HandlerConfig as H;
match h {
H::Static { .. } => HandlerKind::Static,
H::Proxy { .. } => HandlerKind::Proxy,
H::Redirect { .. } => HandlerKind::Redirect,
H::Respond { .. } => HandlerKind::Respond,
H::FastCgi { .. } => HandlerKind::FastCgi,
H::Scgi { .. } => HandlerKind::Scgi,
H::Cgi { .. } => HandlerKind::Cgi,
H::Status => HandlerKind::Status,
H::AuthRequest => HandlerKind::AuthRequest,
}
}
fn rate_limit_rule_from_config(
cfg: &crate::config::RateLimitConfig,
) -> anyhow::Result<Arc<crate::rate_limit::RateLimitRule>> {
use crate::config::RateLimitKeyConfig;
let key = match &cfg.key {
RateLimitKeyConfig::ClientIp => {
crate::rate_limit::RateLimitKey::ClientIp
}
RateLimitKeyConfig::User => {
crate::rate_limit::RateLimitKey::User
}
RateLimitKeyConfig::Header(name) => {
let h = HeaderName::from_bytes(name.as_bytes())
.map_err(|e| {
anyhow::anyhow!(
"rate-limit invalid header name {name:?}: {e}"
)
})?;
crate::rate_limit::RateLimitKey::Header(h)
}
};
Ok(Arc::new(crate::rate_limit::RateLimitRule::new(
cfg.name.clone(),
cfg.rate_per_sec,
cfg.burst,
key,
)))
}
fn matcher_from_config(
cfg: &crate::config::MatcherConfig,
) -> anyhow::Result<crate::matcher::Matcher> {
let predicates = compile_predicates(&cfg.predicates)?;
Ok(crate::matcher::Matcher { predicates })
}
fn compile_predicates(
cfgs: &[crate::config::MatchPredicateConfig],
) -> anyhow::Result<Vec<crate::matcher::MatchPredicate>> {
use crate::config::MatchPredicateConfig;
use crate::matcher::{HeaderMatch, MatchPredicate};
let mut out = Vec::with_capacity(cfgs.len());
for p in cfgs {
match p {
MatchPredicateConfig::Method(methods) => {
let parsed = methods
.iter()
.map(|m| {
hyper::Method::from_bytes(m.as_bytes())
.map_err(|e| {
anyhow::anyhow!(
"matcher invalid method {m:?}: {e}"
)
})
})
.collect::<anyhow::Result<Vec<_>>>()?;
out.push(MatchPredicate::Method(parsed));
}
MatchPredicateConfig::Header { name, values } => {
let h = HeaderName::from_bytes(name.as_bytes())
.map_err(|e| {
anyhow::anyhow!(
"matcher invalid header name {name:?}: {e}"
)
})?;
let mut compiled = Vec::with_capacity(values.len());
for v in values {
if let Some(re) = v.strip_prefix('~') {
compiled.push(HeaderMatch::Regex(
Regex::new(re).map_err(|e| {
anyhow::anyhow!(
"matcher invalid regex {re:?}: {e}"
)
})?,
));
} else {
compiled.push(HeaderMatch::Exact(v.clone()));
}
}
out.push(MatchPredicate::Header {
name: h,
values: compiled,
});
}
MatchPredicateConfig::HeaderAbsent { name } => {
let h = HeaderName::from_bytes(name.as_bytes())
.map_err(|e| {
anyhow::anyhow!(
"matcher invalid header name {name:?}: {e}"
)
})?;
out.push(MatchPredicate::HeaderAbsent { name: h });
}
MatchPredicateConfig::Query { name, values } => {
out.push(MatchPredicate::Query {
name: name.clone(),
values: values.clone(),
});
}
MatchPredicateConfig::Path(patterns) => {
let compiled = patterns
.iter()
.map(|p| {
Regex::new(p).map_err(|e| {
anyhow::anyhow!(
"matcher invalid path regex {p:?}: {e}"
)
})
})
.collect::<anyhow::Result<Vec<_>>>()?;
out.push(MatchPredicate::Path(compiled));
}
MatchPredicateConfig::Not(inner) => {
let inner_compiled = compile_predicates(inner)?;
out.push(MatchPredicate::Not(inner_compiled));
}
}
}
Ok(out)
}
fn rewrite_from_config(
cfg: &crate::config::RewriteConfig,
) -> anyhow::Result<Rewrite> {
let from = Regex::new(&cfg.from).map_err(|e| {
anyhow::anyhow!("rewrite invalid `from` regex: {e}")
})?;
Ok(Rewrite {
from,
to: cfg.to.clone(),
})
}
impl Router {
pub fn all_rate_limit_rules(
&self,
) -> Vec<Arc<crate::rate_limit::RateLimitRule>> {
let mut out: Vec<Arc<crate::rate_limit::RateLimitRule>> =
Vec::new();
let mut seen = std::collections::HashSet::new();
let push_loc = |loc: &Location,
seen: &mut std::collections::HashSet<usize>,
out: &mut Vec<_>| {
for r in &loc.rate_limits {
let id = Arc::as_ptr(r) as usize;
if seen.insert(id) {
out.push(r.clone());
}
}
};
let mut seen_vhost = std::collections::HashSet::new();
for table in self.tables.values() {
let vhosts = table
.literals
.values()
.chain(table.patterns.iter().map(|(_, v)| v))
.chain(table.default.iter());
for v in vhosts {
if !seen_vhost.insert(Arc::as_ptr(v)) {
continue;
}
for loc in &v.locations {
push_loc(loc, &mut seen, &mut out);
}
}
}
out
}
pub fn any_cache_enabled(&self) -> bool {
self.tables.values().any(|table| {
table
.literals
.values()
.chain(table.patterns.iter().map(|(_, v)| v))
.chain(table.default.iter())
.any(|v| {
v.locations
.iter()
.any(|loc| loc.cache_policy.is_some())
})
})
}
}
fn op_from_config(cfg: &HeaderOpConfig) -> anyhow::Result<HeaderOp> {
use crate::config::HeaderOpConfig as C;
Ok(match cfg {
C::Set { name, value } => HeaderOp::Set {
name: HeaderName::from_bytes(name.as_bytes())
.map_err(|_| anyhow::anyhow!("invalid header name '{name}'"))?,
template: Template::parse(value),
},
C::Add { name, value } => HeaderOp::Add {
name: HeaderName::from_bytes(name.as_bytes())
.map_err(|_| anyhow::anyhow!("invalid header name '{name}'"))?,
template: Template::parse(value),
},
C::Remove { name } => HeaderOp::Remove {
name: HeaderName::from_bytes(name.as_bytes())
.map_err(|_| anyhow::anyhow!("invalid header name '{name}'"))?,
},
})
}
#[cfg(test)]
mod tests;