Skip to main content

forge_runtime/webhook/
handler.rs

1//! Axum handler for webhook requests with signature validation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use axum::{
7    Json,
8    body::Bytes,
9    extract::{Path, State},
10    http::{HeaderMap, StatusCode},
11    response::IntoResponse,
12};
13use forge_core::CircuitBreakerClient;
14use forge_core::function::JobDispatch;
15use forge_core::webhook::{IdempotencySource, SignatureAlgorithm, WebhookContext};
16use hmac::{Hmac, Mac};
17use serde_json::{Value, json};
18use sha1::Sha1;
19use sha2::{Sha256, Sha512};
20use sqlx::PgPool;
21use tracing::{error, info, warn};
22use uuid::Uuid;
23
24use super::registry::WebhookRegistry;
25
26/// State for webhook handler.
27#[derive(Clone)]
28pub struct WebhookState {
29    registry: Arc<WebhookRegistry>,
30    pool: PgPool,
31    http_client: CircuitBreakerClient,
32    job_dispatcher: Option<Arc<dyn JobDispatch>>,
33}
34
35impl WebhookState {
36    /// Create new webhook state.
37    pub fn new(registry: Arc<WebhookRegistry>, pool: PgPool) -> Self {
38        Self {
39            registry,
40            pool,
41            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
42            job_dispatcher: None,
43        }
44    }
45
46    /// Set job dispatcher.
47    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
48        self.job_dispatcher = Some(dispatcher);
49        self
50    }
51}
52
53/// Handle webhook requests.
54///
55/// This handler:
56/// 1. Looks up webhook by path
57/// 2. Validates signature if configured
58/// 3. Checks idempotency
59/// 4. Executes handler
60/// 5. Records idempotency key
61pub async fn webhook_handler(
62    State(state): State<Arc<WebhookState>>,
63    Path(path): Path<String>,
64    headers: HeaderMap,
65    body: Bytes,
66) -> impl IntoResponse {
67    let full_path = format!("/webhooks/{}", path);
68    let request_id = Uuid::new_v4().to_string();
69
70    // Look up webhook by path
71    let entry = match state.registry.get_by_path(&full_path) {
72        Some(e) => e,
73        None => {
74            warn!(path = %full_path, "Webhook not found");
75            return (
76                StatusCode::NOT_FOUND,
77                Json(json!({"error": "Webhook not found"})),
78            );
79        }
80    };
81
82    let info = &entry.info;
83    info!(
84        webhook = info.name,
85        path = %full_path,
86        request_id = %request_id,
87        "Webhook request received"
88    );
89
90    if info.signature.is_none() && !info.allow_unsigned {
91        warn!(
92            webhook = info.name,
93            "Unsigned webhook rejected (set allow_unsigned to opt in)"
94        );
95        return (
96            StatusCode::UNAUTHORIZED,
97            Json(json!({"error": "Webhook signature is required"})),
98        );
99    }
100
101    // Validate signature if configured
102    if let Some(ref sig_config) = info.signature {
103        // Get signature from header
104        let signature = match headers
105            .get(sig_config.header_name)
106            .and_then(|v| v.to_str().ok())
107        {
108            Some(s) => s,
109            None => {
110                warn!(webhook = info.name, "Missing signature header");
111                return (
112                    StatusCode::UNAUTHORIZED,
113                    Json(json!({"error": "Missing signature"})),
114                );
115            }
116        };
117
118        // Get secret from environment
119        let secret = match std::env::var(sig_config.secret_env) {
120            Ok(s) => s,
121            Err(_) => {
122                error!(
123                    webhook = info.name,
124                    env = sig_config.secret_env,
125                    "Webhook secret not configured"
126                );
127                return (
128                    StatusCode::INTERNAL_SERVER_ERROR,
129                    Json(json!({"error": "Webhook configuration error"})),
130                );
131            }
132        };
133
134        // Validate signature
135        if !validate_signature(sig_config.algorithm, &body, &secret, signature) {
136            warn!(webhook = info.name, "Invalid signature");
137            return (
138                StatusCode::UNAUTHORIZED,
139                Json(json!({"error": "Invalid signature"})),
140            );
141        }
142    }
143
144    // Extract idempotency key if configured
145    let idempotency_key = if let Some(ref idem_config) = info.idempotency {
146        match &idem_config.source {
147            IdempotencySource::Header(header_name) => headers
148                .get(*header_name)
149                .and_then(|v| v.to_str().ok())
150                .map(|s| s.to_string()),
151            IdempotencySource::Body(json_path) => {
152                // Parse body and extract value using JSON path
153                if let Ok(payload) = serde_json::from_slice::<Value>(&body) {
154                    extract_json_path(&payload, json_path)
155                } else {
156                    None
157                }
158            }
159        }
160    } else {
161        None
162    };
163
164    // Atomically claim idempotency key before execution.
165    let mut idempotency_claimed = false;
166    if let Some(ref key) = idempotency_key
167        && let Some(ref idem_config) = info.idempotency
168    {
169        match claim_idempotency(&state.pool, info.name, key, idem_config.ttl).await {
170            Ok(true) => {
171                idempotency_claimed = true;
172            }
173            Ok(false) => {
174                info!(
175                    webhook = info.name,
176                    idempotency_key = %key,
177                    "Request already processed (idempotent)"
178                );
179                return (StatusCode::OK, Json(json!({"status": "already_processed"})));
180            }
181            Err(e) => {
182                // Fail closed: if idempotency is configured but the DB is unavailable,
183                // reject the request rather than processing without replay protection
184                error!(webhook = info.name, error = %e, "Failed to claim idempotency key -- rejecting request");
185                return (
186                    StatusCode::SERVICE_UNAVAILABLE,
187                    Json(json!({"error": "Service temporarily unavailable"})),
188                );
189            }
190        }
191    }
192
193    // Parse payload
194    let payload: Value = match serde_json::from_slice(&body) {
195        Ok(v) => v,
196        Err(e) => {
197            if idempotency_claimed
198                && let Some(ref key) = idempotency_key
199                && let Err(release_err) = release_idempotency(&state.pool, info.name, key).await
200            {
201                warn!(
202                    webhook = info.name,
203                    error = %release_err,
204                    "Failed to release idempotency key after JSON parse failure"
205                );
206            }
207            warn!(webhook = info.name, error = %e, "Invalid JSON payload");
208            return (
209                StatusCode::BAD_REQUEST,
210                Json(json!({"error": "Invalid JSON"})),
211            );
212        }
213    };
214
215    // Build headers map (lowercase keys)
216    let header_map: HashMap<String, String> = headers
217        .iter()
218        .filter_map(|(k, v)| {
219            v.to_str()
220                .ok()
221                .map(|v| (k.as_str().to_lowercase(), v.to_string()))
222        })
223        .collect();
224
225    // Create context
226    let mut ctx = WebhookContext::new(
227        info.name.to_string(),
228        request_id.clone(),
229        header_map,
230        state.pool.clone(),
231        state.http_client.clone(),
232    )
233    .with_idempotency_key(idempotency_key.clone());
234    ctx.set_http_timeout(info.http_timeout);
235
236    if let Some(ref dispatcher) = state.job_dispatcher {
237        ctx = ctx.with_job_dispatch(dispatcher.clone());
238    }
239
240    // Execute handler with timeout
241    let result = tokio::time::timeout(info.timeout, (entry.handler)(&ctx, payload)).await;
242
243    match result {
244        Ok(Ok(webhook_result)) => {
245            let status =
246                StatusCode::from_u16(webhook_result.status_code()).unwrap_or(StatusCode::OK);
247            (status, Json(webhook_result.body()))
248        }
249        Ok(Err(e)) => {
250            if idempotency_claimed
251                && let Some(ref key) = idempotency_key
252                && let Err(release_err) = release_idempotency(&state.pool, info.name, key).await
253            {
254                warn!(
255                    webhook = info.name,
256                    error = %release_err,
257                    "Failed to release idempotency key after handler error"
258                );
259            }
260            error!(webhook = info.name, error = %e, "Webhook handler error");
261            (
262                StatusCode::INTERNAL_SERVER_ERROR,
263                Json(json!({"error": "Internal server error", "request_id": request_id})),
264            )
265        }
266        Err(_) => {
267            if idempotency_claimed
268                && let Some(ref key) = idempotency_key
269                && let Err(release_err) = release_idempotency(&state.pool, info.name, key).await
270            {
271                warn!(
272                    webhook = info.name,
273                    error = %release_err,
274                    "Failed to release idempotency key after timeout"
275                );
276            }
277            error!(
278                webhook = info.name,
279                timeout = ?info.timeout,
280                "Webhook handler timed out"
281            );
282            (
283                StatusCode::GATEWAY_TIMEOUT,
284                Json(json!({"error": "Request timeout"})),
285            )
286        }
287    }
288}
289
290/// Validate HMAC signature.
291fn validate_signature(
292    algorithm: SignatureAlgorithm,
293    body: &[u8],
294    secret: &str,
295    signature: &str,
296) -> bool {
297    // Strip algorithm prefix if present (e.g., "sha256=")
298    let sig_hex = signature
299        .strip_prefix(algorithm.prefix())
300        .unwrap_or(signature);
301
302    // Decode expected signature from hex
303    let expected = match decode_hex(sig_hex) {
304        Some(b) => b,
305        None => return false,
306    };
307
308    match algorithm {
309        SignatureAlgorithm::HmacSha256 => {
310            let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
311                .expect("HMAC can take key of any size");
312            mac.update(body);
313            mac.verify_slice(&expected).is_ok()
314        }
315        SignatureAlgorithm::HmacSha1 => {
316            let mut mac = Hmac::<Sha1>::new_from_slice(secret.as_bytes())
317                .expect("HMAC can take key of any size");
318            mac.update(body);
319            mac.verify_slice(&expected).is_ok()
320        }
321        SignatureAlgorithm::HmacSha512 => {
322            let mut mac = Hmac::<Sha512>::new_from_slice(secret.as_bytes())
323                .expect("HMAC can take key of any size");
324            mac.update(body);
325            mac.verify_slice(&expected).is_ok()
326        }
327    }
328}
329
330fn decode_hex(s: &str) -> Option<Vec<u8>> {
331    if !s.len().is_multiple_of(2) {
332        return None;
333    }
334    (0..s.len())
335        .step_by(2)
336        .map(|i| u8::from_str_radix(s.get(i..i + 2)?, 16).ok())
337        .collect()
338}
339
340/// Extract value from JSON using a simple path (e.g., "$.id" or "$.data.id").
341fn extract_json_path(value: &Value, path: &str) -> Option<String> {
342    let path = path.strip_prefix("$.").unwrap_or(path);
343    let parts: Vec<&str> = path.split('.').collect();
344
345    let mut current = value;
346    for part in parts {
347        current = current.get(part)?;
348    }
349
350    match current {
351        Value::String(s) => Some(s.clone()),
352        Value::Number(n) => Some(n.to_string()),
353        _ => Some(current.to_string()),
354    }
355}
356
357/// Atomically claim idempotency key before processing.
358///
359/// Returns:
360/// - `Ok(true)` if this request acquired the claim
361/// - `Ok(false)` if key is already active
362async fn claim_idempotency(
363    pool: &PgPool,
364    webhook_name: &str,
365    key: &str,
366    ttl: std::time::Duration,
367) -> Result<bool, sqlx::Error> {
368    let expires_at =
369        chrono::Utc::now() + chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::hours(24));
370
371    let result = sqlx::query!(
372        r#"
373        INSERT INTO forge_webhook_events (idempotency_key, webhook_name, processed_at, expires_at)
374        VALUES ($1, $2, NOW(), $3)
375        ON CONFLICT (webhook_name, idempotency_key) DO UPDATE
376            SET processed_at = EXCLUDED.processed_at,
377                expires_at = EXCLUDED.expires_at
378        WHERE forge_webhook_events.expires_at < NOW()
379        "#,
380        key,
381        webhook_name,
382        expires_at,
383    )
384    .execute(pool)
385    .await?;
386
387    Ok(result.rows_affected() > 0)
388}
389
390/// Release idempotency key after failure so retries can proceed.
391async fn release_idempotency(
392    pool: &PgPool,
393    webhook_name: &str,
394    key: &str,
395) -> Result<(), sqlx::Error> {
396    sqlx::query!(
397        r#"
398        DELETE FROM forge_webhook_events
399        WHERE webhook_name = $1 AND idempotency_key = $2
400        "#,
401        webhook_name,
402        key,
403    )
404    .execute(pool)
405    .await?;
406
407    Ok(())
408}
409
410#[cfg(test)]
411#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
412mod tests {
413    use super::*;
414
415    fn encode_hex(bytes: &[u8]) -> String {
416        bytes
417            .iter()
418            .fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
419                use std::fmt::Write;
420                let _ = write!(s, "{b:02x}");
421                s
422            })
423    }
424
425    #[test]
426    fn test_extract_json_path_simple() {
427        let value = json!({"id": "test-123"});
428        assert_eq!(
429            extract_json_path(&value, "$.id"),
430            Some("test-123".to_string())
431        );
432    }
433
434    #[test]
435    fn test_extract_json_path_nested() {
436        let value = json!({"data": {"id": "nested-456"}});
437        assert_eq!(
438            extract_json_path(&value, "$.data.id"),
439            Some("nested-456".to_string())
440        );
441    }
442
443    #[test]
444    fn test_extract_json_path_number() {
445        let value = json!({"count": 42});
446        assert_eq!(extract_json_path(&value, "$.count"), Some("42".to_string()));
447    }
448
449    #[test]
450    fn test_extract_json_path_missing() {
451        let value = json!({"other": "value"});
452        assert_eq!(extract_json_path(&value, "$.id"), None);
453    }
454
455    #[test]
456    fn test_validate_signature_sha256() {
457        use hmac::{Hmac, Mac};
458        use sha2::Sha256;
459
460        let body = b"test payload";
461        let secret = "test_secret";
462
463        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
464        mac.update(body);
465        let signature = encode_hex(&mac.finalize().into_bytes());
466
467        assert!(validate_signature(
468            SignatureAlgorithm::HmacSha256,
469            body,
470            secret,
471            &signature
472        ));
473
474        // With prefix
475        let sig_with_prefix = format!("sha256={}", signature);
476        assert!(validate_signature(
477            SignatureAlgorithm::HmacSha256,
478            body,
479            secret,
480            &sig_with_prefix
481        ));
482    }
483
484    #[test]
485    fn test_validate_signature_invalid() {
486        assert!(!validate_signature(
487            SignatureAlgorithm::HmacSha256,
488            b"test",
489            "secret",
490            "invalid_hex"
491        ));
492
493        assert!(!validate_signature(
494            SignatureAlgorithm::HmacSha256,
495            b"test",
496            "secret",
497            "0000000000000000000000000000000000000000000000000000000000000000"
498        ));
499    }
500}