mod handlers;
mod state;
pub(crate) use handlers::UpdateJobRequest;
use super::events::EventBus;
use super::executor::TmuxExecutor;
use super::scheduler_runtime;
use super::state_saver::StateSaverHandle;
use axum::{
extract::Request,
http::HeaderValue,
middleware::{self, Next},
response::Response,
routing::{get, post},
Router,
};
use socket2::{Domain, Protocol, Socket, Type};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::Instrument;
pub async fn run(config: gflow::config::Config) -> anyhow::Result<()> {
let state_dir = gflow::paths::get_data_dir()?;
let allowed_gpus = config.daemon.gpus.clone();
let gpu_allocation_strategy = config.daemon.gpu_allocation_strategy;
let gpu_poll_interval_secs = config.daemon.gpu_poll_interval_secs;
let notifications = config.notifications.clone();
let daemon_host = config.daemon.host.clone();
if gpu_poll_interval_secs == 0 {
anyhow::bail!(
"Invalid daemon.gpu_poll_interval_secs '0'. Use a value of at least 1 second."
);
}
let gpu_poll_interval = Duration::from_secs(gpu_poll_interval_secs);
let executor = Box::new(TmuxExecutor);
let (state_tx, state_rx) = tokio::sync::mpsc::unbounded_channel();
let state_saver_handle = StateSaverHandle::new(state_tx);
let mut scheduler_runtime = scheduler_runtime::SchedulerRuntime::with_state_path(
executor,
state_dir,
allowed_gpus,
gpu_allocation_strategy,
config.projects.clone(),
)?;
scheduler_runtime.set_state_saver(state_saver_handle.clone());
let scheduler = Arc::new(tokio::sync::RwLock::new(scheduler_runtime));
let scheduler_clone = Arc::clone(&scheduler);
let event_bus = Arc::new(EventBus::new(1000));
let event_bus_clone = Arc::clone(&event_bus);
let scheduler_for_saver = Arc::clone(&scheduler);
let state_saver_task = tokio::spawn(
async move {
tracing::info!(interval_secs = 30u64, "Starting state saver task");
super::state_saver::run(scheduler_for_saver, state_rx, Duration::from_secs(30)).await;
}
.instrument(tracing::info_span!("state_saver_task")),
);
state_saver_handle.set_task_handle(state_saver_task);
let can_schedule = scheduler.read().await.can_mutate();
if can_schedule {
tokio::spawn(
async move {
tracing::info!("Starting event-driven scheduler");
scheduler_runtime::run_event_driven(
scheduler_clone,
event_bus_clone,
gpu_poll_interval,
)
.await;
}
.instrument(tracing::info_span!("event_driven_scheduler")),
);
} else {
tracing::error!(
persistence_mode = "read_only",
"No persistence available; gflowd started without scheduling or mutation support"
);
}
let server_state = state::ServerState::new(scheduler, event_bus, state_saver_handle.clone());
if notifications.enabled
&& (!notifications.webhooks.is_empty() || !notifications.emails.is_empty())
{
let delivery_semaphore = Arc::new(Semaphore::new(
notifications.max_concurrent_deliveries.max(1),
));
super::webhooks::spawn_webhook_notifier(
notifications.clone(),
Arc::clone(&delivery_semaphore),
Arc::clone(&server_state.scheduler),
Arc::clone(&server_state.event_bus),
daemon_host,
);
super::emails::spawn_email_notifier(
notifications,
delivery_semaphore,
Arc::clone(&server_state.scheduler),
Arc::clone(&server_state.event_bus),
config.daemon.host.clone(),
);
server_state
.event_bus
.publish(super::events::SchedulerEvent::DaemonStarted);
}
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route("/jobs", get(handlers::list_jobs).post(handlers::create_job))
.route("/jobs/batch", post(handlers::create_jobs_batch))
.route(
"/jobs/resolve-dependency",
get(handlers::resolve_dependency),
)
.route(
"/jobs/{id}",
get(handlers::get_job).patch(handlers::update_job),
)
.route("/jobs/{id}/finish", post(handlers::finish_job))
.route("/jobs/{id}/fail", post(handlers::fail_job))
.route("/jobs/{id}/cancel", post(handlers::cancel_job))
.route("/jobs/{id}/hold", post(handlers::hold_job))
.route("/jobs/{id}/release", post(handlers::release_job))
.route("/jobs/{id}/log", get(handlers::get_job_log))
.route("/info", get(handlers::info))
.route("/health", get(handlers::get_health))
.route("/gpus", post(handlers::set_allowed_gpus))
.route("/gpu-processes", get(handlers::list_ignored_gpu_processes))
.route("/gpu-processes/ignore", post(handlers::ignore_gpu_process))
.route(
"/gpu-processes/unignore",
post(handlers::unignore_gpu_process),
)
.route(
"/groups/{group_id}/max-concurrency",
post(handlers::set_group_max_concurrency),
)
.route(
"/reservations",
get(handlers::list_reservations).post(handlers::create_reservation),
)
.route(
"/reservations/{id}",
get(handlers::get_reservation).delete(handlers::cancel_reservation),
)
.route("/stats", get(handlers::get_stats))
.route("/metrics", get(handlers::get_metrics))
.route("/debug/state", get(handlers::debug_state))
.route("/debug/jobs/{id}", get(handlers::debug_job))
.route("/debug/metrics", get(handlers::debug_metrics))
.layer(middleware::from_fn(request_tracing_middleware))
.with_state(server_state);
let host = &config.daemon.host;
let port = config.daemon.port;
let bind_addr = if host.contains(':') && !host.starts_with('[') {
format!("[{host}]:{port}")
} else {
format!("{host}:{port}")
};
let addr = tokio::net::lookup_host(&bind_addr)
.await?
.next()
.ok_or_else(|| anyhow::anyhow!("Failed to resolve address: {}", bind_addr))?;
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
socket.set_reuse_address(true)?;
socket.set_reuse_port(true)?; socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(1024)?;
let std_listener: std::net::TcpListener = socket.into();
std_listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(std_listener)?;
tracing::info!(%addr, reuse_port = true, "Listening for HTTP requests");
let shutdown_signal = create_shutdown_signal(state_saver_handle);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal)
.await?;
tracing::info!("Server shutdown complete");
Ok(())
}
async fn request_tracing_middleware(req: Request, next: Next) -> Response {
let request_id = uuid::Uuid::new_v4().to_string();
let method = req.method().clone();
let uri = req.uri().clone();
let route = uri.path().to_string();
let span = tracing::info_span!(
"http_request",
request_id = %request_id,
method = %method,
route = %route,
uri = %uri
);
async move {
let started_at = std::time::Instant::now();
tracing::info!("Request received");
let mut response = next.run(req).await;
if let Ok(header_value) = HeaderValue::from_str(&request_id) {
response.headers_mut().insert("x-request-id", header_value);
}
tracing::info!(
status = response.status().as_u16(),
latency_ms = started_at.elapsed().as_millis() as u64,
"Request completed"
);
response
}
.instrument(span)
.await
}
async fn create_shutdown_signal(state_saver: StateSaverHandle) {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
let mut sigusr2 =
signal(SignalKind::user_defined2()).expect("Failed to register SIGUSR2 handler");
tokio::select! {
_ = sigterm.recv() => {
tracing::info!(signal = "SIGTERM", "Initiating graceful shutdown");
}
_ = sigint.recv() => {
tracing::info!(signal = "SIGINT", "Initiating graceful shutdown");
}
_ = sigusr2.recv() => {
tracing::info!(signal = "SIGUSR2", reload = true, "Initiating graceful shutdown");
}
}
tracing::info!("Saving state before shutdown");
if let Err(e) = state_saver.shutdown_and_wait().await {
tracing::error!(error = %e, "Failed to save state during shutdown");
} else {
tracing::info!("State saved successfully");
}
}