use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn};
use crate::error::Result;
use crate::job::JobRecord;
use crate::queue::Queue;
pub type WorkerError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub struct PermanentFailure {
pub reason: String,
}
impl PermanentFailure {
pub fn new(reason: impl Into<String>) -> Self {
Self {
reason: reason.into(),
}
}
}
impl std::fmt::Display for PermanentFailure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.reason)
}
}
impl std::error::Error for PermanentFailure {}
pub trait Worker: Send + Sync {
fn process(
&self,
job: &JobRecord,
) -> impl Future<Output = std::result::Result<(), WorkerError>> + Send;
}
pub async fn run_worker<W, F>(
queue_handle: &Queue,
queue: &str,
worker: &W,
poll_interval: Duration,
shutdown: F,
) -> Result<()>
where
W: Worker,
F: Future<Output = ()>,
{
let mut shutdown = std::pin::pin!(shutdown);
loop {
match queue_handle.claim_next(queue).await? {
Some(job) => {
match worker.process(&job).await {
Ok(()) => queue_handle.ack(&job).await?,
Err(e) if e.downcast_ref::<PermanentFailure>().is_some() => {
queue_handle.dead_letter(job, &e.to_string()).await?
}
Err(e) => queue_handle.nack(job, &e.to_string()).await?,
}
if check_shutdown(shutdown.as_mut()) {
debug!(queue = queue, "worker shutdown requested");
return Ok(());
}
}
None => {
tokio::select! {
biased;
_ = &mut shutdown => {
debug!(queue = queue, "worker shutdown requested");
return Ok(());
}
_ = queue_handle.wait_for_jobs(poll_interval) => {}
}
}
}
}
}
pub async fn run_worker_concurrent<W, F>(
queue_handle: &Arc<Queue>,
queue: &str,
worker: Arc<W>,
concurrency: usize,
poll_interval: Duration,
shutdown: F,
) -> Result<()>
where
W: Worker + 'static,
F: Future<Output = ()>,
{
assert!(concurrency > 0, "concurrency must be at least 1");
let mut set = tokio::task::JoinSet::new();
let mut shutdown = std::pin::pin!(shutdown);
'main: loop {
while let Some(result) = set.try_join_next() {
if let Err(e) = result {
warn!(queue = queue, "worker task panicked: {e}");
}
}
if set.len() >= concurrency {
tokio::select! {
biased;
_ = &mut shutdown => break 'main,
r = set.join_next() => {
if let Some(Err(e)) = r {
warn!(queue = queue, "worker task panicked: {e}");
}
}
}
continue;
}
match queue_handle.claim_next(queue).await? {
Some(job) => {
let q = queue_handle.clone();
let w = worker.clone();
let queue_owned = queue.to_string();
set.spawn(async move {
match w.process(&job).await {
Ok(()) => {
if let Err(e) = q.ack(&job).await {
warn!(queue = %queue_owned, job_id = %job.id, "ack failed: {e}");
}
}
Err(e) if e.downcast_ref::<PermanentFailure>().is_some() => {
if let Err(se) = q.dead_letter(job, &e.to_string()).await {
warn!(queue = %queue_owned, "dead_letter failed: {se}");
}
}
Err(e) => {
if let Err(se) = q.nack(job, &e.to_string()).await {
warn!(queue = %queue_owned, "nack failed: {se}");
}
}
}
});
if check_shutdown(shutdown.as_mut()) {
break 'main;
}
}
None => {
tokio::select! {
biased;
_ = &mut shutdown => break 'main,
_ = queue_handle.wait_for_jobs(poll_interval) => {}
}
}
}
}
debug!(
queue = queue,
in_flight = set.len(),
"draining workers on shutdown"
);
while let Some(result) = set.join_next().await {
if let Err(e) = result {
warn!(queue = queue, "worker task panicked during drain: {e}");
}
}
Ok(())
}
fn check_shutdown<F: Future<Output = ()>>(shutdown: std::pin::Pin<&mut F>) -> bool {
use std::task::{Context, Poll};
let waker = std::task::Waker::noop();
let mut cx = Context::from_waker(waker);
matches!(shutdown.poll(&mut cx), Poll::Ready(()))
}