1pub mod acme;
9pub mod auth;
10pub mod config;
11pub mod generate;
12pub mod limiter;
13pub mod metrics;
14pub mod proxy;
15pub mod reload;
16pub mod supervisor;
17pub mod tls;
18pub mod waf;
19
20use std::num::NonZeroU32;
21use std::sync::Arc;
22
23use anyhow::{Context, Result};
24use arc_swap::ArcSwap;
25use axum::{
26 extract::DefaultBodyLimit,
27 routing::{any, get, post},
28 Router,
29};
30use governor::{Quota, RateLimiter};
31use hyper_util::client::legacy::Client;
32use hyper_util::rt::TokioExecutor;
33
34use crate::auth::AuthEngine;
35use crate::config::{parse_duration, parse_rate, parse_size, Config};
36use crate::metrics::Metrics;
37use crate::proxy::{
38 csp_report, metrics_handler, ready, AppState, RouteLimiter, Runtime, StrLimiter,
39};
40
41pub use crate::auth::hash_password;
42
43fn quota(rate: &str, burst: u32) -> Result<Quota> {
47 let (count, period) = parse_rate(rate)?;
48 anyhow::ensure!(count > 0, "rate count must be > 0 (got \"{rate}\")");
49 anyhow::ensure!(burst > 0, "burst must be > 0 (rate \"{rate}\")");
50 let per_cell = period / count;
52 let burst = NonZeroU32::new(burst).unwrap();
53 Ok(Quota::with_period(per_cell)
54 .context("rate too high for a usable replenish interval")?
55 .allow_burst(burst))
56}
57
58pub fn build_runtime(cfg: Arc<Config>) -> Result<Runtime> {
64 let rl = &cfg.ratelimit;
65
66 let store_mode = crate::limiter::StoreMode::parse(&rl.store)?;
71 let use_distributed = rl.enabled && store_mode.is_distributed();
72
73 let distributed = if use_distributed {
74 Some(crate::limiter::DistributedLimiter::build(rl, store_mode)?)
75 } else {
76 None
77 };
78
79 let build_local = rl.enabled && !use_distributed;
81
82 let ip_limiter = if build_local {
83 Some(Arc::new(RateLimiter::keyed(quota(&rl.rate, rl.burst)?)))
84 } else {
85 None
86 };
87
88 let mut route_limiters = Vec::new();
89 if build_local {
90 for route in &rl.routes {
91 anyhow::ensure!(
92 !route.path.is_empty(),
93 "ratelimit.routes[].path must not be empty"
94 );
95 route_limiters.push(RouteLimiter {
96 prefix: route.path.clone(),
97 limiter: Arc::new(RateLimiter::keyed(quota(&route.rate, route.burst)?)),
98 });
99 }
100 }
101
102 let key_limiter: Option<Arc<StrLimiter>> = if build_local && rl.per_key.enabled {
103 Some(Arc::new(RateLimiter::keyed(quota(
104 &rl.per_key.rate,
105 rl.per_key.burst,
106 )?)))
107 } else {
108 None
109 };
110
111 let auth = AuthEngine::build(&cfg.auth)?;
112 let waf = crate::waf::WafEngine::build(&cfg.waf)?;
115
116 let max_body = parse_size(&cfg.validation.max_body)?;
117 let max_response_body = parse_size(&cfg.validation.max_response_body)?;
118 let max_header_bytes = parse_size(&cfg.validation.max_header_bytes)?;
119 let upstream_timeout = parse_duration(&cfg.validation.upstream_timeout)?;
121 let upstream_timeout = (!upstream_timeout.is_zero()).then_some(upstream_timeout);
122
123 Ok(Runtime {
124 upstream_base: Arc::new(cfg.upstream_base()),
125 auth,
126 waf,
127 distributed,
128 ip_limiter,
129 route_limiters,
130 key_limiter,
131 max_body,
132 max_response_body,
133 max_header_bytes,
134 upstream_timeout,
135 cfg,
136 })
137}
138
139pub fn build_state(cfg: Arc<Config>) -> Result<AppState> {
142 let runtime = build_runtime(cfg)?;
143 let client =
144 Client::builder(TokioExecutor::new()).build_http::<http_body_util::Full<bytes::Bytes>>();
145 Ok(AppState {
146 client,
147 metrics: Arc::new(Metrics::new()),
148 runtime: Arc::new(ArcSwap::from_pointee(runtime)),
149 })
150}
151
152pub fn build_router(state: AppState) -> Router {
159 public_routes()
160 .merge(admin_routes())
161 .layer(DefaultBodyLimit::disable())
162 .with_state(state)
163}
164
165pub fn build_public_router(state: AppState) -> Router {
169 public_routes()
170 .layer(DefaultBodyLimit::disable())
171 .with_state(state)
172}
173
174pub fn build_admin_router(state: AppState) -> Router {
179 admin_routes().with_state(state)
180}
181
182fn public_routes() -> Router<AppState> {
185 Router::new()
186 .route(
187 "/__edgeguard/csp-report",
188 post(csp_report).layer(DefaultBodyLimit::max(64 * 1024)),
189 )
190 .fallback(any(proxy::handle))
191}
192
193fn admin_routes() -> Router<AppState> {
195 Router::new()
196 .route("/__edgeguard/health", get(|| async { "ok" }))
197 .route("/__edgeguard/ready", get(ready))
198 .route("/__edgeguard/metrics", get(metrics_handler))
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::config::RateLimitCfg;
205
206 fn cfg_with_ratelimit(rate: &str, burst: u32) -> Config {
207 Config {
208 ratelimit: RateLimitCfg {
209 enabled: true,
210 rate: rate.into(),
211 burst,
212 ..Default::default()
213 },
214 ..Default::default()
215 }
216 }
217
218 #[test]
219 fn build_state_rejects_zero_rate() {
220 assert!(build_state(Arc::new(cfg_with_ratelimit("0/min", 20))).is_err());
223 }
224
225 #[test]
226 fn build_state_rejects_zero_burst() {
227 assert!(build_state(Arc::new(cfg_with_ratelimit("60/min", 0))).is_err());
228 }
229
230 #[test]
231 fn build_runtime_builds_route_and_key_limiters() {
232 let mut cfg = Config::default();
233 cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
234 path: "/api/".into(),
235 rate: "10/sec".into(),
236 burst: 5,
237 }];
238 cfg.ratelimit.per_key = crate::config::PerKeyRateLimit {
239 enabled: true,
240 rate: "1000/hour".into(),
241 burst: 100,
242 };
243 let rt = build_runtime(Arc::new(cfg)).unwrap();
244 assert_eq!(rt.route_limiters.len(), 1);
245 assert_eq!(rt.route_limiters[0].prefix, "/api/");
246 assert!(rt.key_limiter.is_some());
247 }
248
249 #[test]
250 fn build_runtime_rejects_bad_route_rate() {
251 let mut cfg = Config::default();
252 cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
253 path: "/api/".into(),
254 rate: "0/sec".into(),
255 burst: 5,
256 }];
257 assert!(build_runtime(Arc::new(cfg)).is_err());
258 }
259}