xerv_nodes/triggers/
webhook.rs

1//! Webhook trigger (HTTP endpoint).
2//!
3//! Listens for incoming HTTP requests and converts them to trigger events.
4
5use bytes::Bytes;
6use http_body_util::{BodyExt, Full};
7use hyper::server::conn::http1;
8use hyper::service::service_fn;
9use hyper::{Method, Request, Response, StatusCode};
10use hyper_util::rt::TokioIo;
11use parking_lot::RwLock;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use tokio::net::TcpListener;
16use xerv_core::error::{Result, XervError};
17use xerv_core::traits::{Trigger, TriggerConfig, TriggerEvent, TriggerFuture, TriggerType};
18use xerv_core::types::RelPtr;
19
20/// State for the webhook trigger.
21struct WebhookState {
22    /// Whether the trigger is running.
23    running: AtomicBool,
24    /// Whether the trigger is paused.
25    paused: AtomicBool,
26    /// Shutdown signal sender.
27    shutdown_tx: RwLock<Option<tokio::sync::oneshot::Sender<()>>>,
28}
29
30/// HTTP webhook trigger.
31///
32/// Listens on a specified host:port for incoming HTTP requests.
33/// Converts request body to a trigger event.
34///
35/// # Configuration
36///
37/// ```yaml
38/// triggers:
39///   - id: webhook_orders
40///     type: trigger::webhook
41///     params:
42///       host: "0.0.0.0"
43///       port: 8080
44///       path: "/orders"
45///       method: "POST"
46/// ```
47///
48/// # Parameters
49///
50/// - `host` - Host to bind to (default: "0.0.0.0")
51/// - `port` - Port to listen on (default: 8080)
52/// - `path` - URL path to match (default: "/")
53/// - `method` - HTTP method to accept (default: "POST")
54pub struct WebhookTrigger {
55    /// Trigger ID.
56    id: String,
57    /// Host to bind to.
58    host: String,
59    /// Port to listen on.
60    port: u16,
61    /// URL path to match.
62    path: String,
63    /// HTTP method to accept.
64    method: Method,
65    /// Internal state.
66    state: Arc<WebhookState>,
67}
68
69impl WebhookTrigger {
70    /// Create a new webhook trigger.
71    pub fn new(id: impl Into<String>, host: impl Into<String>, port: u16) -> Self {
72        Self {
73            id: id.into(),
74            host: host.into(),
75            port,
76            path: "/".to_string(),
77            method: Method::POST,
78            state: Arc::new(WebhookState {
79                running: AtomicBool::new(false),
80                paused: AtomicBool::new(false),
81                shutdown_tx: RwLock::new(None),
82            }),
83        }
84    }
85
86    /// Create from configuration.
87    pub fn from_config(config: &TriggerConfig) -> Result<Self> {
88        let host = config.get_string("host").unwrap_or("0.0.0.0").to_string();
89        let port = config.get_i64("port").unwrap_or(8080) as u16;
90        let path = config.get_string("path").unwrap_or("/").to_string();
91        let method_str = config.get_string("method").unwrap_or("POST");
92
93        let method = match method_str.to_uppercase().as_str() {
94            "GET" => Method::GET,
95            "POST" => Method::POST,
96            "PUT" => Method::PUT,
97            "DELETE" => Method::DELETE,
98            "PATCH" => Method::PATCH,
99            _ => {
100                return Err(XervError::ConfigValue {
101                    field: "method".to_string(),
102                    cause: format!("Invalid HTTP method: {}", method_str),
103                });
104            }
105        };
106
107        Ok(Self {
108            id: config.id.clone(),
109            host,
110            port,
111            path,
112            method,
113            state: Arc::new(WebhookState {
114                running: AtomicBool::new(false),
115                paused: AtomicBool::new(false),
116                shutdown_tx: RwLock::new(None),
117            }),
118        })
119    }
120
121    /// Set the URL path.
122    pub fn with_path(mut self, path: impl Into<String>) -> Self {
123        self.path = path.into();
124        self
125    }
126
127    /// Set the HTTP method.
128    pub fn with_method(mut self, method: Method) -> Self {
129        self.method = method;
130        self
131    }
132}
133
134impl Trigger for WebhookTrigger {
135    fn trigger_type(&self) -> TriggerType {
136        TriggerType::Webhook
137    }
138
139    fn id(&self) -> &str {
140        &self.id
141    }
142
143    fn start<'a>(
144        &'a self,
145        callback: Box<dyn Fn(TriggerEvent) + Send + Sync + 'static>,
146    ) -> TriggerFuture<'a, ()> {
147        let state = self.state.clone();
148        let host = self.host.clone();
149        let port = self.port;
150        let path = self.path.clone();
151        let method = self.method.clone();
152        let trigger_id = self.id.clone();
153
154        Box::pin(async move {
155            if state.running.load(Ordering::SeqCst) {
156                return Err(XervError::ConfigValue {
157                    field: "trigger".to_string(),
158                    cause: "Trigger is already running".to_string(),
159                });
160            }
161
162            let addr: SocketAddr =
163                format!("{}:{}", host, port)
164                    .parse()
165                    .map_err(|e| XervError::ConfigValue {
166                        field: "host/port".to_string(),
167                        cause: format!("Invalid address: {}", e),
168                    })?;
169
170            let listener = TcpListener::bind(addr)
171                .await
172                .map_err(|e| XervError::Network {
173                    cause: format!("Failed to bind to {}: {}", addr, e),
174                })?;
175
176            tracing::info!(trigger_id = %trigger_id, addr = %addr, "Webhook trigger started");
177            state.running.store(true, Ordering::SeqCst);
178
179            let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
180            *state.shutdown_tx.write() = Some(shutdown_tx);
181
182            let callback = Arc::new(callback);
183
184            loop {
185                tokio::select! {
186                    _ = &mut shutdown_rx => {
187                        tracing::info!(trigger_id = %trigger_id, "Webhook trigger shutting down");
188                        break;
189                    }
190                    result = listener.accept() => {
191                        match result {
192                            Ok((stream, remote_addr)) => {
193                                if state.paused.load(Ordering::SeqCst) {
194                                    tracing::debug!(trigger_id = %trigger_id, "Trigger paused, ignoring request");
195                                    continue;
196                                }
197
198                                let callback = callback.clone();
199                                let trigger_id = trigger_id.clone();
200                                let path = path.clone();
201                                let method = method.clone();
202                                let state = state.clone();
203
204                                tokio::spawn(async move {
205                                    let io = TokioIo::new(stream);
206
207                                    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
208                                        let callback = callback.clone();
209                                        let trigger_id = trigger_id.clone();
210                                        let path = path.clone();
211                                        let method = method.clone();
212                                        let state = state.clone();
213
214                                        async move {
215                                            // Check if paused
216                                            if state.paused.load(Ordering::SeqCst) {
217                                                return Ok::<_, hyper::Error>(Response::builder()
218                                                    .status(StatusCode::SERVICE_UNAVAILABLE)
219                                                    .body(Full::new(Bytes::from("Trigger paused")))
220                                                    .unwrap());
221                                            }
222
223                                            // Check path
224                                            if req.uri().path() != path {
225                                                return Ok(Response::builder()
226                                                    .status(StatusCode::NOT_FOUND)
227                                                    .body(Full::new(Bytes::from("Not found")))
228                                                    .unwrap());
229                                            }
230
231                                            // Check method
232                                            if req.method() != method {
233                                                return Ok(Response::builder()
234                                                    .status(StatusCode::METHOD_NOT_ALLOWED)
235                                                    .body(Full::new(Bytes::from("Method not allowed")))
236                                                    .unwrap());
237                                            }
238
239                                            // Read body
240                                            let body = match req.collect().await {
241                                                Ok(collected) => collected.to_bytes(),
242                                                Err(e) => {
243                                                    tracing::error!(error = %e, "Failed to read request body");
244                                                    return Ok(Response::builder()
245                                                        .status(StatusCode::BAD_REQUEST)
246                                                        .body(Full::new(Bytes::from("Failed to read body")))
247                                                        .unwrap());
248                                                }
249                                            };
250
251                                            // Create event with null pointer (data will be written by pipeline)
252                                            let event = TriggerEvent::new(&trigger_id, RelPtr::null())
253                                                .with_metadata(format!("body_size={}", body.len()));
254
255                                            tracing::debug!(
256                                                trigger_id = %trigger_id,
257                                                trace_id = %event.trace_id,
258                                                body_size = body.len(),
259                                                "Webhook received request"
260                                            );
261
262                                            // Call the callback
263                                            callback(event);
264
265                                            Ok(Response::builder()
266                                                .status(StatusCode::ACCEPTED)
267                                                .body(Full::new(Bytes::from("Event accepted")))
268                                                .unwrap())
269                                        }
270                                    });
271
272                                    if let Err(e) = http1::Builder::new()
273                                        .serve_connection(io, service)
274                                        .await
275                                    {
276                                        tracing::error!(
277                                            error = %e,
278                                            remote_addr = %remote_addr,
279                                            "HTTP connection error"
280                                        );
281                                    }
282                                });
283                            }
284                            Err(e) => {
285                                tracing::error!(error = %e, "Failed to accept connection");
286                            }
287                        }
288                    }
289                }
290            }
291
292            state.running.store(false, Ordering::SeqCst);
293            Ok(())
294        })
295    }
296
297    fn stop<'a>(&'a self) -> TriggerFuture<'a, ()> {
298        let state = self.state.clone();
299        let trigger_id = self.id.clone();
300
301        Box::pin(async move {
302            if let Some(tx) = state.shutdown_tx.write().take() {
303                let _ = tx.send(());
304                tracing::info!(trigger_id = %trigger_id, "Webhook trigger stopped");
305            }
306            state.running.store(false, Ordering::SeqCst);
307            Ok(())
308        })
309    }
310
311    fn pause<'a>(&'a self) -> TriggerFuture<'a, ()> {
312        let state = self.state.clone();
313        let trigger_id = self.id.clone();
314
315        Box::pin(async move {
316            state.paused.store(true, Ordering::SeqCst);
317            tracing::info!(trigger_id = %trigger_id, "Webhook trigger paused");
318            Ok(())
319        })
320    }
321
322    fn resume<'a>(&'a self) -> TriggerFuture<'a, ()> {
323        let state = self.state.clone();
324        let trigger_id = self.id.clone();
325
326        Box::pin(async move {
327            state.paused.store(false, Ordering::SeqCst);
328            tracing::info!(trigger_id = %trigger_id, "Webhook trigger resumed");
329            Ok(())
330        })
331    }
332
333    fn is_running(&self) -> bool {
334        self.state.running.load(Ordering::SeqCst)
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn webhook_trigger_creation() {
344        let trigger = WebhookTrigger::new("test_webhook", "127.0.0.1", 8080);
345        assert_eq!(trigger.id(), "test_webhook");
346        assert_eq!(trigger.trigger_type(), TriggerType::Webhook);
347        assert!(!trigger.is_running());
348    }
349
350    #[test]
351    fn webhook_trigger_from_config() {
352        let mut params = serde_yaml::Mapping::new();
353        params.insert(
354            serde_yaml::Value::String("host".to_string()),
355            serde_yaml::Value::String("localhost".to_string()),
356        );
357        params.insert(
358            serde_yaml::Value::String("port".to_string()),
359            serde_yaml::Value::Number(9090.into()),
360        );
361        params.insert(
362            serde_yaml::Value::String("path".to_string()),
363            serde_yaml::Value::String("/api/events".to_string()),
364        );
365        params.insert(
366            serde_yaml::Value::String("method".to_string()),
367            serde_yaml::Value::String("POST".to_string()),
368        );
369
370        let config = TriggerConfig::new("webhook_test", TriggerType::Webhook)
371            .with_params(serde_yaml::Value::Mapping(params));
372
373        let trigger = WebhookTrigger::from_config(&config).unwrap();
374        assert_eq!(trigger.id(), "webhook_test");
375        assert_eq!(trigger.host, "localhost");
376        assert_eq!(trigger.port, 9090);
377        assert_eq!(trigger.path, "/api/events");
378        assert_eq!(trigger.method, Method::POST);
379    }
380
381    #[test]
382    fn webhook_trigger_builder() {
383        let trigger = WebhookTrigger::new("builder_test", "0.0.0.0", 8080)
384            .with_path("/webhook")
385            .with_method(Method::PUT);
386
387        assert_eq!(trigger.path, "/webhook");
388        assert_eq!(trigger.method, Method::PUT);
389    }
390}