1use bytes::Bytes;
6use http_body_util::{BodyExt, Full};
7use hyper::server::conn::http1;
8use hyper::service::service_fn;
9use hyper::{Method, Request, Response, StatusCode};
10use hyper_util::rt::TokioIo;
11use parking_lot::RwLock;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use tokio::net::TcpListener;
16use xerv_core::error::{Result, XervError};
17use xerv_core::traits::{Trigger, TriggerConfig, TriggerEvent, TriggerFuture, TriggerType};
18use xerv_core::types::RelPtr;
19
20struct WebhookState {
22 running: AtomicBool,
24 paused: AtomicBool,
26 shutdown_tx: RwLock<Option<tokio::sync::oneshot::Sender<()>>>,
28}
29
30pub struct WebhookTrigger {
55 id: String,
57 host: String,
59 port: u16,
61 path: String,
63 method: Method,
65 state: Arc<WebhookState>,
67}
68
69impl WebhookTrigger {
70 pub fn new(id: impl Into<String>, host: impl Into<String>, port: u16) -> Self {
72 Self {
73 id: id.into(),
74 host: host.into(),
75 port,
76 path: "/".to_string(),
77 method: Method::POST,
78 state: Arc::new(WebhookState {
79 running: AtomicBool::new(false),
80 paused: AtomicBool::new(false),
81 shutdown_tx: RwLock::new(None),
82 }),
83 }
84 }
85
86 pub fn from_config(config: &TriggerConfig) -> Result<Self> {
88 let host = config.get_string("host").unwrap_or("0.0.0.0").to_string();
89 let port = config.get_i64("port").unwrap_or(8080) as u16;
90 let path = config.get_string("path").unwrap_or("/").to_string();
91 let method_str = config.get_string("method").unwrap_or("POST");
92
93 let method = match method_str.to_uppercase().as_str() {
94 "GET" => Method::GET,
95 "POST" => Method::POST,
96 "PUT" => Method::PUT,
97 "DELETE" => Method::DELETE,
98 "PATCH" => Method::PATCH,
99 _ => {
100 return Err(XervError::ConfigValue {
101 field: "method".to_string(),
102 cause: format!("Invalid HTTP method: {}", method_str),
103 });
104 }
105 };
106
107 Ok(Self {
108 id: config.id.clone(),
109 host,
110 port,
111 path,
112 method,
113 state: Arc::new(WebhookState {
114 running: AtomicBool::new(false),
115 paused: AtomicBool::new(false),
116 shutdown_tx: RwLock::new(None),
117 }),
118 })
119 }
120
121 pub fn with_path(mut self, path: impl Into<String>) -> Self {
123 self.path = path.into();
124 self
125 }
126
127 pub fn with_method(mut self, method: Method) -> Self {
129 self.method = method;
130 self
131 }
132}
133
134impl Trigger for WebhookTrigger {
135 fn trigger_type(&self) -> TriggerType {
136 TriggerType::Webhook
137 }
138
139 fn id(&self) -> &str {
140 &self.id
141 }
142
143 fn start<'a>(
144 &'a self,
145 callback: Box<dyn Fn(TriggerEvent) + Send + Sync + 'static>,
146 ) -> TriggerFuture<'a, ()> {
147 let state = self.state.clone();
148 let host = self.host.clone();
149 let port = self.port;
150 let path = self.path.clone();
151 let method = self.method.clone();
152 let trigger_id = self.id.clone();
153
154 Box::pin(async move {
155 if state.running.load(Ordering::SeqCst) {
156 return Err(XervError::ConfigValue {
157 field: "trigger".to_string(),
158 cause: "Trigger is already running".to_string(),
159 });
160 }
161
162 let addr: SocketAddr =
163 format!("{}:{}", host, port)
164 .parse()
165 .map_err(|e| XervError::ConfigValue {
166 field: "host/port".to_string(),
167 cause: format!("Invalid address: {}", e),
168 })?;
169
170 let listener = TcpListener::bind(addr)
171 .await
172 .map_err(|e| XervError::Network {
173 cause: format!("Failed to bind to {}: {}", addr, e),
174 })?;
175
176 tracing::info!(trigger_id = %trigger_id, addr = %addr, "Webhook trigger started");
177 state.running.store(true, Ordering::SeqCst);
178
179 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
180 *state.shutdown_tx.write() = Some(shutdown_tx);
181
182 let callback = Arc::new(callback);
183
184 loop {
185 tokio::select! {
186 _ = &mut shutdown_rx => {
187 tracing::info!(trigger_id = %trigger_id, "Webhook trigger shutting down");
188 break;
189 }
190 result = listener.accept() => {
191 match result {
192 Ok((stream, remote_addr)) => {
193 if state.paused.load(Ordering::SeqCst) {
194 tracing::debug!(trigger_id = %trigger_id, "Trigger paused, ignoring request");
195 continue;
196 }
197
198 let callback = callback.clone();
199 let trigger_id = trigger_id.clone();
200 let path = path.clone();
201 let method = method.clone();
202 let state = state.clone();
203
204 tokio::spawn(async move {
205 let io = TokioIo::new(stream);
206
207 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
208 let callback = callback.clone();
209 let trigger_id = trigger_id.clone();
210 let path = path.clone();
211 let method = method.clone();
212 let state = state.clone();
213
214 async move {
215 if state.paused.load(Ordering::SeqCst) {
217 return Ok::<_, hyper::Error>(Response::builder()
218 .status(StatusCode::SERVICE_UNAVAILABLE)
219 .body(Full::new(Bytes::from("Trigger paused")))
220 .unwrap());
221 }
222
223 if req.uri().path() != path {
225 return Ok(Response::builder()
226 .status(StatusCode::NOT_FOUND)
227 .body(Full::new(Bytes::from("Not found")))
228 .unwrap());
229 }
230
231 if req.method() != method {
233 return Ok(Response::builder()
234 .status(StatusCode::METHOD_NOT_ALLOWED)
235 .body(Full::new(Bytes::from("Method not allowed")))
236 .unwrap());
237 }
238
239 let body = match req.collect().await {
241 Ok(collected) => collected.to_bytes(),
242 Err(e) => {
243 tracing::error!(error = %e, "Failed to read request body");
244 return Ok(Response::builder()
245 .status(StatusCode::BAD_REQUEST)
246 .body(Full::new(Bytes::from("Failed to read body")))
247 .unwrap());
248 }
249 };
250
251 let event = TriggerEvent::new(&trigger_id, RelPtr::null())
253 .with_metadata(format!("body_size={}", body.len()));
254
255 tracing::debug!(
256 trigger_id = %trigger_id,
257 trace_id = %event.trace_id,
258 body_size = body.len(),
259 "Webhook received request"
260 );
261
262 callback(event);
264
265 Ok(Response::builder()
266 .status(StatusCode::ACCEPTED)
267 .body(Full::new(Bytes::from("Event accepted")))
268 .unwrap())
269 }
270 });
271
272 if let Err(e) = http1::Builder::new()
273 .serve_connection(io, service)
274 .await
275 {
276 tracing::error!(
277 error = %e,
278 remote_addr = %remote_addr,
279 "HTTP connection error"
280 );
281 }
282 });
283 }
284 Err(e) => {
285 tracing::error!(error = %e, "Failed to accept connection");
286 }
287 }
288 }
289 }
290 }
291
292 state.running.store(false, Ordering::SeqCst);
293 Ok(())
294 })
295 }
296
297 fn stop<'a>(&'a self) -> TriggerFuture<'a, ()> {
298 let state = self.state.clone();
299 let trigger_id = self.id.clone();
300
301 Box::pin(async move {
302 if let Some(tx) = state.shutdown_tx.write().take() {
303 let _ = tx.send(());
304 tracing::info!(trigger_id = %trigger_id, "Webhook trigger stopped");
305 }
306 state.running.store(false, Ordering::SeqCst);
307 Ok(())
308 })
309 }
310
311 fn pause<'a>(&'a self) -> TriggerFuture<'a, ()> {
312 let state = self.state.clone();
313 let trigger_id = self.id.clone();
314
315 Box::pin(async move {
316 state.paused.store(true, Ordering::SeqCst);
317 tracing::info!(trigger_id = %trigger_id, "Webhook trigger paused");
318 Ok(())
319 })
320 }
321
322 fn resume<'a>(&'a self) -> TriggerFuture<'a, ()> {
323 let state = self.state.clone();
324 let trigger_id = self.id.clone();
325
326 Box::pin(async move {
327 state.paused.store(false, Ordering::SeqCst);
328 tracing::info!(trigger_id = %trigger_id, "Webhook trigger resumed");
329 Ok(())
330 })
331 }
332
333 fn is_running(&self) -> bool {
334 self.state.running.load(Ordering::SeqCst)
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn webhook_trigger_creation() {
344 let trigger = WebhookTrigger::new("test_webhook", "127.0.0.1", 8080);
345 assert_eq!(trigger.id(), "test_webhook");
346 assert_eq!(trigger.trigger_type(), TriggerType::Webhook);
347 assert!(!trigger.is_running());
348 }
349
350 #[test]
351 fn webhook_trigger_from_config() {
352 let mut params = serde_yaml::Mapping::new();
353 params.insert(
354 serde_yaml::Value::String("host".to_string()),
355 serde_yaml::Value::String("localhost".to_string()),
356 );
357 params.insert(
358 serde_yaml::Value::String("port".to_string()),
359 serde_yaml::Value::Number(9090.into()),
360 );
361 params.insert(
362 serde_yaml::Value::String("path".to_string()),
363 serde_yaml::Value::String("/api/events".to_string()),
364 );
365 params.insert(
366 serde_yaml::Value::String("method".to_string()),
367 serde_yaml::Value::String("POST".to_string()),
368 );
369
370 let config = TriggerConfig::new("webhook_test", TriggerType::Webhook)
371 .with_params(serde_yaml::Value::Mapping(params));
372
373 let trigger = WebhookTrigger::from_config(&config).unwrap();
374 assert_eq!(trigger.id(), "webhook_test");
375 assert_eq!(trigger.host, "localhost");
376 assert_eq!(trigger.port, 9090);
377 assert_eq!(trigger.path, "/api/events");
378 assert_eq!(trigger.method, Method::POST);
379 }
380
381 #[test]
382 fn webhook_trigger_builder() {
383 let trigger = WebhookTrigger::new("builder_test", "0.0.0.0", 8080)
384 .with_path("/webhook")
385 .with_method(Method::PUT);
386
387 assert_eq!(trigger.path, "/webhook");
388 assert_eq!(trigger.method, Method::PUT);
389 }
390}