zeph_gateway/server.rs
1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::SocketAddr;
5use std::time::Instant;
6
7use tokio::sync::{mpsc, watch};
8
9use crate::error::GatewayError;
10use crate::router::build_router;
11
12/// Shared state threaded through every axum handler.
13///
14/// Cloned cheaply for each request because all fields are either `Clone` or
15/// wrapped in `Arc`-backed primitives.
16#[derive(Clone)]
17pub(crate) struct AppState {
18 /// Channel used to forward sanitised webhook messages to the agent.
19 pub webhook_tx: mpsc::Sender<String>,
20 /// Monotonic timestamp recorded when the server started, used by `/health`.
21 pub started_at: Instant,
22}
23
24/// HTTP gateway server with bearer-auth, rate limiting, and body-size enforcement.
25///
26/// Build the server with [`GatewayServer::new`], apply optional configuration via
27/// the builder methods, then drive it with [`GatewayServer::serve`].
28///
29/// # Defaults
30///
31/// | Setting | Default |
32/// |---|---|
33/// | Bearer auth | disabled (open) |
34/// | Rate limit | 120 requests / 60 s per IP |
35/// | Max body size | 1 MiB (1 048 576 bytes) |
36///
37/// # Example
38///
39/// ```no_run
40/// use tokio::sync::{mpsc, watch};
41/// use zeph_gateway::GatewayServer;
42///
43/// #[tokio::main]
44/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
45/// let (tx, _rx) = mpsc::channel::<String>(64);
46/// let (_stx, srx) = watch::channel(false);
47///
48/// GatewayServer::new("127.0.0.1", 9000, tx, srx)
49/// .with_auth(Some("hunter2".into()))
50/// .with_rate_limit(30)
51/// .with_max_body_size(512 * 1024)
52/// .serve()
53/// .await?;
54///
55/// Ok(())
56/// }
57/// ```
58pub struct GatewayServer {
59 addr: SocketAddr,
60 auth_token: Option<String>,
61 rate_limit: u32,
62 max_body_size: usize,
63 webhook_tx: mpsc::Sender<String>,
64 shutdown_rx: watch::Receiver<bool>,
65 /// Prometheus metrics registry and endpoint path (feature-gated).
66 #[cfg(feature = "prometheus")]
67 metrics_registry: Option<(
68 std::sync::Arc<prometheus_client::registry::Registry>,
69 String,
70 )>,
71}
72
73impl GatewayServer {
74 /// Create a new gateway server.
75 ///
76 /// `bind` is parsed as an IP address string (e.g. `"127.0.0.1"` or `"0.0.0.0"`).
77 /// If parsing fails, the server falls back to `127.0.0.1:<port>` and emits a warning.
78 ///
79 /// `webhook_tx` receives every valid, sanitised webhook message as a formatted
80 /// `"[sender@channel] body"` string.
81 ///
82 /// `shutdown_rx` is a [`watch::Receiver<bool>`] that signals graceful shutdown
83 /// when its value transitions to `true`. Sending `true` causes the server to
84 /// stop accepting new connections and drain in-flight requests.
85 ///
86 /// # Panics
87 ///
88 /// Does not panic. Invalid `bind` values fall back to `127.0.0.1` with a log warning.
89 #[must_use]
90 pub fn new(
91 bind: &str,
92 port: u16,
93 webhook_tx: mpsc::Sender<String>,
94 shutdown_rx: watch::Receiver<bool>,
95 ) -> Self {
96 let addr: SocketAddr = format!("{bind}:{port}").parse().unwrap_or_else(|e| {
97 tracing::warn!("invalid bind '{bind}': {e}, falling back to 127.0.0.1:{port}");
98 SocketAddr::from(([127, 0, 0, 1], port))
99 });
100
101 if bind == "0.0.0.0" {
102 tracing::warn!("gateway binding to 0.0.0.0 — ensure this is intended for production");
103 }
104
105 Self {
106 addr,
107 auth_token: None,
108 rate_limit: 120,
109 max_body_size: 1_048_576,
110 webhook_tx,
111 shutdown_rx,
112 #[cfg(feature = "prometheus")]
113 metrics_registry: None,
114 }
115 }
116
117 /// Set the bearer token required on `POST /webhook` requests.
118 ///
119 /// When `token` is `Some`, every request to `/webhook` must carry an
120 /// `Authorization: Bearer <token>` header. The comparison is performed
121 /// in constant time (BLAKE3 + `subtle::ConstantTimeEq`) to prevent
122 /// timing-oracle attacks.
123 ///
124 /// When `token` is `None`, bearer authentication is disabled and a warning
125 /// is logged at startup.
126 ///
127 /// # Example
128 ///
129 /// ```
130 /// use tokio::sync::{mpsc, watch};
131 /// use zeph_gateway::GatewayServer;
132 ///
133 /// let (tx, _rx) = mpsc::channel::<String>(1);
134 /// let (_stx, srx) = watch::channel(false);
135 ///
136 /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
137 /// .with_auth(Some("super-secret".into()));
138 /// ```
139 #[must_use]
140 pub fn with_auth(mut self, token: Option<String>) -> Self {
141 self.auth_token = token;
142 self
143 }
144
145 /// Set the per-IP rate limit for `POST /webhook`.
146 ///
147 /// `limit` is the maximum number of requests allowed per remote IP in a
148 /// 60-second fixed window. Setting `limit` to `0` disables rate limiting.
149 ///
150 /// # Example
151 ///
152 /// ```
153 /// use tokio::sync::{mpsc, watch};
154 /// use zeph_gateway::GatewayServer;
155 ///
156 /// let (tx, _rx) = mpsc::channel::<String>(1);
157 /// let (_stx, srx) = watch::channel(false);
158 ///
159 /// // Allow at most 30 webhook posts per minute per IP.
160 /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
161 /// .with_rate_limit(30);
162 /// ```
163 #[must_use]
164 pub fn with_rate_limit(mut self, limit: u32) -> Self {
165 self.rate_limit = limit;
166 self
167 }
168
169 /// Set the maximum allowed request body size in bytes.
170 ///
171 /// Requests whose body exceeds this size are rejected with `413 Content Too Large`
172 /// before any handler is invoked. The default is 1 MiB (1 048 576 bytes).
173 ///
174 /// # Example
175 ///
176 /// ```
177 /// use tokio::sync::{mpsc, watch};
178 /// use zeph_gateway::GatewayServer;
179 ///
180 /// let (tx, _rx) = mpsc::channel::<String>(1);
181 /// let (_stx, srx) = watch::channel(false);
182 ///
183 /// // Restrict bodies to 64 KiB.
184 /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
185 /// .with_max_body_size(64 * 1024);
186 /// ```
187 #[must_use]
188 pub fn with_max_body_size(mut self, size: usize) -> Self {
189 self.max_body_size = size;
190 self
191 }
192
193 /// Attach a Prometheus metrics registry to the gateway.
194 ///
195 /// When set, the server mounts an additional route at `path` that returns the registry
196 /// contents encoded as `OpenMetrics` 1.0.0 text. The endpoint is unauthenticated and
197 /// bypasses rate limiting.
198 ///
199 /// Requires the `prometheus` feature.
200 ///
201 /// # Example
202 ///
203 /// ```no_run
204 /// # #[cfg(feature = "prometheus")]
205 /// # {
206 /// use std::sync::Arc;
207 /// use prometheus_client::registry::Registry;
208 /// use tokio::sync::{mpsc, watch};
209 /// use zeph_gateway::GatewayServer;
210 ///
211 /// let (tx, _rx) = mpsc::channel::<String>(1);
212 /// let (_stx, srx) = watch::channel(false);
213 /// let registry = Arc::new(Registry::default());
214 ///
215 /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
216 /// .with_metrics_registry(registry, "/metrics");
217 /// # }
218 /// ```
219 #[cfg(feature = "prometheus")]
220 #[must_use]
221 pub fn with_metrics_registry(
222 mut self,
223 registry: std::sync::Arc<prometheus_client::registry::Registry>,
224 path: impl Into<String>,
225 ) -> Self {
226 self.metrics_registry = Some((registry, path.into()));
227 self
228 }
229
230 /// Start the HTTP gateway server and block until shutdown is signalled.
231 ///
232 /// Binds a TCP listener on the configured address, installs middleware
233 /// (body-size limit → auth → rate limiting), and serves requests until
234 /// the [`watch::Receiver`] supplied to [`GatewayServer::new`] transitions
235 /// to `true`.
236 ///
237 /// # Errors
238 ///
239 /// - [`GatewayError::Bind`] — the listener could not be bound (port in use,
240 /// permission denied, etc.).
241 /// - [`GatewayError::Server`] — the server encountered a fatal I/O error
242 /// after binding.
243 pub async fn serve(self) -> Result<(), GatewayError> {
244 let state = AppState {
245 webhook_tx: self.webhook_tx,
246 started_at: Instant::now(),
247 };
248
249 if self.auth_token.is_none() {
250 tracing::warn!(
251 "gateway running without bearer auth — ensure firewall or upstream proxy enforces access control"
252 );
253 }
254
255 let router = build_router(
256 state,
257 self.auth_token.as_deref(),
258 self.rate_limit,
259 self.max_body_size,
260 );
261
262 #[cfg(feature = "prometheus")]
263 let router = if let Some((registry, path)) = self.metrics_registry {
264 let metrics_route = axum::Router::new()
265 .route(&path, axum::routing::get(crate::handlers::metrics_handler))
266 .with_state(registry);
267 router.merge(metrics_route)
268 } else {
269 router
270 };
271
272 let listener = tokio::net::TcpListener::bind(self.addr)
273 .await
274 .map_err(|e| GatewayError::Bind(self.addr.to_string(), e))?;
275 tracing::info!("gateway listening on {}", self.addr);
276
277 let mut shutdown_rx = self.shutdown_rx;
278 axum::serve(
279 listener,
280 router.into_make_service_with_connect_info::<SocketAddr>(),
281 )
282 .with_graceful_shutdown(async move {
283 while !*shutdown_rx.borrow_and_update() {
284 if shutdown_rx.changed().await.is_err() {
285 std::future::pending::<()>().await;
286 }
287 }
288 tracing::info!("gateway shutting down");
289 })
290 .await
291 .map_err(|e| GatewayError::Server(format!("{e}")))?;
292
293 Ok(())
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[cfg(feature = "prometheus")]
302 #[tokio::test]
303 async fn test_metrics_endpoint_returns_openmetrics() {
304 use axum::body::Body;
305 use http_body_util::BodyExt;
306 use prometheus_client::registry::Registry;
307 use tower::ServiceExt;
308
309 let registry = std::sync::Arc::new(Registry::default());
310
311 let (tx, _rx) = mpsc::channel(1);
312 let (_stx, srx) = watch::channel(false);
313 let server = GatewayServer::new("127.0.0.1", 19999, tx, srx)
314 .with_metrics_registry(std::sync::Arc::clone(®istry), "/metrics");
315
316 // Build the router directly without binding a port
317 let state = AppState {
318 webhook_tx: server.webhook_tx,
319 started_at: Instant::now(),
320 };
321 let router = crate::router::build_router(
322 state,
323 server.auth_token.as_deref(),
324 server.rate_limit,
325 server.max_body_size,
326 );
327 let metrics_route = axum::Router::new()
328 .route(
329 "/metrics",
330 axum::routing::get(crate::handlers::metrics_handler),
331 )
332 .with_state(registry);
333 let router = router.merge(metrics_route);
334
335 let req = axum::http::Request::builder()
336 .method("GET")
337 .uri("/metrics")
338 .body(Body::empty())
339 .unwrap();
340
341 let response = router.oneshot(req).await.unwrap();
342 assert_eq!(response.status(), axum::http::StatusCode::OK);
343
344 let ct = response
345 .headers()
346 .get("content-type")
347 .unwrap()
348 .to_str()
349 .unwrap();
350 assert!(
351 ct.contains("application/openmetrics-text"),
352 "unexpected content-type: {ct}"
353 );
354
355 let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
356 let body = String::from_utf8(body_bytes.to_vec()).unwrap();
357 assert!(body.ends_with("# EOF\n"), "missing EOF marker in:\n{body}");
358 }
359
360 #[test]
361 fn server_builder_chain() {
362 let (tx, _rx) = mpsc::channel(1);
363 let (_stx, srx) = watch::channel(false);
364 let server = GatewayServer::new("127.0.0.1", 8090, tx, srx)
365 .with_auth(Some("token".into()))
366 .with_rate_limit(60)
367 .with_max_body_size(512);
368
369 assert_eq!(server.rate_limit, 60);
370 assert_eq!(server.max_body_size, 512);
371 assert!(server.auth_token.is_some());
372 }
373
374 #[test]
375 fn server_invalid_bind_fallback() {
376 let (tx, _rx) = mpsc::channel(1);
377 let (_stx, srx) = watch::channel(false);
378 let server = GatewayServer::new("not_an_ip", 9999, tx, srx);
379 assert_eq!(server.addr.port(), 9999);
380 }
381}