Skip to main content

forge_runtime/webhook/
handler.rs

1//! Axum handler for webhook requests with signature validation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use axum::{
7    Json,
8    body::Bytes,
9    extract::{Path, State},
10    http::{HeaderMap, StatusCode},
11    response::IntoResponse,
12};
13use forge_core::CircuitBreakerClient;
14use forge_core::function::JobDispatch;
15use forge_core::webhook::{IdempotencySource, SignatureAlgorithm, WebhookContext};
16use hmac::{Hmac, Mac};
17use serde_json::{Value, json};
18use sha1::Sha1;
19use sha2::{Sha256, Sha512};
20use sqlx::PgPool;
21use tracing::{error, info, warn};
22use uuid::Uuid;
23
24use super::registry::WebhookRegistry;
25
26/// State for webhook handler.
27#[derive(Clone)]
28pub struct WebhookState {
29    registry: Arc<WebhookRegistry>,
30    pool: PgPool,
31    http_client: CircuitBreakerClient,
32    job_dispatcher: Option<Arc<dyn JobDispatch>>,
33}
34
35impl WebhookState {
36    /// Create new webhook state.
37    pub fn new(registry: Arc<WebhookRegistry>, pool: PgPool) -> Self {
38        Self {
39            registry,
40            pool,
41            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
42            job_dispatcher: None,
43        }
44    }
45
46    /// Set job dispatcher.
47    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
48        self.job_dispatcher = Some(dispatcher);
49        self
50    }
51}
52
53/// Handle webhook requests.
54///
55/// This handler:
56/// 1. Looks up webhook by path
57/// 2. Validates signature if configured
58/// 3. Checks idempotency
59/// 4. Executes handler
60/// 5. Records idempotency key
61pub async fn webhook_handler(
62    State(state): State<Arc<WebhookState>>,
63    Path(path): Path<String>,
64    headers: HeaderMap,
65    body: Bytes,
66) -> impl IntoResponse {
67    let full_path = format!("/webhooks/{}", path);
68    let request_id = Uuid::new_v4().to_string();
69
70    // Look up webhook by path
71    let entry = match state.registry.get_by_path(&full_path) {
72        Some(e) => e,
73        None => {
74            warn!(path = %full_path, "Webhook not found");
75            return (
76                StatusCode::NOT_FOUND,
77                Json(json!({"error": "Webhook not found"})),
78            );
79        }
80    };
81
82    let info = &entry.info;
83    info!(
84        webhook = info.name,
85        path = %full_path,
86        request_id = %request_id,
87        "Webhook request received"
88    );
89
90    // Validate signature if configured
91    if let Some(ref sig_config) = info.signature {
92        // Get signature from header
93        let signature = match headers
94            .get(sig_config.header_name)
95            .and_then(|v| v.to_str().ok())
96        {
97            Some(s) => s,
98            None => {
99                warn!(webhook = info.name, "Missing signature header");
100                return (
101                    StatusCode::UNAUTHORIZED,
102                    Json(json!({"error": "Missing signature"})),
103                );
104            }
105        };
106
107        // Get secret from environment
108        let secret = match std::env::var(sig_config.secret_env) {
109            Ok(s) => s,
110            Err(_) => {
111                error!(
112                    webhook = info.name,
113                    env = sig_config.secret_env,
114                    "Webhook secret not configured"
115                );
116                return (
117                    StatusCode::INTERNAL_SERVER_ERROR,
118                    Json(json!({"error": "Webhook configuration error"})),
119                );
120            }
121        };
122
123        // Validate signature
124        if !validate_signature(sig_config.algorithm, &body, &secret, signature) {
125            warn!(webhook = info.name, "Invalid signature");
126            return (
127                StatusCode::UNAUTHORIZED,
128                Json(json!({"error": "Invalid signature"})),
129            );
130        }
131    }
132
133    // Extract idempotency key if configured
134    let idempotency_key = if let Some(ref idem_config) = info.idempotency {
135        match &idem_config.source {
136            IdempotencySource::Header(header_name) => headers
137                .get(*header_name)
138                .and_then(|v| v.to_str().ok())
139                .map(|s| s.to_string()),
140            IdempotencySource::Body(json_path) => {
141                // Parse body and extract value using JSON path
142                if let Ok(payload) = serde_json::from_slice::<Value>(&body) {
143                    extract_json_path(&payload, json_path)
144                } else {
145                    None
146                }
147            }
148        }
149    } else {
150        None
151    };
152
153    // Check idempotency
154    if let Some(ref key) = idempotency_key {
155        match check_idempotency(&state.pool, info.name, key).await {
156            Ok(true) => {
157                info!(
158                    webhook = info.name,
159                    idempotency_key = %key,
160                    "Request already processed (idempotent)"
161                );
162                return (StatusCode::OK, Json(json!({"status": "already_processed"})));
163            }
164            Ok(false) => {}
165            Err(e) => {
166                warn!(webhook = info.name, error = %e, "Failed to check idempotency");
167            }
168        }
169    }
170
171    // Parse payload
172    let payload: Value = match serde_json::from_slice(&body) {
173        Ok(v) => v,
174        Err(e) => {
175            warn!(webhook = info.name, error = %e, "Invalid JSON payload");
176            return (
177                StatusCode::BAD_REQUEST,
178                Json(json!({"error": "Invalid JSON"})),
179            );
180        }
181    };
182
183    // Build headers map (lowercase keys)
184    let header_map: HashMap<String, String> = headers
185        .iter()
186        .filter_map(|(k, v)| {
187            v.to_str()
188                .ok()
189                .map(|v| (k.as_str().to_lowercase(), v.to_string()))
190        })
191        .collect();
192
193    // Create context
194    let mut ctx = WebhookContext::new(
195        info.name.to_string(),
196        request_id.clone(),
197        header_map,
198        state.pool.clone(),
199        state.http_client.inner().clone(),
200    )
201    .with_idempotency_key(idempotency_key.clone());
202
203    if let Some(ref dispatcher) = state.job_dispatcher {
204        ctx = ctx.with_job_dispatch(dispatcher.clone());
205    }
206
207    // Execute handler with timeout
208    let result = tokio::time::timeout(info.timeout, (entry.handler)(&ctx, payload)).await;
209
210    match result {
211        Ok(Ok(webhook_result)) => {
212            // Record idempotency on success
213            if let Some(ref key) = idempotency_key {
214                if let Some(ref idem_config) = info.idempotency {
215                    if let Err(e) =
216                        record_idempotency(&state.pool, info.name, key, idem_config.ttl).await
217                    {
218                        warn!(webhook = info.name, error = %e, "Failed to record idempotency");
219                    }
220                }
221            }
222
223            let status =
224                StatusCode::from_u16(webhook_result.status_code()).unwrap_or(StatusCode::OK);
225            (status, Json(webhook_result.body()))
226        }
227        Ok(Err(e)) => {
228            error!(webhook = info.name, error = %e, "Webhook handler error");
229            (
230                StatusCode::INTERNAL_SERVER_ERROR,
231                Json(json!({"error": e.to_string()})),
232            )
233        }
234        Err(_) => {
235            error!(
236                webhook = info.name,
237                timeout = ?info.timeout,
238                "Webhook handler timed out"
239            );
240            (
241                StatusCode::GATEWAY_TIMEOUT,
242                Json(json!({"error": "Request timeout"})),
243            )
244        }
245    }
246}
247
248/// Validate HMAC signature.
249fn validate_signature(
250    algorithm: SignatureAlgorithm,
251    body: &[u8],
252    secret: &str,
253    signature: &str,
254) -> bool {
255    // Strip algorithm prefix if present (e.g., "sha256=")
256    let sig_hex = signature
257        .strip_prefix(algorithm.prefix())
258        .unwrap_or(signature);
259
260    // Decode expected signature from hex
261    let expected = match hex::decode(sig_hex) {
262        Ok(b) => b,
263        Err(_) => return false,
264    };
265
266    match algorithm {
267        SignatureAlgorithm::HmacSha256 => {
268            let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
269                .expect("HMAC can take key of any size");
270            mac.update(body);
271            mac.verify_slice(&expected).is_ok()
272        }
273        SignatureAlgorithm::HmacSha1 => {
274            let mut mac = Hmac::<Sha1>::new_from_slice(secret.as_bytes())
275                .expect("HMAC can take key of any size");
276            mac.update(body);
277            mac.verify_slice(&expected).is_ok()
278        }
279        SignatureAlgorithm::HmacSha512 => {
280            let mut mac = Hmac::<Sha512>::new_from_slice(secret.as_bytes())
281                .expect("HMAC can take key of any size");
282            mac.update(body);
283            mac.verify_slice(&expected).is_ok()
284        }
285    }
286}
287
288/// Extract value from JSON using a simple path (e.g., "$.id" or "$.data.id").
289fn extract_json_path(value: &Value, path: &str) -> Option<String> {
290    let path = path.strip_prefix("$.").unwrap_or(path);
291    let parts: Vec<&str> = path.split('.').collect();
292
293    let mut current = value;
294    for part in parts {
295        current = current.get(part)?;
296    }
297
298    match current {
299        Value::String(s) => Some(s.clone()),
300        Value::Number(n) => Some(n.to_string()),
301        _ => Some(current.to_string()),
302    }
303}
304
305/// Check if idempotency key was already processed.
306async fn check_idempotency(
307    pool: &PgPool,
308    webhook_name: &str,
309    key: &str,
310) -> Result<bool, sqlx::Error> {
311    let result: Option<(i32,)> = sqlx::query_as(
312        r#"
313        SELECT 1 FROM forge_webhook_events
314        WHERE webhook_name = $1 AND idempotency_key = $2 AND expires_at > NOW()
315        "#,
316    )
317    .bind(webhook_name)
318    .bind(key)
319    .fetch_optional(pool)
320    .await?;
321
322    Ok(result.is_some())
323}
324
325/// Record idempotency key after successful processing.
326async fn record_idempotency(
327    pool: &PgPool,
328    webhook_name: &str,
329    key: &str,
330    ttl: std::time::Duration,
331) -> Result<(), sqlx::Error> {
332    let expires_at =
333        chrono::Utc::now() + chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::hours(24));
334
335    sqlx::query(
336        r#"
337        INSERT INTO forge_webhook_events (idempotency_key, webhook_name, processed_at, expires_at)
338        VALUES ($1, $2, NOW(), $3)
339        ON CONFLICT (idempotency_key) DO NOTHING
340        "#,
341    )
342    .bind(key)
343    .bind(webhook_name)
344    .bind(expires_at)
345    .execute(pool)
346    .await?;
347
348    Ok(())
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_extract_json_path_simple() {
357        let value = json!({"id": "test-123"});
358        assert_eq!(
359            extract_json_path(&value, "$.id"),
360            Some("test-123".to_string())
361        );
362    }
363
364    #[test]
365    fn test_extract_json_path_nested() {
366        let value = json!({"data": {"id": "nested-456"}});
367        assert_eq!(
368            extract_json_path(&value, "$.data.id"),
369            Some("nested-456".to_string())
370        );
371    }
372
373    #[test]
374    fn test_extract_json_path_number() {
375        let value = json!({"count": 42});
376        assert_eq!(extract_json_path(&value, "$.count"), Some("42".to_string()));
377    }
378
379    #[test]
380    fn test_extract_json_path_missing() {
381        let value = json!({"other": "value"});
382        assert_eq!(extract_json_path(&value, "$.id"), None);
383    }
384
385    #[test]
386    fn test_validate_signature_sha256() {
387        use hmac::{Hmac, Mac};
388        use sha2::Sha256;
389
390        let body = b"test payload";
391        let secret = "test_secret";
392
393        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
394        mac.update(body);
395        let signature = hex::encode(mac.finalize().into_bytes());
396
397        assert!(validate_signature(
398            SignatureAlgorithm::HmacSha256,
399            body,
400            secret,
401            &signature
402        ));
403
404        // With prefix
405        let sig_with_prefix = format!("sha256={}", signature);
406        assert!(validate_signature(
407            SignatureAlgorithm::HmacSha256,
408            body,
409            secret,
410            &sig_with_prefix
411        ));
412    }
413
414    #[test]
415    fn test_validate_signature_invalid() {
416        assert!(!validate_signature(
417            SignatureAlgorithm::HmacSha256,
418            b"test",
419            "secret",
420            "invalid_hex"
421        ));
422
423        assert!(!validate_signature(
424            SignatureAlgorithm::HmacSha256,
425            b"test",
426            "secret",
427            "0000000000000000000000000000000000000000000000000000000000000000"
428        ));
429    }
430}