use super::proto::Client;
use crate::error::Error;
use crate::proto::{Ack, Fail, Job};
use fnv::FnvHashMap;
use std::future::Future;
use std::pin::Pin;
use std::process;
use std::sync::{atomic, Arc};
use std::time::Duration;
use std::{error::Error as StdError, sync::atomic::AtomicUsize};
use tokio::task::{spawn, spawn_blocking, AbortHandle, JoinError, JoinSet};
use tokio::time::sleep as tokio_sleep;
mod builder;
mod health;
mod runner;
mod state;
mod stop;
pub use builder::WorkerBuilder;
pub use runner::JobRunner;
pub use stop::{StopDetails, StopReason};
pub(crate) const STATUS_RUNNING: usize = 0;
pub(crate) const STATUS_QUIET: usize = 1;
pub(crate) const STATUS_TERMINATING: usize = 2;
type ShutdownSignal = Pin<Box<dyn Future<Output = ()> + 'static + Send>>;
pub(crate) enum Callback<E> {
Async(runner::BoxedJobRunner<E>),
Sync(Box<dyn Fn(Job) -> Result<(), E> + Sync + Send + 'static>),
}
type CallbacksRegistry<E> = FnvHashMap<String, Callback<E>>;
pub struct Worker<E> {
c: Client,
worker_states: Arc<state::WorkerStatesRegistry>,
callbacks: Arc<CallbacksRegistry<E>>,
terminated: bool,
forever: bool,
shutdown_timeout: Option<Duration>,
shutdown_signal: Option<ShutdownSignal>,
}
impl Worker<()> {
pub fn builder<E>() -> WorkerBuilder<E> {
WorkerBuilder::default()
}
}
impl<E> Worker<E> {
async fn reconnect(&mut self) -> Result<(), Error> {
self.c.reconnect().await
}
}
impl<E> Worker<E> {
fn new(
c: Client,
workers_count: usize,
callbacks: CallbacksRegistry<E>,
shutdown_timeout: Option<Duration>,
shutdown_signal: Option<ShutdownSignal>,
) -> Self {
Worker {
c,
callbacks: Arc::new(callbacks),
worker_states: Arc::new(state::WorkerStatesRegistry::new(workers_count)),
terminated: false,
forever: false,
shutdown_timeout,
shutdown_signal: Some(
shutdown_signal.unwrap_or_else(|| Box::pin(std::future::pending())),
),
}
}
pub fn is_terminated(&self) -> bool {
self.terminated
}
}
enum Failed<E: StdError, JE: StdError> {
Application(E),
HandlerPanic(JE),
BadJobType(String),
}
impl<E: StdError + 'static + Send> Worker<E> {
async fn run_job(&mut self, job: Job) -> Result<(), Failed<E, JoinError>> {
let handler = self
.callbacks
.get(&job.kind)
.ok_or(Failed::BadJobType(job.kind().to_string()))?;
let spawning_result = match handler {
Callback::Async(_) => {
let callbacks = self.callbacks.clone();
let processing_task = async move {
let callback = callbacks.get(&job.kind).unwrap();
if let Callback::Async(cb) = callback {
cb.run(job).await
} else {
unreachable!()
}
};
spawn(processing_task).await
}
Callback::Sync(_) => {
let callbacks = self.callbacks.clone();
let processing_task = move || {
let callback = callbacks.get(&job.kind).unwrap();
if let Callback::Sync(cb) = callback {
cb(job)
} else {
unreachable!()
}
};
spawn_blocking(processing_task).await
}
};
match spawning_result {
Err(join_error) => Err(Failed::HandlerPanic(join_error)),
Ok(processing_result) => processing_result.map_err(Failed::Application),
}
}
async fn report_on_all_workers(&mut self) -> Result<(), Error> {
let worker_states = Arc::get_mut(&mut self.worker_states)
.expect("all workers are scoped to &mut of the user-code-visible Worker");
for wstate in worker_states {
let wstate = wstate.get_mut().unwrap();
if let Some(res) = wstate.take_last_result() {
let r = match res {
Ok(ref jid) => self.c.issue(&Ack::new(jid.clone())).await,
Err(ref fail) => self.c.issue(fail).await,
};
let r = match r {
Ok(r) => r,
Err(e) => {
wstate.save_last_result(res);
return Err(e);
}
};
if let Err(e) = r.read_ok().await {
if let Error::IO(_) = e {
wstate.save_last_result(res);
return Err(e);
}
}
}
}
Ok(())
}
async fn force_fail_all_workers(&mut self, reason: &str) -> usize {
let mut running = 0;
for wstate in &*self.worker_states {
let may_be_jid = wstate.lock().unwrap().take_currently_running();
if let Some(jid) = may_be_jid {
running += 1;
let f = Fail::generic(jid, reason);
let _ = match self.c.issue(&f).await {
Ok(r) => r.read_ok().await,
Err(_) => continue,
}
.is_ok();
}
}
running
}
pub async fn run_one<Q>(&mut self, worker: usize, queues: &[Q]) -> Result<bool, Error>
where
Q: AsRef<str> + Sync,
{
assert!(
!self.terminated,
"do not re-run a terminated worker (coordinator)"
);
let job = match self.c.fetch(queues).await? {
None => return Ok(false),
Some(j) => j,
};
let jid = job.jid.clone();
self.worker_states.register_running(worker, jid.clone());
match self.run_job(job).await {
Ok(_) => {
self.worker_states.register_success(worker, jid.clone());
self.c.issue(&Ack::new(jid)).await?.read_ok().await?;
}
Err(e) => {
let fail = match e {
Failed::BadJobType(jt) => Fail::generic(jid, format!("No handler for {}", jt)),
Failed::Application(e) => Fail::generic_with_backtrace(jid, e),
Failed::HandlerPanic(e) => {
if e.is_cancelled() {
Fail::generic(jid, "job processing was cancelled")
} else if e.is_panic() {
let panic_obj = e.into_panic();
if panic_obj.is::<String>() {
Fail::generic(jid, *panic_obj.downcast::<String>().unwrap())
} else if panic_obj.is::<&'static str>() {
Fail::generic(jid, *panic_obj.downcast::<&'static str>().unwrap())
} else {
Fail::generic(jid, "job processing panicked")
}
} else {
Fail::generic_with_backtrace(jid, e)
}
}
};
self.worker_states.register_failure(worker, fail.clone());
self.c.issue(&fail).await?.read_ok().await?;
}
}
self.worker_states.reset(worker);
Ok(true)
}
}
impl<E: StdError + 'static + Send> Worker<E> {
async fn for_worker(&mut self) -> Result<Self, Error> {
Ok(Worker {
c: self.c.connect_again().await?,
callbacks: Arc::clone(&self.callbacks),
worker_states: Arc::clone(&self.worker_states),
terminated: self.terminated,
forever: self.forever,
shutdown_timeout: self.shutdown_timeout,
shutdown_signal: Some(Box::pin(std::future::pending())),
})
}
async fn spawn_worker_into<Q>(
&mut self,
set: &mut JoinSet<Result<(), Error>>,
status: Arc<AtomicUsize>,
worker: usize,
queues: &[Q],
) -> Result<AbortHandle, Error>
where
Q: AsRef<str>,
{
let mut w = self.for_worker().await?;
let queues: Vec<_> = queues.iter().map(|s| s.as_ref().to_string()).collect();
Ok(set.spawn(async move {
while status.load(atomic::Ordering::SeqCst) == STATUS_RUNNING {
if let Err(e) = w.run_one(worker, &queues[..]).await {
status.store(STATUS_TERMINATING, atomic::Ordering::SeqCst);
return Err(e);
}
}
status.store(STATUS_TERMINATING, atomic::Ordering::SeqCst);
Ok(())
}))
}
pub async fn run<Q>(&mut self, queues: &[Q]) -> Result<StopDetails, Error>
where
Q: AsRef<str>,
{
assert!(
!self.terminated,
"do not re-run a terminated worker (coordinator)"
);
self.report_on_all_workers().await?;
let nworkers = self.worker_states.len();
let statuses: Vec<_> = (0..nworkers)
.map(|_| Arc::new(atomic::AtomicUsize::new(STATUS_RUNNING)))
.collect();
let mut join_set = JoinSet::new();
for (worker, status) in statuses.iter().enumerate() {
let _abort_handle = self
.spawn_worker_into(&mut join_set, Arc::clone(status), worker, queues)
.await?;
}
let mut shutdown_signal = self
.shutdown_signal
.take()
.expect("see shutdown_signal comment");
self.terminated = true;
let maybe_shutdown_timeout = self.shutdown_timeout;
let report = tokio::select! {
_ = &mut shutdown_signal => {
let nrunning = tokio::select! {
_ = async { tokio_sleep(maybe_shutdown_timeout.unwrap()).await; }, if maybe_shutdown_timeout.is_some() => {
0
},
nrunning = self.force_fail_all_workers("termination signal received from user space") => {
nrunning
}
};
Ok(stop::StopDetails::new(StopReason::GracefulShutdown, nrunning))
},
exit = self.listen_for_heartbeats(&statuses) => {
if exit.is_err() {
self.terminated = false;
}
self.shutdown_signal = Some(shutdown_signal);
if let Ok(true) = exit {
let nrunning = self.force_fail_all_workers("terminated").await;
if nrunning != 0 {
return Ok(stop::StopDetails::new(StopReason::ServerInstruction, nrunning));
}
}
let mut results = Vec::with_capacity(nworkers);
while let Some(res) = join_set.join_next().await {
results.push(res.expect("joined ok"));
}
let results = results.into_iter().collect::<Result<Vec<_>, _>>();
match exit {
Ok(_) => results.map(|_| stop::StopDetails::new(StopReason::ServerInstruction, 0)),
Err(e) => results.and(Err(e)),
}
}
};
report
}
pub async fn run_to_completion<Q>(mut self, queues: &[Q]) -> !
where
Q: AsRef<str>,
{
while self.run(queues).await.is_err() {
if self.reconnect().await.is_err() {
break;
}
}
process::exit(0);
}
}