1use std::collections::HashMap;
4use std::sync::Arc;
5
6use axum::{
7 Json,
8 body::Bytes,
9 extract::{Path, State},
10 http::{HeaderMap, StatusCode},
11 response::IntoResponse,
12};
13use forge_core::CircuitBreakerClient;
14use forge_core::function::JobDispatch;
15use forge_core::webhook::{IdempotencySource, SignatureAlgorithm, WebhookContext};
16use hmac::{Hmac, Mac};
17use serde_json::{Value, json};
18use sha1::Sha1;
19use sha2::{Sha256, Sha512};
20use sqlx::PgPool;
21use tracing::{error, info, warn};
22use uuid::Uuid;
23
24use super::registry::WebhookRegistry;
25
26#[derive(Clone)]
28pub struct WebhookState {
29 registry: Arc<WebhookRegistry>,
30 pool: PgPool,
31 http_client: CircuitBreakerClient,
32 job_dispatcher: Option<Arc<dyn JobDispatch>>,
33}
34
35impl WebhookState {
36 pub fn new(registry: Arc<WebhookRegistry>, pool: PgPool) -> Self {
38 Self {
39 registry,
40 pool,
41 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
42 job_dispatcher: None,
43 }
44 }
45
46 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
48 self.job_dispatcher = Some(dispatcher);
49 self
50 }
51}
52
53pub async fn webhook_handler(
62 State(state): State<Arc<WebhookState>>,
63 Path(path): Path<String>,
64 headers: HeaderMap,
65 body: Bytes,
66) -> impl IntoResponse {
67 let full_path = format!("/webhooks/{}", path);
68 let request_id = Uuid::new_v4().to_string();
69
70 let entry = match state.registry.get_by_path(&full_path) {
72 Some(e) => e,
73 None => {
74 warn!(path = %full_path, "Webhook not found");
75 return (
76 StatusCode::NOT_FOUND,
77 Json(json!({"error": "Webhook not found"})),
78 );
79 }
80 };
81
82 let info = &entry.info;
83 info!(
84 webhook = info.name,
85 path = %full_path,
86 request_id = %request_id,
87 "Webhook request received"
88 );
89
90 if let Some(ref sig_config) = info.signature {
92 let signature = match headers
94 .get(sig_config.header_name)
95 .and_then(|v| v.to_str().ok())
96 {
97 Some(s) => s,
98 None => {
99 warn!(webhook = info.name, "Missing signature header");
100 return (
101 StatusCode::UNAUTHORIZED,
102 Json(json!({"error": "Missing signature"})),
103 );
104 }
105 };
106
107 let secret = match std::env::var(sig_config.secret_env) {
109 Ok(s) => s,
110 Err(_) => {
111 error!(
112 webhook = info.name,
113 env = sig_config.secret_env,
114 "Webhook secret not configured"
115 );
116 return (
117 StatusCode::INTERNAL_SERVER_ERROR,
118 Json(json!({"error": "Webhook configuration error"})),
119 );
120 }
121 };
122
123 if !validate_signature(sig_config.algorithm, &body, &secret, signature) {
125 warn!(webhook = info.name, "Invalid signature");
126 return (
127 StatusCode::UNAUTHORIZED,
128 Json(json!({"error": "Invalid signature"})),
129 );
130 }
131 }
132
133 let idempotency_key = if let Some(ref idem_config) = info.idempotency {
135 match &idem_config.source {
136 IdempotencySource::Header(header_name) => headers
137 .get(*header_name)
138 .and_then(|v| v.to_str().ok())
139 .map(|s| s.to_string()),
140 IdempotencySource::Body(json_path) => {
141 if let Ok(payload) = serde_json::from_slice::<Value>(&body) {
143 extract_json_path(&payload, json_path)
144 } else {
145 None
146 }
147 }
148 }
149 } else {
150 None
151 };
152
153 if let Some(ref key) = idempotency_key {
155 match check_idempotency(&state.pool, info.name, key).await {
156 Ok(true) => {
157 info!(
158 webhook = info.name,
159 idempotency_key = %key,
160 "Request already processed (idempotent)"
161 );
162 return (StatusCode::OK, Json(json!({"status": "already_processed"})));
163 }
164 Ok(false) => {}
165 Err(e) => {
166 warn!(webhook = info.name, error = %e, "Failed to check idempotency");
167 }
168 }
169 }
170
171 let payload: Value = match serde_json::from_slice(&body) {
173 Ok(v) => v,
174 Err(e) => {
175 warn!(webhook = info.name, error = %e, "Invalid JSON payload");
176 return (
177 StatusCode::BAD_REQUEST,
178 Json(json!({"error": "Invalid JSON"})),
179 );
180 }
181 };
182
183 let header_map: HashMap<String, String> = headers
185 .iter()
186 .filter_map(|(k, v)| {
187 v.to_str()
188 .ok()
189 .map(|v| (k.as_str().to_lowercase(), v.to_string()))
190 })
191 .collect();
192
193 let mut ctx = WebhookContext::new(
195 info.name.to_string(),
196 request_id.clone(),
197 header_map,
198 state.pool.clone(),
199 state.http_client.inner().clone(),
200 )
201 .with_idempotency_key(idempotency_key.clone());
202
203 if let Some(ref dispatcher) = state.job_dispatcher {
204 ctx = ctx.with_job_dispatch(dispatcher.clone());
205 }
206
207 let result = tokio::time::timeout(info.timeout, (entry.handler)(&ctx, payload)).await;
209
210 match result {
211 Ok(Ok(webhook_result)) => {
212 if let Some(ref key) = idempotency_key {
214 if let Some(ref idem_config) = info.idempotency {
215 if let Err(e) =
216 record_idempotency(&state.pool, info.name, key, idem_config.ttl).await
217 {
218 warn!(webhook = info.name, error = %e, "Failed to record idempotency");
219 }
220 }
221 }
222
223 let status =
224 StatusCode::from_u16(webhook_result.status_code()).unwrap_or(StatusCode::OK);
225 (status, Json(webhook_result.body()))
226 }
227 Ok(Err(e)) => {
228 error!(webhook = info.name, error = %e, "Webhook handler error");
229 (
230 StatusCode::INTERNAL_SERVER_ERROR,
231 Json(json!({"error": e.to_string()})),
232 )
233 }
234 Err(_) => {
235 error!(
236 webhook = info.name,
237 timeout = ?info.timeout,
238 "Webhook handler timed out"
239 );
240 (
241 StatusCode::GATEWAY_TIMEOUT,
242 Json(json!({"error": "Request timeout"})),
243 )
244 }
245 }
246}
247
248fn validate_signature(
250 algorithm: SignatureAlgorithm,
251 body: &[u8],
252 secret: &str,
253 signature: &str,
254) -> bool {
255 let sig_hex = signature
257 .strip_prefix(algorithm.prefix())
258 .unwrap_or(signature);
259
260 let expected = match hex::decode(sig_hex) {
262 Ok(b) => b,
263 Err(_) => return false,
264 };
265
266 match algorithm {
267 SignatureAlgorithm::HmacSha256 => {
268 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
269 .expect("HMAC can take key of any size");
270 mac.update(body);
271 mac.verify_slice(&expected).is_ok()
272 }
273 SignatureAlgorithm::HmacSha1 => {
274 let mut mac = Hmac::<Sha1>::new_from_slice(secret.as_bytes())
275 .expect("HMAC can take key of any size");
276 mac.update(body);
277 mac.verify_slice(&expected).is_ok()
278 }
279 SignatureAlgorithm::HmacSha512 => {
280 let mut mac = Hmac::<Sha512>::new_from_slice(secret.as_bytes())
281 .expect("HMAC can take key of any size");
282 mac.update(body);
283 mac.verify_slice(&expected).is_ok()
284 }
285 }
286}
287
288fn extract_json_path(value: &Value, path: &str) -> Option<String> {
290 let path = path.strip_prefix("$.").unwrap_or(path);
291 let parts: Vec<&str> = path.split('.').collect();
292
293 let mut current = value;
294 for part in parts {
295 current = current.get(part)?;
296 }
297
298 match current {
299 Value::String(s) => Some(s.clone()),
300 Value::Number(n) => Some(n.to_string()),
301 _ => Some(current.to_string()),
302 }
303}
304
305async fn check_idempotency(
307 pool: &PgPool,
308 webhook_name: &str,
309 key: &str,
310) -> Result<bool, sqlx::Error> {
311 let result: Option<(i32,)> = sqlx::query_as(
312 r#"
313 SELECT 1 FROM forge_webhook_events
314 WHERE webhook_name = $1 AND idempotency_key = $2 AND expires_at > NOW()
315 "#,
316 )
317 .bind(webhook_name)
318 .bind(key)
319 .fetch_optional(pool)
320 .await?;
321
322 Ok(result.is_some())
323}
324
325async fn record_idempotency(
327 pool: &PgPool,
328 webhook_name: &str,
329 key: &str,
330 ttl: std::time::Duration,
331) -> Result<(), sqlx::Error> {
332 let expires_at =
333 chrono::Utc::now() + chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::hours(24));
334
335 sqlx::query(
336 r#"
337 INSERT INTO forge_webhook_events (idempotency_key, webhook_name, processed_at, expires_at)
338 VALUES ($1, $2, NOW(), $3)
339 ON CONFLICT (idempotency_key) DO NOTHING
340 "#,
341 )
342 .bind(key)
343 .bind(webhook_name)
344 .bind(expires_at)
345 .execute(pool)
346 .await?;
347
348 Ok(())
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_extract_json_path_simple() {
357 let value = json!({"id": "test-123"});
358 assert_eq!(
359 extract_json_path(&value, "$.id"),
360 Some("test-123".to_string())
361 );
362 }
363
364 #[test]
365 fn test_extract_json_path_nested() {
366 let value = json!({"data": {"id": "nested-456"}});
367 assert_eq!(
368 extract_json_path(&value, "$.data.id"),
369 Some("nested-456".to_string())
370 );
371 }
372
373 #[test]
374 fn test_extract_json_path_number() {
375 let value = json!({"count": 42});
376 assert_eq!(extract_json_path(&value, "$.count"), Some("42".to_string()));
377 }
378
379 #[test]
380 fn test_extract_json_path_missing() {
381 let value = json!({"other": "value"});
382 assert_eq!(extract_json_path(&value, "$.id"), None);
383 }
384
385 #[test]
386 fn test_validate_signature_sha256() {
387 use hmac::{Hmac, Mac};
388 use sha2::Sha256;
389
390 let body = b"test payload";
391 let secret = "test_secret";
392
393 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
394 mac.update(body);
395 let signature = hex::encode(mac.finalize().into_bytes());
396
397 assert!(validate_signature(
398 SignatureAlgorithm::HmacSha256,
399 body,
400 secret,
401 &signature
402 ));
403
404 let sig_with_prefix = format!("sha256={}", signature);
406 assert!(validate_signature(
407 SignatureAlgorithm::HmacSha256,
408 body,
409 secret,
410 &sig_with_prefix
411 ));
412 }
413
414 #[test]
415 fn test_validate_signature_invalid() {
416 assert!(!validate_signature(
417 SignatureAlgorithm::HmacSha256,
418 b"test",
419 "secret",
420 "invalid_hex"
421 ));
422
423 assert!(!validate_signature(
424 SignatureAlgorithm::HmacSha256,
425 b"test",
426 "secret",
427 "0000000000000000000000000000000000000000000000000000000000000000"
428 ));
429 }
430}