use std::net::SocketAddr;
use std::time::Instant;
use tokio::sync::{mpsc, watch};
use crate::error::GatewayError;
use crate::router::build_router;
#[derive(Clone)]
pub(crate) struct AppState {
pub webhook_tx: mpsc::Sender<String>,
pub started_at: Instant,
}
pub struct GatewayServer {
addr: SocketAddr,
auth_token: Option<String>,
rate_limit: u32,
max_body_size: usize,
webhook_tx: mpsc::Sender<String>,
shutdown_rx: watch::Receiver<bool>,
#[cfg(feature = "prometheus")]
metrics_registry: Option<(
std::sync::Arc<prometheus_client::registry::Registry>,
String,
)>,
}
impl GatewayServer {
#[must_use]
pub fn new(
bind: &str,
port: u16,
webhook_tx: mpsc::Sender<String>,
shutdown_rx: watch::Receiver<bool>,
) -> Self {
let addr: SocketAddr = format!("{bind}:{port}").parse().unwrap_or_else(|e| {
tracing::warn!("invalid bind '{bind}': {e}, falling back to 127.0.0.1:{port}");
SocketAddr::from(([127, 0, 0, 1], port))
});
if bind == "0.0.0.0" {
tracing::warn!("gateway binding to 0.0.0.0 — ensure this is intended for production");
}
Self {
addr,
auth_token: None,
rate_limit: 120,
max_body_size: 1_048_576,
webhook_tx,
shutdown_rx,
#[cfg(feature = "prometheus")]
metrics_registry: None,
}
}
#[must_use]
pub fn with_auth(mut self, token: Option<String>) -> Self {
self.auth_token = token;
self
}
#[must_use]
pub fn with_rate_limit(mut self, limit: u32) -> Self {
self.rate_limit = limit;
self
}
#[must_use]
pub fn with_max_body_size(mut self, size: usize) -> Self {
self.max_body_size = size;
self
}
#[cfg(feature = "prometheus")]
#[must_use]
pub fn with_metrics_registry(
mut self,
registry: std::sync::Arc<prometheus_client::registry::Registry>,
path: impl Into<String>,
) -> Self {
self.metrics_registry = Some((registry, path.into()));
self
}
pub async fn serve(self) -> Result<(), GatewayError> {
let state = AppState {
webhook_tx: self.webhook_tx,
started_at: Instant::now(),
};
if self.auth_token.is_none() {
tracing::warn!(
"gateway running without bearer auth — ensure firewall or upstream proxy enforces access control"
);
}
let router = build_router(
state,
self.auth_token.as_deref(),
self.rate_limit,
self.max_body_size,
);
#[cfg(feature = "prometheus")]
let router = if let Some((registry, path)) = self.metrics_registry {
let metrics_route = axum::Router::new()
.route(&path, axum::routing::get(crate::handlers::metrics_handler))
.with_state(registry);
router.merge(metrics_route)
} else {
router
};
let listener = tokio::net::TcpListener::bind(self.addr)
.await
.map_err(|e| GatewayError::Bind(self.addr.to_string(), e))?;
tracing::info!("gateway listening on {}", self.addr);
let mut shutdown_rx = self.shutdown_rx;
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(async move {
while !*shutdown_rx.borrow_and_update() {
if shutdown_rx.changed().await.is_err() {
std::future::pending::<()>().await;
}
}
tracing::info!("gateway shutting down");
})
.await
.map_err(|e| GatewayError::Server(format!("{e}")))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "prometheus")]
#[tokio::test]
async fn test_metrics_endpoint_returns_openmetrics() {
use axum::body::Body;
use http_body_util::BodyExt;
use prometheus_client::registry::Registry;
use tower::ServiceExt;
let registry = std::sync::Arc::new(Registry::default());
let (tx, _rx) = mpsc::channel(1);
let (_stx, srx) = watch::channel(false);
let server = GatewayServer::new("127.0.0.1", 19999, tx, srx)
.with_metrics_registry(std::sync::Arc::clone(®istry), "/metrics");
let state = AppState {
webhook_tx: server.webhook_tx,
started_at: Instant::now(),
};
let router = crate::router::build_router(
state,
server.auth_token.as_deref(),
server.rate_limit,
server.max_body_size,
);
let metrics_route = axum::Router::new()
.route(
"/metrics",
axum::routing::get(crate::handlers::metrics_handler),
)
.with_state(registry);
let router = router.merge(metrics_route);
let req = axum::http::Request::builder()
.method("GET")
.uri("/metrics")
.body(Body::empty())
.unwrap();
let response = router.oneshot(req).await.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("application/openmetrics-text"),
"unexpected content-type: {ct}"
);
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8(body_bytes.to_vec()).unwrap();
assert!(body.ends_with("# EOF\n"), "missing EOF marker in:\n{body}");
}
#[test]
fn server_builder_chain() {
let (tx, _rx) = mpsc::channel(1);
let (_stx, srx) = watch::channel(false);
let server = GatewayServer::new("127.0.0.1", 8090, tx, srx)
.with_auth(Some("token".into()))
.with_rate_limit(60)
.with_max_body_size(512);
assert_eq!(server.rate_limit, 60);
assert_eq!(server.max_body_size, 512);
assert!(server.auth_token.is_some());
}
#[test]
fn server_invalid_bind_fallback() {
let (tx, _rx) = mpsc::channel(1);
let (_stx, srx) = watch::channel(false);
let server = GatewayServer::new("not_an_ip", 9999, tx, srx);
assert_eq!(server.addr.port(), 9999);
}
}