1use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use axum::Router;
17use axum::body::Bytes;
18use axum::extract::{DefaultBodyLimit, State};
19use axum::http::{HeaderMap, StatusCode, header};
20use axum::middleware;
21use axum::routing::{get, post};
22use serde_json::{Value, from_slice};
23use tokio::sync::{Mutex, Semaphore};
24use tokio::task::JoinSet;
25use tokio_cron_scheduler::{Job, JobScheduler};
26use tracing::{error, info, warn};
27
28use crate::cron::CronJob;
29use crate::error::RuntimeError;
30use crate::webhook::WebhookAuth;
31
32const DEFAULT_MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
34
35const DEFAULT_MAX_CONCURRENT_HANDLERS: usize = 64;
37
38type WebhookHandler = Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
39
40#[cfg(feature = "prometheus")]
42mod metric_names {
43 pub const WEBHOOK_RECEIVED_TOTAL: &str = "ironflow_webhook_received_total";
44 pub const CRON_RUNS_TOTAL: &str = "ironflow_cron_runs_total";
45
46 pub const AUTH_REJECTED: &str = "rejected";
47 pub const AUTH_ACCEPTED: &str = "accepted";
48 pub const AUTH_INVALID_BODY: &str = "invalid_body";
49}
50
51struct WebhookRoute {
52 path: String,
53 auth: WebhookAuth,
54 handler: WebhookHandler,
55}
56
57pub struct Runtime {
93 webhooks: Vec<WebhookRoute>,
94 crons: Vec<CronJob>,
95 max_body_size: usize,
96 max_concurrent_handlers: usize,
97}
98
99impl Runtime {
100 pub fn new() -> Self {
110 Self {
111 webhooks: Vec::new(),
112 crons: Vec::new(),
113 max_body_size: DEFAULT_MAX_BODY_SIZE,
114 max_concurrent_handlers: DEFAULT_MAX_CONCURRENT_HANDLERS,
115 }
116 }
117
118 pub fn max_body_size(mut self, bytes: usize) -> Self {
131 self.max_body_size = bytes;
132 self
133 }
134
135 pub fn max_concurrent_handlers(mut self, limit: usize) -> Self {
153 assert!(limit > 0, "max_concurrent_handlers must be greater than 0");
154 self.max_concurrent_handlers = limit;
155 self
156 }
157
158 pub fn webhook<F, Fut>(mut self, path: &str, auth: WebhookAuth, handler: F) -> Self
182 where
183 F: Fn(Value) -> Fut + Send + Sync + Clone + 'static,
184 Fut: Future<Output = ()> + Send + 'static,
185 {
186 assert!(
187 path.starts_with('/'),
188 "webhook path must start with '/', got: {path}"
189 );
190 if matches!(auth, WebhookAuth::None) {
191 warn!(path = %path, "webhook registered with WebhookAuth::None - all requests will be accepted without authentication");
192 }
193 let handler: WebhookHandler = Arc::new(move |payload| {
194 let handler = handler.clone();
195 Box::pin(async move { handler(payload).await })
196 });
197 self.webhooks.push(WebhookRoute {
198 path: path.to_string(),
199 auth,
200 handler,
201 });
202 self
203 }
204
205 pub fn cron<F, Fut>(mut self, schedule: &str, name: &str, handler: F) -> Self
227 where
228 F: Fn() -> Fut + Send + Sync + 'static,
229 Fut: Future<Output = ()> + Send + 'static,
230 {
231 let handler_fn: Box<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync> =
232 Box::new(move || Box::pin(handler()));
233 self.crons.push(CronJob {
234 schedule: schedule.to_string(),
235 name: name.to_string(),
236 handler: handler_fn,
237 });
238 self
239 }
240
241 fn build_router(
247 webhooks: Vec<WebhookRoute>,
248 handler_tracker: Arc<HandlerTracker>,
249 max_body_size: usize,
250 #[cfg(feature = "prometheus")] prom_handle: Option<
251 metrics_exporter_prometheus::PrometheusHandle,
252 >,
253 ) -> Router {
254 let mut router = Router::new();
255
256 for webhook in webhooks {
257 let auth = Arc::new(webhook.auth);
258 let handler = webhook.handler;
259 let path = webhook.path.clone();
260
261 let name: Arc<str> = Arc::from(path.as_str());
262 let route_state = WebhookState {
263 auth,
264 handler,
265 name,
266 tracker: handler_tracker.clone(),
267 };
268
269 router = router.route(&path, post(webhook_handler).with_state(route_state));
270 info!(path = %path, "registered webhook");
271 }
272
273 router = router.route("/health", get(|| async { "ok" }));
274
275 #[cfg(feature = "prometheus")]
276 if let Some(handle) = prom_handle {
277 router = router.route(
278 "/metrics",
279 get(move || {
280 let h = handle.clone();
281 async move { h.render() }
282 }),
283 );
284 info!("registered /metrics endpoint");
285 }
286
287 router
288 .layer(middleware::from_fn(security_headers))
289 .layer(DefaultBodyLimit::max(max_body_size))
290 }
291
292 pub fn into_router(self) -> Router {
308 if !self.crons.is_empty() {
309 warn!(
310 cron_count = self.crons.len(),
311 "into_router() drops registered cron jobs - use serve() or run_crons() to start them"
312 );
313 }
314 let tracker = Arc::new(HandlerTracker::new(self.max_concurrent_handlers));
315 Self::build_router(
316 self.webhooks,
317 tracker,
318 self.max_body_size,
319 #[cfg(feature = "prometheus")]
320 None,
321 )
322 }
323
324 async fn start_scheduler(crons: Vec<CronJob>) -> Result<JobScheduler, RuntimeError> {
329 let scheduler = JobScheduler::new().await?;
330
331 for cron_job in crons {
332 let handler = Arc::new(cron_job.handler);
333 let name = cron_job.name.clone();
334 let running = Arc::new(std::sync::atomic::AtomicBool::new(false));
335 let job = Job::new_async(cron_job.schedule.as_str(), move |_uuid, _lock| {
336 let handler = handler.clone();
337 let name = name.clone();
338 let running = running.clone();
339 Box::pin(async move {
340 if running.swap(true, std::sync::atomic::Ordering::AcqRel) {
341 warn!(cron = %name, "cron job still running, skipping this tick");
342 return;
343 }
344 info!(cron = %name, "cron job triggered");
345 #[cfg(feature = "prometheus")]
346 metrics::counter!(metric_names::CRON_RUNS_TOTAL, "job" => name.clone())
347 .increment(1);
348 (handler)().await;
349 running.store(false, std::sync::atomic::Ordering::Release);
350 })
351 })?;
352 info!(cron = %cron_job.name, schedule = %cron_job.schedule, "registered cron job");
353 scheduler.add(job).await?;
354 }
355
356 scheduler.start().await?;
357 Ok(scheduler)
358 }
359
360 pub async fn run_crons(self) -> Result<(), RuntimeError> {
391 let _ = dotenvy::dotenv();
392
393 if !self.webhooks.is_empty() {
394 warn!(
395 webhook_count = self.webhooks.len(),
396 "run_crons() ignores registered webhooks - use serve() to start both webhooks and crons"
397 );
398 }
399
400 #[cfg(feature = "prometheus")]
401 {
402 match metrics_exporter_prometheus::PrometheusBuilder::new().install_recorder() {
403 Ok(_) => info!("prometheus metrics recorder installed"),
404 Err(_) => {
405 info!("prometheus metrics recorder already installed, reusing existing")
406 }
407 }
408 }
409
410 let mut scheduler = Self::start_scheduler(self.crons).await?;
411
412 info!("ironflow cron scheduler running (no HTTP server)");
413 shutdown_signal().await;
414
415 info!("shutting down scheduler");
416 scheduler.shutdown().await.map_err(RuntimeError::Shutdown)?;
417 info!("ironflow cron scheduler stopped");
418
419 Ok(())
420 }
421
422 pub async fn serve(self, addr: &str) -> Result<(), RuntimeError> {
458 let _ = dotenvy::dotenv();
459
460 #[cfg(feature = "prometheus")]
461 let prom_handle = {
462 match metrics_exporter_prometheus::PrometheusBuilder::new().install_recorder() {
463 Ok(handle) => {
464 info!("prometheus metrics recorder installed");
465 Some(handle)
466 }
467 Err(_) => {
468 info!("prometheus metrics recorder already installed, reusing existing");
469 None
470 }
471 }
472 };
473
474 let mut scheduler = Self::start_scheduler(self.crons).await?;
475
476 let tracker = Arc::new(HandlerTracker::new(self.max_concurrent_handlers));
477 let router = Self::build_router(
478 self.webhooks,
479 tracker.clone(),
480 self.max_body_size,
481 #[cfg(feature = "prometheus")]
482 prom_handle,
483 );
484
485 let listener = tokio::net::TcpListener::bind(addr)
486 .await
487 .map_err(RuntimeError::Bind)?;
488 info!(addr = %addr, "ironflow runtime listening");
489
490 axum::serve(listener, router)
491 .with_graceful_shutdown(shutdown_signal())
492 .await
493 .map_err(RuntimeError::Serve)?;
494
495 info!("waiting for in-flight webhook handlers to complete");
497 tracker.wait().await;
498
499 info!("shutting down scheduler");
500 scheduler.shutdown().await.map_err(RuntimeError::Shutdown)?;
501 info!("ironflow runtime stopped");
502
503 Ok(())
504 }
505}
506
507impl Default for Runtime {
508 fn default() -> Self {
509 Self::new()
510 }
511}
512
513struct HandlerTracker {
518 semaphore: Arc<Semaphore>,
519 join_set: Mutex<JoinSet<()>>,
520}
521
522impl HandlerTracker {
523 fn new(max_concurrent: usize) -> Self {
524 Self {
525 semaphore: Arc::new(Semaphore::new(max_concurrent)),
526 join_set: Mutex::new(JoinSet::new()),
527 }
528 }
529
530 async fn spawn(&self, name: String, handler: WebhookHandler, payload: Value) {
532 let semaphore = self.semaphore.clone();
533 let mut js = self.join_set.lock().await;
534 while let Some(result) = js.try_join_next() {
536 if let Err(e) = result {
537 error!(error = %e, "webhook handler panicked");
538 }
539 }
540 use tracing::Instrument;
541 let span = tracing::info_span!("webhook", path = %name);
542 js.spawn(
543 async move {
544 let _permit = semaphore
545 .acquire()
546 .await
547 .expect("semaphore closed unexpectedly");
548 info!("webhook workflow started");
549 handler(payload).await;
550 info!("webhook workflow completed");
551 }
552 .instrument(span),
553 );
554 }
555
556 async fn wait(&self) {
558 let mut js = self.join_set.lock().await;
559 while let Some(result) = js.join_next().await {
560 if let Err(e) = result {
561 error!(error = %e, "webhook handler panicked");
562 }
563 }
564 }
565}
566
567#[derive(Clone)]
568struct WebhookState {
569 auth: Arc<WebhookAuth>,
570 handler: WebhookHandler,
571 name: Arc<str>,
572 tracker: Arc<HandlerTracker>,
573}
574
575async fn webhook_handler(
576 State(state): State<WebhookState>,
577 headers: HeaderMap,
578 body: Bytes,
579) -> StatusCode {
580 let name = &state.name;
581 if !state.auth.verify(&headers, &body) {
582 warn!(webhook = %name, "webhook auth failed");
583 #[cfg(feature = "prometheus")]
584 {
585 let label: String = name.to_string();
586 metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_REJECTED).increment(1);
587 }
588 return StatusCode::UNAUTHORIZED;
589 }
590
591 let payload: Value = match from_slice(&body) {
592 Ok(v) => v,
593 Err(e) => {
594 warn!(webhook = %name, error = %e, "invalid JSON body");
595 #[cfg(feature = "prometheus")]
596 {
597 let label: String = name.to_string();
598 metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_INVALID_BODY).increment(1);
599 }
600 return StatusCode::BAD_REQUEST;
601 }
602 };
603
604 #[cfg(feature = "prometheus")]
605 {
606 let label: String = name.to_string();
607 metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_ACCEPTED).increment(1);
608 }
609
610 state
611 .tracker
612 .spawn(name.to_string(), state.handler.clone(), payload)
613 .await;
614
615 StatusCode::ACCEPTED
616}
617
618async fn security_headers(
619 request: axum::http::Request<axum::body::Body>,
620 next: axum::middleware::Next,
621) -> axum::response::Response {
622 let mut response = next.run(request).await;
623 let headers = response.headers_mut();
624 headers.insert(
625 header::X_CONTENT_TYPE_OPTIONS,
626 "nosniff".parse().expect("valid header value"),
627 );
628 headers.insert(
629 header::X_FRAME_OPTIONS,
630 "DENY".parse().expect("valid header value"),
631 );
632 headers.insert(
633 "x-xss-protection",
634 "1; mode=block".parse().expect("valid header value"),
635 );
636 headers.insert(
637 header::STRICT_TRANSPORT_SECURITY,
638 "max-age=31536000; includeSubDomains"
639 .parse()
640 .expect("valid header value"),
641 );
642 headers.insert(
643 header::CONTENT_SECURITY_POLICY,
644 "default-src 'none'".parse().expect("valid header value"),
645 );
646 response
647}
648
649async fn shutdown_signal() {
650 let ctrl_c = async {
651 if let Err(e) = tokio::signal::ctrl_c().await {
652 warn!("failed to install ctrl+c handler: {e}");
653 }
654 };
655
656 #[cfg(unix)]
657 {
658 use tokio::signal::unix::{SignalKind, signal};
659 let mut sigterm =
660 signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
661 tokio::select! {
662 () = ctrl_c => info!("received SIGINT, shutting down"),
663 _ = sigterm.recv() => info!("received SIGTERM, shutting down"),
664 }
665 }
666
667 #[cfg(not(unix))]
668 {
669 ctrl_c.await;
670 info!("received ctrl+c, shutting down");
671 }
672}