Skip to main content

faucet_source_webhook/
stream.rs

1//! Webhook source stream executor.
2
3use crate::config::WebhookSourceConfig;
4use async_trait::async_trait;
5use axum::{Router, extract::State, http::StatusCode, routing::post};
6use faucet_core::FaucetError;
7use serde_json::Value;
8use std::sync::Arc;
9use subtle::ConstantTimeEq;
10use tokio::sync::{Mutex, Notify};
11
12/// Shared state for the webhook HTTP handler.
13struct AppState {
14    records: Mutex<Vec<Value>>,
15    max_payloads: Option<usize>,
16    done: Notify,
17    /// Optional shared secret required in the `Authorization` header.
18    auth_token: Option<String>,
19}
20
21impl WebhookSource {
22    fn new_state(&self) -> Arc<AppState> {
23        Arc::new(AppState {
24            records: Mutex::new(Vec::new()),
25            max_payloads: self.config.max_payloads,
26            done: Notify::new(),
27            auth_token: self.config.auth_token.clone(),
28        })
29    }
30
31    fn build_router(&self, path: &str, state: Arc<AppState>) -> Router {
32        Router::new()
33            .route(path, post(webhook_handler))
34            // Bound request body size so a single huge POST can't exhaust
35            // memory (#78/#26).
36            .layer(axum::extract::DefaultBodyLimit::max(
37                self.config.max_body_bytes,
38            ))
39            .with_state(state)
40    }
41}
42
43/// A webhook receiver source that starts a temporary HTTP server and
44/// collects incoming POST payloads as records.
45pub struct WebhookSource {
46    config: WebhookSourceConfig,
47}
48
49impl WebhookSource {
50    /// Create a new webhook source from the given configuration.
51    pub fn new(config: WebhookSourceConfig) -> Self {
52        Self { config }
53    }
54
55    /// Start the webhook server, collect payloads, and return them.
56    pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
57        let state = self.new_state();
58        let app = self.build_router(&self.config.path, Arc::clone(&state));
59
60        let listener = tokio::net::TcpListener::bind(&self.config.listen_addr)
61            .await
62            .map_err(|e| {
63                FaucetError::Config(format!(
64                    "failed to bind to {}: {e}",
65                    self.config.listen_addr
66                ))
67            })?;
68
69        tracing::info!(
70            addr = %self.config.listen_addr,
71            path = %self.config.path,
72            "webhook server listening"
73        );
74
75        let timeout = tokio::time::sleep(std::time::Duration::from_secs(self.config.timeout_secs));
76        let done_notified = state.done.notified();
77
78        tokio::select! {
79            result = axum::serve(listener, app).into_future() => {
80                if let Err(e) = result {
81                    return Err(FaucetError::Config(format!("webhook server error: {e}")));
82                }
83            }
84            () = timeout => {
85                tracing::info!("webhook timeout reached");
86            }
87            () = done_notified => {
88                tracing::info!("max payloads reached");
89            }
90        }
91
92        let records = state.records.lock().await.clone();
93        tracing::info!(records = records.len(), "webhook fetch complete");
94        Ok(records)
95    }
96}
97
98/// Constant-time check of an `Authorization` header value against the shared
99/// secret. Accepts either the raw token or a `Bearer <token>` form, matching the
100/// original behaviour. The comparison uses [`subtle::ConstantTimeEq`] and a
101/// non-short-circuiting `|` so neither *which* form matched nor *where* the first
102/// byte differs is observable via response timing — closing a side-channel that
103/// could otherwise leak the secret byte-by-byte. (Length difference short-circuits;
104/// the secret's length is not itself sensitive.)
105fn token_matches(provided: Option<&str>, expected: &str) -> bool {
106    let Some(p) = provided else {
107        return false;
108    };
109    let exp = expected.as_bytes();
110    let raw = bool::from(p.as_bytes().ct_eq(exp));
111    let stripped = p
112        .strip_prefix("Bearer ")
113        .map(|s| bool::from(s.as_bytes().ct_eq(exp)))
114        .unwrap_or(false);
115    raw | stripped
116}
117
118/// Decision for an incoming payload, given the current record count and the
119/// configured cap. Extracted as a pure function so the cap invariant can be
120/// tested deterministically without driving concurrent HTTP.
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122struct PayloadDecision {
123    /// Whether to store the just-received payload.
124    accept: bool,
125    /// Whether the receive loop is now done (cap reached) and `done` should be
126    /// notified.
127    done: bool,
128}
129
130/// Decide whether to store a payload and whether the cap is now satisfied.
131///
132/// `current_len` is the number of records already stored (observed under the
133/// records lock). With no cap, every payload is accepted and the loop never
134/// self-terminates. With a cap, the cap is **exact**: once `current_len` has
135/// reached `max`, further concurrently-accepted POSTs are dropped (so the Vec
136/// never exceeds `max`) while still notifying `done` so a slightly-late request
137/// doesn't wedge the receive loop (#146 LOW).
138fn decide_payload(current_len: usize, max_payloads: Option<usize>) -> PayloadDecision {
139    match max_payloads {
140        None => PayloadDecision {
141            accept: true,
142            done: false,
143        },
144        Some(max) => {
145            if current_len >= max {
146                // Already at (or somehow past) the cap — drop this payload and
147                // signal completion. Pushing here is what let concurrent
148                // in-flight POSTs overflow the Vec past `max`.
149                PayloadDecision {
150                    accept: false,
151                    done: true,
152                }
153            } else {
154                // Room for this one. Accept it; we're done once it fills the
155                // last slot.
156                PayloadDecision {
157                    accept: true,
158                    done: current_len + 1 >= max,
159                }
160            }
161        }
162    }
163}
164
165/// Axum handler for incoming webhook POST requests.
166async fn webhook_handler(
167    State(state): State<Arc<AppState>>,
168    headers: axum::http::HeaderMap,
169    body: axum::body::Bytes,
170) -> StatusCode {
171    // Optional shared-secret check: accept either the raw token or
172    // `Bearer <token>` in the Authorization header (#78/#26).
173    if let Some(expected) = &state.auth_token {
174        let provided = headers
175            .get(axum::http::header::AUTHORIZATION)
176            .and_then(|v| v.to_str().ok());
177        if !token_matches(provided, expected) {
178            return StatusCode::UNAUTHORIZED;
179        }
180    }
181
182    let value = match serde_json::from_slice::<Value>(&body) {
183        Ok(v) => v,
184        Err(_) => {
185            // If the body is not valid JSON, wrap it as a string.
186            match String::from_utf8(body.to_vec()) {
187                Ok(s) => Value::String(s),
188                Err(_) => return StatusCode::BAD_REQUEST,
189            }
190        }
191    };
192
193    let mut records = state.records.lock().await;
194    // Decide under the lock so the cap is exact: concurrent in-flight POSTs
195    // that have already been accepted can't push the Vec past `max_payloads`
196    // — once the cap is reached we drop the surplus payload instead of
197    // storing it (#146 LOW).
198    let decision = decide_payload(records.len(), state.max_payloads);
199    if decision.accept {
200        records.push(value);
201    }
202    if decision.done {
203        state.done.notify_one();
204    }
205
206    StatusCode::OK
207}
208
209#[async_trait]
210impl faucet_core::Source for WebhookSource {
211    async fn fetch_with_context(
212        &self,
213        context: &std::collections::HashMap<String, serde_json::Value>,
214    ) -> Result<Vec<Value>, FaucetError> {
215        if context.is_empty() {
216            return WebhookSource::fetch_all(self).await;
217        }
218
219        // Substitute context into the webhook path.
220        let resolved_path = faucet_core::util::substitute_context(&self.config.path, context);
221
222        let state = self.new_state();
223        let app = self.build_router(&resolved_path, Arc::clone(&state));
224
225        let listener = tokio::net::TcpListener::bind(&self.config.listen_addr)
226            .await
227            .map_err(|e| {
228                FaucetError::Config(format!(
229                    "failed to bind to {}: {e}",
230                    self.config.listen_addr
231                ))
232            })?;
233
234        tracing::info!(
235            addr = %self.config.listen_addr,
236            path = %resolved_path,
237            "webhook server listening (with context)"
238        );
239
240        let timeout = tokio::time::sleep(std::time::Duration::from_secs(self.config.timeout_secs));
241        let done_notified = state.done.notified();
242
243        tokio::select! {
244            result = axum::serve(listener, app).into_future() => {
245                if let Err(e) = result {
246                    return Err(FaucetError::Config(format!("webhook server error: {e}")));
247                }
248            }
249            () = timeout => {
250                tracing::info!("webhook timeout reached");
251            }
252            () = done_notified => {
253                tracing::info!("max payloads reached");
254            }
255        }
256
257        let records = state.records.lock().await.clone();
258        tracing::info!(
259            records = records.len(),
260            "webhook fetch complete (with context)"
261        );
262        Ok(records)
263    }
264
265    fn config_schema(&self) -> serde_json::Value {
266        serde_json::to_value(faucet_core::schema_for!(WebhookSourceConfig))
267            .expect("schema serialization")
268    }
269
270    fn connector_name(&self) -> &'static str {
271        "webhook"
272    }
273
274    /// Preflight probe that does **not** start the receive loop.
275    ///
276    /// The default `Source::check` would call `stream_pages`, which boots the
277    /// HTTP server and blocks for the whole receive window waiting for inbound
278    /// POSTs — useless as a fast preflight. Instead we just verify the
279    /// configured `listen_addr` is bindable: bind a `tokio::net::TcpListener`
280    /// to it and immediately drop it. Success means the port is free; a bind
281    /// error (port in use, permission denied, bad address) fails the probe.
282    async fn check(
283        &self,
284        _ctx: &faucet_core::check::CheckContext,
285    ) -> Result<faucet_core::check::CheckReport, FaucetError> {
286        use faucet_core::check::{CheckReport, Probe};
287
288        let start = std::time::Instant::now();
289        match tokio::net::TcpListener::bind(&self.config.listen_addr).await {
290            Ok(listener) => {
291                // Drop the listener immediately so we don't hold the port.
292                drop(listener);
293                Ok(CheckReport::single(Probe::pass("io", start.elapsed())))
294            }
295            Err(e) => Ok(CheckReport::single(Probe::fail_hint(
296                "io",
297                start.elapsed(),
298                e.to_string(),
299                format!("{} is not bindable", self.config.listen_addr),
300            ))),
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use serde_json::json;
309
310    #[test]
311    fn token_matches_accepts_raw_and_bearer() {
312        assert!(token_matches(
313            Some("sekret-token-value"),
314            "sekret-token-value"
315        ));
316        assert!(token_matches(
317            Some("Bearer sekret-token-value"),
318            "sekret-token-value"
319        ));
320    }
321
322    #[test]
323    fn token_matches_rejects_wrong_and_missing() {
324        assert!(!token_matches(
325            Some("wrong-token-value"),
326            "sekret-token-value"
327        ));
328        assert!(!token_matches(
329            Some("Bearer wrong-token-value"),
330            "sekret-token-value"
331        ));
332        assert!(!token_matches(None, "sekret-token-value"));
333        // Differing length must reject (not panic).
334        assert!(!token_matches(
335            Some("sekret-token-valu"),
336            "sekret-token-value"
337        ));
338    }
339
340    #[test]
341    fn decide_payload_no_cap_always_accepts() {
342        for len in [0usize, 1, 100, 10_000] {
343            assert_eq!(
344                decide_payload(len, None),
345                PayloadDecision {
346                    accept: true,
347                    done: false
348                }
349            );
350        }
351    }
352
353    #[test]
354    fn decide_payload_accepts_until_cap_then_drops() {
355        let max = Some(2);
356        // 0 stored: accept, not yet done.
357        assert_eq!(
358            decide_payload(0, max),
359            PayloadDecision {
360                accept: true,
361                done: false
362            }
363        );
364        // 1 stored: accept the last slot, now done.
365        assert_eq!(
366            decide_payload(1, max),
367            PayloadDecision {
368                accept: true,
369                done: true
370            }
371        );
372        // 2 stored (at cap): drop, signal done.
373        assert_eq!(
374            decide_payload(2, max),
375            PayloadDecision {
376                accept: false,
377                done: true
378            }
379        );
380        // 3 stored (somehow past cap): still drop, still done.
381        assert_eq!(
382            decide_payload(3, max),
383            PayloadDecision {
384                accept: false,
385                done: true
386            }
387        );
388    }
389
390    #[test]
391    fn cap_invariant_never_exceeded_under_concurrent_arrivals() {
392        // Simulate many in-flight POSTs all racing the cap: the same record
393        // count can be observed by several requests before any of them push.
394        // Applying `decide_payload` per request must still bound the Vec at
395        // `max`, dropping the surplus rather than overflowing (#146 LOW).
396        let max = 2usize;
397        let mut records: Vec<Value> = Vec::new();
398        // 5 concurrent arrivals; each makes its decision against the current
399        // length and only pushes if accepted (mirroring the handler under the
400        // records lock).
401        for i in 0..5 {
402            let decision = decide_payload(records.len(), Some(max));
403            if decision.accept {
404                records.push(json!({ "id": i }));
405            }
406        }
407        assert_eq!(
408            records.len(),
409            max,
410            "Vec must never exceed max_payloads, got {}",
411            records.len()
412        );
413    }
414
415    #[tokio::test]
416    async fn handler_never_exceeds_cap_under_concurrent_posts() {
417        // Drive the real handler with many concurrent requests against one
418        // shared AppState. The records Vec must never exceed `max_payloads`,
419        // even though several requests are accepted in-flight before any of
420        // them observe the cap being hit (#146 LOW).
421        let max = 3usize;
422        let state = Arc::new(AppState {
423            records: Mutex::new(Vec::new()),
424            max_payloads: Some(max),
425            done: Notify::new(),
426            auth_token: None,
427        });
428
429        let mut handles = Vec::new();
430        for i in 0..50 {
431            let st = Arc::clone(&state);
432            handles.push(tokio::spawn(async move {
433                let body = axum::body::Bytes::from(format!("{{\"id\":{i}}}"));
434                webhook_handler(State(st), axum::http::HeaderMap::new(), body).await
435            }));
436        }
437        for h in handles {
438            // Every request returns 200 (accepted or gracefully dropped — we
439            // don't surface a different status for drops to keep clients happy).
440            assert_eq!(h.await.unwrap(), StatusCode::OK);
441        }
442
443        let records = state.records.lock().await;
444        assert_eq!(
445            records.len(),
446            max,
447            "Vec must never exceed max_payloads, got {}",
448            records.len()
449        );
450    }
451
452    #[tokio::test]
453    async fn webhook_collects_payloads() {
454        // Use port 0 to get a random available port.
455        let config = WebhookSourceConfig::new()
456            .listen_addr("127.0.0.1:0")
457            .max_payloads(2)
458            .timeout_secs(5);
459
460        let state = Arc::new(AppState {
461            records: Mutex::new(Vec::new()),
462            max_payloads: config.max_payloads,
463            done: Notify::new(),
464            auth_token: config.auth_token.clone(),
465        });
466
467        let server_state = Arc::clone(&state);
468        let app = Router::new()
469            .route(&config.path, post(webhook_handler))
470            .with_state(Arc::clone(&state));
471
472        let listener = tokio::net::TcpListener::bind(&config.listen_addr)
473            .await
474            .unwrap();
475        let addr = listener.local_addr().unwrap();
476
477        let server_handle = tokio::spawn(async move {
478            let done_notified = server_state.done.notified();
479            tokio::select! {
480                result = axum::serve(listener, app).into_future() => {
481                    if let Err(e) = result {
482                        panic!("server error: {e}");
483                    }
484                }
485                () = done_notified => {}
486            }
487        });
488
489        let client = reqwest::Client::new();
490        let url = format!("http://{addr}/webhook");
491
492        // Send two payloads.
493        let resp1 = client
494            .post(&url)
495            .json(&json!({"event": "created", "id": 1}))
496            .send()
497            .await
498            .unwrap();
499        assert_eq!(resp1.status(), 200);
500
501        let resp2 = client
502            .post(&url)
503            .json(&json!({"event": "updated", "id": 2}))
504            .send()
505            .await
506            .unwrap();
507        assert_eq!(resp2.status(), 200);
508
509        // Wait for the server to notice max_payloads.
510        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
511        server_handle.abort();
512
513        let records = state.records.lock().await;
514        assert_eq!(records.len(), 2);
515        assert_eq!(records[0]["event"], "created");
516        assert_eq!(records[1]["event"], "updated");
517    }
518
519    #[tokio::test]
520    async fn check_passes_when_port_is_bindable() {
521        use faucet_core::Source;
522        use faucet_core::check::{CheckContext, ProbeStatus};
523
524        // Port 0 = let the OS pick a free port, so the bind always succeeds.
525        let source = WebhookSource::new(WebhookSourceConfig::new().listen_addr("127.0.0.1:0"));
526        let report = source.check(&CheckContext::default()).await.unwrap();
527        assert_eq!(report.probes.len(), 1);
528        assert_eq!(report.probes[0].name, "io");
529        assert!(
530            matches!(report.probes[0].status, ProbeStatus::Pass),
531            "expected Pass, got {:?}",
532            report.probes[0].status
533        );
534        assert_eq!(report.failed_count(), 0);
535    }
536
537    #[tokio::test]
538    async fn check_fails_when_port_is_already_bound() {
539        use faucet_core::Source;
540        use faucet_core::check::{CheckContext, ProbeStatus};
541
542        // Hold a real listener, then point the source at the same address so
543        // the probe's bind collides.
544        let held = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
545        let addr = held.local_addr().unwrap();
546
547        let source = WebhookSource::new(WebhookSourceConfig::new().listen_addr(addr.to_string()));
548        let report = source.check(&CheckContext::default()).await.unwrap();
549        assert_eq!(report.probes.len(), 1);
550        assert_eq!(report.probes[0].name, "io");
551        assert!(
552            matches!(report.probes[0].status, ProbeStatus::Fail { .. }),
553            "expected Fail, got {:?}",
554            report.probes[0].status
555        );
556        assert_eq!(report.failed_count(), 1);
557        assert!(
558            report.probes[0]
559                .hint
560                .as_deref()
561                .unwrap()
562                .contains("not bindable")
563        );
564    }
565
566    #[tokio::test]
567    async fn webhook_handles_non_json_body() {
568        let state = Arc::new(AppState {
569            records: Mutex::new(Vec::new()),
570            max_payloads: Some(1),
571            done: Notify::new(),
572            auth_token: None,
573        });
574
575        let server_state = Arc::clone(&state);
576        let app = Router::new()
577            .route("/webhook", post(webhook_handler))
578            .with_state(Arc::clone(&state));
579
580        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
581        let addr = listener.local_addr().unwrap();
582
583        let server_handle = tokio::spawn(async move {
584            let done_notified = server_state.done.notified();
585            tokio::select! {
586                result = axum::serve(listener, app).into_future() => {
587                    if let Err(e) = result {
588                        panic!("server error: {e}");
589                    }
590                }
591                () = done_notified => {}
592            }
593        });
594
595        let client = reqwest::Client::new();
596        let resp = client
597            .post(format!("http://{addr}/webhook"))
598            .body("plain text body")
599            .send()
600            .await
601            .unwrap();
602        assert_eq!(resp.status(), 200);
603
604        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
605        server_handle.abort();
606
607        let records = state.records.lock().await;
608        assert_eq!(records.len(), 1);
609        assert_eq!(records[0], Value::String("plain text body".into()));
610    }
611}