use crate::command::{CoordinatorCommand, ShutdownMode};
use crate::coordinator::{Coordinator, CoordinatorState};
use crate::error::{BuildError, QueryError, ShutdownError, SubmitError};
use crate::job::{
BoxedExecFn, InstanceId, JobDetails, JobSummary, RecurringJobId, RecurringJobRequest,
};
use crate::metrics::{MetricsSnapshot, SchedulerMetrics};
use crate::worker::Worker;
use async_channel;
use chrono::{DateTime, Utc};
use futures::future::try_join_all;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::sync::{mpsc, oneshot, watch, Mutex, RwLock};
use tokio::task::JoinHandle;
use tracing::{error, info, warn};
use uuid::Uuid;
const DEFAULT_CHANNEL_BOUND: usize = 128; const DEFAULT_JOB_DISPATCH_BOUND: usize = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PriorityQueueType {
BinaryHeap,
HandleBased,
}
#[derive(Debug)]
pub struct SchedulerBuilder {
max_workers: Option<usize>,
pq_type: PriorityQueueType,
staging_buffer_size: usize,
command_buffer_size: usize,
job_dispatch_buffer_size: usize,
}
impl Default for SchedulerBuilder {
fn default() -> Self {
Self {
max_workers: None,
pq_type: PriorityQueueType::HandleBased, staging_buffer_size: DEFAULT_CHANNEL_BOUND,
command_buffer_size: DEFAULT_CHANNEL_BOUND,
job_dispatch_buffer_size: DEFAULT_JOB_DISPATCH_BOUND,
}
}
}
impl SchedulerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_workers(mut self, count: usize) -> Self {
self.max_workers = Some(count);
self
}
pub fn priority_queue(mut self, pq_type: PriorityQueueType) -> Self {
self.pq_type = pq_type;
self
}
pub fn staging_buffer_size(mut self, size: usize) -> Self {
self.staging_buffer_size = size;
self
}
pub fn command_buffer_size(mut self, size: usize) -> Self {
self.command_buffer_size = size;
self
}
pub fn job_dispatch_buffer_size(mut self, size: usize) -> Self {
self.job_dispatch_buffer_size = size.max(1);
self
}
pub fn build(self) -> Result<TurnKeeper, BuildError> {
let max_workers = self
.max_workers
.ok_or(BuildError::MissingOrZeroMaxWorkers)?;
if max_workers == 0 {
warn!("Scheduler built with 0 workers. No jobs will execute.");
}
let metrics = SchedulerMetrics::new();
let job_definitions = Arc::new(RwLock::new(HashMap::new()));
let cancellations = Arc::new(RwLock::new(HashSet::new()));
let instance_to_lineage = Arc::new(RwLock::new(HashMap::new()));
let active_workers_counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let (staging_tx, staging_rx) = mpsc::channel::<(
RecurringJobId, // Add the ID here
RecurringJobRequest,
Arc<BoxedExecFn>,
)>(self.staging_buffer_size);
let (cmd_tx, cmd_rx) = mpsc::channel::<CoordinatorCommand>(self.command_buffer_size);
let (shutdown_tx, shutdown_rx) = watch::channel::<Option<ShutdownMode>>(None);
let (job_dispatch_tx, job_dispatch_rx) =
async_channel::bounded::<(InstanceId, RecurringJobId)>(self.job_dispatch_buffer_size);
let (worker_outcome_tx, worker_outcome_rx) =
mpsc::channel::<crate::command::WorkerOutcome>(self.command_buffer_size);
let coordinator_state = CoordinatorState::new(
self.pq_type,
staging_rx,
cmd_rx,
shutdown_rx.clone(),
job_dispatch_tx,
worker_outcome_rx, job_definitions.clone(),
cancellations.clone(),
instance_to_lineage.clone(), metrics.clone(),
active_workers_counter.clone(),
max_workers,
);
let coordinator_handle = Handle::current().spawn(async move {
let mut coordinator = Coordinator::new(coordinator_state);
coordinator.run().await;
info!("Coordinator task finished.");
});
let mut worker_handles = Vec::with_capacity(max_workers);
for worker_id in 0..max_workers {
let worker_job_definitions = job_definitions.clone();
let worker_metrics = metrics.clone();
let worker_shutdown_rx = shutdown_rx.clone();
let worker_active_counter = active_workers_counter.clone();
let worker_job_dispatch_rx = job_dispatch_rx.clone();
let worker_outcome_tx_clone = worker_outcome_tx.clone();
let handle = Handle::current().spawn(async move {
let mut worker = Worker::new(
worker_id,
worker_job_definitions,
worker_metrics,
worker_shutdown_rx,
worker_outcome_tx_clone, worker_job_dispatch_rx,
worker_active_counter,
);
worker.run().await;
});
worker_handles.push(handle);
}
drop(worker_outcome_tx);
Ok(TurnKeeper {
metrics,
staging_tx,
cmd_tx,
shutdown_tx,
coordinator_handle: Arc::new(Mutex::new(Some(coordinator_handle))),
worker_handles: Arc::new(Mutex::new(worker_handles)),
})
}
}
#[derive(Debug)]
pub struct TurnKeeper {
metrics: SchedulerMetrics, staging_tx: mpsc::Sender<(RecurringJobId, RecurringJobRequest, Arc<BoxedExecFn>)>,
cmd_tx: mpsc::Sender<CoordinatorCommand>,
shutdown_tx: watch::Sender<Option<ShutdownMode>>,
coordinator_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
worker_handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
}
impl TurnKeeper {
pub fn builder() -> SchedulerBuilder {
SchedulerBuilder::new()
}
pub fn try_add_job<F>(
&self,
request: RecurringJobRequest,
exec_fn: F,
) -> Result<RecurringJobId, SubmitError<(RecurringJobRequest, Arc<BoxedExecFn>)>>
where
F: Fn() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static,
{
let lineage_id = Uuid::new_v4();
let boxed_fn = Arc::new(Box::new(exec_fn) as BoxedExecFn);
self
.metrics
.staging_submitted_total
.fetch_add(1, AtomicOrdering::Relaxed);
let send_payload = (lineage_id, request, boxed_fn);
match self.staging_tx.try_send(send_payload) {
Ok(()) => Ok(lineage_id),
Err(mpsc::error::TrySendError::Full((_, req, func))) => {
self
.metrics
.staging_rejected_full
.fetch_add(1, AtomicOrdering::Relaxed);
Err(SubmitError::StagingFull((req, func)))
}
Err(mpsc::error::TrySendError::Closed((_, req, func))) => {
Err(SubmitError::ChannelClosed((req, func)))
}
}
}
pub async fn add_job_async<F>(
&self,
request: RecurringJobRequest,
exec_fn: F,
) -> Result<RecurringJobId, SubmitError<(RecurringJobRequest, Arc<BoxedExecFn>)>>
where
F: Fn() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static,
{
let lineage_id = Uuid::new_v4();
let boxed_fn = Arc::new(Box::new(exec_fn) as BoxedExecFn);
self
.metrics
.staging_submitted_total
.fetch_add(1, AtomicOrdering::Relaxed);
let send_payload = (lineage_id, request, boxed_fn);
self
.staging_tx
.send(send_payload)
.await
.map(|()| lineage_id) .map_err(|mpsc::error::SendError((_, req, func))| SubmitError::ChannelClosed((req, func)))
}
pub async fn get_job_details(&self, job_id: RecurringJobId) -> Result<JobDetails, QueryError> {
let (responder, response_rx) = oneshot::channel();
let cmd = CoordinatorCommand::GetJobDetails { job_id, responder };
self
.cmd_tx
.send(cmd)
.await
.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.await.map_err(|_| QueryError::ResponseFailed)? }
pub async fn list_all_jobs(&self) -> Result<Vec<JobSummary>, QueryError> {
let (responder, response_rx) = oneshot::channel();
let cmd = CoordinatorCommand::ListAllJobs { responder };
self
.cmd_tx
.send(cmd)
.await
.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.await.map_err(|_| QueryError::ResponseFailed)
}
pub async fn get_metrics_snapshot(&self) -> Result<MetricsSnapshot, QueryError> {
let (responder, response_rx) = oneshot::channel();
let cmd = CoordinatorCommand::GetMetricsSnapshot { responder };
self
.cmd_tx
.send(cmd)
.await
.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.await.map_err(|_| QueryError::ResponseFailed)
}
pub async fn cancel_job(&self, job_id: RecurringJobId) -> Result<(), QueryError> {
let (responder, response_rx) = oneshot::channel();
let cmd = CoordinatorCommand::CancelJob { job_id, responder };
self
.cmd_tx
.send(cmd)
.await
.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.await.map_err(|_| QueryError::ResponseFailed)? }
pub async fn shutdown_graceful(&self, timeout: Option<Duration>) -> Result<(), ShutdownError> {
info!("Initiating graceful shutdown...");
self
.shutdown_tx
.send(Some(ShutdownMode::Graceful))
.map_err(|_| ShutdownError::SignalFailed)?;
self.await_shutdown(timeout).await
}
pub async fn shutdown_force(&self, timeout: Option<Duration>) -> Result<(), ShutdownError> {
info!("Initiating forced shutdown...");
self
.shutdown_tx
.send(Some(ShutdownMode::Force))
.map_err(|_| ShutdownError::SignalFailed)?;
self.await_shutdown(timeout).await
}
async fn await_shutdown(&self, timeout_duration: Option<Duration>) -> Result<(), ShutdownError> {
let mut coordinator_handle_opt = self.coordinator_handle.lock().await.take();
let worker_handles = {
let mut guard = self.worker_handles.lock().await;
std::mem::take(&mut *guard) };
let mut tasks = Vec::with_capacity(1 + worker_handles.len());
if let Some(coord_handle) = coordinator_handle_opt.take() {
tasks.push(tokio::spawn(async move {
match coord_handle.await {
Ok(()) => {
info!("Coordinator task joined.");
Ok(())
}
Err(e) => {
error!("Coordinator task panicked: {:?}", e);
Err(ShutdownError::TaskPanic)
}
}
}));
} else {
warn!("Coordinator handle missing during shutdown wait.");
}
for (i, handle) in worker_handles.into_iter().enumerate() {
tasks.push(tokio::spawn(async move {
match handle.await {
Ok(()) => {
Ok(())
} Err(e) => {
error!(worker_id = i, "Worker task panicked: {:?}", e);
Err(ShutdownError::TaskPanic)
}
}
}));
}
if tasks.is_empty() {
warn!("No tasks found to await during shutdown.");
return Ok(());
}
let join_all_fut = try_join_all(tasks);
let result = if let Some(timeout) = timeout_duration {
match tokio::time::timeout(timeout, join_all_fut).await {
Ok(Ok(results)) => {
let _ = results; Ok(())
}
Ok(Err(join_err)) => {
error!("A task panicked during shutdown: {:?}", join_err);
Err(ShutdownError::TaskPanic)
}
Err(_) => {
error!("Shutdown timed out after {:?}", timeout);
Err(ShutdownError::Timeout)
}
}
} else {
match join_all_fut.await {
Ok(results) => {
let _ = results;
Ok(())
}
Err(join_err) => {
error!(
"A task panicked during shutdown (no timeout): {:?}",
join_err
);
Err(ShutdownError::TaskPanic)
}
}
};
if result.is_ok() {
info!("All tasks joined successfully.");
} else {
error!("Error during shutdown task joining: {:?}", result);
}
result
}
}