Skip to main content

edgeguard/
cp.rs

1//! Managed-mode client: talk to a remote control plane.
2//!
3//! Off by default. When `[control_plane]` is configured, the edge:
4//!   * **pulls its policy** (conditional `GET`, ETag/`304`) and hot-reloads it through the same
5//!     `build_runtime` + arc-swap path a local file edit uses;
6//!   * **reports usage** (requests + ingress/egress bytes) as periodic deltas;
7//!   * **forwards CSP reports** it receives to the control plane.
8//!
9//! This is a generic "pull config / report usage to a URL" client — it carries no control-plane
10//! logic; it just speaks the control plane's edge HTTP API with a per-tenant bearer token. Built
11//! on the same `reqwest` + rustls stack as the JWKS fetcher (`auth.rs`).
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{Context, Result};
17use arc_swap::ArcSwap;
18use bytes::Bytes;
19use serde::{Deserialize, Serialize};
20use tokio::sync::watch;
21use tracing::{info, warn};
22
23use crate::config::{Config, ControlPlaneCfg};
24use crate::metrics::Metrics;
25use crate::proxy::Runtime;
26
27/// A usage delta reported to the control plane (matches its `/v3/edge/{id}/usage` wire shape).
28#[derive(Debug, Clone, Copy, Default, Serialize)]
29pub struct UsageDelta {
30    pub requests: u64,
31    pub ingress_bytes: u64,
32    pub egress_bytes: u64,
33}
34
35/// The subset of the control plane's `PolicyDocument` the edge needs.
36#[derive(Debug, Deserialize)]
37struct PolicyResp {
38    etag: String,
39    body: String,
40}
41
42/// Outcome of a conditional policy pull.
43pub enum PullResult {
44    /// The edge's ETag still matched — nothing changed.
45    NotModified,
46    /// A new policy: its opaque TOML body and the new ETag.
47    Policy { body: String, etag: String },
48}
49
50/// Outbound client to a control plane's per-tenant edge API.
51pub struct CpClient {
52    http: reqwest::Client,
53    /// `{base}/v3/edge/{tenant}` prefix, already trimmed.
54    edge_base: String,
55    token: String,
56}
57
58impl CpClient {
59    /// Build the client if managed mode is enabled and configured; otherwise `None`. Fails fast on
60    /// an enabled-but-incomplete config so a misconfigured edge doesn't silently run unmanaged.
61    pub fn from_cfg(cfg: &ControlPlaneCfg) -> Result<Option<Arc<CpClient>>> {
62        if !cfg.enabled {
63            return Ok(None);
64        }
65        anyhow::ensure!(
66            !cfg.url.is_empty(),
67            "control_plane.url is required when enabled"
68        );
69        anyhow::ensure!(
70            !cfg.tenant_id.is_empty(),
71            "control_plane.tenant_id is required when enabled"
72        );
73        anyhow::ensure!(
74            !cfg.edge_token.is_empty(),
75            "control_plane.edge_token (or EDGEGUARD_CP_EDGE_TOKEN) is required when enabled"
76        );
77        let http = reqwest::Client::builder()
78            .timeout(Duration::from_secs(10))
79            .build()
80            .context("building control-plane HTTP client")?;
81        let edge_base = format!(
82            "{}/v3/edge/{}",
83            cfg.url.trim_end_matches('/'),
84            cfg.tenant_id
85        );
86        Ok(Some(Arc::new(CpClient {
87            http,
88            edge_base,
89            token: cfg.edge_token.clone(),
90        })))
91    }
92
93    /// Conditional policy pull. `200` → `Policy`; `304` → `NotModified`; other statuses → `Err`.
94    pub async fn pull_policy(&self, etag: Option<&str>) -> Result<PullResult> {
95        let mut req = self
96            .http
97            .get(format!("{}/policy", self.edge_base))
98            .bearer_auth(&self.token);
99        if let Some(e) = etag {
100            req = req.header(reqwest::header::IF_NONE_MATCH, e);
101        }
102        let resp = req.send().await.context("pulling policy")?;
103        match resp.status() {
104            reqwest::StatusCode::NOT_MODIFIED => Ok(PullResult::NotModified),
105            s if s.is_success() => {
106                let doc: PolicyResp = resp.json().await.context("parsing policy document")?;
107                Ok(PullResult::Policy {
108                    body: doc.body,
109                    etag: doc.etag,
110                })
111            }
112            s => anyhow::bail!("control plane returned {s} for policy pull"),
113        }
114    }
115
116    /// Report a usage delta.
117    pub async fn report_usage(&self, delta: &UsageDelta) -> Result<()> {
118        self.http
119            .post(format!("{}/usage", self.edge_base))
120            .bearer_auth(&self.token)
121            .json(delta)
122            .send()
123            .await
124            .context("reporting usage")?
125            .error_for_status()
126            .context("control plane rejected usage report")?;
127        Ok(())
128    }
129
130    /// Forward a raw CSP report body (best-effort; errors are logged, never surfaced).
131    pub async fn forward_csp(&self, raw: &Bytes) {
132        let res = self
133            .http
134            .post(format!("{}/csp-report", self.edge_base))
135            .bearer_auth(&self.token)
136            .header(reqwest::header::CONTENT_TYPE, "application/json")
137            .body(raw.clone())
138            .send()
139            .await;
140        if let Err(e) = res {
141            warn!(error = %e, "forwarding CSP report to control plane failed");
142        }
143    }
144}
145
146/// Sleep for `dur`, returning early (`true`) if shutdown is signalled.
147async fn sleep_or_shutdown(rx: &mut watch::Receiver<bool>, dur: Duration) -> bool {
148    tokio::select! {
149        _ = tokio::time::sleep(dur) => *rx.borrow(),
150        _ = rx.changed() => true,
151    }
152}
153
154/// Background loop: poll the control plane for policy and hot-reload it through `build_runtime` +
155/// the arc-swap, exactly like a local file edit. A parse/build failure keeps the current policy.
156pub async fn poll_loop(
157    client: Arc<CpClient>,
158    base: Arc<Config>,
159    runtime: Arc<ArcSwap<Runtime>>,
160    interval: Duration,
161    mut shutdown: watch::Receiver<bool>,
162) {
163    let mut etag: Option<String> = None;
164    info!(?interval, "control-plane policy poller started");
165    loop {
166        match client.pull_policy(etag.as_deref()).await {
167            Ok(PullResult::NotModified) => {}
168            Ok(PullResult::Policy { body, etag: new }) => {
169                match apply_policy(&base, &body, &runtime) {
170                    Ok(()) => {
171                        etag = Some(new);
172                        info!("applied policy from control plane");
173                    }
174                    Err(e) => warn!(
175                        error = format!("{e:#}"),
176                        "rejected control-plane policy; keeping current"
177                    ),
178                }
179            }
180            Err(e) => warn!(
181                error = format!("{e:#}"),
182                "policy pull failed; keeping current"
183            ),
184        }
185        if sleep_or_shutdown(&mut shutdown, interval).await {
186            break;
187        }
188    }
189}
190
191/// Overlay a pushed policy onto the local base config, rebuild the runtime, and swap it in.
192fn apply_policy(base: &Config, body: &str, runtime: &ArcSwap<Runtime>) -> Result<()> {
193    let merged = base.with_policy_from(body)?;
194    let rt = crate::build_runtime(Arc::new(merged))?;
195    runtime.store(Arc::new(rt));
196    Ok(())
197}
198
199/// Background loop: flush the usage accumulator to the control plane each period. On a failed
200/// report the drained delta is added back so billable usage isn't lost.
201pub async fn report_loop(
202    client: Arc<CpClient>,
203    metrics: Arc<Metrics>,
204    interval: Duration,
205    mut shutdown: watch::Receiver<bool>,
206) {
207    info!(?interval, "control-plane usage reporter started");
208    loop {
209        if sleep_or_shutdown(&mut shutdown, interval).await {
210            break;
211        }
212        let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
213        if requests == 0 && ingress_bytes == 0 && egress_bytes == 0 {
214            continue;
215        }
216        let delta = UsageDelta {
217            requests,
218            ingress_bytes,
219            egress_bytes,
220        };
221        if let Err(e) = client.report_usage(&delta).await {
222            warn!(
223                error = format!("{e:#}"),
224                "usage report failed; will retry next period"
225            );
226            metrics.restore_usage(requests, ingress_bytes, egress_bytes);
227        }
228    }
229    // Best-effort final flush on graceful shutdown so billable usage isn't lost.
230    let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
231    if requests > 0 || ingress_bytes > 0 || egress_bytes > 0 {
232        let delta = UsageDelta { requests, ingress_bytes, egress_bytes };
233        if let Err(e) = client.report_usage(&delta).await {
234            warn!(error = format!("{e:#}"), "final usage report on shutdown failed");
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::config::ControlPlaneCfg;
243    use std::net::SocketAddr;
244    use std::sync::Mutex as StdMutex;
245
246    use axum::{
247        extract::State,
248        http::{HeaderMap, StatusCode},
249        response::IntoResponse,
250        routing::{get, post},
251        Json, Router,
252    };
253
254    const ETAG: &str = "\"abc123\"";
255
256    #[derive(Clone, Default)]
257    struct Stub {
258        last_usage: Arc<StdMutex<Option<serde_json::Value>>>,
259    }
260
261    async fn policy(headers: HeaderMap) -> axum::response::Response {
262        // Conditional: a matching If-None-Match gets a 304.
263        if headers
264            .get(axum::http::header::IF_NONE_MATCH)
265            .and_then(|v| v.to_str().ok())
266            == Some(ETAG)
267        {
268            return StatusCode::NOT_MODIFIED.into_response();
269        }
270        (
271            [(axum::http::header::ETAG, ETAG)],
272            Json(serde_json::json!({
273                "version": 1, "etag": ETAG, "format": "toml",
274                "body": "[auth]\nmode = \"none\"\n", "updated_at": 0
275            })),
276        )
277            .into_response()
278    }
279
280    async fn usage(State(s): State<Stub>, body: axum::body::Bytes) -> StatusCode {
281        *s.last_usage.lock().unwrap() = serde_json::from_slice(&body).ok();
282        StatusCode::ACCEPTED
283    }
284
285    async fn spawn_stub() -> (SocketAddr, Stub) {
286        let stub = Stub::default();
287        let app = Router::new()
288            .route("/v3/edge/t1/policy", get(policy))
289            .route("/v3/edge/t1/usage", post(usage))
290            .with_state(stub.clone());
291        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
292        let addr = listener.local_addr().unwrap();
293        tokio::spawn(async move {
294            let _ = axum::serve(listener, app).await;
295        });
296        (addr, stub)
297    }
298
299    fn client(addr: SocketAddr) -> Arc<CpClient> {
300        CpClient::from_cfg(&ControlPlaneCfg {
301            enabled: true,
302            url: format!("http://{addr}"),
303            tenant_id: "t1".into(),
304            edge_token: "tok".into(),
305            ..Default::default()
306        })
307        .unwrap()
308        .unwrap()
309    }
310
311    #[test]
312    fn disabled_or_incomplete_config() {
313        // Disabled -> no client.
314        assert!(CpClient::from_cfg(&ControlPlaneCfg::default())
315            .unwrap()
316            .is_none());
317        // Enabled but missing a token -> hard error (don't silently run unmanaged).
318        assert!(CpClient::from_cfg(&ControlPlaneCfg {
319            enabled: true,
320            url: "http://x".into(),
321            tenant_id: "t1".into(),
322            ..Default::default()
323        })
324        .is_err());
325    }
326
327    #[tokio::test]
328    async fn policy_pull_conditional() {
329        let (addr, _) = spawn_stub().await;
330        let c = client(addr);
331        // First pull (no ETag) returns the policy + its ETag.
332        match c.pull_policy(None).await.unwrap() {
333            PullResult::Policy { body, etag } => {
334                assert!(body.contains("mode = \"none\""));
335                assert_eq!(etag, ETAG);
336            }
337            _ => panic!("expected a policy"),
338        }
339        // Re-pull with the ETag -> 304 NotModified.
340        assert!(matches!(
341            c.pull_policy(Some(ETAG)).await.unwrap(),
342            PullResult::NotModified
343        ));
344    }
345
346    #[tokio::test]
347    async fn usage_report_posts_delta() {
348        let (addr, stub) = spawn_stub().await;
349        let c = client(addr);
350        c.report_usage(&UsageDelta {
351            requests: 3,
352            ingress_bytes: 100,
353            egress_bytes: 250,
354        })
355        .await
356        .unwrap();
357        let got = stub.last_usage.lock().unwrap().clone().unwrap();
358        assert_eq!(got["requests"], 3);
359        assert_eq!(got["ingress_bytes"], 100);
360        assert_eq!(got["egress_bytes"], 250);
361    }
362}