Skip to main content

edgeguard/
lib.rs

1//! EdgeGuard library surface.
2//!
3//! The `edgeguard` binary (`src/main.rs`) is a thin CLI on top of this crate. Exposing the
4//! pipeline as a library lets integration tests drive the *same* `build_state` /
5//! `build_router` entry points the binary uses, so tests exercise the real request path
6//! rather than a reimplementation of it.
7
8pub mod access;
9pub mod acme;
10pub mod auth;
11pub mod config;
12pub mod cors;
13pub mod cp;
14pub mod doctor;
15pub mod generate;
16pub mod limiter;
17pub mod metrics;
18pub mod proxy;
19pub mod reload;
20pub mod scaffold;
21pub mod supervisor;
22pub mod tls;
23pub mod waf;
24
25use std::num::NonZeroU32;
26use std::sync::Arc;
27
28use anyhow::{Context, Result};
29use arc_swap::ArcSwap;
30use axum::{
31    extract::DefaultBodyLimit,
32    routing::{any, get, post},
33    Router,
34};
35use governor::{Quota, RateLimiter};
36use hyper_util::client::legacy::Client;
37use hyper_util::rt::TokioExecutor;
38
39use crate::auth::AuthEngine;
40use crate::config::{parse_duration, parse_rate, parse_size, Config};
41use crate::metrics::Metrics;
42use crate::proxy::{
43    csp_report, metrics_handler, ready, AppState, RouteLimiter, Runtime, StrLimiter,
44};
45
46pub use crate::auth::hash_password;
47
48/// Translate a `rate`/`burst` policy into a GCRA [`Quota`]. Rejects degenerate input (a `0`
49/// rate or burst) rather than silently coercing it to `1/1`, which would mask the operator's
50/// mistake. Shared by the global, per-route, and per-key limiters.
51fn quota(rate: &str, burst: u32) -> Result<Quota> {
52    let (count, period) = parse_rate(rate)?;
53    anyhow::ensure!(count > 0, "rate count must be > 0 (got \"{rate}\")");
54    anyhow::ensure!(burst > 0, "burst must be > 0 (rate \"{rate}\")");
55    // One cell replenishes every (period / count); burst is the bucket depth.
56    let per_cell = period / count;
57    let burst = NonZeroU32::new(burst).unwrap();
58    Ok(Quota::with_period(per_cell)
59        .context("rate too high for a usable replenish interval")?
60        .allow_burst(burst))
61}
62
63/// Build the hot-swappable [`Runtime`] from a fully-resolved [`Config`]: the rate limiters
64/// (global per-IP, per-route, per-key), the auth engine, and the parsed size/timeout limits.
65/// Errors if any size/rate/auth setting is invalid, so a bad config fails fast — at startup
66/// or on reload — rather than per-request. The HTTP client and metric registry live outside
67/// the runtime (in [`AppState`]) so a reload preserves the connection pool and counters.
68pub fn build_runtime(cfg: Arc<Config>) -> Result<Runtime> {
69    let rl = &cfg.ratelimit;
70
71    // Pick the limiter backend. `local` keeps the in-process `governor` limiters below; a
72    // distributed store (`memory`/`redis`) builds a shared-store limiter instead, so the two
73    // are mutually exclusive. An unknown store value fails here rather than silently disabling
74    // limiting.
75    let store_mode = crate::limiter::StoreMode::parse(&rl.store)?;
76    let use_distributed = rl.enabled && store_mode.is_distributed();
77
78    let distributed = if use_distributed {
79        Some(crate::limiter::DistributedLimiter::build(rl, store_mode)?)
80    } else {
81        None
82    };
83
84    // The local `governor` limiters are built only when not using a shared store.
85    let build_local = rl.enabled && !use_distributed;
86
87    let ip_limiter = if build_local {
88        Some(Arc::new(RateLimiter::keyed(quota(&rl.rate, rl.burst)?)))
89    } else {
90        None
91    };
92
93    let mut route_limiters = Vec::new();
94    if build_local {
95        for route in &rl.routes {
96            anyhow::ensure!(
97                !route.path.is_empty(),
98                "ratelimit.routes[].path must not be empty"
99            );
100            route_limiters.push(RouteLimiter {
101                prefix: route.path.clone(),
102                limiter: Arc::new(RateLimiter::keyed(quota(&route.rate, route.burst)?)),
103            });
104        }
105    }
106
107    let key_limiter: Option<Arc<StrLimiter>> = if build_local && rl.per_key.enabled {
108        Some(Arc::new(RateLimiter::keyed(quota(
109            &rl.per_key.rate,
110            rl.per_key.burst,
111        )?)))
112    } else {
113        None
114    };
115
116    // Per-path upstream overrides ([[upstreams]]): normalize each target like `upstream_base`
117    // (trim a trailing '/'). A bad/empty entry fails here so it surfaces at startup/reload.
118    let mut upstream_routes = Vec::with_capacity(cfg.upstreams.len());
119    for route in &cfg.upstreams {
120        anyhow::ensure!(!route.path.is_empty(), "upstreams[].path must not be empty");
121        // The path is a URL path prefix matched against request paths (which start with '/'), so a
122        // value like "api/" could never match — reject it at startup instead of silently routing
123        // everything to the default upstream.
124        anyhow::ensure!(
125            route.path.starts_with('/'),
126            "upstreams[].path must start with '/' (got {:?})",
127            route.path
128        );
129        anyhow::ensure!(
130            !route.target.is_empty(),
131            "upstreams[].target must not be empty (path {:?})",
132            route.path
133        );
134        let base = route.target.trim_end_matches('/').to_string();
135        upstream_routes.push((route.path.clone(), std::sync::Arc::new(base)));
136    }
137
138    let auth = AuthEngine::build(&cfg.auth)?;
139    // Compile the WAF here too, so a bad custom pattern fails fast at startup/reload rather
140    // than per-request (and a broken hot-reload keeps the previous policy).
141    let waf = crate::waf::WafEngine::build(&cfg.waf)?;
142    // Compile the CORS policy (None when disabled). An incoherent policy — credentialed
143    // wildcard, enabled-but-no-origins — fails here, so it's caught at startup/reload.
144    let cors = crate::cors::CorsPolicy::build(&cfg.cors)?;
145    // Compile the IP allow/deny lists (None when both empty). A bad CIDR fails here.
146    let access = crate::access::AccessPolicy::build(&cfg.access)?;
147
148    let max_body = parse_size(&cfg.validation.max_body)?;
149    let max_response_body = parse_size(&cfg.validation.max_response_body)?;
150    let max_header_bytes = parse_size(&cfg.validation.max_header_bytes)?;
151    // A zero duration ("0") means "no timeout".
152    let upstream_timeout = parse_duration(&cfg.validation.upstream_timeout)?;
153    let upstream_timeout = (!upstream_timeout.is_zero()).then_some(upstream_timeout);
154
155    Ok(Runtime {
156        upstream_base: Arc::new(cfg.upstream_base()),
157        upstream_routes,
158        auth,
159        waf,
160        cors,
161        access,
162        distributed,
163        ip_limiter,
164        route_limiters,
165        key_limiter,
166        max_body,
167        max_response_body,
168        max_header_bytes,
169        upstream_timeout,
170        stream_passthrough: cfg.validation.stream_passthrough,
171        websocket_passthrough: cfg.validation.websocket_passthrough,
172        cfg,
173    })
174}
175
176/// Build the shared [`AppState`]: a fresh [`Runtime`] wrapped in an [`ArcSwap`] for
177/// hot-reload, the upstream HTTP client, and the metric registry.
178pub fn build_state(cfg: Arc<Config>) -> Result<AppState> {
179    // Build the managed-mode client (if `[control_plane]` is enabled) before `cfg` is consumed.
180    let cp = crate::cp::CpClient::from_cfg(&cfg.control_plane)?;
181    let runtime = build_runtime(cfg)?;
182    let client =
183        Client::builder(TokioExecutor::new()).build_http::<http_body_util::Full<bytes::Bytes>>();
184    Ok(AppState {
185        client,
186        metrics: Arc::new(Metrics::new()),
187        runtime: Arc::new(ArcSwap::from_pointee(runtime)),
188        cp,
189        quota: Arc::new(crate::cp::QuotaState::default()),
190    })
191}
192
193/// Build the combined axum [`Router`]: the internal `/__edgeguard/*` endpoints (health,
194/// readiness, Prometheus metrics, CSP report sink) plus the catch-all proxy handler, all on one
195/// listener. This is the default (single-port) topology; for the public/private split see
196/// [`build_public_router`] / [`build_admin_router`]. Body limits are enforced inside the proxy
197/// handler, so the default layer is disabled there; the CSP sink keeps a small explicit cap
198/// since it parses the body.
199pub fn build_router(state: AppState) -> Router {
200    let router = public_routes()
201        .merge(admin_routes())
202        .layer(DefaultBodyLimit::disable());
203    maybe_compress(router, &state).with_state(state)
204}
205
206/// Optionally wrap the router in a gzip [`CompressionLayer`] when `validation.compress_responses`
207/// is set. Compression is a listener-level concern (not hot-reloadable), so it reads the *initial*
208/// config from `state`. The predicate excludes `text/event-stream` so SSE streaming is never held
209/// back by the compressor (on top of the default skip-small / skip-already-compressed rules).
210fn maybe_compress(router: Router<AppState>, state: &AppState) -> Router<AppState> {
211    use tower_http::compression::predicate::{DefaultPredicate, NotForContentType, Predicate};
212    use tower_http::compression::CompressionLayer;
213
214    if !state.runtime.load().cfg.validation.compress_responses {
215        return router;
216    }
217    let predicate = DefaultPredicate::new().and(NotForContentType::const_new("text/event-stream"));
218    router.layer(CompressionLayer::new().compress_when(predicate))
219}
220
221/// The **public** router (used in public/private split mode): the catch-all proxy plus the
222/// browser-facing CSP report sink. The ops endpoints (health/readiness/metrics) are *not* here
223/// — they live on the private [`build_admin_router`] listener, so they aren't exposed publicly.
224pub fn build_public_router(state: AppState) -> Router {
225    let router = public_routes().layer(DefaultBodyLimit::disable());
226    maybe_compress(router, &state).with_state(state)
227}
228
229/// The **private/admin** router (used in public/private split mode): the internal ops endpoints
230/// (health, readiness, metrics). It has no proxy fallback, so an unknown path returns `404`
231/// rather than being forwarded upstream. Shares the same [`AppState`] as the public router, so
232/// `/__edgeguard/metrics` reports the live proxy counters.
233pub fn build_admin_router(state: AppState) -> Router {
234    admin_routes().with_state(state)
235}
236
237/// Public-surface routes: the proxy fallback and the CSP report sink (which browsers POST to
238/// from the public web, so it stays on the public listener).
239fn public_routes() -> Router<AppState> {
240    Router::new()
241        .route(
242            "/__edgeguard/csp-report",
243            post(csp_report).layer(DefaultBodyLimit::max(64 * 1024)),
244        )
245        .fallback(any(proxy::handle))
246}
247
248/// Internal ops routes: liveness, readiness, and the Prometheus metrics scrape.
249fn admin_routes() -> Router<AppState> {
250    Router::new()
251        .route("/__edgeguard/health", get(|| async { "ok" }))
252        .route("/__edgeguard/ready", get(ready))
253        .route("/__edgeguard/metrics", get(metrics_handler))
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::config::RateLimitCfg;
260
261    fn cfg_with_ratelimit(rate: &str, burst: u32) -> Config {
262        Config {
263            ratelimit: RateLimitCfg {
264                enabled: true,
265                rate: rate.into(),
266                burst,
267                ..Default::default()
268            },
269            ..Default::default()
270        }
271    }
272
273    #[test]
274    fn build_state_rejects_zero_rate() {
275        // `0/min` is a misconfiguration, not "1/min" — validation fails before we ever
276        // build the client, so no async runtime is needed here.
277        assert!(build_state(Arc::new(cfg_with_ratelimit("0/min", 20))).is_err());
278    }
279
280    #[test]
281    fn build_state_rejects_zero_burst() {
282        assert!(build_state(Arc::new(cfg_with_ratelimit("60/min", 0))).is_err());
283    }
284
285    #[test]
286    fn build_runtime_builds_route_and_key_limiters() {
287        let mut cfg = Config::default();
288        cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
289            path: "/api/".into(),
290            rate: "10/sec".into(),
291            burst: 5,
292        }];
293        cfg.ratelimit.per_key = crate::config::PerKeyRateLimit {
294            enabled: true,
295            rate: "1000/hour".into(),
296            burst: 100,
297        };
298        let rt = build_runtime(Arc::new(cfg)).unwrap();
299        assert_eq!(rt.route_limiters.len(), 1);
300        assert_eq!(rt.route_limiters[0].prefix, "/api/");
301        assert!(rt.key_limiter.is_some());
302    }
303
304    #[test]
305    fn build_runtime_rejects_bad_route_rate() {
306        let mut cfg = Config::default();
307        cfg.ratelimit.routes = vec![crate::config::RouteRateLimit {
308            path: "/api/".into(),
309            rate: "0/sec".into(),
310            burst: 5,
311        }];
312        assert!(build_runtime(Arc::new(cfg)).is_err());
313    }
314
315    #[test]
316    fn build_runtime_validates_upstream_route_paths() {
317        // A leading '/' is required — "api/" could never match a request path.
318        let bad = Config {
319            upstreams: vec![crate::config::UpstreamRoute {
320                path: "api/".into(),
321                target: "http://api:4000".into(),
322            }],
323            ..Default::default()
324        };
325        assert!(build_runtime(Arc::new(bad)).is_err());
326
327        // A well-formed route compiles, with the target's trailing slash trimmed.
328        let ok = Config {
329            upstreams: vec![crate::config::UpstreamRoute {
330                path: "/api/".into(),
331                target: "http://api:4000/".into(),
332            }],
333            ..Default::default()
334        };
335        let rt = build_runtime(Arc::new(ok)).unwrap();
336        assert_eq!(rt.pick_upstream("/api/x"), "http://api:4000");
337        assert_eq!(rt.pick_upstream("/other"), rt.upstream_base.as_str());
338    }
339}