1use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use axum::Router;
17use axum::body::Bytes;
18use axum::extract::{DefaultBodyLimit, State};
19use axum::http::{HeaderMap, StatusCode, header};
20use axum::middleware;
21use axum::routing::{get, post};
22use serde_json::{Value, from_slice};
23use tokio::sync::{Mutex, Semaphore};
24use tokio::task::JoinSet;
25use tokio_cron_scheduler::{Job, JobScheduler};
26use tracing::{error, info, warn};
27
28use crate::cron::CronJob;
29use crate::error::RuntimeError;
30use crate::webhook::WebhookAuth;
31
32const DEFAULT_MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
34
35const DEFAULT_MAX_CONCURRENT_HANDLERS: usize = 64;
37
38type WebhookHandler = Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
39
40#[cfg(feature = "prometheus")]
42mod metric_names {
43 pub const WEBHOOK_RECEIVED_TOTAL: &str = "ironflow_webhook_received_total";
44 pub const CRON_RUNS_TOTAL: &str = "ironflow_cron_runs_total";
45
46 pub const AUTH_REJECTED: &str = "rejected";
47 pub const AUTH_ACCEPTED: &str = "accepted";
48 pub const AUTH_INVALID_BODY: &str = "invalid_body";
49}
50
51struct WebhookRoute {
52 path: String,
53 auth: WebhookAuth,
54 handler: WebhookHandler,
55}
56
57pub struct Runtime {
91 webhooks: Vec<WebhookRoute>,
92 crons: Vec<CronJob>,
93 max_body_size: usize,
94 max_concurrent_handlers: usize,
95}
96
97impl Runtime {
98 pub fn new() -> Self {
108 Self {
109 webhooks: Vec::new(),
110 crons: Vec::new(),
111 max_body_size: DEFAULT_MAX_BODY_SIZE,
112 max_concurrent_handlers: DEFAULT_MAX_CONCURRENT_HANDLERS,
113 }
114 }
115
116 pub fn max_body_size(mut self, bytes: usize) -> Self {
129 self.max_body_size = bytes;
130 self
131 }
132
133 pub fn max_concurrent_handlers(mut self, limit: usize) -> Self {
151 assert!(limit > 0, "max_concurrent_handlers must be greater than 0");
152 self.max_concurrent_handlers = limit;
153 self
154 }
155
156 pub fn webhook<F, Fut>(mut self, path: &str, auth: WebhookAuth, handler: F) -> Self
180 where
181 F: Fn(Value) -> Fut + Send + Sync + Clone + 'static,
182 Fut: Future<Output = ()> + Send + 'static,
183 {
184 assert!(
185 path.starts_with('/'),
186 "webhook path must start with '/', got: {path}"
187 );
188 if matches!(auth, WebhookAuth::None) {
189 warn!(path = %path, "webhook registered with WebhookAuth::None - all requests will be accepted without authentication");
190 }
191 let handler: WebhookHandler = Arc::new(move |payload| {
192 let handler = handler.clone();
193 Box::pin(async move { handler(payload).await })
194 });
195 self.webhooks.push(WebhookRoute {
196 path: path.to_string(),
197 auth,
198 handler,
199 });
200 self
201 }
202
203 pub fn cron<F, Fut>(mut self, schedule: &str, name: &str, handler: F) -> Self
225 where
226 F: Fn() -> Fut + Send + Sync + 'static,
227 Fut: Future<Output = ()> + Send + 'static,
228 {
229 let handler_fn: Box<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync> =
230 Box::new(move || Box::pin(handler()));
231 self.crons.push(CronJob {
232 schedule: schedule.to_string(),
233 name: name.to_string(),
234 handler: handler_fn,
235 });
236 self
237 }
238
239 fn build_router(
245 webhooks: Vec<WebhookRoute>,
246 handler_tracker: Arc<HandlerTracker>,
247 max_body_size: usize,
248 #[cfg(feature = "prometheus")] prom_handle: Option<
249 metrics_exporter_prometheus::PrometheusHandle,
250 >,
251 ) -> Router {
252 let mut router = Router::new();
253
254 for webhook in webhooks {
255 let auth = Arc::new(webhook.auth);
256 let handler = webhook.handler;
257 let path = webhook.path.clone();
258
259 let name: Arc<str> = Arc::from(path.as_str());
260 let route_state = WebhookState {
261 auth,
262 handler,
263 name,
264 tracker: handler_tracker.clone(),
265 };
266
267 router = router.route(&path, post(webhook_handler).with_state(route_state));
268 info!(path = %path, "registered webhook");
269 }
270
271 router = router.route("/health", get(|| async { "ok" }));
272
273 #[cfg(feature = "prometheus")]
274 if let Some(handle) = prom_handle {
275 router = router.route(
276 "/metrics",
277 get(move || {
278 let h = handle.clone();
279 async move { h.render() }
280 }),
281 );
282 info!("registered /metrics endpoint");
283 }
284
285 router
286 .layer(middleware::from_fn(security_headers))
287 .layer(DefaultBodyLimit::max(max_body_size))
288 }
289
290 pub fn into_router(self) -> Router {
306 if !self.crons.is_empty() {
307 warn!(
308 cron_count = self.crons.len(),
309 "into_router() drops registered cron jobs - use serve() to start both webhooks and crons"
310 );
311 }
312 let tracker = Arc::new(HandlerTracker::new(self.max_concurrent_handlers));
313 Self::build_router(
314 self.webhooks,
315 tracker,
316 self.max_body_size,
317 #[cfg(feature = "prometheus")]
318 None,
319 )
320 }
321
322 pub async fn serve(self, addr: &str) -> Result<(), RuntimeError> {
355 let _ = dotenvy::dotenv();
356
357 #[cfg(feature = "prometheus")]
358 let prom_handle = {
359 match metrics_exporter_prometheus::PrometheusBuilder::new().install_recorder() {
360 Ok(handle) => {
361 info!("prometheus metrics recorder installed");
362 Some(handle)
363 }
364 Err(_) => {
365 info!("prometheus metrics recorder already installed, reusing existing");
366 None
367 }
368 }
369 };
370
371 let mut scheduler = JobScheduler::new().await?;
372
373 for cron_job in self.crons {
374 let handler = Arc::new(cron_job.handler);
375 let name = cron_job.name.clone();
376 let running = Arc::new(std::sync::atomic::AtomicBool::new(false));
377 let job = Job::new_async(cron_job.schedule.as_str(), move |_uuid, _lock| {
378 let handler = handler.clone();
379 let name = name.clone();
380 let running = running.clone();
381 Box::pin(async move {
382 if running.swap(true, std::sync::atomic::Ordering::AcqRel) {
383 warn!(cron = %name, "cron job still running, skipping this tick");
384 return;
385 }
386 info!(cron = %name, "cron job triggered");
387 #[cfg(feature = "prometheus")]
388 metrics::counter!(metric_names::CRON_RUNS_TOTAL, "job" => name.clone())
389 .increment(1);
390 (handler)().await;
391 running.store(false, std::sync::atomic::Ordering::Release);
392 })
393 })?;
394 info!(cron = %cron_job.name, schedule = %cron_job.schedule, "registered cron job");
395 scheduler.add(job).await?;
396 }
397
398 scheduler.start().await?;
399
400 let tracker = Arc::new(HandlerTracker::new(self.max_concurrent_handlers));
401 let router = Self::build_router(
402 self.webhooks,
403 tracker.clone(),
404 self.max_body_size,
405 #[cfg(feature = "prometheus")]
406 prom_handle,
407 );
408
409 let listener = tokio::net::TcpListener::bind(addr)
410 .await
411 .map_err(RuntimeError::Bind)?;
412 info!(addr = %addr, "ironflow runtime listening");
413
414 axum::serve(listener, router)
415 .with_graceful_shutdown(shutdown_signal())
416 .await
417 .map_err(RuntimeError::Serve)?;
418
419 info!("waiting for in-flight webhook handlers to complete");
421 tracker.wait().await;
422
423 info!("shutting down scheduler");
424 scheduler.shutdown().await.map_err(RuntimeError::Shutdown)?;
425 info!("ironflow runtime stopped");
426
427 Ok(())
428 }
429}
430
431impl Default for Runtime {
432 fn default() -> Self {
433 Self::new()
434 }
435}
436
437struct HandlerTracker {
442 semaphore: Arc<Semaphore>,
443 join_set: Mutex<JoinSet<()>>,
444}
445
446impl HandlerTracker {
447 fn new(max_concurrent: usize) -> Self {
448 Self {
449 semaphore: Arc::new(Semaphore::new(max_concurrent)),
450 join_set: Mutex::new(JoinSet::new()),
451 }
452 }
453
454 async fn spawn(&self, name: String, handler: WebhookHandler, payload: Value) {
456 let semaphore = self.semaphore.clone();
457 let mut js = self.join_set.lock().await;
458 while let Some(result) = js.try_join_next() {
460 if let Err(e) = result {
461 error!(error = %e, "webhook handler panicked");
462 }
463 }
464 use tracing::Instrument;
465 let span = tracing::info_span!("webhook", path = %name);
466 js.spawn(
467 async move {
468 let _permit = semaphore
469 .acquire()
470 .await
471 .expect("semaphore closed unexpectedly");
472 info!("webhook workflow started");
473 handler(payload).await;
474 info!("webhook workflow completed");
475 }
476 .instrument(span),
477 );
478 }
479
480 async fn wait(&self) {
482 let mut js = self.join_set.lock().await;
483 while let Some(result) = js.join_next().await {
484 if let Err(e) = result {
485 error!(error = %e, "webhook handler panicked");
486 }
487 }
488 }
489}
490
491#[derive(Clone)]
492struct WebhookState {
493 auth: Arc<WebhookAuth>,
494 handler: WebhookHandler,
495 name: Arc<str>,
496 tracker: Arc<HandlerTracker>,
497}
498
499async fn webhook_handler(
500 State(state): State<WebhookState>,
501 headers: HeaderMap,
502 body: Bytes,
503) -> StatusCode {
504 let name = &state.name;
505 if !state.auth.verify(&headers, &body) {
506 warn!(webhook = %name, "webhook auth failed");
507 #[cfg(feature = "prometheus")]
508 {
509 let label: String = name.to_string();
510 metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_REJECTED).increment(1);
511 }
512 return StatusCode::UNAUTHORIZED;
513 }
514
515 let payload: Value = match from_slice(&body) {
516 Ok(v) => v,
517 Err(e) => {
518 warn!(webhook = %name, error = %e, "invalid JSON body");
519 #[cfg(feature = "prometheus")]
520 {
521 let label: String = name.to_string();
522 metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_INVALID_BODY).increment(1);
523 }
524 return StatusCode::BAD_REQUEST;
525 }
526 };
527
528 #[cfg(feature = "prometheus")]
529 {
530 let label: String = name.to_string();
531 metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_ACCEPTED).increment(1);
532 }
533
534 state
535 .tracker
536 .spawn(name.to_string(), state.handler.clone(), payload)
537 .await;
538
539 StatusCode::ACCEPTED
540}
541
542async fn security_headers(
543 request: axum::http::Request<axum::body::Body>,
544 next: axum::middleware::Next,
545) -> axum::response::Response {
546 let mut response = next.run(request).await;
547 let headers = response.headers_mut();
548 headers.insert(
549 header::X_CONTENT_TYPE_OPTIONS,
550 "nosniff".parse().expect("valid header value"),
551 );
552 headers.insert(
553 header::X_FRAME_OPTIONS,
554 "DENY".parse().expect("valid header value"),
555 );
556 headers.insert(
557 "x-xss-protection",
558 "1; mode=block".parse().expect("valid header value"),
559 );
560 headers.insert(
561 header::STRICT_TRANSPORT_SECURITY,
562 "max-age=31536000; includeSubDomains"
563 .parse()
564 .expect("valid header value"),
565 );
566 headers.insert(
567 header::CONTENT_SECURITY_POLICY,
568 "default-src 'none'".parse().expect("valid header value"),
569 );
570 response
571}
572
573async fn shutdown_signal() {
574 let ctrl_c = async {
575 if let Err(e) = tokio::signal::ctrl_c().await {
576 warn!("failed to install ctrl+c handler: {e}");
577 }
578 };
579
580 #[cfg(unix)]
581 {
582 use tokio::signal::unix::{SignalKind, signal};
583 let mut sigterm =
584 signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
585 tokio::select! {
586 () = ctrl_c => info!("received SIGINT, shutting down"),
587 _ = sigterm.recv() => info!("received SIGTERM, shutting down"),
588 }
589 }
590
591 #[cfg(not(unix))]
592 {
593 ctrl_c.await;
594 info!("received ctrl+c, shutting down");
595 }
596}