use std::net::SocketAddr;
use std::sync::Arc;
use axum::{
body::Body,
extract::{ConnectInfo, State},
http::{Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use ipnet::IpNet;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
use super::events::WebhookEvent;
use super::logging::WebhookLogger;
#[derive(Clone)]
pub struct WebhookServerConfig {
pub allowed_ips: Vec<IpNet>,
pub auth_header_name: Option<String>,
pub auth_header_value: Option<String>,
pub enable_logging: bool,
pub db_logger: Option<Arc<dyn WebhookLogger>>,
pub channel_buffer_size: usize,
}
impl std::fmt::Debug for WebhookServerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebhookServerConfig")
.field("allowed_ips", &self.allowed_ips)
.field("auth_header_name", &self.auth_header_name)
.field("auth_header_value", &"[REDACTED]")
.field("enable_logging", &self.enable_logging)
.field("db_logger", &self.db_logger.as_ref().map(|_| "[logger]"))
.field("channel_buffer_size", &self.channel_buffer_size)
.finish()
}
}
impl Default for WebhookServerConfig {
fn default() -> Self {
Self {
allowed_ips: Vec::new(),
auth_header_name: None,
auth_header_value: None,
enable_logging: true,
db_logger: None,
channel_buffer_size: 1000,
}
}
}
impl WebhookServerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_allowed_ips(mut self, ips: Vec<IpNet>) -> Self {
self.allowed_ips = ips;
self
}
pub fn with_auth_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.auth_header_name = Some(name.into());
self.auth_header_value = Some(value.into());
self
}
pub fn with_logger(mut self, logger: Arc<dyn WebhookLogger>) -> Self {
self.db_logger = Some(logger);
self
}
pub fn with_stdout_logging(mut self, enabled: bool) -> Self {
self.enable_logging = enabled;
self
}
pub fn with_channel_buffer(mut self, size: usize) -> Self {
self.channel_buffer_size = size;
self
}
}
#[derive(Clone)]
struct ServerState {
config: WebhookServerConfig,
event_sender: mpsc::Sender<WebhookEvent>,
}
pub struct WebhookServer {
state: ServerState,
}
impl WebhookServer {
pub fn new() -> (Self, mpsc::Receiver<WebhookEvent>) {
Self::with_config(WebhookServerConfig::default())
}
pub fn with_config(config: WebhookServerConfig) -> (Self, mpsc::Receiver<WebhookEvent>) {
let (sender, receiver) = mpsc::channel(config.channel_buffer_size);
let state = ServerState {
config,
event_sender: sender,
};
(Self { state }, receiver)
}
pub fn router(self) -> Router {
let state = Arc::new(self.state);
Router::new()
.route("/webhooks/payrix", post(handle_webhook))
.route("/health", get(health_check))
.layer(middleware::from_fn_with_state(
state.clone(),
security_middleware,
))
.with_state(state)
}
pub async fn run(self, addr: SocketAddr) -> Result<(), std::io::Error> {
let router = self.router().into_make_service_with_connect_info::<SocketAddr>();
info!("Starting webhook server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router).await
}
pub async fn run_with_shutdown<F>(
self,
addr: SocketAddr,
shutdown_signal: F,
) -> Result<(), std::io::Error>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let router = self.router().into_make_service_with_connect_info::<SocketAddr>();
info!("Starting webhook server on {} (with graceful shutdown)", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal)
.await
}
}
impl Default for WebhookServer {
fn default() -> Self {
Self::new().0
}
}
async fn handle_webhook(
State(state): State<Arc<ServerState>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(payload): Json<serde_json::Value>,
) -> impl IntoResponse {
let source_ip = addr.ip();
let event_type = payload
.get("event")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let resource_type = payload
.get("resourceType")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let resource_id = payload
.get("resourceId")
.or_else(|| payload.get("id"))
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let data = payload
.get("resource")
.cloned()
.unwrap_or_else(|| payload.clone());
let event = WebhookEvent::new(event_type.clone(), resource_type, resource_id, data, source_ip);
if state.config.enable_logging {
info!(
event_type = %event.event_type,
resource_id = %event.resource_id,
source_ip = %source_ip,
"Received webhook event"
);
}
if let Some(logger) = &state.config.db_logger {
if let Err(e) = logger.log_received(&event).await {
warn!("Failed to log webhook event: {}", e);
}
}
if let Err(e) = state.event_sender.send(event).await {
warn!("Failed to send webhook event to channel: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Event processing failed");
}
(StatusCode::OK, "OK")
}
async fn health_check() -> impl IntoResponse {
(StatusCode::OK, "OK")
}
async fn security_middleware(
State(state): State<Arc<ServerState>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: Request<Body>,
next: Next,
) -> Response {
let source_ip = addr.ip();
if !state.config.allowed_ips.is_empty() {
let allowed = state
.config
.allowed_ips
.iter()
.any(|net| net.contains(&source_ip));
if !allowed {
warn!(
source_ip = %source_ip,
"Webhook request from unauthorized IP"
);
return (StatusCode::FORBIDDEN, "IP not allowed").into_response();
}
}
if let (Some(header_name), Some(expected_value)) = (
&state.config.auth_header_name,
&state.config.auth_header_value,
) {
let actual_value = request
.headers()
.get(header_name)
.and_then(|v| v.to_str().ok());
if actual_value != Some(expected_value.as_str()) {
warn!(
source_ip = %source_ip,
header = %header_name,
"Webhook request with invalid authentication"
);
return (StatusCode::UNAUTHORIZED, "Invalid authentication").into_response();
}
}
debug!(source_ip = %source_ip, "Webhook request passed security checks");
next.run(request).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = WebhookServerConfig::new()
.with_auth_header("X-Secret", "value")
.with_stdout_logging(false)
.with_channel_buffer(500);
assert_eq!(config.auth_header_name, Some("X-Secret".to_string()));
assert_eq!(config.auth_header_value, Some("value".to_string()));
assert!(!config.enable_logging);
assert_eq!(config.channel_buffer_size, 500);
}
#[test]
fn test_config_with_ip_allowlist() {
let config = WebhookServerConfig::new().with_allowed_ips(vec![
"10.0.0.0/8".parse().unwrap(),
"192.168.0.0/16".parse().unwrap(),
]);
assert_eq!(config.allowed_ips.len(), 2);
}
#[tokio::test]
async fn test_server_creation() {
let (_server, mut receiver) = WebhookServer::new();
assert!(receiver.try_recv().is_err()); }
}