use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use axum::Router;
use axum::body::Bytes;
use axum::extract::{DefaultBodyLimit, State};
use axum::http::{HeaderMap, StatusCode, header};
use axum::middleware;
use axum::routing::{get, post};
use serde_json::{Value, from_slice};
use tokio::sync::{Mutex, Semaphore};
use tokio::task::JoinSet;
use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{error, info, warn};
use crate::cron::CronJob;
use crate::error::RuntimeError;
use crate::webhook::WebhookAuth;
const DEFAULT_MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
const DEFAULT_MAX_CONCURRENT_HANDLERS: usize = 64;
type WebhookHandler = Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
type ShutdownSignal = Pin<Box<dyn Future<Output = ()> + Send>>;
#[cfg(feature = "prometheus")]
mod metric_names {
pub const WEBHOOK_RECEIVED_TOTAL: &str = "ironflow_webhook_received_total";
pub const CRON_RUNS_TOTAL: &str = "ironflow_cron_runs_total";
pub const AUTH_REJECTED: &str = "rejected";
pub const AUTH_ACCEPTED: &str = "accepted";
pub const AUTH_INVALID_BODY: &str = "invalid_body";
}
struct WebhookRoute {
path: String,
auth: WebhookAuth,
handler: WebhookHandler,
}
pub struct Runtime {
webhooks: Vec<WebhookRoute>,
crons: Vec<CronJob>,
max_body_size: usize,
max_concurrent_handlers: usize,
custom_shutdown: Option<ShutdownSignal>,
}
impl Runtime {
pub fn new() -> Self {
Self {
webhooks: Vec::new(),
crons: Vec::new(),
max_body_size: DEFAULT_MAX_BODY_SIZE,
max_concurrent_handlers: DEFAULT_MAX_CONCURRENT_HANDLERS,
custom_shutdown: None,
}
}
pub fn max_body_size(mut self, bytes: usize) -> Self {
self.max_body_size = bytes;
self
}
pub fn max_concurrent_handlers(mut self, limit: usize) -> Self {
assert!(limit > 0, "max_concurrent_handlers must be greater than 0");
self.max_concurrent_handlers = limit;
self
}
pub fn with_shutdown<F>(mut self, signal: F) -> Self
where
F: Future<Output = ()> + Send + 'static,
{
self.custom_shutdown = Some(Box::pin(signal));
self
}
pub fn webhook<F, Fut>(mut self, path: &str, auth: WebhookAuth, handler: F) -> Self
where
F: Fn(Value) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
assert!(
path.starts_with('/'),
"webhook path must start with '/', got: {path}"
);
if matches!(auth, WebhookAuth::None) {
warn!(path = %path, "webhook registered with WebhookAuth::None - all requests will be accepted without authentication");
}
let handler: WebhookHandler = Arc::new(move |payload| {
let handler = handler.clone();
Box::pin(async move { handler(payload).await })
});
self.webhooks.push(WebhookRoute {
path: path.to_string(),
auth,
handler,
});
self
}
pub fn cron<F, Fut>(mut self, schedule: &str, name: &str, handler: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let handler_fn: Box<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync> =
Box::new(move || Box::pin(handler()));
self.crons.push(CronJob {
schedule: schedule.to_string(),
name: name.to_string(),
handler: handler_fn,
});
self
}
fn build_router(
webhooks: Vec<WebhookRoute>,
handler_tracker: Arc<HandlerTracker>,
max_body_size: usize,
#[cfg(feature = "prometheus")] prom_handle: Option<
metrics_exporter_prometheus::PrometheusHandle,
>,
) -> Router {
let mut router = Router::new();
for webhook in webhooks {
let auth = Arc::new(webhook.auth);
let handler = webhook.handler;
let path = webhook.path.clone();
let name: Arc<str> = Arc::from(path.as_str());
let route_state = WebhookState {
auth,
handler,
name,
tracker: handler_tracker.clone(),
};
router = router.route(&path, post(webhook_handler).with_state(route_state));
info!(path = %path, "registered webhook");
}
router = router.route("/health", get(|| async { "ok" }));
#[cfg(feature = "prometheus")]
if let Some(handle) = prom_handle {
router = router.route(
"/metrics",
get(move || {
let h = handle.clone();
async move { h.render() }
}),
);
info!("registered /metrics endpoint");
}
router
.layer(middleware::from_fn(security_headers))
.layer(DefaultBodyLimit::max(max_body_size))
}
pub fn into_router(self) -> Router {
if !self.crons.is_empty() {
warn!(
cron_count = self.crons.len(),
"into_router() drops registered cron jobs - use serve() or run_crons() to start them"
);
}
let tracker = Arc::new(HandlerTracker::new(self.max_concurrent_handlers));
Self::build_router(
self.webhooks,
tracker,
self.max_body_size,
#[cfg(feature = "prometheus")]
None,
)
}
async fn start_scheduler(crons: Vec<CronJob>) -> Result<JobScheduler, RuntimeError> {
let scheduler = JobScheduler::new().await?;
for cron_job in crons {
let handler = Arc::new(cron_job.handler);
let name = cron_job.name.clone();
let running = Arc::new(std::sync::atomic::AtomicBool::new(false));
let job = Job::new_async(cron_job.schedule.as_str(), move |_uuid, _lock| {
let handler = handler.clone();
let name = name.clone();
let running = running.clone();
Box::pin(async move {
if running.swap(true, std::sync::atomic::Ordering::AcqRel) {
warn!(cron = %name, "cron job still running, skipping this tick");
return;
}
info!(cron = %name, "cron job triggered");
#[cfg(feature = "prometheus")]
metrics::counter!(metric_names::CRON_RUNS_TOTAL, "job" => name.clone())
.increment(1);
(handler)().await;
running.store(false, std::sync::atomic::Ordering::Release);
})
})?;
info!(cron = %cron_job.name, schedule = %cron_job.schedule, "registered cron job");
scheduler.add(job).await?;
}
scheduler.start().await?;
Ok(scheduler)
}
pub async fn run_crons(self) -> Result<(), RuntimeError> {
let _ = dotenvy::dotenv();
if !self.webhooks.is_empty() {
warn!(
webhook_count = self.webhooks.len(),
"run_crons() ignores registered webhooks - use serve() to start both webhooks and crons"
);
}
#[cfg(feature = "prometheus")]
{
match metrics_exporter_prometheus::PrometheusBuilder::new().install_recorder() {
Ok(_) => info!("prometheus metrics recorder installed"),
Err(_) => {
info!("prometheus metrics recorder already installed, reusing existing")
}
}
}
let mut scheduler = Self::start_scheduler(self.crons).await?;
info!("ironflow cron scheduler running (no HTTP server)");
match self.custom_shutdown {
Some(signal) => signal.await,
None => shutdown_signal().await,
}
info!("shutting down scheduler");
scheduler.shutdown().await.map_err(RuntimeError::Shutdown)?;
info!("ironflow cron scheduler stopped");
Ok(())
}
pub async fn serve(self, addr: &str) -> Result<(), RuntimeError> {
let _ = dotenvy::dotenv();
#[cfg(feature = "prometheus")]
let prom_handle = {
match metrics_exporter_prometheus::PrometheusBuilder::new().install_recorder() {
Ok(handle) => {
info!("prometheus metrics recorder installed");
Some(handle)
}
Err(_) => {
info!("prometheus metrics recorder already installed, reusing existing");
None
}
}
};
let mut scheduler = Self::start_scheduler(self.crons).await?;
let tracker = Arc::new(HandlerTracker::new(self.max_concurrent_handlers));
let router = Self::build_router(
self.webhooks,
tracker.clone(),
self.max_body_size,
#[cfg(feature = "prometheus")]
prom_handle,
);
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(RuntimeError::Bind)?;
info!(addr = %addr, "ironflow runtime listening");
let graceful_shutdown = match self.custom_shutdown {
Some(signal) => signal,
None => Box::pin(shutdown_signal()),
};
axum::serve(listener, router)
.with_graceful_shutdown(graceful_shutdown)
.await
.map_err(RuntimeError::Serve)?;
info!("waiting for in-flight webhook handlers to complete");
tracker.wait().await;
info!("shutting down scheduler");
scheduler.shutdown().await.map_err(RuntimeError::Shutdown)?;
info!("ironflow runtime stopped");
Ok(())
}
}
impl Default for Runtime {
fn default() -> Self {
Self::new()
}
}
struct HandlerTracker {
semaphore: Arc<Semaphore>,
join_set: Mutex<JoinSet<()>>,
}
impl HandlerTracker {
fn new(max_concurrent: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
join_set: Mutex::new(JoinSet::new()),
}
}
async fn spawn(&self, name: String, handler: WebhookHandler, payload: Value) {
let semaphore = self.semaphore.clone();
let mut js = self.join_set.lock().await;
while let Some(result) = js.try_join_next() {
if let Err(e) = result {
error!(error = %e, "webhook handler panicked");
}
}
use tracing::Instrument;
let span = tracing::info_span!("webhook", path = %name);
js.spawn(
async move {
let _permit = semaphore
.acquire()
.await
.expect("semaphore closed unexpectedly");
info!("webhook workflow started");
handler(payload).await;
info!("webhook workflow completed");
}
.instrument(span),
);
}
async fn wait(&self) {
let mut js = self.join_set.lock().await;
while let Some(result) = js.join_next().await {
if let Err(e) = result {
error!(error = %e, "webhook handler panicked");
}
}
}
}
#[derive(Clone)]
struct WebhookState {
auth: Arc<WebhookAuth>,
handler: WebhookHandler,
name: Arc<str>,
tracker: Arc<HandlerTracker>,
}
async fn webhook_handler(
State(state): State<WebhookState>,
headers: HeaderMap,
body: Bytes,
) -> StatusCode {
let name = &state.name;
if !state.auth.verify(&headers, &body) {
warn!(webhook = %name, "webhook auth failed");
#[cfg(feature = "prometheus")]
{
let label: String = name.to_string();
metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_REJECTED).increment(1);
}
return StatusCode::UNAUTHORIZED;
}
let payload: Value = match from_slice(&body) {
Ok(v) => v,
Err(e) => {
warn!(webhook = %name, error = %e, "invalid JSON body");
#[cfg(feature = "prometheus")]
{
let label: String = name.to_string();
metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_INVALID_BODY).increment(1);
}
return StatusCode::BAD_REQUEST;
}
};
#[cfg(feature = "prometheus")]
{
let label: String = name.to_string();
metrics::counter!(metric_names::WEBHOOK_RECEIVED_TOTAL, "path" => label, "auth" => metric_names::AUTH_ACCEPTED).increment(1);
}
state
.tracker
.spawn(name.to_string(), state.handler.clone(), payload)
.await;
StatusCode::ACCEPTED
}
async fn security_headers(
request: axum::http::Request<axum::body::Body>,
next: axum::middleware::Next,
) -> axum::response::Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().expect("valid header value"),
);
headers.insert(
header::X_FRAME_OPTIONS,
"DENY".parse().expect("valid header value"),
);
headers.insert(
"x-xss-protection",
"1; mode=block".parse().expect("valid header value"),
);
headers.insert(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains"
.parse()
.expect("valid header value"),
);
headers.insert(
header::CONTENT_SECURITY_POLICY,
"default-src 'none'".parse().expect("valid header value"),
);
response
}
async fn shutdown_signal() {
let ctrl_c = async {
if let Err(e) = tokio::signal::ctrl_c().await {
warn!("failed to install ctrl+c handler: {e}");
}
};
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm =
signal(SignalKind::terminate()).expect("failed to install SIGTERM handler");
tokio::select! {
() = ctrl_c => info!("received SIGINT, shutting down"),
_ = sigterm.recv() => info!("received SIGTERM, shutting down"),
}
}
#[cfg(not(unix))]
{
ctrl_c.await;
info!("received ctrl+c, shutting down");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn runtime_new_creates_with_defaults() {
let rt = Runtime::new();
assert_eq!(rt.webhooks.len(), 0);
assert_eq!(rt.crons.len(), 0);
assert_eq!(rt.max_body_size, DEFAULT_MAX_BODY_SIZE);
assert_eq!(rt.max_concurrent_handlers, DEFAULT_MAX_CONCURRENT_HANDLERS);
assert!(rt.custom_shutdown.is_none());
}
#[test]
fn runtime_default_equals_new() {
let rt_new = Runtime::new();
let rt_default = Runtime::default();
assert_eq!(rt_new.webhooks.len(), rt_default.webhooks.len());
assert_eq!(rt_new.crons.len(), rt_default.crons.len());
assert_eq!(rt_new.max_body_size, rt_default.max_body_size);
assert_eq!(
rt_new.max_concurrent_handlers,
rt_default.max_concurrent_handlers
);
}
#[test]
fn max_body_size_sets_value_and_returns_self() {
let rt = Runtime::new().max_body_size(512 * 1024);
assert_eq!(rt.max_body_size, 512 * 1024);
}
#[test]
fn max_body_size_chainable() {
let rt =
Runtime::new()
.max_body_size(1024)
.webhook("/test", WebhookAuth::none(), |_| async {});
assert_eq!(rt.max_body_size, 1024);
assert_eq!(rt.webhooks.len(), 1);
}
#[test]
fn max_body_size_can_be_zero() {
let rt = Runtime::new().max_body_size(0);
assert_eq!(rt.max_body_size, 0);
}
#[test]
fn max_body_size_can_be_large() {
let large_size = 1024 * 1024 * 1024; let rt = Runtime::new().max_body_size(large_size);
assert_eq!(rt.max_body_size, large_size);
}
#[test]
#[should_panic(expected = "max_concurrent_handlers must be greater than 0")]
fn max_concurrent_handlers_zero_panics() {
let _ = Runtime::new().max_concurrent_handlers(0);
}
#[test]
fn max_concurrent_handlers_sets_valid_values() {
let rt = Runtime::new().max_concurrent_handlers(16);
assert_eq!(rt.max_concurrent_handlers, 16);
}
#[test]
fn max_concurrent_handlers_one_is_valid() {
let rt = Runtime::new().max_concurrent_handlers(1);
assert_eq!(rt.max_concurrent_handlers, 1);
}
#[test]
fn max_concurrent_handlers_large_value_is_valid() {
let large_limit = 10000;
let rt = Runtime::new().max_concurrent_handlers(large_limit);
assert_eq!(rt.max_concurrent_handlers, large_limit);
}
#[test]
fn max_concurrent_handlers_chainable() {
let rt = Runtime::new().max_concurrent_handlers(32).webhook(
"/test",
WebhookAuth::none(),
|_| async {},
);
assert_eq!(rt.max_concurrent_handlers, 32);
assert_eq!(rt.webhooks.len(), 1);
}
#[tokio::test]
async fn with_shutdown_sets_signal_and_returns_self() {
let (tx, rx) = tokio::sync::oneshot::channel();
let rt = Runtime::new().with_shutdown(async move {
let _ = rx.await;
});
assert!(rt.custom_shutdown.is_some());
let _ = tx.send(());
}
#[tokio::test]
async fn with_shutdown_chainable() {
let (tx, rx) = tokio::sync::oneshot::channel();
let rt = Runtime::new()
.with_shutdown(async move {
let _ = rx.await;
})
.webhook("/test", WebhookAuth::none(), |_| async {});
assert!(rt.custom_shutdown.is_some());
assert_eq!(rt.webhooks.len(), 1);
let _ = tx.send(());
}
#[test]
fn webhook_registers_route_and_returns_self() {
let rt = Runtime::new().webhook("/hooks/test", WebhookAuth::none(), |_| async {});
assert_eq!(rt.webhooks.len(), 1);
assert_eq!(rt.webhooks[0].path, "/hooks/test");
}
#[test]
#[should_panic(expected = "webhook path must start with '/'")]
fn webhook_path_without_slash_panics() {
let _ = Runtime::new().webhook("no-slash", WebhookAuth::none(), |_| async {});
}
#[test]
fn webhook_accepts_valid_paths() {
let rt = Runtime::new()
.webhook("/", WebhookAuth::none(), |_| async {})
.webhook("/simple", WebhookAuth::none(), |_| async {})
.webhook("/nested/path", WebhookAuth::none(), |_| async {})
.webhook("/with-dashes", WebhookAuth::none(), |_| async {})
.webhook("/with_underscores", WebhookAuth::none(), |_| async {})
.webhook("/with/numbers/123", WebhookAuth::none(), |_| async {});
assert_eq!(rt.webhooks.len(), 6);
}
#[test]
fn webhook_chainable() {
let rt = Runtime::new()
.webhook("/hook-a", WebhookAuth::none(), |_| async {})
.webhook("/hook-b", WebhookAuth::none(), |_| async {})
.webhook("/hook-c", WebhookAuth::none(), |_| async {});
assert_eq!(rt.webhooks.len(), 3);
assert_eq!(rt.webhooks[0].path, "/hook-a");
assert_eq!(rt.webhooks[1].path, "/hook-b");
assert_eq!(rt.webhooks[2].path, "/hook-c");
}
#[test]
fn webhook_with_various_auth_types() {
let rt = Runtime::new()
.webhook("/none", WebhookAuth::none(), |_| async {})
.webhook(
"/header",
WebhookAuth::header("x-api-key", "secret"),
|_| async {},
)
.webhook("/github", WebhookAuth::github("secret"), |_| async {})
.webhook("/gitlab", WebhookAuth::gitlab("token"), |_| async {});
assert_eq!(rt.webhooks.len(), 4);
}
#[test]
fn cron_registers_job_and_returns_self() {
let rt = Runtime::new().cron("0 0 * * * *", "daily-task", || async {});
assert_eq!(rt.crons.len(), 1);
assert_eq!(rt.crons[0].name, "daily-task");
assert_eq!(rt.crons[0].schedule, "0 0 * * * *");
}
#[test]
fn cron_chainable() {
let rt = Runtime::new()
.cron("0 0 * * * *", "midnight", || async {})
.cron("0 */5 * * * *", "every-5-minutes", || async {});
assert_eq!(rt.crons.len(), 2);
}
#[test]
fn cron_preserves_schedule_and_name() {
let rt = Runtime::new()
.cron("0 12 * * * MON", "noon-mondays", || async {})
.cron("0 0 1 * * *", "first-of-month", || async {});
assert_eq!(rt.crons[0].name, "noon-mondays");
assert_eq!(rt.crons[0].schedule, "0 12 * * * MON");
assert_eq!(rt.crons[1].name, "first-of-month");
assert_eq!(rt.crons[1].schedule, "0 0 1 * * *");
}
#[test]
fn into_router_returns_router() {
let rt = Runtime::new();
let _router = rt.into_router();
}
#[test]
fn into_router_with_webhooks_returns_router() {
let rt = Runtime::new()
.webhook("/hook-a", WebhookAuth::none(), |_| async {})
.webhook("/hook-b", WebhookAuth::github("secret"), |_| async {});
let _router = rt.into_router();
}
#[test]
fn into_router_with_crons_returns_router() {
let rt = Runtime::new()
.cron("0 0 * * * *", "daily", || async {})
.cron("0 */5 * * * *", "every-5-min", || async {});
let _router = rt.into_router();
}
#[test]
fn into_router_respects_max_body_size_config() {
let rt =
Runtime::new()
.max_body_size(100)
.webhook("/hook", WebhookAuth::none(), |_| async {});
let _router = rt.into_router();
}
#[test]
fn into_router_respects_max_concurrent_handlers_config() {
let rt = Runtime::new().max_concurrent_handlers(16).webhook(
"/hook",
WebhookAuth::none(),
|_| async {},
);
let _router = rt.into_router();
}
#[test]
fn builder_chain_multiple_methods() {
let rt = Runtime::new()
.max_body_size(512 * 1024)
.max_concurrent_handlers(32)
.webhook("/hook-a", WebhookAuth::none(), |_| async {})
.webhook("/hook-b", WebhookAuth::github("secret"), |_| async {})
.cron("0 0 * * * *", "daily", || async {});
assert_eq!(rt.max_body_size, 512 * 1024);
assert_eq!(rt.max_concurrent_handlers, 32);
assert_eq!(rt.webhooks.len(), 2);
assert_eq!(rt.crons.len(), 1);
}
#[test]
fn into_router_with_crons_doesnt_start_them() {
let rt = Runtime::new().cron("0 0 * * * *", "test-cron", || async {});
let _router = rt.into_router();
}
}