use std::net::SocketAddr;
use axum::Router;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use crate::error::ChannelError;
pub struct WebhookServerConfig {
pub addr: SocketAddr,
}
pub struct WebhookServer {
config: WebhookServerConfig,
routes: Vec<Router>,
merged_router: Option<Router>,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<JoinHandle<()>>,
}
impl WebhookServer {
pub fn new(config: WebhookServerConfig) -> Self {
Self {
config,
routes: Vec::new(),
merged_router: None,
shutdown_tx: None,
handle: None,
}
}
pub fn add_routes(&mut self, router: Router) {
self.routes.push(router);
}
pub async fn start(&mut self) -> Result<(), ChannelError> {
let mut app = Router::new();
for fragment in self.routes.drain(..) {
app = app.merge(fragment);
}
self.merged_router = Some(app.clone());
self.bind_and_spawn(app).await
}
async fn bind_and_spawn(&mut self, app: Router) -> Result<(), ChannelError> {
let listener = tokio::net::TcpListener::bind(self.config.addr)
.await
.map_err(|e| ChannelError::StartupFailed {
name: "webhook_server".to_string(),
reason: format!("Failed to bind to {}: {}", self.config.addr, e),
})?;
tracing::debug!("Webhook server listening on {}", self.config.addr);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let handle = tokio::spawn(async move {
if let Err(e) = axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
tracing::debug!("Webhook server shutting down");
})
.await
{
tracing::error!("Webhook server error: {}", e);
}
});
self.handle = Some(handle);
Ok(())
}
pub fn merged_router_clone(&self) -> Option<Router> {
self.merged_router.clone()
}
pub fn install_listener(
&mut self,
new_addr: SocketAddr,
listener: tokio::net::TcpListener,
app: Router,
) -> (Option<oneshot::Sender<()>>, Option<JoinHandle<()>>) {
let old_shutdown_tx = self.shutdown_tx.take();
let old_handle = self.handle.take();
self.config.addr = new_addr;
let (shutdown_tx, shutdown_rx) = oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let handle = tokio::spawn(async move {
if let Err(e) = axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
tracing::debug!("Webhook server shutting down");
})
.await
{
tracing::error!("Webhook server error: {}", e);
}
});
self.handle = Some(handle);
tracing::debug!("Webhook server listening on {}", new_addr);
(old_shutdown_tx, old_handle)
}
pub fn current_addr(&self) -> SocketAddr {
self.config.addr
}
pub fn begin_shutdown(&mut self) -> (Option<oneshot::Sender<()>>, Option<JoinHandle<()>>) {
(self.shutdown_tx.take(), self.handle.take())
}
pub async fn shutdown(&mut self) {
let (shutdown_tx, handle) = self.begin_shutdown();
if let Some(tx) = shutdown_tx {
let _ = tx.send(());
}
if let Some(handle) = handle {
let _ = handle.await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Json;
use serde_json::json;
#[tokio::test]
async fn test_restart_with_addr_rebinds_listener() {
use std::net::TcpListener as StdTcpListener;
let port1 = {
let listener =
StdTcpListener::bind("127.0.0.1:0").expect("Failed to find available port 1");
listener
.local_addr()
.expect("Failed to get local addr")
.port()
};
let port2 = {
let listener =
StdTcpListener::bind("127.0.0.1:0").expect("Failed to find available port 2");
listener
.local_addr()
.expect("Failed to get local addr")
.port()
};
assert_ne!(port1, port2, "Should have different ports");
assert_ne!(port1, 0, "Port 1 should be non-zero");
assert_ne!(port2, 0, "Port 2 should be non-zero");
let addr1 = format!("127.0.0.1:{}", port1).parse().unwrap();
let mut server = WebhookServer::new(WebhookServerConfig { addr: addr1 });
let test_router = axum::Router::new().route(
"/health",
axum::routing::get(|| async { Json(json!({"status": "ok"})) }),
);
server.add_routes(test_router);
server.start().await.expect("Failed to start server");
assert_eq!(
server.current_addr(),
addr1,
"Server should be bound to initial address"
);
let client = reqwest::Client::new();
let response = client
.get(format!("http://{}/health", addr1))
.send()
.await
.expect("Failed to send request to first server");
assert_eq!(
response.status(),
200,
"First server should respond to health check"
);
let addr2: SocketAddr = format!("127.0.0.1:{}", port2).parse().unwrap();
let app = server
.merged_router_clone()
.expect("Router should exist after start()");
let listener = tokio::net::TcpListener::bind(addr2)
.await
.expect("Failed to bind to new addr");
let (old_tx, old_handle) = server.install_listener(addr2, listener, app);
if let Some(tx) = old_tx {
let _ = tx.send(());
}
if let Some(handle) = old_handle {
let _ = handle.await;
}
assert_eq!(
server.current_addr(),
addr2,
"Server address should be updated after restart"
);
assert_ne!(
addr1, addr2,
"Address should change after restart_with_addr"
);
let response = client
.get(format!("http://{}/health", addr2))
.send()
.await
.expect("Failed to send request to restarted server");
assert_eq!(
response.status(),
200,
"Restarted server should respond to health check on new address"
);
let old_result = tokio::time::timeout(
std::time::Duration::from_millis(200),
client.get(format!("http://{}/health", addr1)).send(),
)
.await;
assert!(
old_result.is_err() || old_result.as_ref().unwrap().is_err(),
"Old address should not respond after server restarts"
);
server.shutdown().await;
}
#[tokio::test]
async fn test_begin_shutdown_takes_handles_for_lock_free_shutdown() {
let addr = SocketAddr::from((std::net::Ipv4Addr::LOCALHOST, 0));
let mut server = WebhookServer::new(WebhookServerConfig { addr });
let test_router = axum::Router::new().route(
"/health",
axum::routing::get(|| async { Json(json!({"status": "ok"})) }),
);
server.add_routes(test_router);
server.start().await.expect("Failed to start server");
let (shutdown_tx, handle) = server.begin_shutdown();
assert!(shutdown_tx.is_some(), "shutdown sender should be available"); assert!(handle.is_some(), "server handle should be available");
let (shutdown_tx2, handle2) = server.begin_shutdown();
assert!(shutdown_tx2.is_none(), "shutdown sender should be consumed"); assert!(handle2.is_none(), "server handle should be consumed");
if let Some(tx) = shutdown_tx {
let _ = tx.send(());
}
if let Some(handle) = handle {
let _ = handle.await;
}
}
#[tokio::test]
async fn test_restart_with_addr_rollback_on_bind_failure() {
use std::net::TcpListener as StdTcpListener;
let port1 = {
let listener =
StdTcpListener::bind("127.0.0.1:0").expect("Failed to find available port");
listener
.local_addr()
.expect("Failed to get local addr")
.port()
};
let addr1 = format!("127.0.0.1:{}", port1).parse().unwrap();
let mut server = WebhookServer::new(WebhookServerConfig { addr: addr1 });
let test_router = axum::Router::new().route(
"/health",
axum::routing::get(|| async { Json(json!({"status": "ok"})) }),
);
server.add_routes(test_router);
server.start().await.expect("Failed to start server");
let client = reqwest::Client::new();
let response = client
.get(format!("http://{}/health", addr1))
.send()
.await
.expect("Failed to send request");
assert_eq!(response.status(), 200, "Server should be listening");
let invalid_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let app = server
.merged_router_clone()
.expect("Router should exist after start()");
let result = tokio::net::TcpListener::bind(invalid_addr).await;
assert!(result.is_err(), "Bind to privileged port should fail");
drop(app);
let response = client
.get(format!("http://{}/health", addr1))
.send()
.await
.expect("Failed to send request to old address");
assert_eq!(
response.status(),
200,
"Old listener should still be running after failed restart"
);
assert_eq!(
server.current_addr(),
addr1,
"Server address should be restored after failed restart"
);
server.shutdown().await;
}
}