use crate::Schedule;
use crate::command::{CoordinatorCommand, JobUpdateData, ShutdownMode};
use crate::coordinator::{Coordinator, CoordinatorState};
use crate::error::{BuildError, QueryError, ShutdownError, SubmitError};
use crate::job::{BoxedExecFn, InstanceId, JobDetails, JobSummary, MaxRetries, TKJobId, TKJobRequest};
use crate::metrics::{MetricsSnapshot, SchedulerMetrics};
use crate::worker::Worker;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::time::Duration;
use chrono::{DateTime, Utc};
use fibre::oneshot::oneshot;
use fibre::{SendError, TrySendError, mpmc, mpsc};
use fibre_cache::CacheBuilder;
use futures::future::try_join_all;
use parking_lot::{Mutex, RwLock};
use tokio::runtime::Handle;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tracing::{error, info, warn};
use uuid::Uuid;
const DEFAULT_CHANNEL_BOUND: usize = 128; const DEFAULT_JOB_DISPATCH_BOUND: usize = 1;
type JobDispatchTuple = (InstanceId, TKJobId, DateTime<Utc>);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PriorityQueueType {
BinaryHeap,
HandleBased,
}
#[derive(Debug)]
pub struct SchedulerBuilder {
max_workers: Option<usize>,
pq_type: PriorityQueueType,
staging_buffer_size: usize,
command_buffer_size: usize,
job_dispatch_buffer_size: usize,
}
impl Default for SchedulerBuilder {
fn default() -> Self {
Self {
max_workers: None,
pq_type: PriorityQueueType::HandleBased, staging_buffer_size: DEFAULT_CHANNEL_BOUND,
command_buffer_size: DEFAULT_CHANNEL_BOUND,
job_dispatch_buffer_size: DEFAULT_JOB_DISPATCH_BOUND,
}
}
}
impl SchedulerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_workers(mut self, count: usize) -> Self {
self.max_workers = Some(count);
self
}
pub fn priority_queue(mut self, pq_type: PriorityQueueType) -> Self {
self.pq_type = pq_type;
self
}
pub fn staging_buffer_size(mut self, size: usize) -> Self {
self.staging_buffer_size = size;
self
}
pub fn command_buffer_size(mut self, size: usize) -> Self {
self.command_buffer_size = size;
self
}
pub fn job_dispatch_buffer_size(mut self, size: usize) -> Self {
self.job_dispatch_buffer_size = size.max(1);
self
}
pub fn build(self) -> Result<TurnKeeper, BuildError> {
let max_workers = self.max_workers.ok_or(BuildError::MissingOrZeroMaxWorkers)?;
if max_workers == 0 {
warn!("Scheduler built with 0 workers. No jobs will execute.");
}
let metrics = SchedulerMetrics::new();
let job_definitions = Arc::new(RwLock::new(HashMap::new()));
let cancellations = Arc::new(RwLock::new(HashSet::new()));
let instance_to_lineage = Arc::new(RwLock::new(HashMap::new()));
let active_workers_counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let (staging_tx, staging_rx) = mpsc::bounded_async::<(
TKJobId, // Add the ID here
TKJobRequest,
Arc<BoxedExecFn>,
)>(self.staging_buffer_size);
let (cmd_tx, cmd_rx) = mpsc::bounded_async::<CoordinatorCommand>(self.command_buffer_size);
let (shutdown_tx, shutdown_rx) = watch::channel::<Option<ShutdownMode>>(None);
let (job_dispatch_tx, job_dispatch_rx) = mpmc::bounded_async::<JobDispatchTuple>(self.job_dispatch_buffer_size);
let (worker_outcome_tx, worker_outcome_rx) =
mpsc::bounded_async::<crate::command::WorkerOutcome>(self.command_buffer_size);
let quarantined_jobs = Arc::new(RwLock::new(HashSet::new()));
let mut worker_handles = Vec::with_capacity(max_workers);
for worker_id in 0..max_workers {
let worker_job_definitions = job_definitions.clone();
let worker_metrics = metrics.clone();
let worker_shutdown_rx = shutdown_rx.clone();
let worker_active_counter = active_workers_counter.clone();
let worker_job_dispatch_rx = job_dispatch_rx.clone();
let worker_outcome_tx_clone = worker_outcome_tx.clone();
let handle = Handle::current().spawn(async move {
let mut worker = Worker::new(
worker_id,
worker_job_definitions,
worker_metrics,
worker_shutdown_rx,
worker_outcome_tx_clone, worker_job_dispatch_rx,
worker_active_counter,
);
worker.run().await;
});
worker_handles.push(handle);
}
let job_history = Arc::new(
CacheBuilder::new()
.capacity(10_000)
.time_to_live(Duration::from_secs(3600))
.build()
.expect("Failed to build cache"),
);
let coordinator_state = CoordinatorState::new(
self.pq_type,
staging_rx,
cmd_rx,
shutdown_rx.clone(),
job_dispatch_tx,
job_dispatch_rx,
worker_outcome_tx,
worker_outcome_rx,
job_definitions.clone(),
job_history,
cancellations.clone(),
quarantined_jobs.clone(),
instance_to_lineage.clone(),
metrics.clone(),
active_workers_counter.clone(),
max_workers,
worker_handles,
);
let coordinator_handle = Handle::current().spawn(async move {
let mut coordinator = Coordinator::new(coordinator_state);
coordinator.run().await;
info!("Coordinator task finished.");
});
Ok(TurnKeeper {
metrics,
staging_tx,
cmd_tx,
shutdown_tx,
coordinator_handle: Arc::new(Mutex::new(Some(coordinator_handle))),
})
}
}
pub struct TurnKeeper {
metrics: SchedulerMetrics, staging_tx: mpsc::BoundedAsyncSender<(TKJobId, TKJobRequest, Arc<BoxedExecFn>)>,
cmd_tx: mpsc::BoundedAsyncSender<CoordinatorCommand>,
shutdown_tx: watch::Sender<Option<ShutdownMode>>,
coordinator_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}
impl TurnKeeper {
pub fn builder() -> SchedulerBuilder {
SchedulerBuilder::new()
}
pub fn add_job<F>(
&self,
request: TKJobRequest,
exec_fn: F,
) -> Result<TKJobId, SubmitError<(TKJobRequest, Arc<BoxedExecFn>)>>
where
F: Fn() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static,
{
let lineage_id = Uuid::new_v4();
let boxed_fn = Arc::new(Box::new(exec_fn) as BoxedExecFn);
self
.metrics
.staging_submitted_total
.fetch_add(1, AtomicOrdering::Relaxed);
let req = request.clone();
let send_payload = (lineage_id, request, boxed_fn.clone());
match self.staging_tx.clone().to_sync().send(send_payload) {
Ok(()) => Ok(lineage_id),
Err(SendError::Closed) => {
Err(SubmitError::ChannelClosed((req, boxed_fn)))
}
Err(SendError::Sent) => unreachable!(),
}
}
pub fn try_add_job<F>(
&self,
request: TKJobRequest,
exec_fn: F,
) -> Result<TKJobId, SubmitError<(TKJobRequest, Arc<BoxedExecFn>)>>
where
F: Fn() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static,
{
let lineage_id = Uuid::new_v4();
let boxed_fn = Arc::new(Box::new(exec_fn) as BoxedExecFn);
self
.metrics
.staging_submitted_total
.fetch_add(1, AtomicOrdering::Relaxed);
let send_payload = (lineage_id, request, boxed_fn);
match self.staging_tx.try_send(send_payload) {
Ok(()) => Ok(lineage_id),
Err(TrySendError::Full((_, req, func))) => {
self.metrics.staging_rejected_full.fetch_add(1, AtomicOrdering::Relaxed);
Err(SubmitError::StagingFull((req, func)))
}
Err(TrySendError::Closed((_, req, func))) => {
Err(SubmitError::ChannelClosed((req, func)))
}
Err(TrySendError::Sent(_)) => unreachable!(),
}
}
pub async fn add_job_async<F>(
&self,
request: TKJobRequest,
exec_fn: F,
) -> Result<TKJobId, SubmitError<(TKJobRequest, Arc<BoxedExecFn>)>>
where
F: Fn() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static,
{
let lineage_id = Uuid::new_v4();
let boxed_fn = Arc::new(Box::new(exec_fn) as BoxedExecFn);
self
.metrics
.staging_submitted_total
.fetch_add(1, AtomicOrdering::Relaxed);
let req = request.clone();
let send_payload = (lineage_id, request, boxed_fn.clone());
self
.staging_tx
.send(send_payload)
.await
.map(|()| lineage_id) .map_err(|_| SubmitError::ChannelClosed((req, boxed_fn)))
}
pub async fn get_job_details(&self, job_id: TKJobId) -> Result<JobDetails, QueryError> {
let (responder, response_rx) = oneshot();
let cmd = CoordinatorCommand::GetJobDetails { job_id, responder };
self.cmd_tx.send(cmd).await.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.recv().await.map_err(|_| QueryError::ResponseFailed)? }
pub async fn list_all_jobs(&self) -> Result<Vec<JobSummary>, QueryError> {
let (responder, response_rx) = oneshot();
let cmd = CoordinatorCommand::ListAllJobs { responder };
self.cmd_tx.send(cmd).await.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.recv().await.map_err(|_| QueryError::ResponseFailed)
}
pub async fn get_metrics_snapshot(&self) -> Result<MetricsSnapshot, QueryError> {
let (responder, response_rx) = oneshot();
let cmd = CoordinatorCommand::GetMetricsSnapshot { responder };
self.cmd_tx.send(cmd).await.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.recv().await.map_err(|_| QueryError::ResponseFailed)
}
pub async fn cancel_job(&self, job_id: TKJobId) -> Result<(), QueryError> {
let (responder, response_rx) = oneshot();
let cmd = CoordinatorCommand::CancelJob { job_id, responder };
self.cmd_tx.send(cmd).await.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.recv().await.map_err(|_| QueryError::ResponseFailed)? }
pub async fn update_job(
&self,
job_id: TKJobId,
schedule: Option<Schedule>,
max_retries: Option<MaxRetries>,
) -> Result<(), QueryError> {
let (responder, response_rx) = oneshot();
let update_data = JobUpdateData { schedule, max_retries };
let cmd = CoordinatorCommand::UpdateJob {
job_id,
update_data,
responder,
};
self.cmd_tx.send(cmd).await.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.recv().await.map_err(|_| QueryError::ResponseFailed)? }
pub async fn trigger_job_now(&self, job_id: TKJobId) -> Result<(), QueryError> {
let (responder, response_rx) = oneshot();
let cmd = CoordinatorCommand::TriggerJobNow { job_id, responder };
self.cmd_tx.send(cmd).await.map_err(|_| QueryError::SchedulerShutdown)?;
response_rx.recv().await.map_err(|_| QueryError::ResponseFailed)? }
pub async fn shutdown_graceful(&self, timeout: Option<Duration>) -> Result<(), ShutdownError> {
info!("Initiating graceful shutdown...");
self
.shutdown_tx
.send(Some(ShutdownMode::Graceful))
.map_err(|_| ShutdownError::SignalFailed)?;
self.await_shutdown(timeout).await
}
pub async fn shutdown_force(&self, timeout: Option<Duration>) -> Result<(), ShutdownError> {
info!("Initiating forced shutdown...");
self
.shutdown_tx
.send(Some(ShutdownMode::Force))
.map_err(|_| ShutdownError::SignalFailed)?;
self.await_shutdown(timeout).await
}
async fn await_shutdown(&self, timeout_duration: Option<Duration>) -> Result<(), ShutdownError> {
let coordinator_handle = match self.coordinator_handle.lock().take() {
Some(handle) => handle,
None => {
warn!("Shutdown called, but coordinator handle was already taken (already shut down?).");
return Err(ShutdownError::AlreadyShuttingDown);
}
};
let join_fut = async {
match coordinator_handle.await {
Ok(()) => {
info!("Coordinator task joined successfully.");
Ok(())
}
Err(e) => {
error!("Coordinator task panicked during shutdown: {:?}", e);
Err(ShutdownError::TaskPanic)
}
}
};
if let Some(timeout) = timeout_duration {
match tokio::time::timeout(timeout, join_fut).await {
Ok(Ok(())) => Ok(()), Ok(Err(e)) => Err(e), Err(_) => {
error!("Shutdown timed out waiting for the coordinator after {:?}", timeout);
Err(ShutdownError::Timeout)
}
}
} else {
join_fut.await
}
}
}