use crate::config::WebhookSourceConfig;
use async_trait::async_trait;
use axum::{Router, extract::State, http::StatusCode, routing::post};
use faucet_core::FaucetError;
use serde_json::Value;
use std::sync::Arc;
use subtle::ConstantTimeEq;
use tokio::sync::{Mutex, Notify};
struct AppState {
records: Mutex<Vec<Value>>,
max_payloads: Option<usize>,
done: Notify,
auth_token: Option<String>,
}
impl WebhookSource {
fn new_state(&self) -> Arc<AppState> {
Arc::new(AppState {
records: Mutex::new(Vec::new()),
max_payloads: self.config.max_payloads,
done: Notify::new(),
auth_token: self.config.auth_token.clone(),
})
}
fn build_router(&self, path: &str, state: Arc<AppState>) -> Router {
Router::new()
.route(path, post(webhook_handler))
.layer(axum::extract::DefaultBodyLimit::max(
self.config.max_body_bytes,
))
.with_state(state)
}
}
pub struct WebhookSource {
config: WebhookSourceConfig,
}
impl WebhookSource {
pub fn new(config: WebhookSourceConfig) -> Self {
Self { config }
}
pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
let state = self.new_state();
let app = self.build_router(&self.config.path, Arc::clone(&state));
let listener = tokio::net::TcpListener::bind(&self.config.listen_addr)
.await
.map_err(|e| {
FaucetError::Config(format!(
"failed to bind to {}: {e}",
self.config.listen_addr
))
})?;
tracing::info!(
addr = %self.config.listen_addr,
path = %self.config.path,
"webhook server listening"
);
let timeout = tokio::time::sleep(std::time::Duration::from_secs(self.config.timeout_secs));
let done_notified = state.done.notified();
tokio::select! {
result = axum::serve(listener, app).into_future() => {
if let Err(e) = result {
return Err(FaucetError::Config(format!("webhook server error: {e}")));
}
}
() = timeout => {
tracing::info!("webhook timeout reached");
}
() = done_notified => {
tracing::info!("max payloads reached");
}
}
let records = state.records.lock().await.clone();
tracing::info!(records = records.len(), "webhook fetch complete");
Ok(records)
}
}
fn token_matches(provided: Option<&str>, expected: &str) -> bool {
let Some(p) = provided else {
return false;
};
let exp = expected.as_bytes();
let raw = bool::from(p.as_bytes().ct_eq(exp));
let stripped = p
.strip_prefix("Bearer ")
.map(|s| bool::from(s.as_bytes().ct_eq(exp)))
.unwrap_or(false);
raw | stripped
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct PayloadDecision {
accept: bool,
done: bool,
}
fn decide_payload(current_len: usize, max_payloads: Option<usize>) -> PayloadDecision {
match max_payloads {
None => PayloadDecision {
accept: true,
done: false,
},
Some(max) => {
if current_len >= max {
PayloadDecision {
accept: false,
done: true,
}
} else {
PayloadDecision {
accept: true,
done: current_len + 1 >= max,
}
}
}
}
}
async fn webhook_handler(
State(state): State<Arc<AppState>>,
headers: axum::http::HeaderMap,
body: axum::body::Bytes,
) -> StatusCode {
if let Some(expected) = &state.auth_token {
let provided = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
if !token_matches(provided, expected) {
return StatusCode::UNAUTHORIZED;
}
}
let value = match serde_json::from_slice::<Value>(&body) {
Ok(v) => v,
Err(_) => {
match String::from_utf8(body.to_vec()) {
Ok(s) => Value::String(s),
Err(_) => return StatusCode::BAD_REQUEST,
}
}
};
let mut records = state.records.lock().await;
let decision = decide_payload(records.len(), state.max_payloads);
if decision.accept {
records.push(value);
}
if decision.done {
state.done.notify_one();
}
StatusCode::OK
}
#[async_trait]
impl faucet_core::Source for WebhookSource {
async fn fetch_with_context(
&self,
context: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<Vec<Value>, FaucetError> {
if context.is_empty() {
return WebhookSource::fetch_all(self).await;
}
let resolved_path = faucet_core::util::substitute_context(&self.config.path, context);
let state = self.new_state();
let app = self.build_router(&resolved_path, Arc::clone(&state));
let listener = tokio::net::TcpListener::bind(&self.config.listen_addr)
.await
.map_err(|e| {
FaucetError::Config(format!(
"failed to bind to {}: {e}",
self.config.listen_addr
))
})?;
tracing::info!(
addr = %self.config.listen_addr,
path = %resolved_path,
"webhook server listening (with context)"
);
let timeout = tokio::time::sleep(std::time::Duration::from_secs(self.config.timeout_secs));
let done_notified = state.done.notified();
tokio::select! {
result = axum::serve(listener, app).into_future() => {
if let Err(e) = result {
return Err(FaucetError::Config(format!("webhook server error: {e}")));
}
}
() = timeout => {
tracing::info!("webhook timeout reached");
}
() = done_notified => {
tracing::info!("max payloads reached");
}
}
let records = state.records.lock().await.clone();
tracing::info!(
records = records.len(),
"webhook fetch complete (with context)"
);
Ok(records)
}
fn config_schema(&self) -> serde_json::Value {
serde_json::to_value(faucet_core::schema_for!(WebhookSourceConfig))
.expect("schema serialization")
}
fn connector_name(&self) -> &'static str {
"webhook"
}
async fn check(
&self,
_ctx: &faucet_core::check::CheckContext,
) -> Result<faucet_core::check::CheckReport, FaucetError> {
use faucet_core::check::{CheckReport, Probe};
let start = std::time::Instant::now();
match tokio::net::TcpListener::bind(&self.config.listen_addr).await {
Ok(listener) => {
drop(listener);
Ok(CheckReport::single(Probe::pass("io", start.elapsed())))
}
Err(e) => Ok(CheckReport::single(Probe::fail_hint(
"io",
start.elapsed(),
e.to_string(),
format!("{} is not bindable", self.config.listen_addr),
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn token_matches_accepts_raw_and_bearer() {
assert!(token_matches(
Some("sekret-token-value"),
"sekret-token-value"
));
assert!(token_matches(
Some("Bearer sekret-token-value"),
"sekret-token-value"
));
}
#[test]
fn token_matches_rejects_wrong_and_missing() {
assert!(!token_matches(
Some("wrong-token-value"),
"sekret-token-value"
));
assert!(!token_matches(
Some("Bearer wrong-token-value"),
"sekret-token-value"
));
assert!(!token_matches(None, "sekret-token-value"));
assert!(!token_matches(
Some("sekret-token-valu"),
"sekret-token-value"
));
}
#[test]
fn decide_payload_no_cap_always_accepts() {
for len in [0usize, 1, 100, 10_000] {
assert_eq!(
decide_payload(len, None),
PayloadDecision {
accept: true,
done: false
}
);
}
}
#[test]
fn decide_payload_accepts_until_cap_then_drops() {
let max = Some(2);
assert_eq!(
decide_payload(0, max),
PayloadDecision {
accept: true,
done: false
}
);
assert_eq!(
decide_payload(1, max),
PayloadDecision {
accept: true,
done: true
}
);
assert_eq!(
decide_payload(2, max),
PayloadDecision {
accept: false,
done: true
}
);
assert_eq!(
decide_payload(3, max),
PayloadDecision {
accept: false,
done: true
}
);
}
#[test]
fn cap_invariant_never_exceeded_under_concurrent_arrivals() {
let max = 2usize;
let mut records: Vec<Value> = Vec::new();
for i in 0..5 {
let decision = decide_payload(records.len(), Some(max));
if decision.accept {
records.push(json!({ "id": i }));
}
}
assert_eq!(
records.len(),
max,
"Vec must never exceed max_payloads, got {}",
records.len()
);
}
#[tokio::test]
async fn handler_never_exceeds_cap_under_concurrent_posts() {
let max = 3usize;
let state = Arc::new(AppState {
records: Mutex::new(Vec::new()),
max_payloads: Some(max),
done: Notify::new(),
auth_token: None,
});
let mut handles = Vec::new();
for i in 0..50 {
let st = Arc::clone(&state);
handles.push(tokio::spawn(async move {
let body = axum::body::Bytes::from(format!("{{\"id\":{i}}}"));
webhook_handler(State(st), axum::http::HeaderMap::new(), body).await
}));
}
for h in handles {
assert_eq!(h.await.unwrap(), StatusCode::OK);
}
let records = state.records.lock().await;
assert_eq!(
records.len(),
max,
"Vec must never exceed max_payloads, got {}",
records.len()
);
}
#[tokio::test]
async fn webhook_collects_payloads() {
let config = WebhookSourceConfig::new()
.listen_addr("127.0.0.1:0")
.max_payloads(2)
.timeout_secs(5);
let state = Arc::new(AppState {
records: Mutex::new(Vec::new()),
max_payloads: config.max_payloads,
done: Notify::new(),
auth_token: config.auth_token.clone(),
});
let server_state = Arc::clone(&state);
let app = Router::new()
.route(&config.path, post(webhook_handler))
.with_state(Arc::clone(&state));
let listener = tokio::net::TcpListener::bind(&config.listen_addr)
.await
.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let done_notified = server_state.done.notified();
tokio::select! {
result = axum::serve(listener, app).into_future() => {
if let Err(e) = result {
panic!("server error: {e}");
}
}
() = done_notified => {}
}
});
let client = reqwest::Client::new();
let url = format!("http://{addr}/webhook");
let resp1 = client
.post(&url)
.json(&json!({"event": "created", "id": 1}))
.send()
.await
.unwrap();
assert_eq!(resp1.status(), 200);
let resp2 = client
.post(&url)
.json(&json!({"event": "updated", "id": 2}))
.send()
.await
.unwrap();
assert_eq!(resp2.status(), 200);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
server_handle.abort();
let records = state.records.lock().await;
assert_eq!(records.len(), 2);
assert_eq!(records[0]["event"], "created");
assert_eq!(records[1]["event"], "updated");
}
#[tokio::test]
async fn check_passes_when_port_is_bindable() {
use faucet_core::Source;
use faucet_core::check::{CheckContext, ProbeStatus};
let source = WebhookSource::new(WebhookSourceConfig::new().listen_addr("127.0.0.1:0"));
let report = source.check(&CheckContext::default()).await.unwrap();
assert_eq!(report.probes.len(), 1);
assert_eq!(report.probes[0].name, "io");
assert!(
matches!(report.probes[0].status, ProbeStatus::Pass),
"expected Pass, got {:?}",
report.probes[0].status
);
assert_eq!(report.failed_count(), 0);
}
#[tokio::test]
async fn check_fails_when_port_is_already_bound() {
use faucet_core::Source;
use faucet_core::check::{CheckContext, ProbeStatus};
let held = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = held.local_addr().unwrap();
let source = WebhookSource::new(WebhookSourceConfig::new().listen_addr(addr.to_string()));
let report = source.check(&CheckContext::default()).await.unwrap();
assert_eq!(report.probes.len(), 1);
assert_eq!(report.probes[0].name, "io");
assert!(
matches!(report.probes[0].status, ProbeStatus::Fail { .. }),
"expected Fail, got {:?}",
report.probes[0].status
);
assert_eq!(report.failed_count(), 1);
assert!(
report.probes[0]
.hint
.as_deref()
.unwrap()
.contains("not bindable")
);
}
#[tokio::test]
async fn webhook_handles_non_json_body() {
let state = Arc::new(AppState {
records: Mutex::new(Vec::new()),
max_payloads: Some(1),
done: Notify::new(),
auth_token: None,
});
let server_state = Arc::clone(&state);
let app = Router::new()
.route("/webhook", post(webhook_handler))
.with_state(Arc::clone(&state));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let done_notified = server_state.done.notified();
tokio::select! {
result = axum::serve(listener, app).into_future() => {
if let Err(e) = result {
panic!("server error: {e}");
}
}
() = done_notified => {}
}
});
let client = reqwest::Client::new();
let resp = client
.post(format!("http://{addr}/webhook"))
.body("plain text body")
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
server_handle.abort();
let records = state.records.lock().await;
assert_eq!(records.len(), 1);
assert_eq!(records[0], Value::String("plain text body".into()));
}
}