1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16
17use anyhow::{Context, Result};
18use arc_swap::ArcSwap;
19use bytes::Bytes;
20use serde::{Deserialize, Serialize};
21use tokio::sync::watch;
22use tracing::{info, warn};
23
24use crate::config::{Config, ControlPlaneCfg};
25use crate::metrics::Metrics;
26use crate::proxy::Runtime;
27
28#[derive(Debug, Clone, Copy, Default, Serialize)]
30pub struct UsageDelta {
31 pub requests: u64,
32 pub ingress_bytes: u64,
33 pub egress_bytes: u64,
34}
35
36#[derive(Debug, Deserialize)]
38struct PolicyResp {
39 etag: String,
40 body: String,
41}
42
43#[derive(Debug, Deserialize)]
45struct QuotaResp {
46 over_quota: bool,
47 #[serde(default)]
48 reset_epoch: i64,
49}
50
51#[derive(Debug, Default)]
55pub struct QuotaState {
56 pub over_quota: AtomicBool,
58 pub reset_epoch: AtomicI64,
60}
61
62impl QuotaState {
63 pub fn blocked(&self) -> bool {
65 self.over_quota.load(Ordering::Relaxed)
66 }
67
68 pub fn reset_epoch(&self) -> i64 {
70 self.reset_epoch.load(Ordering::Relaxed)
71 }
72
73 fn apply(&self, over_quota: bool, reset_epoch: i64) {
74 self.over_quota.store(over_quota, Ordering::Relaxed);
75 self.reset_epoch.store(reset_epoch, Ordering::Relaxed);
76 }
77}
78
79pub enum PullResult {
81 NotModified,
83 Policy { body: String, etag: String },
85}
86
87pub struct CpClient {
89 http: reqwest::Client,
90 edge_base: String,
92 token: String,
93}
94
95impl CpClient {
96 pub fn from_cfg(cfg: &ControlPlaneCfg) -> Result<Option<Arc<CpClient>>> {
99 if !cfg.enabled {
100 return Ok(None);
101 }
102 anyhow::ensure!(
103 !cfg.url.is_empty(),
104 "control_plane.url is required when enabled"
105 );
106 anyhow::ensure!(
107 !cfg.tenant_id.is_empty(),
108 "control_plane.tenant_id is required when enabled"
109 );
110 anyhow::ensure!(
111 !cfg.edge_token.is_empty(),
112 "control_plane.edge_token (or EDGEGUARD_CP_EDGE_TOKEN) is required when enabled"
113 );
114 let http = reqwest::Client::builder()
115 .timeout(Duration::from_secs(10))
116 .build()
117 .context("building control-plane HTTP client")?;
118 let edge_base = format!(
119 "{}/v3/edge/{}",
120 cfg.url.trim_end_matches('/'),
121 cfg.tenant_id
122 );
123 Ok(Some(Arc::new(CpClient {
124 http,
125 edge_base,
126 token: cfg.edge_token.clone(),
127 })))
128 }
129
130 pub async fn pull_policy(&self, etag: Option<&str>) -> Result<PullResult> {
132 let mut req = self
133 .http
134 .get(format!("{}/policy", self.edge_base))
135 .bearer_auth(&self.token);
136 if let Some(e) = etag {
137 req = req.header(reqwest::header::IF_NONE_MATCH, e);
138 }
139 let resp = req.send().await.context("pulling policy")?;
140 match resp.status() {
141 reqwest::StatusCode::NOT_MODIFIED => Ok(PullResult::NotModified),
142 s if s.is_success() => {
143 let doc: PolicyResp = resp.json().await.context("parsing policy document")?;
144 Ok(PullResult::Policy {
145 body: doc.body,
146 etag: doc.etag,
147 })
148 }
149 s => anyhow::bail!("control plane returned {s} for policy pull"),
150 }
151 }
152
153 pub async fn report_usage(&self, delta: &UsageDelta) -> Result<()> {
155 self.http
156 .post(format!("{}/usage", self.edge_base))
157 .bearer_auth(&self.token)
158 .json(delta)
159 .send()
160 .await
161 .context("reporting usage")?
162 .error_for_status()
163 .context("control plane rejected usage report")?;
164 Ok(())
165 }
166
167 pub async fn pull_quota(&self) -> Result<(bool, i64)> {
170 let resp = self
171 .http
172 .get(format!("{}/quota", self.edge_base))
173 .bearer_auth(&self.token)
174 .send()
175 .await
176 .context("pulling quota")?
177 .error_for_status()
178 .context("control plane rejected quota poll")?;
179 let q: QuotaResp = resp.json().await.context("parsing quota verdict")?;
180 Ok((q.over_quota, q.reset_epoch))
181 }
182
183 pub async fn forward_csp(&self, raw: &Bytes) {
185 let res = self
186 .http
187 .post(format!("{}/csp-report", self.edge_base))
188 .bearer_auth(&self.token)
189 .header(reqwest::header::CONTENT_TYPE, "application/json")
190 .body(raw.clone())
191 .send()
192 .await;
193 if let Err(e) = res {
194 warn!(error = %e, "forwarding CSP report to control plane failed");
195 }
196 }
197}
198
199async fn sleep_or_shutdown(rx: &mut watch::Receiver<bool>, dur: Duration) -> bool {
201 tokio::select! {
202 _ = tokio::time::sleep(dur) => *rx.borrow(),
203 _ = rx.changed() => true,
204 }
205}
206
207pub async fn poll_loop(
210 client: Arc<CpClient>,
211 base: Arc<Config>,
212 runtime: Arc<ArcSwap<Runtime>>,
213 interval: Duration,
214 mut shutdown: watch::Receiver<bool>,
215) {
216 let mut etag: Option<String> = None;
217 info!(?interval, "control-plane policy poller started");
218 loop {
219 match client.pull_policy(etag.as_deref()).await {
220 Ok(PullResult::NotModified) => {}
221 Ok(PullResult::Policy { body, etag: new }) => {
222 match apply_policy(&base, &body, &runtime) {
223 Ok(()) => {
224 etag = Some(new);
225 info!("applied policy from control plane");
226 }
227 Err(e) => warn!(
228 error = format!("{e:#}"),
229 "rejected control-plane policy; keeping current"
230 ),
231 }
232 }
233 Err(e) => warn!(
234 error = format!("{e:#}"),
235 "policy pull failed; keeping current"
236 ),
237 }
238 if sleep_or_shutdown(&mut shutdown, interval).await {
239 break;
240 }
241 }
242}
243
244fn apply_policy(base: &Config, body: &str, runtime: &ArcSwap<Runtime>) -> Result<()> {
246 let merged = base.with_policy_from(body)?;
247 let rt = crate::build_runtime(Arc::new(merged))?;
248 runtime.store(Arc::new(rt));
249 Ok(())
250}
251
252pub async fn report_loop(
255 client: Arc<CpClient>,
256 metrics: Arc<Metrics>,
257 interval: Duration,
258 mut shutdown: watch::Receiver<bool>,
259) {
260 info!(?interval, "control-plane usage reporter started");
261 loop {
262 if sleep_or_shutdown(&mut shutdown, interval).await {
263 break;
264 }
265 let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
266 if requests == 0 && ingress_bytes == 0 && egress_bytes == 0 {
267 continue;
268 }
269 let delta = UsageDelta {
270 requests,
271 ingress_bytes,
272 egress_bytes,
273 };
274 if let Err(e) = client.report_usage(&delta).await {
275 warn!(
276 error = format!("{e:#}"),
277 "usage report failed; will retry next period"
278 );
279 metrics.restore_usage(requests, ingress_bytes, egress_bytes);
280 }
281 }
282 let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
284 if requests > 0 || ingress_bytes > 0 || egress_bytes > 0 {
285 let delta = UsageDelta {
286 requests,
287 ingress_bytes,
288 egress_bytes,
289 };
290 if let Err(e) = client.report_usage(&delta).await {
291 warn!(
292 error = format!("{e:#}"),
293 "final usage report on shutdown failed"
294 );
295 }
296 }
297}
298
299pub async fn quota_loop(
303 client: Arc<CpClient>,
304 quota: Arc<QuotaState>,
305 interval: Duration,
306 mut shutdown: watch::Receiver<bool>,
307) {
308 info!(?interval, "control-plane quota poller started");
309 loop {
310 match client.pull_quota().await {
311 Ok((over_quota, reset_epoch)) => {
312 quota.apply(over_quota, reset_epoch);
313 }
314 Err(e) => warn!(
315 error = format!("{e:#}"),
316 "quota poll failed; keeping last verdict"
317 ),
318 }
319 if sleep_or_shutdown(&mut shutdown, interval).await {
320 break;
321 }
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::config::ControlPlaneCfg;
329 use std::net::SocketAddr;
330 use std::sync::Mutex as StdMutex;
331
332 use axum::{
333 extract::State,
334 http::{HeaderMap, StatusCode},
335 response::IntoResponse,
336 routing::{get, post},
337 Json, Router,
338 };
339
340 const ETAG: &str = "\"abc123\"";
341
342 #[derive(Clone, Default)]
343 struct Stub {
344 last_usage: Arc<StdMutex<Option<serde_json::Value>>>,
345 }
346
347 async fn policy(headers: HeaderMap) -> axum::response::Response {
348 if headers
350 .get(axum::http::header::IF_NONE_MATCH)
351 .and_then(|v| v.to_str().ok())
352 == Some(ETAG)
353 {
354 return StatusCode::NOT_MODIFIED.into_response();
355 }
356 (
357 [(axum::http::header::ETAG, ETAG)],
358 Json(serde_json::json!({
359 "version": 1, "etag": ETAG, "format": "toml",
360 "body": "[auth]\nmode = \"none\"\n", "updated_at": 0
361 })),
362 )
363 .into_response()
364 }
365
366 async fn usage(State(s): State<Stub>, body: axum::body::Bytes) -> StatusCode {
367 *s.last_usage.lock().unwrap() = serde_json::from_slice(&body).ok();
368 StatusCode::ACCEPTED
369 }
370
371 async fn quota() -> axum::response::Response {
372 Json(serde_json::json!({
374 "over_quota": true, "reset_epoch": 1_782_864_000_i64
375 }))
376 .into_response()
377 }
378
379 async fn spawn_stub() -> (SocketAddr, Stub) {
380 let stub = Stub::default();
381 let app = Router::new()
382 .route("/v3/edge/t1/policy", get(policy))
383 .route("/v3/edge/t1/usage", post(usage))
384 .route("/v3/edge/t1/quota", get(quota))
385 .with_state(stub.clone());
386 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
387 let addr = listener.local_addr().unwrap();
388 tokio::spawn(async move {
389 let _ = axum::serve(listener, app).await;
390 });
391 (addr, stub)
392 }
393
394 fn client(addr: SocketAddr) -> Arc<CpClient> {
395 CpClient::from_cfg(&ControlPlaneCfg {
396 enabled: true,
397 url: format!("http://{addr}"),
398 tenant_id: "t1".into(),
399 edge_token: "tok".into(),
400 ..Default::default()
401 })
402 .unwrap()
403 .unwrap()
404 }
405
406 #[test]
407 fn disabled_or_incomplete_config() {
408 assert!(CpClient::from_cfg(&ControlPlaneCfg::default())
410 .unwrap()
411 .is_none());
412 assert!(CpClient::from_cfg(&ControlPlaneCfg {
414 enabled: true,
415 url: "http://x".into(),
416 tenant_id: "t1".into(),
417 ..Default::default()
418 })
419 .is_err());
420 }
421
422 #[tokio::test]
423 async fn policy_pull_conditional() {
424 let (addr, _) = spawn_stub().await;
425 let c = client(addr);
426 match c.pull_policy(None).await.unwrap() {
428 PullResult::Policy { body, etag } => {
429 assert!(body.contains("mode = \"none\""));
430 assert_eq!(etag, ETAG);
431 }
432 _ => panic!("expected a policy"),
433 }
434 assert!(matches!(
436 c.pull_policy(Some(ETAG)).await.unwrap(),
437 PullResult::NotModified
438 ));
439 }
440
441 #[tokio::test]
442 async fn usage_report_posts_delta() {
443 let (addr, stub) = spawn_stub().await;
444 let c = client(addr);
445 c.report_usage(&UsageDelta {
446 requests: 3,
447 ingress_bytes: 100,
448 egress_bytes: 250,
449 })
450 .await
451 .unwrap();
452 let got = stub.last_usage.lock().unwrap().clone().unwrap();
453 assert_eq!(got["requests"], 3);
454 assert_eq!(got["ingress_bytes"], 100);
455 assert_eq!(got["egress_bytes"], 250);
456 }
457
458 #[tokio::test]
459 async fn quota_pull_returns_verdict() {
460 let (addr, _) = spawn_stub().await;
461 let c = client(addr);
462 let (over, reset) = c.pull_quota().await.unwrap();
463 assert!(over);
464 assert_eq!(reset, 1_782_864_000);
465 }
466
467 #[tokio::test]
468 async fn quota_loop_publishes_to_shared_state() {
469 let (addr, _) = spawn_stub().await;
470 let c = client(addr);
471 let state = Arc::new(QuotaState::default());
472 assert!(!state.blocked(), "starts permissive");
473
474 let (tx, rx) = watch::channel(false);
475 let st = state.clone();
476 let handle =
477 tokio::spawn(async move { quota_loop(c, st, Duration::from_millis(50), rx).await });
478
479 for _ in 0..50 {
481 if state.blocked() {
482 break;
483 }
484 tokio::time::sleep(Duration::from_millis(10)).await;
485 }
486 assert!(
487 state.blocked(),
488 "verdict from the control plane should publish"
489 );
490 assert_eq!(state.reset_epoch(), 1_782_864_000);
491 let _ = tx.send(true);
492 let _ = handle.await;
493 }
494}