use super::backoff::Backoff;
use futures::{
channel::{mpsc::Receiver, oneshot::Receiver as OneShotReceiver},
future::{abortable, AbortHandle, Aborted},
lock::Mutex,
prelude::*,
};
use log::{debug, error, info, warn};
use std::{
collections::HashMap,
fmt,
ops::Deref,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use tokio::{
sync::{oneshot, watch::Sender as WatchSender},
task,
task::JoinHandle,
time::sleep,
};
use super::job::Job;
use super::task_manager::{ResourceStatus, TaskManager};
static TASK_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
pub enum JobStatus {
Ready(Option<WatchSender<Option<()>>>),
Startup,
Restarting,
CrashLoopBackOff,
Terminated,
Finished,
}
impl fmt::Display for JobStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
JobStatus::Ready(_) => write!(f, "Ready"),
_ => write!(f, "{:?}", self),
}
}
}
impl PartialEq for JobStatus {
fn eq(&self, other: &Self) -> bool {
matches!(
(self, other),
(&JobStatus::Startup, &JobStatus::Startup)
| (&JobStatus::Restarting, &JobStatus::Restarting)
| (&JobStatus::CrashLoopBackOff, &JobStatus::CrashLoopBackOff)
| (&JobStatus::Terminated, &JobStatus::Terminated)
| (&JobStatus::Ready(_), &JobStatus::Ready(_))
)
}
}
impl Eq for JobStatus {}
impl JobStatus {
fn is_gracefully_terminatable(&self) -> bool {
matches!(*self, JobStatus::Ready(Some(_)))
}
}
#[derive(Default)]
pub struct JobScheduler {
pub(crate) status: Arc<Mutex<HashMap<String, JobStatus>>>,
termination_handles: Arc<Mutex<HashMap<String, AbortHandle>>>,
readiness_oneshots: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
}
impl JobScheduler {
fn add_dependency_watcher(
mut rx: Receiver<ResourceStatus>,
abort_handle: AbortHandle,
) -> AbortableJoinHandle<()> {
spawn_abortable(async move {
#[allow(clippy::never_loop)]
while let Some(status) = rx.next().await {
match status {
ResourceStatus::Dead => {
abort_handle.abort();
break;
}
};
}
})
}
fn add_status_watcher(
readiness_rx: OneShotReceiver<()>,
termination_tx: Option<WatchSender<Option<()>>>,
status_map: Arc<Mutex<HashMap<String, JobStatus>>>,
readiness_oneshots: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
job_name: String,
) -> AbortableJoinHandle<()> {
spawn_abortable(async move {
if readiness_rx.await.is_ok() {
JobScheduler::change_status(
&status_map,
&job_name,
JobStatus::Ready(termination_tx),
)
.await;
if let Some(oneshot) = readiness_oneshots.lock().await.remove(&job_name) {
if oneshot.send(()).is_err() {
log::trace!(
"Failed to react to readiness oneshot, sender might have been dropped."
);
}
}
}
})
}
async fn change_status(
status_map: &Arc<Mutex<HashMap<String, JobStatus>>>,
job_name: &str,
status: JobStatus,
) {
info!("{:<16} {}", format!("{}", status), job_name);
status_map.lock().await.insert(job_name.to_owned(), status);
}
pub fn spawn_task<T, F: 'static + Send, O: 'static + Send, Context>(
task: &T,
ctx: Context,
) -> JoinHandle<Result<O, Aborted>>
where
F: Future<Output = O>,
T: Fn(TaskManager<Context>) -> F,
{
let task_id = TASK_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
let (manager, rx, _, _) = TaskManager::new(task_id, ctx);
let (future, abort_handle) = abortable(task(manager));
let dependency_handle = JobScheduler::add_dependency_watcher(rx, abort_handle);
task::spawn(async move {
let result = future.await;
dependency_handle.cancel();
result
})
}
async fn manage_job_lifecycle<J: 'static + Job + Send>(
job: J,
status_map: Arc<Mutex<HashMap<String, JobStatus>>>,
readiness_oneshots: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
) {
let job_name = job.name().to_owned();
let mut backoff = Backoff::default();
JobScheduler::change_status(&status_map, &job_name, JobStatus::Startup).await;
loop {
let job_instance_id = TASK_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
let (manager, dependency_rx, readiness_rx, termination_tx) =
TaskManager::new(job_instance_id, ());
let wrapped_termination_tx = if job.supports_graceful_termination() {
Some(termination_tx)
} else {
None
};
let instance = job.execute(manager);
let (dependent_future, dependency_abort_handle) = abortable(instance);
let dependency_handle =
JobScheduler::add_dependency_watcher(dependency_rx, dependency_abort_handle);
let status_handle = JobScheduler::add_status_watcher(
readiness_rx,
wrapped_termination_tx,
status_map.clone(),
readiness_oneshots.clone(),
job_name.clone(),
);
let result = dependent_future.await;
dependency_handle.cancel();
status_handle.cancel();
match result {
Ok(return_value) => match return_value {
Ok(_) => {
JobScheduler::change_status(&status_map, &job_name, JobStatus::Finished)
.await;
status_map.lock().await.remove(&job_name);
break;
}
Err(e) => {
error!("{} crashed: {:?}", job_name.clone(), e);
JobScheduler::change_status(
&status_map,
&job_name,
JobStatus::CrashLoopBackOff,
)
.await;
if let Some(sleep_duration) = backoff.next() {
debug!("{} backing off for {:?}", &job_name, sleep_duration);
sleep(sleep_duration).await;
} else {
error!("{} exceeded its retry limit!", &job_name);
JobScheduler::change_status(
&status_map,
&job_name,
JobStatus::Terminated,
)
.await;
return;
}
}
},
Err(_) => warn!("{} lost a resource lock", &job_name),
}
JobScheduler::change_status(&status_map, &job_name, JobStatus::Restarting).await;
}
}
pub async fn spawn_job<J: 'static + Job + Send>(&self, job: J) -> oneshot::Receiver<()> {
let status_map = self.status.clone();
let readiness_oneshots = self.readiness_oneshots.clone();
let termination_handles = self.termination_handles.clone();
let job_name = job.name().to_owned();
let (readiness_tx, readiness_rx) = oneshot::channel();
readiness_oneshots
.lock()
.await
.insert(job_name.clone(), readiness_tx);
let (job_lifecycle, termination_handle) = abortable(JobScheduler::manage_job_lifecycle(
job,
status_map.clone(),
readiness_oneshots.clone(),
));
termination_handles
.lock()
.await
.insert(job_name.clone(), termination_handle);
task::spawn(async move {
if job_lifecycle.await.is_err() {
JobScheduler::change_status(&status_map, &job_name, JobStatus::Terminated).await;
}
termination_handles.lock().await.remove(&job_name);
status_map.lock().await.remove(&job_name);
});
readiness_rx
}
pub async fn terminate_job(&self, name: &String, grace_period: Duration) {
{
let status_map = self.status.lock().await;
if let Some(JobStatus::Ready(Some(job))) = status_map.get(name) {
job.send(Some(())).ok();
} else if let Some(forceful_handle) = self.termination_handles.lock().await.get(name) {
forceful_handle.abort();
}
}
let check_interval = Duration::from_millis(10);
let mut passed_duration = Duration::ZERO;
while passed_duration < grace_period {
{
let termination_handles = self.termination_handles.lock().await;
if !termination_handles.contains_key(name) {
break;
}
}
sleep(check_interval).await;
passed_duration += check_interval;
}
let termination_handles = self.termination_handles.lock().await;
if let Some(handle) = termination_handles.get(name) {
handle.abort();
}
}
pub async fn terminate_jobs(&self, grace_period: Duration) {
{
let status = self.status.lock().await;
for (job_name, status) in status.iter() {
if let JobStatus::Ready(Some(graceful_handle)) = status {
graceful_handle.send(Some(())).ok();
} else if let Some(forceful_handle) =
self.termination_handles.lock().await.get(job_name)
{
forceful_handle.abort();
}
}
}
let check_interval = Duration::from_millis(10);
let mut passed_duration = Duration::ZERO;
while passed_duration < grace_period {
{
let termination_handles = self.termination_handles.lock().await;
let status = self.status.lock().await;
let graceful_handles: Vec<&String> = termination_handles
.keys()
.filter(|job_name| {
if let Some(job_status) = status.get(*job_name) {
job_status.is_gracefully_terminatable()
} else {
false
}
})
.collect();
if graceful_handles.is_empty() {
break;
}
}
sleep(check_interval).await;
passed_duration += check_interval;
}
for (job_name, handle) in self.termination_handles.lock().await.iter() {
warn!("{} ignored graceful termination request", job_name);
handle.abort()
}
}
pub async fn wait_for_ready(&self) {
let mut ready = false;
while !ready {
ready = true;
for (_, status) in self.status.lock().await.iter() {
match status {
JobStatus::Ready(_) => ready = ready && true,
_ => {
ready = false;
break;
}
}
}
sleep(Duration::from_millis(100)).await;
}
}
}
#[macro_export]
macro_rules! schedule {
($scheduler:expr, { $($job:ident$(,)? )+ }) => {
$(
$scheduler.spawn_job($job).await;
)+
};
}
#[macro_export]
macro_rules! schedule_and_wait {
($scheduler:expr, $timeout:expr, { $($job:ident$(,)? )+ }) => {
$(
tokio::time::timeout($timeout, $scheduler.spawn_job($job).await).await?.ok();
)+
};
}
pub struct AbortableJoinHandle<O> {
join_handle: JoinHandle<Result<O, Aborted>>,
abort_handle: AbortHandle,
}
impl<O> AbortableJoinHandle<O> {
pub fn cancel(&self) {
self.abort_handle.abort()
}
}
impl<O> Deref for AbortableJoinHandle<O> {
type Target = JoinHandle<Result<O, Aborted>>;
fn deref(&self) -> &Self::Target {
&self.join_handle
}
}
pub fn spawn_abortable<F: 'static + Send, O: 'static + Send>(fut: F) -> AbortableJoinHandle<O>
where
F: Future<Output = O>,
{
let (future, abort_handle) = abortable(fut);
AbortableJoinHandle {
join_handle: task::spawn(future),
abort_handle,
}
}