Skip to main content

harn_vm/connectors/webhook/
mod.rs

1use std::collections::{BTreeMap, BTreeSet, HashMap};
2use std::sync::{Arc, RwLock};
3
4use async_trait::async_trait;
5use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
6use base64::Engine;
7use serde::Deserialize;
8use serde_json::{json, Value as JsonValue};
9use sha2::{Digest, Sha256};
10use time::Duration;
11
12use crate::connectors::{
13    ActivationHandle, ClientError, Connector, ConnectorClient, ConnectorCtx, ConnectorError,
14    ProviderPayloadSchema, RawInbound, TriggerBinding, TriggerKind,
15};
16use crate::secrets::{SecretId, SecretVersion};
17use crate::triggers::{
18    redact_headers, HeaderRedactionPolicy, ProviderId, ProviderPayload, SignatureStatus, TraceId,
19    TriggerEvent, TriggerEventId,
20};
21
22pub mod variants;
23
24pub use variants::WebhookSignatureVariant;
25
26#[cfg(test)]
27mod tests;
28
29pub const WEBHOOK_PROVIDER_ID: &str = "webhook";
30const DEFAULT_EVENT_KIND: &str = "webhook";
31
32#[derive(Clone, Debug)]
33pub(crate) struct WebhookProviderProfile {
34    provider_id: ProviderId,
35    payload_schema_name: String,
36    default_signature_variant: WebhookSignatureVariant,
37}
38
39impl WebhookProviderProfile {
40    pub(crate) fn webhook() -> Self {
41        Self::new(
42            ProviderId::from(WEBHOOK_PROVIDER_ID),
43            "GenericWebhookPayload",
44            WebhookSignatureVariant::Standard,
45        )
46    }
47
48    pub(crate) fn new(
49        provider_id: ProviderId,
50        payload_schema_name: impl Into<String>,
51        default_signature_variant: WebhookSignatureVariant,
52    ) -> Self {
53        Self {
54            provider_id,
55            payload_schema_name: payload_schema_name.into(),
56            default_signature_variant,
57        }
58    }
59}
60
61pub struct GenericWebhookConnector {
62    profile: WebhookProviderProfile,
63    kinds: Vec<TriggerKind>,
64    client: Arc<GenericWebhookClient>,
65    state: RwLock<ConnectorState>,
66}
67
68#[derive(Default)]
69struct ConnectorState {
70    ctx: Option<ConnectorCtx>,
71    bindings: HashMap<String, ActivatedWebhookBinding>,
72}
73
74#[derive(Clone, Debug)]
75struct ActivatedWebhookBinding {
76    #[allow(dead_code)]
77    binding_id: String,
78    path: Option<String>,
79    signing_secret: SecretId,
80    signature_variant: WebhookSignatureVariant,
81    timestamp_tolerance: Option<Duration>,
82    source: Option<String>,
83}
84
85#[derive(Default)]
86struct GenericWebhookClient;
87
88#[async_trait]
89impl ConnectorClient for GenericWebhookClient {
90    async fn call(&self, method: &str, _args: JsonValue) -> Result<JsonValue, ClientError> {
91        Err(ClientError::MethodNotFound(format!(
92            "generic webhook connector has no outbound method `{method}`"
93        )))
94    }
95}
96
97impl GenericWebhookConnector {
98    pub fn new() -> Self {
99        Self::with_profile(WebhookProviderProfile::webhook())
100    }
101
102    pub(crate) fn with_profile(profile: WebhookProviderProfile) -> Self {
103        Self {
104            profile,
105            kinds: vec![TriggerKind::from("webhook")],
106            client: Arc::new(GenericWebhookClient),
107            state: RwLock::new(ConnectorState::default()),
108        }
109    }
110
111    fn binding_for_raw(&self, raw: &RawInbound) -> Result<ActivatedWebhookBinding, ConnectorError> {
112        let state = self.state.read().expect("webhook connector state poisoned");
113        let binding = if let Some(binding_id) =
114            raw.metadata.get("binding_id").and_then(JsonValue::as_str)
115        {
116            state.bindings.get(binding_id).cloned().ok_or_else(|| {
117                ConnectorError::Unsupported(format!(
118                    "generic webhook connector has no active binding `{binding_id}`"
119                ))
120            })?
121        } else if state.bindings.len() == 1 {
122            state
123                .bindings
124                .values()
125                .next()
126                .cloned()
127                .expect("checked single binding")
128        } else {
129            return Err(ConnectorError::Unsupported(
130                "generic webhook connector requires raw.metadata.binding_id when multiple bindings are active".to_string(),
131            ));
132        };
133        Ok(binding)
134    }
135
136    fn ctx(&self) -> Result<ConnectorCtx, ConnectorError> {
137        self.state
138            .read()
139            .expect("webhook connector state poisoned")
140            .ctx
141            .clone()
142            .ok_or_else(|| {
143                ConnectorError::Activation(
144                    "generic webhook connector must be initialized before use".to_string(),
145                )
146            })
147    }
148}
149
150impl Default for GenericWebhookConnector {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156#[async_trait]
157impl Connector for GenericWebhookConnector {
158    fn provider_id(&self) -> &ProviderId {
159        &self.profile.provider_id
160    }
161
162    fn kinds(&self) -> &[TriggerKind] {
163        &self.kinds
164    }
165
166    async fn init(&mut self, ctx: ConnectorCtx) -> Result<(), ConnectorError> {
167        self.state
168            .write()
169            .expect("webhook connector state poisoned")
170            .ctx = Some(ctx);
171        Ok(())
172    }
173
174    async fn activate(
175        &self,
176        bindings: &[TriggerBinding],
177    ) -> Result<ActivationHandle, ConnectorError> {
178        let mut configured = HashMap::new();
179        let mut paths = BTreeSet::new();
180        for binding in bindings {
181            let activated = ActivatedWebhookBinding::from_binding(
182                binding,
183                self.profile.default_signature_variant,
184            )?;
185            if let Some(path) = &activated.path {
186                if !paths.insert(path.clone()) {
187                    return Err(ConnectorError::Activation(format!(
188                        "generic webhook connector path `{path}` is configured by multiple bindings"
189                    )));
190                }
191            }
192            configured.insert(binding.binding_id.clone(), activated);
193        }
194
195        self.state
196            .write()
197            .expect("webhook connector state poisoned")
198            .bindings = configured;
199        Ok(ActivationHandle::new(
200            self.provider_id().clone(),
201            bindings.len(),
202        ))
203    }
204
205    async fn normalize_inbound(&self, raw: RawInbound) -> Result<TriggerEvent, ConnectorError> {
206        let ctx = self.ctx()?;
207        let binding = self.binding_for_raw(&raw)?;
208        let provider = self.profile.provider_id.clone();
209        let received_at = raw.received_at;
210        let effective_headers = effective_headers(&raw.headers, binding.source.as_deref());
211        let secret = load_secret(&ctx, &binding.signing_secret)?;
212        binding.signature_variant.verify(
213            ctx.event_log.as_ref(),
214            &provider,
215            &raw.body,
216            &effective_headers,
217            secret.as_str(),
218            binding.timestamp_tolerance,
219            received_at,
220        )?;
221
222        let normalized_body = normalize_body(&raw.body, &effective_headers);
223        let kind = derive_kind(&raw, &effective_headers, &normalized_body);
224        let dedupe_key = derive_dedupe_key(
225            binding.signature_variant,
226            &effective_headers,
227            &normalized_body,
228            &raw.body,
229        );
230
231        let provider_payload = ProviderPayload::normalize(
232            &provider,
233            kind.as_str(),
234            &effective_headers,
235            normalized_body,
236        )
237        .map_err(|error| ConnectorError::Unsupported(error.to_string()))?;
238
239        Ok(TriggerEvent {
240            id: TriggerEventId::new(),
241            provider,
242            kind,
243            received_at,
244            occurred_at: raw
245                .occurred_at
246                .or_else(|| infer_occurred_at(&provider_payload)),
247            dedupe_key,
248            trace_id: TraceId::new(),
249            tenant_id: raw.tenant_id.clone(),
250            headers: redact_headers(&effective_headers, &HeaderRedactionPolicy::default()),
251            batch: None,
252            raw_body: Some(raw.body.clone()),
253            provider_payload,
254            signature_status: SignatureStatus::Verified,
255            dedupe_claimed: false,
256        })
257    }
258
259    fn payload_schema(&self) -> ProviderPayloadSchema {
260        ProviderPayloadSchema::named(self.profile.payload_schema_name.clone())
261    }
262
263    fn client(&self) -> Arc<dyn ConnectorClient> {
264        self.client.clone()
265    }
266}
267
268impl ActivatedWebhookBinding {
269    fn from_binding(
270        binding: &TriggerBinding,
271        default_signature_variant: WebhookSignatureVariant,
272    ) -> Result<Self, ConnectorError> {
273        let config: WebhookBindingConfig =
274            serde_json::from_value(binding.config.clone()).map_err(|error| {
275                ConnectorError::Activation(format!(
276                    "generic webhook binding `{}` has invalid config: {error}",
277                    binding.binding_id
278                ))
279            })?;
280        let signing_secret =
281            parse_secret_id(config.secrets.signing_secret.as_deref()).ok_or_else(|| {
282                ConnectorError::Activation(format!(
283                    "generic webhook binding `{}` requires secrets.signing_secret",
284                    binding.binding_id
285                ))
286            })?;
287        let signature_variant = match config.webhook.signature_scheme.as_deref() {
288            Some(raw) => WebhookSignatureVariant::parse(Some(raw))?,
289            None => default_signature_variant,
290        };
291        let timestamp_tolerance = match config.webhook.timestamp_tolerance_secs {
292            Some(seconds) if seconds < 0 => {
293                return Err(ConnectorError::Activation(format!(
294                    "generic webhook binding `{}` has a negative timestamp_tolerance_secs",
295                    binding.binding_id
296                )))
297            }
298            Some(seconds) => Some(Duration::seconds(seconds)),
299            None => signature_variant.default_timestamp_window(),
300        };
301
302        Ok(Self {
303            binding_id: binding.binding_id.clone(),
304            path: config.match_config.path,
305            signing_secret,
306            signature_variant,
307            timestamp_tolerance,
308            source: config.webhook.source,
309        })
310    }
311}
312
313#[derive(Default, Deserialize)]
314struct WebhookBindingConfig {
315    #[serde(default, rename = "match")]
316    match_config: WebhookMatchConfig,
317    #[serde(default)]
318    secrets: WebhookSecretsConfig,
319    #[serde(default)]
320    webhook: WebhookConnectorConfig,
321}
322
323#[derive(Default, Deserialize)]
324struct WebhookMatchConfig {
325    path: Option<String>,
326}
327
328#[derive(Default, Deserialize)]
329struct WebhookSecretsConfig {
330    signing_secret: Option<String>,
331}
332
333#[derive(Default, Deserialize)]
334struct WebhookConnectorConfig {
335    signature_scheme: Option<String>,
336    timestamp_tolerance_secs: Option<i64>,
337    source: Option<String>,
338}
339
340fn effective_headers(
341    headers: &BTreeMap<String, String>,
342    source: Option<&str>,
343) -> BTreeMap<String, String> {
344    let mut effective = headers.clone();
345    if let Some(content_type) = header_value(headers, "content-type") {
346        effective
347            .entry("Content-Type".to_string())
348            .or_insert_with(|| content_type.to_string());
349    }
350    if let Some(source) = source.or_else(|| header_value(headers, "x-webhook-source")) {
351        effective
352            .entry("X-Webhook-Source".to_string())
353            .or_insert_with(|| source.to_string());
354    }
355    if let Some(event) = header_value(headers, "x-github-event") {
356        effective
357            .entry("X-GitHub-Event".to_string())
358            .or_insert_with(|| event.to_string());
359    }
360    if let Some(delivery) = header_value(headers, "x-github-delivery") {
361        effective
362            .entry("X-GitHub-Delivery".to_string())
363            .or_insert_with(|| delivery.to_string());
364    }
365    effective
366}
367
368fn load_secret(ctx: &ConnectorCtx, secret_id: &SecretId) -> Result<String, ConnectorError> {
369    let secret = futures::executor::block_on(ctx.secrets.get(secret_id))?;
370    secret.with_exposed(|bytes| {
371        std::str::from_utf8(bytes)
372            .map(|value| value.to_string())
373            .map_err(|error| {
374                ConnectorError::Secret(format!(
375                    "generic webhook signing secret `{secret_id}` is not valid UTF-8: {error}"
376                ))
377            })
378    })
379}
380
381fn normalize_body(body: &[u8], headers: &BTreeMap<String, String>) -> JsonValue {
382    let content_type = header_value(headers, "content-type").unwrap_or_default();
383    if content_type.contains("json") {
384        if let Ok(value) = serde_json::from_slice(body) {
385            return value;
386        }
387    }
388    serde_json::from_slice(body).unwrap_or_else(|_| {
389        json!({
390            "raw_base64": BASE64_STANDARD.encode(body),
391            "raw_utf8": std::str::from_utf8(body).ok(),
392        })
393    })
394}
395
396fn derive_kind(raw: &RawInbound, headers: &BTreeMap<String, String>, body: &JsonValue) -> String {
397    if !raw.kind.trim().is_empty() {
398        return raw.kind.clone();
399    }
400    header_value(headers, "x-github-event")
401        .map(ToString::to_string)
402        .or_else(|| {
403            body.get("type")
404                .and_then(JsonValue::as_str)
405                .map(ToString::to_string)
406        })
407        .or_else(|| {
408            body.get("event")
409                .and_then(JsonValue::as_str)
410                .map(ToString::to_string)
411        })
412        .unwrap_or_else(|| DEFAULT_EVENT_KIND.to_string())
413}
414
415fn derive_dedupe_key(
416    variant: WebhookSignatureVariant,
417    headers: &BTreeMap<String, String>,
418    body: &JsonValue,
419    raw_body: &[u8],
420) -> String {
421    match variant {
422        WebhookSignatureVariant::Standard => header_value(headers, "webhook-id")
423            .map(ToString::to_string)
424            .unwrap_or_else(|| fallback_body_digest(raw_body)),
425        WebhookSignatureVariant::Stripe => body
426            .get("id")
427            .and_then(JsonValue::as_str)
428            .map(ToString::to_string)
429            .unwrap_or_else(|| fallback_body_digest(raw_body)),
430        WebhookSignatureVariant::GitHub => header_value(headers, "x-github-delivery")
431            .map(ToString::to_string)
432            .unwrap_or_else(|| fallback_body_digest(raw_body)),
433        WebhookSignatureVariant::Slack => body
434            .get("event_id")
435            .and_then(JsonValue::as_str)
436            .map(ToString::to_string)
437            .unwrap_or_else(|| fallback_body_digest(raw_body)),
438    }
439}
440
441fn infer_occurred_at(provider_payload: &ProviderPayload) -> Option<time::OffsetDateTime> {
442    let ProviderPayload::Known(payload) = provider_payload else {
443        return None;
444    };
445    let raw = match payload {
446        crate::triggers::event::KnownProviderPayload::Webhook(payload) => &payload.raw,
447        _ => return None,
448    };
449    raw.get("timestamp")
450        .and_then(JsonValue::as_str)
451        .and_then(|value| {
452            time::OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339).ok()
453        })
454}
455
456fn header_value<'a>(headers: &'a BTreeMap<String, String>, name: &str) -> Option<&'a str> {
457    headers
458        .iter()
459        .find(|(key, _)| key.eq_ignore_ascii_case(name))
460        .map(|(_, value)| value.as_str())
461}
462
463fn parse_secret_id(raw: Option<&str>) -> Option<SecretId> {
464    let trimmed = raw?.trim();
465    if trimmed.is_empty() {
466        return None;
467    }
468    let (base, version) = match trimmed.rsplit_once('@') {
469        Some((base, version_text)) => {
470            let version = version_text.parse::<u64>().ok()?;
471            (base, SecretVersion::Exact(version))
472        }
473        None => (trimmed, SecretVersion::Latest),
474    };
475    let (namespace, name) = base.split_once('/')?;
476    if namespace.is_empty() || name.is_empty() {
477        return None;
478    }
479    Some(SecretId::new(namespace, name).with_version(version))
480}
481
482fn fallback_body_digest(body: &[u8]) -> String {
483    let digest = Sha256::digest(body);
484    format!("sha256:{}", hex::encode(digest))
485}