1use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{Context, Result};
17use arc_swap::ArcSwap;
18use bytes::Bytes;
19use serde::{Deserialize, Serialize};
20use tokio::sync::watch;
21use tracing::{info, warn};
22
23use crate::config::{Config, ControlPlaneCfg};
24use crate::metrics::Metrics;
25use crate::proxy::Runtime;
26
27#[derive(Debug, Clone, Copy, Default, Serialize)]
29pub struct UsageDelta {
30 pub requests: u64,
31 pub ingress_bytes: u64,
32 pub egress_bytes: u64,
33}
34
35#[derive(Debug, Deserialize)]
37struct PolicyResp {
38 etag: String,
39 body: String,
40}
41
42pub enum PullResult {
44 NotModified,
46 Policy { body: String, etag: String },
48}
49
50pub struct CpClient {
52 http: reqwest::Client,
53 edge_base: String,
55 token: String,
56}
57
58impl CpClient {
59 pub fn from_cfg(cfg: &ControlPlaneCfg) -> Result<Option<Arc<CpClient>>> {
62 if !cfg.enabled {
63 return Ok(None);
64 }
65 anyhow::ensure!(
66 !cfg.url.is_empty(),
67 "control_plane.url is required when enabled"
68 );
69 anyhow::ensure!(
70 !cfg.tenant_id.is_empty(),
71 "control_plane.tenant_id is required when enabled"
72 );
73 anyhow::ensure!(
74 !cfg.edge_token.is_empty(),
75 "control_plane.edge_token (or EDGEGUARD_CP_EDGE_TOKEN) is required when enabled"
76 );
77 let http = reqwest::Client::builder()
78 .timeout(Duration::from_secs(10))
79 .build()
80 .context("building control-plane HTTP client")?;
81 let edge_base = format!(
82 "{}/v3/edge/{}",
83 cfg.url.trim_end_matches('/'),
84 cfg.tenant_id
85 );
86 Ok(Some(Arc::new(CpClient {
87 http,
88 edge_base,
89 token: cfg.edge_token.clone(),
90 })))
91 }
92
93 pub async fn pull_policy(&self, etag: Option<&str>) -> Result<PullResult> {
95 let mut req = self
96 .http
97 .get(format!("{}/policy", self.edge_base))
98 .bearer_auth(&self.token);
99 if let Some(e) = etag {
100 req = req.header(reqwest::header::IF_NONE_MATCH, e);
101 }
102 let resp = req.send().await.context("pulling policy")?;
103 match resp.status() {
104 reqwest::StatusCode::NOT_MODIFIED => Ok(PullResult::NotModified),
105 s if s.is_success() => {
106 let doc: PolicyResp = resp.json().await.context("parsing policy document")?;
107 Ok(PullResult::Policy {
108 body: doc.body,
109 etag: doc.etag,
110 })
111 }
112 s => anyhow::bail!("control plane returned {s} for policy pull"),
113 }
114 }
115
116 pub async fn report_usage(&self, delta: &UsageDelta) -> Result<()> {
118 self.http
119 .post(format!("{}/usage", self.edge_base))
120 .bearer_auth(&self.token)
121 .json(delta)
122 .send()
123 .await
124 .context("reporting usage")?
125 .error_for_status()
126 .context("control plane rejected usage report")?;
127 Ok(())
128 }
129
130 pub async fn forward_csp(&self, raw: &Bytes) {
132 let res = self
133 .http
134 .post(format!("{}/csp-report", self.edge_base))
135 .bearer_auth(&self.token)
136 .header(reqwest::header::CONTENT_TYPE, "application/json")
137 .body(raw.clone())
138 .send()
139 .await;
140 if let Err(e) = res {
141 warn!(error = %e, "forwarding CSP report to control plane failed");
142 }
143 }
144}
145
146async fn sleep_or_shutdown(rx: &mut watch::Receiver<bool>, dur: Duration) -> bool {
148 tokio::select! {
149 _ = tokio::time::sleep(dur) => *rx.borrow(),
150 _ = rx.changed() => true,
151 }
152}
153
154pub async fn poll_loop(
157 client: Arc<CpClient>,
158 base: Arc<Config>,
159 runtime: Arc<ArcSwap<Runtime>>,
160 interval: Duration,
161 mut shutdown: watch::Receiver<bool>,
162) {
163 let mut etag: Option<String> = None;
164 info!(?interval, "control-plane policy poller started");
165 loop {
166 match client.pull_policy(etag.as_deref()).await {
167 Ok(PullResult::NotModified) => {}
168 Ok(PullResult::Policy { body, etag: new }) => {
169 match apply_policy(&base, &body, &runtime) {
170 Ok(()) => {
171 etag = Some(new);
172 info!("applied policy from control plane");
173 }
174 Err(e) => warn!(
175 error = format!("{e:#}"),
176 "rejected control-plane policy; keeping current"
177 ),
178 }
179 }
180 Err(e) => warn!(
181 error = format!("{e:#}"),
182 "policy pull failed; keeping current"
183 ),
184 }
185 if sleep_or_shutdown(&mut shutdown, interval).await {
186 break;
187 }
188 }
189}
190
191fn apply_policy(base: &Config, body: &str, runtime: &ArcSwap<Runtime>) -> Result<()> {
193 let merged = base.with_policy_from(body)?;
194 let rt = crate::build_runtime(Arc::new(merged))?;
195 runtime.store(Arc::new(rt));
196 Ok(())
197}
198
199pub async fn report_loop(
202 client: Arc<CpClient>,
203 metrics: Arc<Metrics>,
204 interval: Duration,
205 mut shutdown: watch::Receiver<bool>,
206) {
207 info!(?interval, "control-plane usage reporter started");
208 loop {
209 if sleep_or_shutdown(&mut shutdown, interval).await {
210 break;
211 }
212 let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
213 if requests == 0 && ingress_bytes == 0 && egress_bytes == 0 {
214 continue;
215 }
216 let delta = UsageDelta {
217 requests,
218 ingress_bytes,
219 egress_bytes,
220 };
221 if let Err(e) = client.report_usage(&delta).await {
222 warn!(
223 error = format!("{e:#}"),
224 "usage report failed; will retry next period"
225 );
226 metrics.restore_usage(requests, ingress_bytes, egress_bytes);
227 }
228 }
229 let (requests, ingress_bytes, egress_bytes) = metrics.drain_usage();
231 if requests > 0 || ingress_bytes > 0 || egress_bytes > 0 {
232 let delta = UsageDelta { requests, ingress_bytes, egress_bytes };
233 if let Err(e) = client.report_usage(&delta).await {
234 warn!(error = format!("{e:#}"), "final usage report on shutdown failed");
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::config::ControlPlaneCfg;
243 use std::net::SocketAddr;
244 use std::sync::Mutex as StdMutex;
245
246 use axum::{
247 extract::State,
248 http::{HeaderMap, StatusCode},
249 response::IntoResponse,
250 routing::{get, post},
251 Json, Router,
252 };
253
254 const ETAG: &str = "\"abc123\"";
255
256 #[derive(Clone, Default)]
257 struct Stub {
258 last_usage: Arc<StdMutex<Option<serde_json::Value>>>,
259 }
260
261 async fn policy(headers: HeaderMap) -> axum::response::Response {
262 if headers
264 .get(axum::http::header::IF_NONE_MATCH)
265 .and_then(|v| v.to_str().ok())
266 == Some(ETAG)
267 {
268 return StatusCode::NOT_MODIFIED.into_response();
269 }
270 (
271 [(axum::http::header::ETAG, ETAG)],
272 Json(serde_json::json!({
273 "version": 1, "etag": ETAG, "format": "toml",
274 "body": "[auth]\nmode = \"none\"\n", "updated_at": 0
275 })),
276 )
277 .into_response()
278 }
279
280 async fn usage(State(s): State<Stub>, body: axum::body::Bytes) -> StatusCode {
281 *s.last_usage.lock().unwrap() = serde_json::from_slice(&body).ok();
282 StatusCode::ACCEPTED
283 }
284
285 async fn spawn_stub() -> (SocketAddr, Stub) {
286 let stub = Stub::default();
287 let app = Router::new()
288 .route("/v3/edge/t1/policy", get(policy))
289 .route("/v3/edge/t1/usage", post(usage))
290 .with_state(stub.clone());
291 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
292 let addr = listener.local_addr().unwrap();
293 tokio::spawn(async move {
294 let _ = axum::serve(listener, app).await;
295 });
296 (addr, stub)
297 }
298
299 fn client(addr: SocketAddr) -> Arc<CpClient> {
300 CpClient::from_cfg(&ControlPlaneCfg {
301 enabled: true,
302 url: format!("http://{addr}"),
303 tenant_id: "t1".into(),
304 edge_token: "tok".into(),
305 ..Default::default()
306 })
307 .unwrap()
308 .unwrap()
309 }
310
311 #[test]
312 fn disabled_or_incomplete_config() {
313 assert!(CpClient::from_cfg(&ControlPlaneCfg::default())
315 .unwrap()
316 .is_none());
317 assert!(CpClient::from_cfg(&ControlPlaneCfg {
319 enabled: true,
320 url: "http://x".into(),
321 tenant_id: "t1".into(),
322 ..Default::default()
323 })
324 .is_err());
325 }
326
327 #[tokio::test]
328 async fn policy_pull_conditional() {
329 let (addr, _) = spawn_stub().await;
330 let c = client(addr);
331 match c.pull_policy(None).await.unwrap() {
333 PullResult::Policy { body, etag } => {
334 assert!(body.contains("mode = \"none\""));
335 assert_eq!(etag, ETAG);
336 }
337 _ => panic!("expected a policy"),
338 }
339 assert!(matches!(
341 c.pull_policy(Some(ETAG)).await.unwrap(),
342 PullResult::NotModified
343 ));
344 }
345
346 #[tokio::test]
347 async fn usage_report_posts_delta() {
348 let (addr, stub) = spawn_stub().await;
349 let c = client(addr);
350 c.report_usage(&UsageDelta {
351 requests: 3,
352 ingress_bytes: 100,
353 egress_bytes: 250,
354 })
355 .await
356 .unwrap();
357 let got = stub.last_usage.lock().unwrap().clone().unwrap();
358 assert_eq!(got["requests"], 3);
359 assert_eq!(got["ingress_bytes"], 100);
360 assert_eq!(got["egress_bytes"], 250);
361 }
362}