Skip to main content

construct/channels/
webhook.rs

1use super::traits::{Channel, ChannelMessage, SendMessage};
2use anyhow::{Result, bail};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6/// Generic Webhook channel — receives messages via HTTP POST and sends replies
7/// to a configurable outbound URL. This is the "universal adapter" for any system
8/// that supports webhooks.
9pub struct WebhookChannel {
10    listen_port: u16,
11    listen_path: String,
12    send_url: Option<String>,
13    send_method: String,
14    auth_header: Option<String>,
15    secret: Option<String>,
16}
17
18/// Incoming webhook payload format.
19#[derive(Debug, Deserialize)]
20struct IncomingWebhook {
21    sender: String,
22    content: String,
23    #[serde(default)]
24    thread_id: Option<String>,
25}
26
27/// Outgoing webhook payload format.
28#[derive(Debug, Serialize)]
29struct OutgoingWebhook {
30    content: String,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    thread_id: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    recipient: Option<String>,
35}
36
37impl WebhookChannel {
38    pub fn new(
39        listen_port: u16,
40        listen_path: Option<String>,
41        send_url: Option<String>,
42        send_method: Option<String>,
43        auth_header: Option<String>,
44        secret: Option<String>,
45    ) -> Self {
46        let path = listen_path.unwrap_or_else(|| "/webhook".to_string());
47        // Ensure path starts with /
48        let listen_path = if path.starts_with('/') {
49            path
50        } else {
51            format!("/{path}")
52        };
53
54        Self {
55            listen_port,
56            listen_path,
57            send_url,
58            send_method: send_method
59                .unwrap_or_else(|| "POST".to_string())
60                .to_uppercase(),
61            auth_header,
62            secret,
63        }
64    }
65
66    fn http_client(&self) -> reqwest::Client {
67        crate::config::build_runtime_proxy_client("channel.webhook")
68    }
69
70    /// Verify an incoming request's signature if a secret is configured.
71    fn verify_signature(&self, body: &[u8], signature: Option<&str>) -> bool {
72        let Some(ref secret) = self.secret else {
73            return true; // No secret configured, accept all
74        };
75
76        let Some(sig) = signature else {
77            return false; // Secret is set but no signature header provided
78        };
79
80        // HMAC-SHA256 verification
81        use hmac::{Hmac, Mac};
82        use sha2::Sha256;
83
84        type HmacSha256 = Hmac<Sha256>;
85
86        let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
87            return false;
88        };
89        mac.update(body);
90
91        // Signature should be hex-encoded
92        let Ok(expected) = hex::decode(sig.trim_start_matches("sha256=")) else {
93            return false;
94        };
95
96        mac.verify_slice(&expected).is_ok()
97    }
98}
99
100#[async_trait]
101impl Channel for WebhookChannel {
102    fn name(&self) -> &str {
103        "webhook"
104    }
105
106    async fn send(&self, message: &SendMessage) -> Result<()> {
107        let Some(ref send_url) = self.send_url else {
108            tracing::debug!("Webhook channel: no send_url configured, skipping outbound message");
109            return Ok(());
110        };
111
112        let client = self.http_client();
113        let payload = OutgoingWebhook {
114            content: message.content.clone(),
115            thread_id: message.thread_ts.clone(),
116            recipient: if message.recipient.is_empty() {
117                None
118            } else {
119                Some(message.recipient.clone())
120            },
121        };
122
123        let mut request = match self.send_method.as_str() {
124            "PUT" => client.put(send_url),
125            _ => client.post(send_url),
126        };
127
128        if let Some(ref auth) = self.auth_header {
129            request = request.header("Authorization", auth);
130        }
131
132        let resp = request.json(&payload).send().await?;
133
134        let status = resp.status();
135        if !status.is_success() {
136            let body = resp
137                .text()
138                .await
139                .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
140            bail!("Webhook send failed ({status}): {body}");
141        }
142
143        Ok(())
144    }
145
146    async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
147        use axum::{
148            Router,
149            body::Bytes,
150            extract::State,
151            http::{HeaderMap, StatusCode},
152            routing::post,
153        };
154        use portable_atomic::{AtomicU64, Ordering};
155        use std::sync::Arc;
156
157        let counter = Arc::new(AtomicU64::new(0));
158
159        struct WebhookState {
160            tx: tokio::sync::mpsc::Sender<ChannelMessage>,
161            secret: Option<String>,
162            counter: Arc<AtomicU64>,
163        }
164
165        let state = Arc::new(WebhookState {
166            tx: tx.clone(),
167            secret: self.secret.clone(),
168            counter: counter.clone(),
169        });
170
171        let listen_path = self.listen_path.clone();
172
173        async fn handle_webhook(
174            State(state): State<Arc<WebhookState>>,
175            headers: HeaderMap,
176            body: Bytes,
177        ) -> StatusCode {
178            // Verify signature if secret is configured
179            if let Some(ref secret) = state.secret {
180                use hmac::{Hmac, Mac};
181                use sha2::Sha256;
182                type HmacSha256 = Hmac<Sha256>;
183
184                let signature = headers
185                    .get("x-webhook-signature")
186                    .and_then(|v| v.to_str().ok());
187
188                let valid = if let Some(sig) = signature {
189                    if let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) {
190                        mac.update(&body);
191                        let expected =
192                            hex::decode(sig.trim_start_matches("sha256=")).unwrap_or_default();
193                        mac.verify_slice(&expected).is_ok()
194                    } else {
195                        false
196                    }
197                } else {
198                    false
199                };
200
201                if !valid {
202                    tracing::warn!("Webhook: invalid signature, rejecting request");
203                    return StatusCode::UNAUTHORIZED;
204                }
205            }
206
207            let payload: IncomingWebhook = match serde_json::from_slice(&body) {
208                Ok(p) => p,
209                Err(e) => {
210                    tracing::warn!("Webhook: invalid JSON payload: {e}");
211                    return StatusCode::BAD_REQUEST;
212                }
213            };
214
215            if payload.content.is_empty() {
216                return StatusCode::BAD_REQUEST;
217            }
218
219            let seq = state.counter.fetch_add(1, Ordering::Relaxed);
220
221            #[allow(clippy::cast_possible_truncation)]
222            let timestamp = std::time::SystemTime::now()
223                .duration_since(std::time::UNIX_EPOCH)
224                .unwrap_or_default()
225                .as_secs();
226
227            let reply_target = payload
228                .thread_id
229                .clone()
230                .unwrap_or_else(|| payload.sender.clone());
231
232            let msg = ChannelMessage {
233                id: format!("webhook_{seq}"),
234                sender: payload.sender,
235                reply_target,
236                content: payload.content,
237                channel: "webhook".to_string(),
238                timestamp,
239                thread_ts: payload.thread_id,
240                interruption_scope_id: None,
241                attachments: vec![],
242            };
243
244            if state.tx.send(msg).await.is_err() {
245                return StatusCode::SERVICE_UNAVAILABLE;
246            }
247
248            StatusCode::OK
249        }
250
251        let app = Router::new()
252            .route(&listen_path, post(handle_webhook))
253            .with_state(state);
254
255        let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.listen_port));
256        tracing::info!(
257            "Webhook channel listening on http://0.0.0.0:{}{} ...",
258            self.listen_port,
259            self.listen_path
260        );
261
262        let listener = tokio::net::TcpListener::bind(addr).await?;
263        axum::serve(listener, app)
264            .await
265            .map_err(|e| anyhow::anyhow!("Webhook server error: {e}"))?;
266
267        Ok(())
268    }
269
270    async fn health_check(&self) -> bool {
271        // Webhook channel is healthy if the port can be bound (basic check).
272        // In practice, once listen() starts the server is running.
273        true
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    fn make_channel() -> WebhookChannel {
282        WebhookChannel::new(
283            8080,
284            Some("/webhook".into()),
285            Some("https://example.com/callback".into()),
286            None,
287            None,
288            None,
289        )
290    }
291
292    fn make_channel_with_secret() -> WebhookChannel {
293        WebhookChannel::new(
294            8080,
295            None,
296            Some("https://example.com/callback".into()),
297            None,
298            None,
299            Some("mysecret".into()),
300        )
301    }
302
303    #[test]
304    fn default_path() {
305        let ch = WebhookChannel::new(8080, None, None, None, None, None);
306        assert_eq!(ch.listen_path, "/webhook");
307    }
308
309    #[test]
310    fn path_normalized() {
311        let ch = WebhookChannel::new(8080, Some("hooks/incoming".into()), None, None, None, None);
312        assert_eq!(ch.listen_path, "/hooks/incoming");
313    }
314
315    #[test]
316    fn send_method_default() {
317        let ch = make_channel();
318        assert_eq!(ch.send_method, "POST");
319    }
320
321    #[test]
322    fn send_method_put() {
323        let ch = WebhookChannel::new(
324            8080,
325            None,
326            Some("https://example.com".into()),
327            Some("put".into()),
328            None,
329            None,
330        );
331        assert_eq!(ch.send_method, "PUT");
332    }
333
334    #[test]
335    fn incoming_payload_deserializes_all_fields() {
336        let json = r#"{"sender": "construct_user", "content": "hello", "thread_id": "t1"}"#;
337        let payload: IncomingWebhook = serde_json::from_str(json).unwrap();
338        assert_eq!(payload.sender, "construct_user");
339        assert_eq!(payload.content, "hello");
340        assert_eq!(payload.thread_id.as_deref(), Some("t1"));
341    }
342
343    #[test]
344    fn incoming_payload_without_thread() {
345        let json = r#"{"sender": "bob", "content": "hi"}"#;
346        let payload: IncomingWebhook = serde_json::from_str(json).unwrap();
347        assert_eq!(payload.sender, "bob");
348        assert_eq!(payload.content, "hi");
349        assert!(payload.thread_id.is_none());
350    }
351
352    #[test]
353    fn outgoing_payload_serializes_content() {
354        let payload = OutgoingWebhook {
355            content: "response".into(),
356            thread_id: Some("t1".into()),
357            recipient: Some("construct_user".into()),
358        };
359        let json = serde_json::to_value(&payload).unwrap();
360        assert_eq!(json["content"], "response");
361        assert_eq!(json["thread_id"], "t1");
362        assert_eq!(json["recipient"], "construct_user");
363    }
364
365    #[test]
366    fn outgoing_payload_omits_none_fields() {
367        let payload = OutgoingWebhook {
368            content: "response".into(),
369            thread_id: None,
370            recipient: None,
371        };
372        let json = serde_json::to_value(&payload).unwrap();
373        assert_eq!(json["content"], "response");
374        assert!(json.get("thread_id").is_none());
375        assert!(json.get("recipient").is_none());
376    }
377
378    #[test]
379    fn verify_signature_no_secret() {
380        let ch = make_channel();
381        assert!(ch.verify_signature(b"body", None));
382    }
383
384    #[test]
385    fn verify_signature_missing_header() {
386        let ch = make_channel_with_secret();
387        assert!(!ch.verify_signature(b"body", None));
388    }
389
390    #[test]
391    fn verify_signature_valid() {
392        use hmac::{Hmac, Mac};
393        use sha2::Sha256;
394        type HmacSha256 = Hmac<Sha256>;
395
396        let ch = make_channel_with_secret();
397        let body = b"test body";
398
399        let mut mac = HmacSha256::new_from_slice(b"mysecret").unwrap();
400        mac.update(body);
401        let sig = hex::encode(mac.finalize().into_bytes());
402
403        assert!(ch.verify_signature(body, Some(&sig)));
404    }
405
406    #[test]
407    fn verify_signature_invalid() {
408        let ch = make_channel_with_secret();
409        assert!(!ch.verify_signature(b"body", Some("badhex")));
410    }
411}