use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use parking_lot::Mutex;
use scc::HashMap as SccHashMap;
use tokio::sync::Notify;
#[derive(Debug)]
pub enum QueueError {
UnknownJob(String),
SerializeError(String),
HandlerError(String),
Shutdown,
}
impl std::fmt::Display for QueueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownJob(name) => write!(f, "no handler registered for job '{name}'"),
Self::SerializeError(e) => write!(f, "failed to serialize job payload: {e}"),
Self::HandlerError(e) => write!(f, "job handler error: {e}"),
Self::Shutdown => write!(f, "queue has been shut down"),
}
}
}
impl std::error::Error for QueueError {}
#[derive(Debug, Clone)]
pub enum RetryPolicy {
None,
Fixed {
max_retries: u32,
delay: Duration,
},
Exponential {
max_retries: u32,
base_delay: Duration,
},
}
impl RetryPolicy {
pub fn fixed(max_retries: u32, delay: Duration) -> Self {
Self::Fixed { max_retries, delay }
}
pub fn exponential(max_retries: u32, base_delay: Duration) -> Self {
Self::Exponential {
max_retries,
base_delay,
}
}
fn max_retries(&self) -> u32 {
match self {
Self::None => 0,
Self::Fixed { max_retries, .. } | Self::Exponential { max_retries, .. } => *max_retries,
}
}
fn delay_for_attempt(&self, attempt: u32) -> Duration {
match self {
Self::None => Duration::ZERO,
Self::Fixed { delay, .. } => *delay,
Self::Exponential { base_delay, .. } => *base_delay * 2u32.saturating_pow(attempt),
}
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::None
}
}
pub struct Job {
pub(crate) payload: Vec<u8>,
pub name: String,
pub attempt: u32,
pub id: u64,
}
impl Job {
pub fn deserialize<T: serde::de::DeserializeOwned>(&self) -> Result<T, QueueError> {
serde_json::from_slice(&self.payload).map_err(|e| QueueError::HandlerError(e.to_string()))
}
pub fn raw_payload(&self) -> &[u8] {
&self.payload
}
}
#[derive(Debug, Clone)]
pub struct DeadJob {
pub id: u64,
pub name: String,
pub payload: Vec<u8>,
pub attempts: u32,
pub error: String,
pub failed_at: Instant,
}
struct PendingJob {
id: u64,
name: String,
payload: Vec<u8>,
attempt: u32,
run_after: Option<Instant>,
}
type BoxHandler =
Arc<dyn Fn(Job) -> Pin<Box<dyn Future<Output = Result<(), QueueError>> + Send>> + Send + Sync>;
struct QueueInner {
pending: Mutex<VecDeque<PendingJob>>,
handlers: SccHashMap<String, BoxHandler>,
dead_letters: Mutex<Vec<DeadJob>>,
notify: Notify,
next_id: AtomicU64,
num_workers: usize,
retry_policy: RetryPolicy,
shutdown: AtomicBool,
inflight: AtomicU64,
drain_notify: Notify,
}
#[derive(Clone)]
pub struct Queue {
inner: Arc<QueueInner>,
}
pub struct QueueBuilder {
workers: usize,
retry: RetryPolicy,
}
impl QueueBuilder {
pub fn workers(mut self, n: usize) -> Self {
self.workers = n.max(1);
self
}
pub fn retry(mut self, policy: RetryPolicy) -> Self {
self.retry = policy;
self
}
pub fn build(self) -> Queue {
Queue {
inner: Arc::new(QueueInner {
pending: Mutex::new(VecDeque::new()),
handlers: SccHashMap::new(),
dead_letters: Mutex::new(Vec::new()),
notify: Notify::new(),
next_id: AtomicU64::new(1),
num_workers: self.workers,
retry_policy: self.retry,
shutdown: AtomicBool::new(false),
inflight: AtomicU64::new(0),
drain_notify: Notify::new(),
}),
}
}
}
impl Queue {
pub fn new() -> Self {
Self::builder().build()
}
pub fn builder() -> QueueBuilder {
QueueBuilder {
workers: 4,
retry: RetryPolicy::default(),
}
}
pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
where
F: Fn(Job) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), QueueError>> + Send + 'static,
{
let name = name.into();
let handler: BoxHandler = Arc::new(move |job| Box::pin(handler(job)));
let _ = self.inner.handlers.insert_sync(name, handler);
}
pub async fn push(
&self,
name: impl Into<String>,
payload: &(impl serde::Serialize + ?Sized),
) -> Result<u64, QueueError> {
self.push_inner(name.into(), payload, None)
}
pub async fn push_delayed(
&self,
name: impl Into<String>,
payload: &(impl serde::Serialize + ?Sized),
delay: Duration,
) -> Result<u64, QueueError> {
self.push_inner(name.into(), payload, Some(Instant::now() + delay))
}
fn push_inner(
&self,
name: String,
payload: &(impl serde::Serialize + ?Sized),
run_after: Option<Instant>,
) -> Result<u64, QueueError> {
if self.inner.shutdown.load(Ordering::SeqCst) {
return Err(QueueError::Shutdown);
}
let bytes =
serde_json::to_vec(payload).map_err(|e| QueueError::SerializeError(e.to_string()))?;
let id = self.inner.next_id.fetch_add(1, Ordering::SeqCst);
self.inner.pending.lock().push_back(PendingJob {
id,
name,
payload: bytes,
attempt: 0,
run_after,
});
self.inner.notify.notify_one();
Ok(id)
}
#[cfg(not(feature = "compio"))]
pub fn start(&self) {
for _ in 0..self.inner.num_workers {
let inner = self.inner.clone();
tokio::spawn(async move { worker_loop(inner).await });
}
tracing::debug!("Queue started with {} workers", self.inner.num_workers);
}
#[cfg(feature = "compio")]
pub fn start(&self) {
for _ in 0..self.inner.num_workers {
let inner = self.inner.clone();
compio::runtime::spawn(async move { worker_loop(inner).await }).detach();
}
tracing::debug!("Queue started with {} workers", self.inner.num_workers);
}
pub async fn shutdown(&self, timeout: Duration) {
self.inner.shutdown.store(true, Ordering::SeqCst);
for _ in 0..self.inner.num_workers {
self.inner.notify.notify_one();
}
if self.inner.inflight.load(Ordering::SeqCst) > 0 {
#[cfg(not(feature = "compio"))]
{
let _ = tokio::time::timeout(timeout, self.inner.drain_notify.notified()).await;
}
#[cfg(feature = "compio")]
{
let drain = std::pin::pin!(self.inner.drain_notify.notified());
let sleep = std::pin::pin!(compio::time::sleep(timeout));
let _ = futures_util::future::select(drain, sleep).await;
}
}
tracing::debug!("Queue shut down");
}
pub fn dead_letters(&self) -> Vec<DeadJob> {
self.inner.dead_letters.lock().clone()
}
pub fn clear_dead_letters(&self) {
self.inner.dead_letters.lock().clear();
}
pub fn pending_count(&self) -> usize {
self.inner.pending.lock().len()
}
pub fn inflight_count(&self) -> u64 {
self.inner.inflight.load(Ordering::SeqCst)
}
}
impl Default for Queue {
fn default() -> Self {
Self::new()
}
}
async fn worker_loop(inner: Arc<QueueInner>) {
loop {
#[cfg(not(feature = "compio"))]
{
let _ = tokio::time::timeout(Duration::from_millis(100), inner.notify.notified()).await;
}
#[cfg(feature = "compio")]
{
let notified = std::pin::pin!(inner.notify.notified());
let sleep = std::pin::pin!(compio::time::sleep(Duration::from_millis(100)));
let _ = futures_util::future::select(notified, sleep).await;
}
if inner.shutdown.load(Ordering::SeqCst) && inner.pending.lock().is_empty() {
break;
}
let job = {
let mut pending = inner.pending.lock();
let now = Instant::now();
let pos = pending.iter().position(|j| match j.run_after {
Some(t) => now >= t,
None => true,
});
pos.and_then(|i| pending.remove(i))
};
let Some(pending_job) = job else {
continue;
};
let handler = inner
.handlers
.get_async(&pending_job.name)
.await
.map(|e| e.get().clone());
let Some(handler) = handler else {
tracing::warn!("No handler for job '{}', moving to DLQ", pending_job.name);
inner.dead_letters.lock().push(DeadJob {
id: pending_job.id,
name: pending_job.name,
payload: pending_job.payload,
attempts: pending_job.attempt + 1,
error: "no handler registered".into(),
failed_at: Instant::now(),
});
continue;
};
inner.inflight.fetch_add(1, Ordering::SeqCst);
let job = Job {
payload: pending_job.payload.clone(),
name: pending_job.name.clone(),
attempt: pending_job.attempt,
id: pending_job.id,
};
let result = handler(job).await;
if let Err(e) = result {
let max_retries = inner.retry_policy.max_retries();
if pending_job.attempt < max_retries {
let next_attempt = pending_job.attempt + 1;
let delay = inner.retry_policy.delay_for_attempt(pending_job.attempt);
tracing::debug!(
"Job '{}' (id={}) failed (attempt {}/{}), retrying in {:?}",
pending_job.name,
pending_job.id,
next_attempt,
max_retries,
delay
);
inner.pending.lock().push_back(PendingJob {
id: pending_job.id,
name: pending_job.name,
payload: pending_job.payload,
attempt: next_attempt,
run_after: Some(Instant::now() + delay),
});
inner.notify.notify_one();
} else {
tracing::warn!(
"Job '{}' (id={}) exhausted {} retries, moving to DLQ: {}",
pending_job.name,
pending_job.id,
max_retries,
e
);
inner.dead_letters.lock().push(DeadJob {
id: pending_job.id,
name: pending_job.name,
payload: pending_job.payload,
attempts: pending_job.attempt + 1,
error: e.to_string(),
failed_at: Instant::now(),
});
}
}
let prev = inner.inflight.fetch_sub(1, Ordering::SeqCst);
if prev == 1 && inner.shutdown.load(Ordering::SeqCst) {
inner.drain_notify.notify_one();
}
}
}