Skip to main content

edgeguard/
cp.rs

1//! Managed-mode client: talk to a remote control plane.
2//!
3//! Off by default. When `[control_plane]` is configured, the edge:
4//!   * **pulls its policy** (conditional `GET`, ETag/`304`) and hot-reloads it through the same
5//!     `build_runtime` + arc-swap path a local file edit uses;
6//!   * **reports usage** (requests + ingress/egress bytes) as periodic deltas;
7//!   * **forwards CSP reports** it receives to the control plane.
8//!
9//! This is a generic "pull config / report usage to a URL" client — it carries no control-plane
10//! logic; it just speaks the control plane's edge HTTP API with a per-tenant bearer token. Built
11//! on the same `reqwest` + rustls stack as the JWKS fetcher (`auth.rs`).
12
13use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16
17use anyhow::{Context, Result};
18use arc_swap::ArcSwap;
19use bytes::Bytes;
20use serde::{Deserialize, Serialize};
21use tokio::sync::watch;
22use tracing::{info, warn};
23
24use crate::config::{Config, ControlPlaneCfg};
25use crate::metrics::Metrics;
26use crate::proxy::Runtime;
27
28/// A usage delta reported to the control plane (matches its `/v3/edge/{id}/usage` wire shape).
29#[derive(Debug, Clone, Copy, Default, Serialize)]
30pub struct UsageDelta {
31    pub requests: u64,
32    pub ingress_bytes: u64,
33    pub egress_bytes: u64,
34}
35
36/// The subset of the control plane's `PolicyDocument` the edge needs.
37#[derive(Debug, Deserialize)]
38struct PolicyResp {
39    etag: String,
40    body: String,
41}
42
43/// The subset of the control plane's `QuotaStatus` the edge needs to enforce a hard stop.
44#[derive(Debug, Deserialize)]
45struct QuotaResp {
46    over_quota: bool,
47    #[serde(default)]
48    reset_epoch: i64,
49}
50
51/// Shared, hot-reload-surviving quota verdict the proxy enforces. The [`quota_loop`] writes it from
52/// the control plane's verdict; the proxy reads it per request. It lives in `AppState` (not the
53/// hot-swappable `Runtime`) so a policy reload never resets the enforcement state.
54#[derive(Debug, Default)]
55pub struct QuotaState {
56    /// `true` while the edge is over its quota — the proxy returns `429`.
57    pub over_quota: AtomicBool,
58    /// Unix second the quota resets (period rollover); `0` = unknown. The `Retry-After` hint.
59    pub reset_epoch: AtomicI64,
60}
61
62impl QuotaState {
63    /// Whether the proxy should currently hard-stop the edge's traffic.
64    pub fn blocked(&self) -> bool {
65        self.over_quota.load(Ordering::Relaxed)
66    }
67
68    /// The reset-epoch hint last reported by the control plane (`0` if none yet).
69    pub fn reset_epoch(&self) -> i64 {
70        self.reset_epoch.load(Ordering::Relaxed)
71    }
72
73    fn apply(&self, over_quota: bool, reset_epoch: i64) {
74        self.over_quota.store(over_quota, Ordering::Relaxed);
75        self.reset_epoch.store(reset_epoch, Ordering::Relaxed);
76    }
77}
78
79/// Outcome of a conditional policy pull.
80pub enum PullResult {
81    /// The edge's ETag still matched — nothing changed.
82    NotModified,
83    /// A new policy: its opaque TOML body and the new ETag.
84    Policy { body: String, etag: String },
85}
86
87/// Outbound client to a control plane's per-tenant edge API.
88pub struct CpClient {
89    http: reqwest::Client,
90    /// `{base}/v3/edge/{tenant}` prefix, already trimmed.
91    edge_base: String,
92    token: String,
93}
94
95impl CpClient {
96    /// Build the client if managed mode is enabled and configured; otherwise `None`. Fails fast on
97    /// an enabled-but-incomplete config so a misconfigured edge doesn't silently run unmanaged.
98    pub fn from_cfg(cfg: &ControlPlaneCfg) -> Result<Option<Arc<CpClient>>> {
99        if !cfg.enabled {
100            return Ok(None);
101        }
102        anyhow::ensure!(
103            !cfg.url.is_empty(),
104            "control_plane.url is required when enabled"
105        );
106        anyhow::ensure!(
107            !cfg.tenant_id.is_empty(),
108            "control_plane.tenant_id is required when enabled"
109        );
110        anyhow::ensure!(
111            !cfg.edge_token.is_empty(),
112            "control_plane.edge_token (or EDGEGUARD_CP_EDGE_TOKEN) is required when enabled"
113        );
114        let http = reqwest::Client::builder()
115            .timeout(Duration::from_secs(10))
116            .build()
117            .context("building control-plane HTTP client")?;
118        let edge_base = format!(
119            "{}/v3/edge/{}",
120            cfg.url.trim_end_matches('/'),
121            cfg.tenant_id
122        );
123        Ok(Some(Arc::new(CpClient {
124            http,
125            edge_base,
126            token: cfg.edge_token.clone(),
127        })))
128    }
129
130    /// Conditional policy pull. `200` → `Policy`; `304` → `NotModified`; other statuses → `Err`.
131    pub async fn pull_policy(&self, etag: Option<&str>) -> Result<PullResult> {
132        let mut req = self
133            .http
134            .get(format!("{}/policy", self.edge_base))
135            .bearer_auth(&self.token);
136        if let Some(e) = etag {
137            req = req.header(reqwest::header::IF_NONE_MATCH, e);
138        }
139        let resp = req.send().await.context("pulling policy")?;
140        match resp.status() {
141            reqwest::StatusCode::NOT_MODIFIED => Ok(PullResult::NotModified),
142            s if s.is_success() => {
143                let doc: PolicyResp = resp.json().await.context("parsing policy document")?;
144                Ok(PullResult::Policy {
145                    body: doc.body,
146                    etag: doc.etag,
147                })
148            }
149            s => anyhow::bail!("control plane returned {s} for policy pull"),
150        }
151    }
152
153    /// Report a usage delta.
154    pub async fn report_usage(&self, delta: &UsageDelta) -> Result<()> {
155        self.http
156            .post(format!("{}/usage", self.edge_base))
157            .bearer_auth(&self.token)
158            .json(delta)
159            .send()
160            .await
161            .context("reporting usage")?
162            .error_for_status()
163            .context("control plane rejected usage report")?;
164        Ok(())
165    }
166
167    /// Pull the tenant's current quota verdict (`over_quota` + `reset_epoch`). Any non-success
168    /// status is an error so the caller keeps the last verdict rather than acting on a partial read.
169    pub async fn pull_quota(&self) -> Result<(bool, i64)> {
170        let resp = self
171            .http
172            .get(format!("{}/quota", self.edge_base))
173            .bearer_auth(&self.token)
174            .send()
175            .await
176            .context("pulling quota")?
177            .error_for_status()
178            .context("control plane rejected quota poll")?;
179        let q: QuotaResp = resp.json().await.context("parsing quota verdict")?;
180        Ok((q.over_quota, q.reset_epoch))
181    }
182
183    /// Forward a raw CSP report body (best-effort; errors are logged, never surfaced).
184    pub async fn forward_csp(&self, raw: &Bytes) {
185        let res = self
186            .http
187            .post(format!("{}/csp-report", self.edge_base))
188            .bearer_auth(&self.token)
189            .header(reqwest::header::CONTENT_TYPE, "application/json")
190            .body(raw.clone())
191            .send()
192            .await;
193        if let Err(e) = res {
194            warn!(error = %e, "forwarding CSP report to control plane failed");
195        }
196    }
197}
198
199/// Sleep for `dur`, returning early (`true`) if shutdown is signalled.
200async fn sleep_or_shutdown(rx: &mut watch::Receiver<bool>, dur: Duration) -> bool {
201    tokio::select! {
202        _ = tokio::time::sleep(dur) => *rx.borrow(),
203        _ = rx.changed() => true,
204    }
205}
206
207/// Background loop: poll the control plane for policy and hot-reload it through `build_runtime` +
208/// the arc-swap, exactly like a local file edit. A parse/build failure keeps the current policy.
209pub async fn poll_loop(
210    client: Arc<CpClient>,
211    base: Arc<Config>,
212    runtime: Arc<ArcSwap<Runtime>>,
213    interval: Duration,
214    mut shutdown: watch::Receiver<bool>,
215) {
216    let mut etag: Option<String> = None;
217    info!(?interval, "control-plane policy poller started");
218    loop {
219        match client.pull_policy(etag.as_deref()).await {
220            Ok(PullResult::NotModified) => {}
221            Ok(PullResult::Policy { body, etag: new }) => {
222                match apply_policy(&base, &body, &runtime) {
223                    Ok(()) => {
224                        etag = Some(new);
225                        info!("applied policy from control plane");
226                    }
227                    Err(e) => warn!(
228                        error = format!("{e:#}"),
229                        "rejected control-plane policy; keeping current"
230                    ),
231                }
232            }
233            Err(e) => warn!(
234                error = format!("{e:#}"),
235                "policy pull failed; keeping current"
236            ),
237        }
238        if sleep_or_shutdown(&mut shutdown, interval).await {
239            break;
240        }
241    }
242}
243
244/// Overlay a pushed policy onto the local base config, rebuild the runtime, and swap it in.
245fn apply_policy(base: &Config, body: &str, runtime: &ArcSwap<Runtime>) -> Result<()> {
246    let merged = base.with_policy_from(body)?;
247    let rt = crate::build_runtime(Arc::new(merged))?;
248    runtime.store(Arc::new(rt));
249    Ok(())
250}
251
252/// Background loop: flush the usage accumulator to the control plane each period. On a failed
253/// report the drained delta is added back so billable usage isn't lost.
254pub async fn report_loop(
255    client: Arc<CpClient>,
256    metrics: Arc<Metrics>,
257    interval: Duration,
258    mut shutdown: watch::Receiver<bool>,
259) {
260    info!(?interval, "control-plane usage reporter started");
261    loop {
262        if sleep_or_shutdown(&mut shutdown, interval).await {
263            break;
264        }
265        let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
266        if requests == 0 && ingress_bytes == 0 && egress_bytes == 0 {
267            continue;
268        }
269        let delta = UsageDelta {
270            requests,
271            ingress_bytes,
272            egress_bytes,
273        };
274        if let Err(e) = client.report_usage(&delta).await {
275            warn!(
276                error = format!("{e:#}"),
277                "usage report failed; will retry next period"
278            );
279            metrics.restore_usage(requests, ingress_bytes, egress_bytes);
280        }
281    }
282    // Best-effort final flush on graceful shutdown so billable usage isn't lost.
283    let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
284    if requests > 0 || ingress_bytes > 0 || egress_bytes > 0 {
285        let delta = UsageDelta {
286            requests,
287            ingress_bytes,
288            egress_bytes,
289        };
290        if let Err(e) = client.report_usage(&delta).await {
291            warn!(
292                error = format!("{e:#}"),
293                "final usage report on shutdown failed"
294            );
295        }
296    }
297}
298
299/// Background loop: poll the control plane for the tenant's quota verdict and publish it to the
300/// shared [`QuotaState`] the proxy enforces. A failed poll keeps the last verdict (fail-static), so
301/// a control-plane blip neither suddenly blocks nor suddenly unblocks the edge.
302pub async fn quota_loop(
303    client: Arc<CpClient>,
304    quota: Arc<QuotaState>,
305    interval: Duration,
306    mut shutdown: watch::Receiver<bool>,
307) {
308    info!(?interval, "control-plane quota poller started");
309    loop {
310        match client.pull_quota().await {
311            Ok((over_quota, reset_epoch)) => {
312                quota.apply(over_quota, reset_epoch);
313            }
314            Err(e) => warn!(
315                error = format!("{e:#}"),
316                "quota poll failed; keeping last verdict"
317            ),
318        }
319        if sleep_or_shutdown(&mut shutdown, interval).await {
320            break;
321        }
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::config::ControlPlaneCfg;
329    use std::net::SocketAddr;
330    use std::sync::Mutex as StdMutex;
331
332    use axum::{
333        extract::State,
334        http::{HeaderMap, StatusCode},
335        response::IntoResponse,
336        routing::{get, post},
337        Json, Router,
338    };
339
340    const ETAG: &str = "\"abc123\"";
341
342    #[derive(Clone, Default)]
343    struct Stub {
344        last_usage: Arc<StdMutex<Option<serde_json::Value>>>,
345    }
346
347    async fn policy(headers: HeaderMap) -> axum::response::Response {
348        // Conditional: a matching If-None-Match gets a 304.
349        if headers
350            .get(axum::http::header::IF_NONE_MATCH)
351            .and_then(|v| v.to_str().ok())
352            == Some(ETAG)
353        {
354            return StatusCode::NOT_MODIFIED.into_response();
355        }
356        (
357            [(axum::http::header::ETAG, ETAG)],
358            Json(serde_json::json!({
359                "version": 1, "etag": ETAG, "format": "toml",
360                "body": "[auth]\nmode = \"none\"\n", "updated_at": 0
361            })),
362        )
363            .into_response()
364    }
365
366    async fn usage(State(s): State<Stub>, body: axum::body::Bytes) -> StatusCode {
367        *s.last_usage.lock().unwrap() = serde_json::from_slice(&body).ok();
368        StatusCode::ACCEPTED
369    }
370
371    async fn quota() -> axum::response::Response {
372        // A trimmed QuotaStatus: the edge only reads over_quota + reset_epoch.
373        Json(serde_json::json!({
374            "over_quota": true, "reset_epoch": 1_782_864_000_i64
375        }))
376        .into_response()
377    }
378
379    async fn spawn_stub() -> (SocketAddr, Stub) {
380        let stub = Stub::default();
381        let app = Router::new()
382            .route("/v3/edge/t1/policy", get(policy))
383            .route("/v3/edge/t1/usage", post(usage))
384            .route("/v3/edge/t1/quota", get(quota))
385            .with_state(stub.clone());
386        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
387        let addr = listener.local_addr().unwrap();
388        tokio::spawn(async move {
389            let _ = axum::serve(listener, app).await;
390        });
391        (addr, stub)
392    }
393
394    fn client(addr: SocketAddr) -> Arc<CpClient> {
395        CpClient::from_cfg(&ControlPlaneCfg {
396            enabled: true,
397            url: format!("http://{addr}"),
398            tenant_id: "t1".into(),
399            edge_token: "tok".into(),
400            ..Default::default()
401        })
402        .unwrap()
403        .unwrap()
404    }
405
406    #[test]
407    fn disabled_or_incomplete_config() {
408        // Disabled -> no client.
409        assert!(CpClient::from_cfg(&ControlPlaneCfg::default())
410            .unwrap()
411            .is_none());
412        // Enabled but missing a token -> hard error (don't silently run unmanaged).
413        assert!(CpClient::from_cfg(&ControlPlaneCfg {
414            enabled: true,
415            url: "http://x".into(),
416            tenant_id: "t1".into(),
417            ..Default::default()
418        })
419        .is_err());
420    }
421
422    #[tokio::test]
423    async fn policy_pull_conditional() {
424        let (addr, _) = spawn_stub().await;
425        let c = client(addr);
426        // First pull (no ETag) returns the policy + its ETag.
427        match c.pull_policy(None).await.unwrap() {
428            PullResult::Policy { body, etag } => {
429                assert!(body.contains("mode = \"none\""));
430                assert_eq!(etag, ETAG);
431            }
432            _ => panic!("expected a policy"),
433        }
434        // Re-pull with the ETag -> 304 NotModified.
435        assert!(matches!(
436            c.pull_policy(Some(ETAG)).await.unwrap(),
437            PullResult::NotModified
438        ));
439    }
440
441    #[tokio::test]
442    async fn usage_report_posts_delta() {
443        let (addr, stub) = spawn_stub().await;
444        let c = client(addr);
445        c.report_usage(&UsageDelta {
446            requests: 3,
447            ingress_bytes: 100,
448            egress_bytes: 250,
449        })
450        .await
451        .unwrap();
452        let got = stub.last_usage.lock().unwrap().clone().unwrap();
453        assert_eq!(got["requests"], 3);
454        assert_eq!(got["ingress_bytes"], 100);
455        assert_eq!(got["egress_bytes"], 250);
456    }
457
458    #[tokio::test]
459    async fn quota_pull_returns_verdict() {
460        let (addr, _) = spawn_stub().await;
461        let c = client(addr);
462        let (over, reset) = c.pull_quota().await.unwrap();
463        assert!(over);
464        assert_eq!(reset, 1_782_864_000);
465    }
466
467    #[tokio::test]
468    async fn quota_loop_publishes_to_shared_state() {
469        let (addr, _) = spawn_stub().await;
470        let c = client(addr);
471        let state = Arc::new(QuotaState::default());
472        assert!(!state.blocked(), "starts permissive");
473
474        let (tx, rx) = watch::channel(false);
475        let st = state.clone();
476        let handle =
477            tokio::spawn(async move { quota_loop(c, st, Duration::from_millis(50), rx).await });
478
479        // Give the loop one poll, then assert the verdict landed, and shut it down.
480        for _ in 0..50 {
481            if state.blocked() {
482                break;
483            }
484            tokio::time::sleep(Duration::from_millis(10)).await;
485        }
486        assert!(
487            state.blocked(),
488            "verdict from the control plane should publish"
489        );
490        assert_eq!(state.reset_epoch(), 1_782_864_000);
491        let _ = tx.send(true);
492        let _ = handle.await;
493    }
494}