pub mod acme;
pub mod auth;
pub mod config;
pub mod generate;
pub mod limiter;
pub mod metrics;
pub mod proxy;
pub mod reload;
pub mod supervisor;
pub mod tls;
pub mod waf;
use std::num::NonZeroU32;
use std::sync::Arc;
use anyhow::{Context, Result};
use arc_swap::ArcSwap;
use axum::{
extract::DefaultBodyLimit,
routing::{any, get, post},
Router,
};
use governor::{Quota, RateLimiter};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use crate::auth::AuthEngine;
use crate::config::{parse_duration, parse_rate, parse_size, Config};
use crate::metrics::Metrics;
use crate::proxy::{
csp_report, metrics_handler, ready, AppState, RouteLimiter, Runtime, StrLimiter,
};
pub use crate::auth::hash_password;
fn quota(rate: &str, burst: u32) -> Result<Quota> {
let (count, period) = parse_rate(rate)?;
anyhow::ensure!(count > 0, "rate count must be > 0 (got \"{rate}\")");
anyhow::ensure!(burst > 0, "burst must be > 0 (rate \"{rate}\")");
let per_cell = period / count;
let burst = NonZeroU32::new(burst).unwrap();
Ok(Quota::with_period(per_cell)
.context("rate too high for a usable replenish interval")?
.allow_burst(burst))
}
pub fn build_runtime(cfg: Arc<Config>) -> Result<Runtime> {
let rl = &cfg.ratelimit;
let store_mode = crate::limiter::StoreMode::parse(&rl.store)?;
let use_distributed = rl.enabled && store_mode.is_distributed();
let distributed = if use_distributed {
Some(crate::limiter::DistributedLimiter::build(rl, store_mode)?)
} else {
None
};
let build_local = rl.enabled && !use_distributed;
let ip_limiter = if build_local {
Some(Arc::new(RateLimiter::keyed(quota(&rl.rate, rl.burst)?)))
} else {
None
};
let mut route_limiters = Vec::new();
if build_local {
for route in &rl.routes {
anyhow::ensure!(
!route.path.is_empty(),
"ratelimit.routes[].path must not be empty"
);
route_limiters.push(RouteLimiter {
prefix: route.path.clone(),
limiter: Arc::new(RateLimiter::keyed(quota(&route.rate, route.burst)?)),
});
}
}
let key_limiter: Option<Arc<StrLimiter>> = if build_local && rl.per_key.enabled {
Some(Arc::new(RateLimiter::keyed(quota(
&rl.per_key.rate,
rl.per_key.burst,
)?)))
} else {
None
};
let auth = AuthEngine::build(&cfg.auth)?;
let waf = crate::waf::WafEngine::build(&cfg.waf)?;
let max_body = parse_size(&cfg.validation.max_body)?;
let max_response_body = parse_size(&cfg.validation.max_response_body)?;
let max_header_bytes = parse_size(&cfg.validation.max_header_bytes)?;
let upstream_timeout = parse_duration(&cfg.validation.upstream_timeout)?;
let upstream_timeout = (!upstream_timeout.is_zero()).then_some(upstream_timeout);
Ok(Runtime {
upstream_base: Arc::new(cfg.upstream_base()),
auth,
waf,
distributed,
ip_limiter,
route_limiters,
key_limiter,
max_body,
max_response_body,
max_header_bytes,
upstream_timeout,
cfg,
})
}
pub fn build_state(cfg: Arc<Config>) -> Result<AppState> {
let runtime = build_runtime(cfg)?;
let client =
Client::builder(TokioExecutor::new()).build_http::<http_body_util::Full<bytes::Bytes>>();
Ok(AppState {
client,
metrics: Arc::new(Metrics::new()),
runtime: Arc::new(ArcSwap::from_pointee(runtime)),
})
}
pub fn build_router(state: AppState) -> Router {
public_routes()
.merge(admin_routes())
.layer(DefaultBodyLimit::disable())
.with_state(state)
}
pub fn build_public_router(state: AppState) -> Router {
public_routes()
.layer(DefaultBodyLimit::disable())
.with_state(state)
}
pub fn build_admin_router(state: AppState) -> Router {
admin_routes().with_state(state)
}
fn public_routes() -> Router<AppState> {
Router::new()
.route(
"/__edgeguard/csp-report",
post(csp_report).layer(DefaultBodyLimit::max(64 * 1024)),
)
.fallback(any(proxy::handle))
}
fn admin_routes() -> Router<AppState> {
Router::new()
.route("/__edgeguard/health", get(|| async { "ok" }))
.route("/__edgeguard/ready", get(ready))
.route("/__edgeguard/metrics", get(metrics_handler))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RateLimitCfg;
fn cfg_with_ratelimit(rate: &str, burst: u32) -> Config {
Config {
ratelimit: RateLimitCfg {
enabled: true,
rate: rate.into(),
burst,
..Default::default()
},
..Default::default()
}
}
#[test]
fn build_state_rejects_zero_rate() {
assert!(build_state(Arc::new(cfg_with_ratelimit("0/min", 20))).is_err());
}
#[test]
fn build_state_rejects_zero_burst() {
assert!(build_state(Arc::new(cfg_with_ratelimit("60/min", 0))).is_err());
}
#[test]
fn build_runtime_builds_route_and_key_limiters() {
let mut cfg = Config::default();
cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
path: "/api/".into(),
rate: "10/sec".into(),
burst: 5,
}];
cfg.ratelimit.per_key = crate::config::PerKeyRateLimit {
enabled: true,
rate: "1000/hour".into(),
burst: 100,
};
let rt = build_runtime(Arc::new(cfg)).unwrap();
assert_eq!(rt.route_limiters.len(), 1);
assert_eq!(rt.route_limiters[0].prefix, "/api/");
assert!(rt.key_limiter.is_some());
}
#[test]
fn build_runtime_rejects_bad_route_rate() {
let mut cfg = Config::default();
cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
path: "/api/".into(),
rate: "0/sec".into(),
burst: 5,
}];
assert!(build_runtime(Arc::new(cfg)).is_err());
}
}