use async_trait::async_trait;
use serde::Serialize;
use super::{SubscriptionError, transport::TransportAdapter, types::SubscriptionEvent};
#[derive(Debug, Clone)]
pub struct WebhookTransportConfig {
pub url: String,
pub secret: Option<String>,
pub timeout_ms: u64,
pub max_retries: u32,
pub retry_delay_ms: u64,
pub headers: std::collections::HashMap<String, String>,
}
impl WebhookTransportConfig {
#[must_use]
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
secret: None,
timeout_ms: 30_000,
max_retries: 3,
retry_delay_ms: 1000,
headers: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
self.secret = Some(secret.into());
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
#[must_use]
pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub const fn with_retry_delay(mut self, delay_ms: u64) -> Self {
self.retry_delay_ms = delay_ms;
self
}
#[must_use]
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize)]
pub struct WebhookPayload {
pub event_id: String,
pub subscription_name: String,
pub entity_type: String,
pub entity_id: String,
pub operation: String,
pub data: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub old_data: Option<serde_json::Value>,
pub timestamp: String,
pub sequence_number: u64,
}
impl WebhookPayload {
#[must_use]
pub fn from_event(event: &SubscriptionEvent, subscription_name: &str) -> Self {
Self {
event_id: event.event_id.clone(),
subscription_name: subscription_name.to_string(),
entity_type: event.entity_type.clone(),
entity_id: event.entity_id.clone(),
operation: format!("{:?}", event.operation),
data: event.data.clone(),
old_data: event.old_data.clone(),
timestamp: event.timestamp.to_rfc3339(),
sequence_number: event.sequence_number,
}
}
}
pub struct WebhookAdapter {
config: WebhookTransportConfig,
client: reqwest::Client,
}
impl WebhookAdapter {
pub fn new(config: WebhookTransportConfig) -> Result<Self, SubscriptionError> {
validate_webhook_url(&config.url)?;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_millis(config.timeout_ms))
.build()
.map_err(|e| SubscriptionError::Internal(format!("HTTP client init failed: {e}")))?;
Ok(Self { config, client })
}
fn compute_signature(&self, payload: &str) -> Option<String> {
use hmac::{Hmac, Mac};
use sha2::Sha256;
let secret = self.config.secret.as_ref()?;
#[allow(clippy::expect_used)]
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
.expect("SHA-256 HMAC accepts any key size");
mac.update(payload.as_bytes());
let result = mac.finalize();
Some(hex::encode(result.into_bytes()))
}
}
#[async_trait]
impl TransportAdapter for WebhookAdapter {
async fn deliver(
&self,
event: &SubscriptionEvent,
subscription_name: &str,
) -> Result<(), SubscriptionError> {
let payload = WebhookPayload::from_event(event, subscription_name);
let payload_json = serde_json::to_string(&payload).map_err(|e| {
SubscriptionError::Internal(format!("Failed to serialize payload: {e}"))
})?;
let mut attempt = 0;
let mut delay = self.config.retry_delay_ms;
loop {
attempt += 1;
let mut request = self
.client
.post(&self.config.url)
.header("Content-Type", "application/json")
.header("X-FraiseQL-Event-Id", &event.event_id)
.header("X-FraiseQL-Event-Type", subscription_name);
if let Some(signature) = self.compute_signature(&payload_json) {
request = request.header("X-FraiseQL-Signature", format!("sha256={signature}"));
}
for (name, value) in &self.config.headers {
request = request.header(name, value);
}
let result = request.body(payload_json.clone()).send().await;
match result {
Ok(response) if response.status().is_success() => {
tracing::debug!(
url = %self.config.url,
event_id = %event.event_id,
attempt = attempt,
"Webhook delivered successfully"
);
return Ok(());
},
Ok(response) => {
let status = response.status();
tracing::warn!(
url = %self.config.url,
event_id = %event.event_id,
status = %status,
attempt = attempt,
"Webhook delivery failed with status"
);
if status.is_client_error() && status.as_u16() != 429 {
return Err(SubscriptionError::Internal(format!(
"Webhook delivery failed: {status}"
)));
}
},
Err(e) => {
tracing::warn!(
url = %self.config.url,
event_id = %event.event_id,
error = %e,
attempt = attempt,
"Webhook delivery error"
);
},
}
if attempt >= self.config.max_retries {
return Err(SubscriptionError::Internal(format!(
"Webhook delivery failed after {} attempts",
attempt
)));
}
tracing::debug!(delay_ms = delay, "Retrying webhook delivery");
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
delay *= 2;
}
}
fn name(&self) -> &'static str {
"webhook"
}
async fn health_check(&self) -> bool {
match self.client.head(&self.config.url).send().await {
Ok(response) => response.status().is_success() || response.status().as_u16() == 405,
Err(_) => false,
}
}
}
impl std::fmt::Debug for WebhookAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebhookAdapter")
.field("url", &self.config.url)
.field("has_secret", &self.config.secret.is_some())
.finish_non_exhaustive()
}
}
pub fn validate_webhook_url(url: &str) -> Result<(), SubscriptionError> {
let parsed = reqwest::Url::parse(url)
.map_err(|e| SubscriptionError::Internal(format!("Invalid webhook URL '{url}': {e}")))?;
let host_raw = parsed
.host_str()
.ok_or_else(|| SubscriptionError::Internal(format!("Webhook URL has no host: {url}")))?;
let host = if host_raw.starts_with('[') && host_raw.ends_with(']') {
&host_raw[1..host_raw.len() - 1]
} else {
host_raw
};
let lower_host = host.to_ascii_lowercase();
if lower_host == "localhost" || lower_host.ends_with(".localhost") {
return Err(SubscriptionError::Internal(format!(
"Webhook URL targets a loopback host ({host}) — SSRF protection blocked"
)));
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_webhook_ssrf_blocked_ip(&ip) {
return Err(SubscriptionError::Internal(format!(
"Webhook URL targets a private/reserved IP ({ip}) — SSRF protection blocked"
)));
}
}
Ok(())
}
fn is_webhook_ssrf_blocked_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
let o = v4.octets();
o[0] == 127 || o[0] == 10 || (o[0] == 172 && (16..=31).contains(&o[1])) || (o[0] == 192 && o[1] == 168) || (o[0] == 169 && o[1] == 254) || (o[0] == 100 && (64..=127).contains(&o[1])) || o == [0, 0, 0, 0] },
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || {
let s = v6.segments();
(s[0] & 0xfe00) == 0xfc00 || (s[0] & 0xffc0) == 0xfe80 || (s[0] == 0 && s[1] == 0 && s[2] == 0 && s[3] == 0 && s[4] == 0 && s[5] == 0xffff)
}
},
}
}