Skip to main content

harn_vm/connectors/webhook/
mod.rs

1use std::collections::{BTreeMap, BTreeSet, HashMap};
2use std::sync::{Arc, RwLock};
3
4use async_trait::async_trait;
5use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
6use base64::Engine;
7use serde::Deserialize;
8use serde_json::{json, Value as JsonValue};
9use sha2::{Digest, Sha256};
10use time::Duration;
11
12use crate::connectors::{
13    ActivationHandle, ClientError, Connector, ConnectorClient, ConnectorCtx, ConnectorError,
14    ProviderPayloadSchema, RawInbound, TriggerBinding, TriggerKind,
15};
16use crate::secrets::{SecretId, SecretVersion};
17use crate::triggers::{
18    redact_headers, HeaderRedactionPolicy, ProviderId, ProviderPayload, SignatureStatus, TraceId,
19    TriggerEvent, TriggerEventId,
20};
21
22pub mod variants;
23
24pub use variants::WebhookSignatureVariant;
25
26#[cfg(test)]
27mod tests;
28
29pub const WEBHOOK_PROVIDER_ID: &str = "webhook";
30const DEFAULT_EVENT_KIND: &str = "webhook";
31
32#[derive(Clone, Debug)]
33pub(crate) struct WebhookProviderProfile {
34    provider_id: ProviderId,
35    payload_schema_name: String,
36    default_signature_variant: WebhookSignatureVariant,
37}
38
39impl WebhookProviderProfile {
40    pub(crate) fn webhook() -> Self {
41        Self::new(
42            ProviderId::from(WEBHOOK_PROVIDER_ID),
43            "GenericWebhookPayload",
44            WebhookSignatureVariant::Standard,
45        )
46    }
47
48    pub(crate) fn new(
49        provider_id: ProviderId,
50        payload_schema_name: impl Into<String>,
51        default_signature_variant: WebhookSignatureVariant,
52    ) -> Self {
53        Self {
54            provider_id,
55            payload_schema_name: payload_schema_name.into(),
56            default_signature_variant,
57        }
58    }
59}
60
61pub struct GenericWebhookConnector {
62    profile: WebhookProviderProfile,
63    kinds: Vec<TriggerKind>,
64    client: Arc<GenericWebhookClient>,
65    state: RwLock<ConnectorState>,
66}
67
68#[derive(Default)]
69struct ConnectorState {
70    ctx: Option<ConnectorCtx>,
71    bindings: HashMap<String, ActivatedWebhookBinding>,
72}
73
74#[derive(Clone, Debug)]
75struct ActivatedWebhookBinding {
76    #[allow(dead_code)]
77    binding_id: String,
78    path: Option<String>,
79    signing_secret: SecretId,
80    signature_variant: WebhookSignatureVariant,
81    timestamp_tolerance: Option<Duration>,
82    source: Option<String>,
83}
84
85#[derive(Default)]
86struct GenericWebhookClient;
87
88#[async_trait]
89impl ConnectorClient for GenericWebhookClient {
90    async fn call(&self, method: &str, _args: JsonValue) -> Result<JsonValue, ClientError> {
91        Err(ClientError::MethodNotFound(format!(
92            "generic webhook connector has no outbound method `{method}`"
93        )))
94    }
95}
96
97impl GenericWebhookConnector {
98    pub fn new() -> Self {
99        Self::with_profile(WebhookProviderProfile::webhook())
100    }
101
102    pub(crate) fn with_profile(profile: WebhookProviderProfile) -> Self {
103        Self {
104            profile,
105            kinds: vec![TriggerKind::from("webhook")],
106            client: Arc::new(GenericWebhookClient),
107            state: RwLock::new(ConnectorState::default()),
108        }
109    }
110
111    fn binding_for_raw(&self, raw: &RawInbound) -> Result<ActivatedWebhookBinding, ConnectorError> {
112        let state = self.state.read().expect("webhook connector state poisoned");
113        let binding = if let Some(binding_id) =
114            raw.metadata.get("binding_id").and_then(JsonValue::as_str)
115        {
116            state.bindings.get(binding_id).cloned().ok_or_else(|| {
117                ConnectorError::Unsupported(format!(
118                    "generic webhook connector has no active binding `{binding_id}`"
119                ))
120            })?
121        } else if state.bindings.len() == 1 {
122            state
123                .bindings
124                .values()
125                .next()
126                .cloned()
127                .expect("checked single binding")
128        } else {
129            return Err(ConnectorError::Unsupported(
130                "generic webhook connector requires raw.metadata.binding_id when multiple bindings are active".to_string(),
131            ));
132        };
133        Ok(binding)
134    }
135
136    fn ctx(&self) -> Result<ConnectorCtx, ConnectorError> {
137        self.state
138            .read()
139            .expect("webhook connector state poisoned")
140            .ctx
141            .clone()
142            .ok_or_else(|| {
143                ConnectorError::Activation(
144                    "generic webhook connector must be initialized before use".to_string(),
145                )
146            })
147    }
148}
149
150impl Default for GenericWebhookConnector {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156#[async_trait]
157impl Connector for GenericWebhookConnector {
158    fn provider_id(&self) -> &ProviderId {
159        &self.profile.provider_id
160    }
161
162    fn kinds(&self) -> &[TriggerKind] {
163        &self.kinds
164    }
165
166    async fn init(&mut self, ctx: ConnectorCtx) -> Result<(), ConnectorError> {
167        self.state
168            .write()
169            .expect("webhook connector state poisoned")
170            .ctx = Some(ctx);
171        Ok(())
172    }
173
174    async fn activate(
175        &self,
176        bindings: &[TriggerBinding],
177    ) -> Result<ActivationHandle, ConnectorError> {
178        let mut configured = HashMap::new();
179        let mut paths = BTreeSet::new();
180        for binding in bindings {
181            let activated = ActivatedWebhookBinding::from_binding(
182                binding,
183                self.profile.default_signature_variant,
184            )?;
185            if let Some(path) = &activated.path {
186                if !paths.insert(path.clone()) {
187                    return Err(ConnectorError::Activation(format!(
188                        "generic webhook connector path `{path}` is configured by multiple bindings"
189                    )));
190                }
191            }
192            configured.insert(binding.binding_id.clone(), activated);
193        }
194
195        self.state
196            .write()
197            .expect("webhook connector state poisoned")
198            .bindings = configured;
199        Ok(ActivationHandle::new(
200            self.provider_id().clone(),
201            bindings.len(),
202        ))
203    }
204
205    fn normalize_inbound(&self, raw: RawInbound) -> Result<TriggerEvent, ConnectorError> {
206        let ctx = self.ctx()?;
207        let binding = self.binding_for_raw(&raw)?;
208        let provider = self.profile.provider_id.clone();
209        let received_at = raw.received_at;
210        let effective_headers = effective_headers(&raw.headers, binding.source.as_deref());
211        let secret = load_secret(&ctx, &binding.signing_secret)?;
212        binding.signature_variant.verify(
213            ctx.event_log.as_ref(),
214            &provider,
215            &raw.body,
216            &effective_headers,
217            secret.as_str(),
218            binding.timestamp_tolerance,
219            received_at,
220        )?;
221
222        let normalized_body = normalize_body(&raw.body, &effective_headers);
223        let kind = derive_kind(&raw, &effective_headers, &normalized_body);
224        let dedupe_key = derive_dedupe_key(
225            binding.signature_variant,
226            &effective_headers,
227            &normalized_body,
228            &raw.body,
229        );
230
231        let provider_payload = ProviderPayload::normalize(
232            &provider,
233            kind.as_str(),
234            &effective_headers,
235            normalized_body,
236        )
237        .map_err(|error| ConnectorError::Unsupported(error.to_string()))?;
238
239        Ok(TriggerEvent {
240            id: TriggerEventId::new(),
241            provider,
242            kind,
243            received_at,
244            occurred_at: raw
245                .occurred_at
246                .or_else(|| infer_occurred_at(&provider_payload)),
247            dedupe_key,
248            trace_id: TraceId::new(),
249            tenant_id: raw.tenant_id.clone(),
250            headers: redact_headers(&effective_headers, &HeaderRedactionPolicy::default()),
251            batch: None,
252            provider_payload,
253            signature_status: SignatureStatus::Verified,
254            dedupe_claimed: false,
255        })
256    }
257
258    fn payload_schema(&self) -> ProviderPayloadSchema {
259        ProviderPayloadSchema::named(self.profile.payload_schema_name.clone())
260    }
261
262    fn client(&self) -> Arc<dyn ConnectorClient> {
263        self.client.clone()
264    }
265}
266
267impl ActivatedWebhookBinding {
268    fn from_binding(
269        binding: &TriggerBinding,
270        default_signature_variant: WebhookSignatureVariant,
271    ) -> Result<Self, ConnectorError> {
272        let config: WebhookBindingConfig =
273            serde_json::from_value(binding.config.clone()).map_err(|error| {
274                ConnectorError::Activation(format!(
275                    "generic webhook binding `{}` has invalid config: {error}",
276                    binding.binding_id
277                ))
278            })?;
279        let signing_secret =
280            parse_secret_id(config.secrets.signing_secret.as_deref()).ok_or_else(|| {
281                ConnectorError::Activation(format!(
282                    "generic webhook binding `{}` requires secrets.signing_secret",
283                    binding.binding_id
284                ))
285            })?;
286        let signature_variant = match config.webhook.signature_scheme.as_deref() {
287            Some(raw) => WebhookSignatureVariant::parse(Some(raw))?,
288            None => default_signature_variant,
289        };
290        let timestamp_tolerance = match config.webhook.timestamp_tolerance_secs {
291            Some(seconds) if seconds < 0 => {
292                return Err(ConnectorError::Activation(format!(
293                    "generic webhook binding `{}` has a negative timestamp_tolerance_secs",
294                    binding.binding_id
295                )))
296            }
297            Some(seconds) => Some(Duration::seconds(seconds)),
298            None => signature_variant.default_timestamp_window(),
299        };
300
301        Ok(Self {
302            binding_id: binding.binding_id.clone(),
303            path: config.match_config.path,
304            signing_secret,
305            signature_variant,
306            timestamp_tolerance,
307            source: config.webhook.source,
308        })
309    }
310}
311
312#[derive(Default, Deserialize)]
313struct WebhookBindingConfig {
314    #[serde(default, rename = "match")]
315    match_config: WebhookMatchConfig,
316    #[serde(default)]
317    secrets: WebhookSecretsConfig,
318    #[serde(default)]
319    webhook: WebhookConnectorConfig,
320}
321
322#[derive(Default, Deserialize)]
323struct WebhookMatchConfig {
324    path: Option<String>,
325}
326
327#[derive(Default, Deserialize)]
328struct WebhookSecretsConfig {
329    signing_secret: Option<String>,
330}
331
332#[derive(Default, Deserialize)]
333struct WebhookConnectorConfig {
334    signature_scheme: Option<String>,
335    timestamp_tolerance_secs: Option<i64>,
336    source: Option<String>,
337}
338
339fn effective_headers(
340    headers: &BTreeMap<String, String>,
341    source: Option<&str>,
342) -> BTreeMap<String, String> {
343    let mut effective = headers.clone();
344    if let Some(content_type) = header_value(headers, "content-type") {
345        effective
346            .entry("Content-Type".to_string())
347            .or_insert_with(|| content_type.to_string());
348    }
349    if let Some(source) = source.or_else(|| header_value(headers, "x-webhook-source")) {
350        effective
351            .entry("X-Webhook-Source".to_string())
352            .or_insert_with(|| source.to_string());
353    }
354    if let Some(event) = header_value(headers, "x-github-event") {
355        effective
356            .entry("X-GitHub-Event".to_string())
357            .or_insert_with(|| event.to_string());
358    }
359    if let Some(delivery) = header_value(headers, "x-github-delivery") {
360        effective
361            .entry("X-GitHub-Delivery".to_string())
362            .or_insert_with(|| delivery.to_string());
363    }
364    effective
365}
366
367fn load_secret(ctx: &ConnectorCtx, secret_id: &SecretId) -> Result<String, ConnectorError> {
368    let secret = futures::executor::block_on(ctx.secrets.get(secret_id))?;
369    secret.with_exposed(|bytes| {
370        std::str::from_utf8(bytes)
371            .map(|value| value.to_string())
372            .map_err(|error| {
373                ConnectorError::Secret(format!(
374                    "generic webhook signing secret `{secret_id}` is not valid UTF-8: {error}"
375                ))
376            })
377    })
378}
379
380fn normalize_body(body: &[u8], headers: &BTreeMap<String, String>) -> JsonValue {
381    let content_type = header_value(headers, "content-type").unwrap_or_default();
382    if content_type.contains("json") {
383        if let Ok(value) = serde_json::from_slice(body) {
384            return value;
385        }
386    }
387    serde_json::from_slice(body).unwrap_or_else(|_| {
388        json!({
389            "raw_base64": BASE64_STANDARD.encode(body),
390            "raw_utf8": std::str::from_utf8(body).ok(),
391        })
392    })
393}
394
395fn derive_kind(raw: &RawInbound, headers: &BTreeMap<String, String>, body: &JsonValue) -> String {
396    if !raw.kind.trim().is_empty() {
397        return raw.kind.clone();
398    }
399    header_value(headers, "x-github-event")
400        .map(ToString::to_string)
401        .or_else(|| {
402            body.get("type")
403                .and_then(JsonValue::as_str)
404                .map(ToString::to_string)
405        })
406        .or_else(|| {
407            body.get("event")
408                .and_then(JsonValue::as_str)
409                .map(ToString::to_string)
410        })
411        .unwrap_or_else(|| DEFAULT_EVENT_KIND.to_string())
412}
413
414fn derive_dedupe_key(
415    variant: WebhookSignatureVariant,
416    headers: &BTreeMap<String, String>,
417    body: &JsonValue,
418    raw_body: &[u8],
419) -> String {
420    match variant {
421        WebhookSignatureVariant::Standard => header_value(headers, "webhook-id")
422            .map(ToString::to_string)
423            .unwrap_or_else(|| fallback_body_digest(raw_body)),
424        WebhookSignatureVariant::Stripe => body
425            .get("id")
426            .and_then(JsonValue::as_str)
427            .map(ToString::to_string)
428            .unwrap_or_else(|| fallback_body_digest(raw_body)),
429        WebhookSignatureVariant::GitHub => header_value(headers, "x-github-delivery")
430            .map(ToString::to_string)
431            .unwrap_or_else(|| fallback_body_digest(raw_body)),
432        WebhookSignatureVariant::Slack => body
433            .get("event_id")
434            .and_then(JsonValue::as_str)
435            .map(ToString::to_string)
436            .unwrap_or_else(|| fallback_body_digest(raw_body)),
437    }
438}
439
440fn infer_occurred_at(provider_payload: &ProviderPayload) -> Option<time::OffsetDateTime> {
441    let ProviderPayload::Known(payload) = provider_payload else {
442        return None;
443    };
444    let raw = match payload {
445        crate::triggers::event::KnownProviderPayload::Webhook(payload) => &payload.raw,
446        _ => return None,
447    };
448    raw.get("timestamp")
449        .and_then(JsonValue::as_str)
450        .and_then(|value| {
451            time::OffsetDateTime::parse(value, &time::format_description::well_known::Rfc3339).ok()
452        })
453}
454
455fn header_value<'a>(headers: &'a BTreeMap<String, String>, name: &str) -> Option<&'a str> {
456    headers
457        .iter()
458        .find(|(key, _)| key.eq_ignore_ascii_case(name))
459        .map(|(_, value)| value.as_str())
460}
461
462fn parse_secret_id(raw: Option<&str>) -> Option<SecretId> {
463    let trimmed = raw?.trim();
464    if trimmed.is_empty() {
465        return None;
466    }
467    let (base, version) = match trimmed.rsplit_once('@') {
468        Some((base, version_text)) => {
469            let version = version_text.parse::<u64>().ok()?;
470            (base, SecretVersion::Exact(version))
471        }
472        None => (trimmed, SecretVersion::Latest),
473    };
474    let (namespace, name) = base.split_once('/')?;
475    if namespace.is_empty() || name.is_empty() {
476        return None;
477    }
478    Some(SecretId::new(namespace, name).with_version(version))
479}
480
481fn fallback_body_digest(body: &[u8]) -> String {
482    let digest = Sha256::digest(body);
483    format!("sha256:{}", hex::encode(digest))
484}