1use crate::config::WebhookSourceConfig;
4use async_trait::async_trait;
5use axum::{Router, extract::State, http::StatusCode, routing::post};
6use faucet_core::FaucetError;
7use serde_json::Value;
8use std::sync::Arc;
9use subtle::ConstantTimeEq;
10use tokio::sync::{Mutex, Notify};
11
12struct AppState {
14 records: Mutex<Vec<Value>>,
15 max_payloads: Option<usize>,
16 done: Notify,
17 auth_token: Option<String>,
19}
20
21impl WebhookSource {
22 fn new_state(&self) -> Arc<AppState> {
23 Arc::new(AppState {
24 records: Mutex::new(Vec::new()),
25 max_payloads: self.config.max_payloads,
26 done: Notify::new(),
27 auth_token: self.config.auth_token.clone(),
28 })
29 }
30
31 fn build_router(&self, path: &str, state: Arc<AppState>) -> Router {
32 Router::new()
33 .route(path, post(webhook_handler))
34 .layer(axum::extract::DefaultBodyLimit::max(
37 self.config.max_body_bytes,
38 ))
39 .with_state(state)
40 }
41}
42
43pub struct WebhookSource {
46 config: WebhookSourceConfig,
47}
48
49impl WebhookSource {
50 pub fn new(config: WebhookSourceConfig) -> Self {
52 Self { config }
53 }
54
55 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
57 let state = self.new_state();
58 let app = self.build_router(&self.config.path, Arc::clone(&state));
59
60 let listener = tokio::net::TcpListener::bind(&self.config.listen_addr)
61 .await
62 .map_err(|e| {
63 FaucetError::Config(format!(
64 "failed to bind to {}: {e}",
65 self.config.listen_addr
66 ))
67 })?;
68
69 tracing::info!(
70 addr = %self.config.listen_addr,
71 path = %self.config.path,
72 "webhook server listening"
73 );
74
75 let timeout = tokio::time::sleep(std::time::Duration::from_secs(self.config.timeout_secs));
76 let done_notified = state.done.notified();
77
78 tokio::select! {
79 result = axum::serve(listener, app).into_future() => {
80 if let Err(e) = result {
81 return Err(FaucetError::Config(format!("webhook server error: {e}")));
82 }
83 }
84 () = timeout => {
85 tracing::info!("webhook timeout reached");
86 }
87 () = done_notified => {
88 tracing::info!("max payloads reached");
89 }
90 }
91
92 let records = state.records.lock().await.clone();
93 tracing::info!(records = records.len(), "webhook fetch complete");
94 Ok(records)
95 }
96}
97
98fn token_matches(provided: Option<&str>, expected: &str) -> bool {
106 let Some(p) = provided else {
107 return false;
108 };
109 let exp = expected.as_bytes();
110 let raw = bool::from(p.as_bytes().ct_eq(exp));
111 let stripped = p
112 .strip_prefix("Bearer ")
113 .map(|s| bool::from(s.as_bytes().ct_eq(exp)))
114 .unwrap_or(false);
115 raw | stripped
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122struct PayloadDecision {
123 accept: bool,
125 done: bool,
128}
129
130fn decide_payload(current_len: usize, max_payloads: Option<usize>) -> PayloadDecision {
139 match max_payloads {
140 None => PayloadDecision {
141 accept: true,
142 done: false,
143 },
144 Some(max) => {
145 if current_len >= max {
146 PayloadDecision {
150 accept: false,
151 done: true,
152 }
153 } else {
154 PayloadDecision {
157 accept: true,
158 done: current_len + 1 >= max,
159 }
160 }
161 }
162 }
163}
164
165async fn webhook_handler(
167 State(state): State<Arc<AppState>>,
168 headers: axum::http::HeaderMap,
169 body: axum::body::Bytes,
170) -> StatusCode {
171 if let Some(expected) = &state.auth_token {
174 let provided = headers
175 .get(axum::http::header::AUTHORIZATION)
176 .and_then(|v| v.to_str().ok());
177 if !token_matches(provided, expected) {
178 return StatusCode::UNAUTHORIZED;
179 }
180 }
181
182 let value = match serde_json::from_slice::<Value>(&body) {
183 Ok(v) => v,
184 Err(_) => {
185 match String::from_utf8(body.to_vec()) {
187 Ok(s) => Value::String(s),
188 Err(_) => return StatusCode::BAD_REQUEST,
189 }
190 }
191 };
192
193 let mut records = state.records.lock().await;
194 let decision = decide_payload(records.len(), state.max_payloads);
199 if decision.accept {
200 records.push(value);
201 }
202 if decision.done {
203 state.done.notify_one();
204 }
205
206 StatusCode::OK
207}
208
209#[async_trait]
210impl faucet_core::Source for WebhookSource {
211 async fn fetch_with_context(
212 &self,
213 context: &std::collections::HashMap<String, serde_json::Value>,
214 ) -> Result<Vec<Value>, FaucetError> {
215 if context.is_empty() {
216 return WebhookSource::fetch_all(self).await;
217 }
218
219 let resolved_path = faucet_core::util::substitute_context(&self.config.path, context);
221
222 let state = self.new_state();
223 let app = self.build_router(&resolved_path, Arc::clone(&state));
224
225 let listener = tokio::net::TcpListener::bind(&self.config.listen_addr)
226 .await
227 .map_err(|e| {
228 FaucetError::Config(format!(
229 "failed to bind to {}: {e}",
230 self.config.listen_addr
231 ))
232 })?;
233
234 tracing::info!(
235 addr = %self.config.listen_addr,
236 path = %resolved_path,
237 "webhook server listening (with context)"
238 );
239
240 let timeout = tokio::time::sleep(std::time::Duration::from_secs(self.config.timeout_secs));
241 let done_notified = state.done.notified();
242
243 tokio::select! {
244 result = axum::serve(listener, app).into_future() => {
245 if let Err(e) = result {
246 return Err(FaucetError::Config(format!("webhook server error: {e}")));
247 }
248 }
249 () = timeout => {
250 tracing::info!("webhook timeout reached");
251 }
252 () = done_notified => {
253 tracing::info!("max payloads reached");
254 }
255 }
256
257 let records = state.records.lock().await.clone();
258 tracing::info!(
259 records = records.len(),
260 "webhook fetch complete (with context)"
261 );
262 Ok(records)
263 }
264
265 fn config_schema(&self) -> serde_json::Value {
266 serde_json::to_value(faucet_core::schema_for!(WebhookSourceConfig))
267 .expect("schema serialization")
268 }
269
270 fn connector_name(&self) -> &'static str {
271 "webhook"
272 }
273
274 async fn check(
283 &self,
284 _ctx: &faucet_core::check::CheckContext,
285 ) -> Result<faucet_core::check::CheckReport, FaucetError> {
286 use faucet_core::check::{CheckReport, Probe};
287
288 let start = std::time::Instant::now();
289 match tokio::net::TcpListener::bind(&self.config.listen_addr).await {
290 Ok(listener) => {
291 drop(listener);
293 Ok(CheckReport::single(Probe::pass("io", start.elapsed())))
294 }
295 Err(e) => Ok(CheckReport::single(Probe::fail_hint(
296 "io",
297 start.elapsed(),
298 e.to_string(),
299 format!("{} is not bindable", self.config.listen_addr),
300 ))),
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use serde_json::json;
309
310 #[test]
311 fn token_matches_accepts_raw_and_bearer() {
312 assert!(token_matches(
313 Some("sekret-token-value"),
314 "sekret-token-value"
315 ));
316 assert!(token_matches(
317 Some("Bearer sekret-token-value"),
318 "sekret-token-value"
319 ));
320 }
321
322 #[test]
323 fn token_matches_rejects_wrong_and_missing() {
324 assert!(!token_matches(
325 Some("wrong-token-value"),
326 "sekret-token-value"
327 ));
328 assert!(!token_matches(
329 Some("Bearer wrong-token-value"),
330 "sekret-token-value"
331 ));
332 assert!(!token_matches(None, "sekret-token-value"));
333 assert!(!token_matches(
335 Some("sekret-token-valu"),
336 "sekret-token-value"
337 ));
338 }
339
340 #[test]
341 fn decide_payload_no_cap_always_accepts() {
342 for len in [0usize, 1, 100, 10_000] {
343 assert_eq!(
344 decide_payload(len, None),
345 PayloadDecision {
346 accept: true,
347 done: false
348 }
349 );
350 }
351 }
352
353 #[test]
354 fn decide_payload_accepts_until_cap_then_drops() {
355 let max = Some(2);
356 assert_eq!(
358 decide_payload(0, max),
359 PayloadDecision {
360 accept: true,
361 done: false
362 }
363 );
364 assert_eq!(
366 decide_payload(1, max),
367 PayloadDecision {
368 accept: true,
369 done: true
370 }
371 );
372 assert_eq!(
374 decide_payload(2, max),
375 PayloadDecision {
376 accept: false,
377 done: true
378 }
379 );
380 assert_eq!(
382 decide_payload(3, max),
383 PayloadDecision {
384 accept: false,
385 done: true
386 }
387 );
388 }
389
390 #[test]
391 fn cap_invariant_never_exceeded_under_concurrent_arrivals() {
392 let max = 2usize;
397 let mut records: Vec<Value> = Vec::new();
398 for i in 0..5 {
402 let decision = decide_payload(records.len(), Some(max));
403 if decision.accept {
404 records.push(json!({ "id": i }));
405 }
406 }
407 assert_eq!(
408 records.len(),
409 max,
410 "Vec must never exceed max_payloads, got {}",
411 records.len()
412 );
413 }
414
415 #[tokio::test]
416 async fn handler_never_exceeds_cap_under_concurrent_posts() {
417 let max = 3usize;
422 let state = Arc::new(AppState {
423 records: Mutex::new(Vec::new()),
424 max_payloads: Some(max),
425 done: Notify::new(),
426 auth_token: None,
427 });
428
429 let mut handles = Vec::new();
430 for i in 0..50 {
431 let st = Arc::clone(&state);
432 handles.push(tokio::spawn(async move {
433 let body = axum::body::Bytes::from(format!("{{\"id\":{i}}}"));
434 webhook_handler(State(st), axum::http::HeaderMap::new(), body).await
435 }));
436 }
437 for h in handles {
438 assert_eq!(h.await.unwrap(), StatusCode::OK);
441 }
442
443 let records = state.records.lock().await;
444 assert_eq!(
445 records.len(),
446 max,
447 "Vec must never exceed max_payloads, got {}",
448 records.len()
449 );
450 }
451
452 #[tokio::test]
453 async fn webhook_collects_payloads() {
454 let config = WebhookSourceConfig::new()
456 .listen_addr("127.0.0.1:0")
457 .max_payloads(2)
458 .timeout_secs(5);
459
460 let state = Arc::new(AppState {
461 records: Mutex::new(Vec::new()),
462 max_payloads: config.max_payloads,
463 done: Notify::new(),
464 auth_token: config.auth_token.clone(),
465 });
466
467 let server_state = Arc::clone(&state);
468 let app = Router::new()
469 .route(&config.path, post(webhook_handler))
470 .with_state(Arc::clone(&state));
471
472 let listener = tokio::net::TcpListener::bind(&config.listen_addr)
473 .await
474 .unwrap();
475 let addr = listener.local_addr().unwrap();
476
477 let server_handle = tokio::spawn(async move {
478 let done_notified = server_state.done.notified();
479 tokio::select! {
480 result = axum::serve(listener, app).into_future() => {
481 if let Err(e) = result {
482 panic!("server error: {e}");
483 }
484 }
485 () = done_notified => {}
486 }
487 });
488
489 let client = reqwest::Client::new();
490 let url = format!("http://{addr}/webhook");
491
492 let resp1 = client
494 .post(&url)
495 .json(&json!({"event": "created", "id": 1}))
496 .send()
497 .await
498 .unwrap();
499 assert_eq!(resp1.status(), 200);
500
501 let resp2 = client
502 .post(&url)
503 .json(&json!({"event": "updated", "id": 2}))
504 .send()
505 .await
506 .unwrap();
507 assert_eq!(resp2.status(), 200);
508
509 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
511 server_handle.abort();
512
513 let records = state.records.lock().await;
514 assert_eq!(records.len(), 2);
515 assert_eq!(records[0]["event"], "created");
516 assert_eq!(records[1]["event"], "updated");
517 }
518
519 #[tokio::test]
520 async fn check_passes_when_port_is_bindable() {
521 use faucet_core::Source;
522 use faucet_core::check::{CheckContext, ProbeStatus};
523
524 let source = WebhookSource::new(WebhookSourceConfig::new().listen_addr("127.0.0.1:0"));
526 let report = source.check(&CheckContext::default()).await.unwrap();
527 assert_eq!(report.probes.len(), 1);
528 assert_eq!(report.probes[0].name, "io");
529 assert!(
530 matches!(report.probes[0].status, ProbeStatus::Pass),
531 "expected Pass, got {:?}",
532 report.probes[0].status
533 );
534 assert_eq!(report.failed_count(), 0);
535 }
536
537 #[tokio::test]
538 async fn check_fails_when_port_is_already_bound() {
539 use faucet_core::Source;
540 use faucet_core::check::{CheckContext, ProbeStatus};
541
542 let held = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
545 let addr = held.local_addr().unwrap();
546
547 let source = WebhookSource::new(WebhookSourceConfig::new().listen_addr(addr.to_string()));
548 let report = source.check(&CheckContext::default()).await.unwrap();
549 assert_eq!(report.probes.len(), 1);
550 assert_eq!(report.probes[0].name, "io");
551 assert!(
552 matches!(report.probes[0].status, ProbeStatus::Fail { .. }),
553 "expected Fail, got {:?}",
554 report.probes[0].status
555 );
556 assert_eq!(report.failed_count(), 1);
557 assert!(
558 report.probes[0]
559 .hint
560 .as_deref()
561 .unwrap()
562 .contains("not bindable")
563 );
564 }
565
566 #[tokio::test]
567 async fn webhook_handles_non_json_body() {
568 let state = Arc::new(AppState {
569 records: Mutex::new(Vec::new()),
570 max_payloads: Some(1),
571 done: Notify::new(),
572 auth_token: None,
573 });
574
575 let server_state = Arc::clone(&state);
576 let app = Router::new()
577 .route("/webhook", post(webhook_handler))
578 .with_state(Arc::clone(&state));
579
580 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
581 let addr = listener.local_addr().unwrap();
582
583 let server_handle = tokio::spawn(async move {
584 let done_notified = server_state.done.notified();
585 tokio::select! {
586 result = axum::serve(listener, app).into_future() => {
587 if let Err(e) = result {
588 panic!("server error: {e}");
589 }
590 }
591 () = done_notified => {}
592 }
593 });
594
595 let client = reqwest::Client::new();
596 let resp = client
597 .post(format!("http://{addr}/webhook"))
598 .body("plain text body")
599 .send()
600 .await
601 .unwrap();
602 assert_eq!(resp.status(), 200);
603
604 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
605 server_handle.abort();
606
607 let records = state.records.lock().await;
608 assert_eq!(records.len(), 1);
609 assert_eq!(records[0], Value::String("plain text body".into()));
610 }
611}