1pub mod access;
9pub mod acme;
10pub mod auth;
11pub mod config;
12pub mod cors;
13pub mod cp;
14pub mod doctor;
15pub mod generate;
16pub mod limiter;
17pub mod metrics;
18pub mod proxy;
19pub mod reload;
20pub mod scaffold;
21pub mod supervisor;
22pub mod tls;
23pub mod waf;
24
25use std::num::NonZeroU32;
26use std::sync::Arc;
27
28use anyhow::{Context, Result};
29use arc_swap::ArcSwap;
30use axum::{
31 extract::DefaultBodyLimit,
32 routing::{any, get, post},
33 Router,
34};
35use governor::{Quota, RateLimiter};
36use hyper_util::client::legacy::Client;
37use hyper_util::rt::TokioExecutor;
38
39use crate::auth::AuthEngine;
40use crate::config::{parse_duration, parse_rate, parse_size, Config};
41use crate::metrics::Metrics;
42use crate::proxy::{
43 csp_report, metrics_handler, ready, AppState, RouteLimiter, Runtime, StrLimiter,
44};
45
46pub use crate::auth::hash_password;
47
48fn quota(rate: &str, burst: u32) -> Result<Quota> {
52 let (count, period) = parse_rate(rate)?;
53 anyhow::ensure!(count > 0, "rate count must be > 0 (got \"{rate}\")");
54 anyhow::ensure!(burst > 0, "burst must be > 0 (rate \"{rate}\")");
55 let per_cell = period / count;
57 let burst = NonZeroU32::new(burst).unwrap();
58 Ok(Quota::with_period(per_cell)
59 .context("rate too high for a usable replenish interval")?
60 .allow_burst(burst))
61}
62
63pub fn build_runtime(cfg: Arc<Config>) -> Result<Runtime> {
69 let rl = &cfg.ratelimit;
70
71 let store_mode = crate::limiter::StoreMode::parse(&rl.store)?;
76 let use_distributed = rl.enabled && store_mode.is_distributed();
77
78 let distributed = if use_distributed {
79 Some(crate::limiter::DistributedLimiter::build(rl, store_mode)?)
80 } else {
81 None
82 };
83
84 let build_local = rl.enabled && !use_distributed;
86
87 let ip_limiter = if build_local {
88 Some(Arc::new(RateLimiter::keyed(quota(&rl.rate, rl.burst)?)))
89 } else {
90 None
91 };
92
93 let mut route_limiters = Vec::new();
94 if build_local {
95 for route in &rl.routes {
96 anyhow::ensure!(
97 !route.path.is_empty(),
98 "ratelimit.routes[].path must not be empty"
99 );
100 route_limiters.push(RouteLimiter {
101 prefix: route.path.clone(),
102 limiter: Arc::new(RateLimiter::keyed(quota(&route.rate, route.burst)?)),
103 });
104 }
105 }
106
107 let key_limiter: Option<Arc<StrLimiter>> = if build_local && rl.per_key.enabled {
108 Some(Arc::new(RateLimiter::keyed(quota(
109 &rl.per_key.rate,
110 rl.per_key.burst,
111 )?)))
112 } else {
113 None
114 };
115
116 let mut upstream_routes = Vec::with_capacity(cfg.upstreams.len());
119 for route in &cfg.upstreams {
120 anyhow::ensure!(!route.path.is_empty(), "upstreams[].path must not be empty");
121 anyhow::ensure!(
125 route.path.starts_with('/'),
126 "upstreams[].path must start with '/' (got {:?})",
127 route.path
128 );
129 anyhow::ensure!(
130 !route.target.is_empty(),
131 "upstreams[].target must not be empty (path {:?})",
132 route.path
133 );
134 let base = route.target.trim_end_matches('/').to_string();
135 upstream_routes.push((route.path.clone(), std::sync::Arc::new(base)));
136 }
137
138 let auth = AuthEngine::build(&cfg.auth)?;
139 let waf = crate::waf::WafEngine::build(&cfg.waf)?;
142 let cors = crate::cors::CorsPolicy::build(&cfg.cors)?;
145 let access = crate::access::AccessPolicy::build(&cfg.access)?;
147
148 let max_body = parse_size(&cfg.validation.max_body)?;
149 let max_response_body = parse_size(&cfg.validation.max_response_body)?;
150 let max_header_bytes = parse_size(&cfg.validation.max_header_bytes)?;
151 let upstream_timeout = parse_duration(&cfg.validation.upstream_timeout)?;
153 let upstream_timeout = (!upstream_timeout.is_zero()).then_some(upstream_timeout);
154
155 Ok(Runtime {
156 upstream_base: Arc::new(cfg.upstream_base()),
157 upstream_routes,
158 auth,
159 waf,
160 cors,
161 access,
162 distributed,
163 ip_limiter,
164 route_limiters,
165 key_limiter,
166 max_body,
167 max_response_body,
168 max_header_bytes,
169 upstream_timeout,
170 stream_passthrough: cfg.validation.stream_passthrough,
171 websocket_passthrough: cfg.validation.websocket_passthrough,
172 cfg,
173 })
174}
175
176pub fn build_state(cfg: Arc<Config>) -> Result<AppState> {
179 let cp = crate::cp::CpClient::from_cfg(&cfg.control_plane)?;
181 let runtime = build_runtime(cfg)?;
182 let client =
183 Client::builder(TokioExecutor::new()).build_http::<http_body_util::Full<bytes::Bytes>>();
184 Ok(AppState {
185 client,
186 metrics: Arc::new(Metrics::new()),
187 runtime: Arc::new(ArcSwap::from_pointee(runtime)),
188 cp,
189 quota: Arc::new(crate::cp::QuotaState::default()),
190 })
191}
192
193pub fn build_router(state: AppState) -> Router {
200 let router = public_routes()
201 .merge(admin_routes())
202 .layer(DefaultBodyLimit::disable());
203 maybe_compress(router, &state).with_state(state)
204}
205
206fn maybe_compress(router: Router<AppState>, state: &AppState) -> Router<AppState> {
211 use tower_http::compression::predicate::{DefaultPredicate, NotForContentType, Predicate};
212 use tower_http::compression::CompressionLayer;
213
214 if !state.runtime.load().cfg.validation.compress_responses {
215 return router;
216 }
217 let predicate = DefaultPredicate::new().and(NotForContentType::const_new("text/event-stream"));
218 router.layer(CompressionLayer::new().compress_when(predicate))
219}
220
221pub fn build_public_router(state: AppState) -> Router {
225 let router = public_routes().layer(DefaultBodyLimit::disable());
226 maybe_compress(router, &state).with_state(state)
227}
228
229pub fn build_admin_router(state: AppState) -> Router {
234 admin_routes().with_state(state)
235}
236
237fn public_routes() -> Router<AppState> {
240 Router::new()
241 .route(
242 "/__edgeguard/csp-report",
243 post(csp_report).layer(DefaultBodyLimit::max(64 * 1024)),
244 )
245 .fallback(any(proxy::handle))
246}
247
248fn admin_routes() -> Router<AppState> {
250 Router::new()
251 .route("/__edgeguard/health", get(|| async { "ok" }))
252 .route("/__edgeguard/ready", get(ready))
253 .route("/__edgeguard/metrics", get(metrics_handler))
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::config::RateLimitCfg;
260
261 fn cfg_with_ratelimit(rate: &str, burst: u32) -> Config {
262 Config {
263 ratelimit: RateLimitCfg {
264 enabled: true,
265 rate: rate.into(),
266 burst,
267 ..Default::default()
268 },
269 ..Default::default()
270 }
271 }
272
273 #[test]
274 fn build_state_rejects_zero_rate() {
275 assert!(build_state(Arc::new(cfg_with_ratelimit("0/min", 20))).is_err());
278 }
279
280 #[test]
281 fn build_state_rejects_zero_burst() {
282 assert!(build_state(Arc::new(cfg_with_ratelimit("60/min", 0))).is_err());
283 }
284
285 #[test]
286 fn build_runtime_builds_route_and_key_limiters() {
287 let mut cfg = Config::default();
288 cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
289 path: "/api/".into(),
290 rate: "10/sec".into(),
291 burst: 5,
292 }];
293 cfg.ratelimit.per_key = crate::config::PerKeyRateLimit {
294 enabled: true,
295 rate: "1000/hour".into(),
296 burst: 100,
297 };
298 let rt = build_runtime(Arc::new(cfg)).unwrap();
299 assert_eq!(rt.route_limiters.len(), 1);
300 assert_eq!(rt.route_limiters[0].prefix, "/api/");
301 assert!(rt.key_limiter.is_some());
302 }
303
304 #[test]
305 fn build_runtime_rejects_bad_route_rate() {
306 let mut cfg = Config::default();
307 cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
308 path: "/api/".into(),
309 rate: "0/sec".into(),
310 burst: 5,
311 }];
312 assert!(build_runtime(Arc::new(cfg)).is_err());
313 }
314
315 #[test]
316 fn build_runtime_validates_upstream_route_paths() {
317 let bad = Config {
319 upstreams: vec![crate::config::UpstreamRoute {
320 path: "api/".into(),
321 target: "http://api:4000".into(),
322 }],
323 ..Default::default()
324 };
325 assert!(build_runtime(Arc::new(bad)).is_err());
326
327 let ok = Config {
329 upstreams: vec![crate::config::UpstreamRoute {
330 path: "/api/".into(),
331 target: "http://api:4000/".into(),
332 }],
333 ..Default::default()
334 };
335 let rt = build_runtime(Arc::new(ok)).unwrap();
336 assert_eq!(rt.pick_upstream("/api/x"), "http://api:4000");
337 assert_eq!(rt.pick_upstream("/other"), rt.upstream_base.as_str());
338 }
339}