use std::{sync::Arc, time::Duration};
use jiff::{Span, ToSpan};
use serde::Deserialize;
use sqlx::{
postgres::{types::PgInterval, PgListener, PgNotification},
Acquire, PgConnection,
};
use tokio::{sync::Semaphore, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use crate::{
queue::{shutdown_channel, Error as QueueError, InProgressTask, Queue},
task::{Error as TaskError, RetryCount, RetryPolicy, Task, TaskId},
};
pub(crate) type Result<T = ()> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Queue(#[from] QueueError),
#[error(transparent)]
Task(#[from] TaskError),
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Jiff(#[from] jiff::Error),
}
#[derive(Debug)]
pub struct Worker<T: Task> {
queue: Arc<Queue<T>>,
task: Arc<T>,
concurrency_limit: usize,
shutdown_token: CancellationToken,
}
impl<T: Task> Clone for Worker<T> {
fn clone(&self) -> Self {
Self {
queue: Arc::clone(&self.queue),
task: Arc::clone(&self.task),
concurrency_limit: self.concurrency_limit,
shutdown_token: self.shutdown_token.clone(),
}
}
}
impl<T: Task + Sync> Worker<T> {
pub fn new(queue: Arc<Queue<T>>, task: T) -> Self {
let task = Arc::new(task);
Self {
queue,
task,
concurrency_limit: num_cpus::get(),
shutdown_token: CancellationToken::new(),
}
}
pub fn set_concurrency_limit(&mut self, concurrency_limit: usize) {
self.concurrency_limit = concurrency_limit;
}
pub fn set_shutdown_token(&mut self, shutdown_token: CancellationToken) {
self.shutdown_token = shutdown_token;
}
pub fn shutdown(&self) {
self.shutdown_token.cancel();
}
pub async fn run(&self) -> Result {
self.run_every(1.minute()).await
}
#[instrument(skip(self), fields(queue.name = self.queue.name), err)]
pub async fn run_every(&self, period: Span) -> Result {
let mut polling_interval = tokio::time::interval(period.try_into()?);
let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?;
let chan = shutdown_channel();
shutdown_listener.listen(chan).await?;
let mut task_change_listener = PgListener::connect_with(&self.queue.pool).await?;
task_change_listener.listen("task_change").await?;
let concurrency_limit = Arc::new(Semaphore::new(self.concurrency_limit));
let mut processing_tasks = JoinSet::new();
loop {
tokio::select! {
notify_shutdown = shutdown_listener.recv() => {
match notify_shutdown {
Ok(_) => {
self.shutdown_token.cancel();
},
Err(err) => {
tracing::error!(%err, "Postgres shutdown notification error");
}
}
}
_ = self.shutdown_token.cancelled() => {
self.handle_shutdown(&mut processing_tasks).await?;
break
}
notify_task_change = task_change_listener.recv() => {
match notify_task_change {
Ok(task_change) => self.handle_task_change(task_change, concurrency_limit.clone(), &mut processing_tasks).await?,
Err(err) => {
tracing::error!(%err, "Postgres task change notification error");
}
};
}
_ = polling_interval.tick() => {
self.trigger_task_processing(concurrency_limit.clone(), &mut processing_tasks).await;
}
}
}
Ok(())
}
async fn handle_shutdown(&self, processing_tasks: &mut JoinSet<()>) -> Result {
let task_timeout = self.task.timeout();
tracing::debug!(
task.timeout = ?task_timeout,
"Waiting for all processing tasks or timeout"
);
let shutdown_result = tokio::time::timeout(task_timeout.try_into()?, async {
while let Some(res) = processing_tasks.join_next().await {
if let Err(err) = res {
tracing::error!(%err, "A processing task failed during shutdown");
}
}
})
.await;
match shutdown_result {
Ok(_) => {
tracing::debug!("All processing tasks completed gracefully");
}
Err(_) => {
let remaining_tasks = processing_tasks.len();
tracing::warn!(
remaining_tasks,
"Reached task timeout before all tasks completed"
);
}
}
Ok(())
}
async fn handle_task_change(
&self,
task_change: PgNotification,
concurrency_limit: Arc<Semaphore>,
processing_tasks: &mut JoinSet<()>,
) -> Result {
let payload = task_change.payload();
let decoded: TaskChange = serde_json::from_str(payload).map_err(|err| {
tracing::error!(%err, "Invalid task change payload; ignoring");
err
})?;
if decoded.queue_name == self.queue.name {
self.trigger_task_processing(concurrency_limit, processing_tasks)
.await;
}
Ok(())
}
#[instrument(
skip_all,
fields(
processing = processing_tasks.len(),
permits = concurrency_limit.available_permits()
)
)]
async fn trigger_task_processing(
&self,
concurrency_limit: Arc<Semaphore>,
processing_tasks: &mut JoinSet<()>,
) {
let Ok(permit) = concurrency_limit.try_acquire_owned() else {
tracing::trace!("Concurrency limit reached");
return;
};
processing_tasks.spawn({
let worker = self.clone();
async move {
while !worker.shutdown_token.is_cancelled() {
match worker.process_next_task().await {
Err(err) => {
tracing::error!(err = %err, "Error processing next task");
continue;
}
Ok(Some(_)) => {
continue;
}
Ok(None) => {
tracing::trace!("No task found");
break;
}
}
}
drop(permit);
}
});
}
#[instrument(
skip(self),
fields(
queue.name = self.queue.name,
task.id = tracing::field::Empty,
),
err
)]
pub async fn process_next_task(&self) -> Result<Option<TaskId>> {
let Some(in_progress_task) = self.queue.dequeue().await? else {
return Ok(None);
};
let task_id = in_progress_task.id;
tracing::Span::current().record("task.id", task_id.as_hyphenated().to_string());
let mut tx = self.queue.pool.begin().await?;
if !in_progress_task.try_acquire_lock(&mut tx).await? {
return Ok(None);
}
let input: T::Input = serde_json::from_value(in_progress_task.input.clone())?;
let timeout = pg_interval_to_span(&in_progress_task.timeout)
.try_into()
.expect("Task timeout should be compatible with std::time");
let heartbeat = pg_interval_to_span(&in_progress_task.heartbeat)
.try_into()
.expect("Task heartbeat should be compatible with std::time");
let heartbeat_task = tokio::spawn({
let pool = self.queue.pool.clone();
let in_progress_task = in_progress_task.clone();
async move {
let mut heartbeat_interval = tokio::time::interval(heartbeat);
heartbeat_interval.tick().await;
loop {
tracing::trace!("Recording task heartbeat");
if let Err(err) = in_progress_task.record_heartbeat(&pool).await {
tracing::error!(err = %err, "Failed to record task heartbeat");
};
heartbeat_interval.tick().await;
}
}
});
let execute_tx = tx.begin().await?;
tokio::select! {
result = self.task.execute(execute_tx, input) => {
match result {
Ok(_) => {
in_progress_task.mark_succeeded(&mut tx).await?;
}
Err(ref error) => {
let retry_policy = &in_progress_task.retry_policy;
self.handle_task_error(&mut tx, &in_progress_task, retry_policy, error)
.await?;
}
}
}
_ = tokio::time::sleep(timeout) => {
tracing::error!("Task execution timed out");
let retry_policy = &in_progress_task.retry_policy;
self.handle_task_timeout(&mut tx, &in_progress_task, retry_policy, timeout).await?;
}
}
heartbeat_task.abort();
tx.commit().await?;
Ok(Some(task_id))
}
async fn handle_task_error(
&self,
conn: &mut PgConnection,
in_progress_task: &InProgressTask,
retry_policy: &RetryPolicy,
error: &TaskError,
) -> Result {
tracing::error!(err = %error, "Task execution encountered an error");
if matches!(error, TaskError::Fatal(_)) {
return self.finalize_task_failure(conn, in_progress_task).await;
}
in_progress_task.record_failure(conn, error).await?;
let retry_count = in_progress_task.retry_count(&mut *conn).await?;
if retry_count < retry_policy.max_attempts {
self.schedule_task_retry(conn, in_progress_task, retry_count, retry_policy)
.await?;
} else {
self.finalize_task_failure(conn, in_progress_task).await?;
}
Ok(())
}
async fn handle_task_timeout(
&self,
conn: &mut PgConnection,
in_progress_task: &InProgressTask,
retry_policy: &RetryPolicy,
timeout: Duration,
) -> Result {
tracing::error!("Task execution timed out");
let error = &TaskError::TimedOut(timeout.try_into()?);
in_progress_task.record_failure(&mut *conn, error).await?;
let retry_count = in_progress_task.retry_count(&mut *conn).await?;
if retry_count < retry_policy.max_attempts {
self.schedule_task_retry(conn, in_progress_task, retry_count, retry_policy)
.await?;
} else {
self.finalize_task_failure(conn, in_progress_task).await?;
}
Ok(())
}
async fn schedule_task_retry(
&self,
conn: &mut PgConnection,
in_progress_task: &InProgressTask,
retry_count: RetryCount,
retry_policy: &RetryPolicy,
) -> Result {
tracing::debug!("Retry policy available, scheduling retry");
let delay = retry_policy.calculate_delay(retry_count);
in_progress_task.retry_after(&mut *conn, delay).await?;
Ok(())
}
async fn finalize_task_failure(
&self,
conn: &mut PgConnection,
in_progress_task: &InProgressTask,
) -> Result {
tracing::debug!("Retry policy exhausted, handling failed task");
in_progress_task.mark_failed(&mut *conn).await?;
if let Some(dlq_name) = &self.queue.dlq_name {
self.queue
.move_task_to_dlq(&mut *conn, in_progress_task.id, dlq_name)
.await?;
}
Ok(())
}
}
#[derive(Debug, Deserialize)]
struct TaskChange {
#[serde(rename = "task_queue_name")]
queue_name: String,
}
pub(crate) fn pg_interval_to_span(
PgInterval {
months,
days,
microseconds,
}: &PgInterval,
) -> Span {
Span::new()
.months(*months)
.days(*days)
.microseconds(*microseconds)
}
#[cfg(test)]
mod tests {
use std::{
sync::Arc,
time::{Duration as StdDuration, Instant},
};
use sqlx::{PgPool, Postgres, Transaction};
use tokio::sync::Mutex;
use super::*;
use crate::{
queue::graceful_shutdown,
task::{Result as TaskResult, State as TaskState},
};
struct TestTask;
impl Task for TestTask {
type Input = ();
type Output = ();
async fn execute(
&self,
_: Transaction<'_, Postgres>,
_: Self::Input,
) -> TaskResult<Self::Output> {
Ok(())
}
}
#[derive(Clone)]
struct FailingTask {
fail_times: Arc<Mutex<u32>>,
}
impl Task for FailingTask {
type Input = ();
type Output = ();
async fn execute(
&self,
_: Transaction<'_, Postgres>,
_: Self::Input,
) -> TaskResult<Self::Output> {
let mut fail_times = self.fail_times.lock().await;
if *fail_times > 0 {
*fail_times -= 1;
Err(TaskError::Retryable("Simulated failure".into()))
} else {
Ok(())
}
}
}
#[sqlx::test]
async fn process_next_task(pool: PgPool) -> sqlx::Result<(), Error> {
let queue = Queue::builder()
.name("process_next_task")
.pool(pool.clone())
.build()
.await?;
let task = TestTask;
let task_id = queue.enqueue(&pool, &task, &()).await?;
let queue = Arc::new(queue);
let worker = Worker::new(queue.clone(), task);
let processed_task_id = worker
.process_next_task()
.await?
.expect("A task should be processed");
assert_eq!(task_id, processed_task_id);
let task_row = sqlx::query!(
r#"select state as "state: TaskState" from underway.task where id = $1"#,
task_id as TaskId
)
.fetch_one(&pool)
.await?;
assert_eq!(task_row.state, TaskState::Succeeded);
assert!(queue.dequeue().await?.is_none());
Ok(())
}
#[sqlx::test]
async fn process_retries(pool: PgPool) -> sqlx::Result<(), Error> {
let queue = Queue::builder()
.name("process_retries")
.pool(pool.clone())
.build()
.await?;
let fail_times = Arc::new(Mutex::new(2));
let task = FailingTask {
fail_times: fail_times.clone(),
};
let queue = Arc::new(queue);
let worker = Worker::new(queue.clone(), task.clone());
let task_id = queue.enqueue(&pool, &worker.task, &()).await?;
for retries in 0..3 {
let delay = task.retry_policy().calculate_delay(retries);
tokio::time::sleep(delay.try_into()?).await;
let processed_task_id = worker
.process_next_task()
.await?
.expect("A task should be processed");
assert_eq!(task_id, processed_task_id);
}
let remaining_fail_times = *fail_times.lock().await;
assert_eq!(remaining_fail_times, 0);
let dequeued_task = sqlx::query!(
r#"
select state as "state: TaskState"
from underway.task
where id = $1
"#,
task_id as TaskId
)
.fetch_one(&pool)
.await?;
assert_eq!(dequeued_task.state, TaskState::Succeeded);
Ok(())
}
#[sqlx::test]
async fn gracefully_shutdown(pool: PgPool) -> sqlx::Result<(), Error> {
struct LongRunningTask;
impl Task for LongRunningTask {
type Input = ();
type Output = ();
async fn execute(
&self,
_: Transaction<'_, Postgres>,
_: Self::Input,
) -> TaskResult<Self::Output> {
tokio::time::sleep(StdDuration::from_secs(1)).await;
Ok(())
}
}
let queue = Queue::builder()
.name("gracefully_shutdown")
.pool(pool.clone())
.build()
.await?;
let queue = Arc::new(queue);
let worker = Worker::new(queue.clone(), LongRunningTask);
for _ in 0..2 {
let worker = worker.clone();
tokio::spawn(async move { worker.run().await });
}
tokio::time::sleep(StdDuration::from_secs(1)).await;
for _ in 0..5 {
queue.enqueue(&pool, &LongRunningTask, &()).await?;
}
graceful_shutdown(&pool).await?;
tokio::time::sleep(StdDuration::from_secs(2)).await;
let succeeded = sqlx::query_scalar!(
r#"
select count(*)
from underway.task
where state = $1
"#,
TaskState::Succeeded as _
)
.fetch_one(&pool)
.await?;
assert_eq!(succeeded, Some(5));
queue.enqueue(&pool, &LongRunningTask, &()).await?;
tokio::time::sleep(StdDuration::from_secs(1)).await;
let succeeded = sqlx::query_scalar!(
r#"
select count(*)
from underway.task
where state = $1
"#,
TaskState::Succeeded as _
)
.fetch_one(&pool)
.await?;
assert_eq!(succeeded, Some(5));
Ok(())
}
#[sqlx::test]
async fn heartbeat_stops_after_task_completion(pool: PgPool) -> sqlx::Result<(), Error> {
struct SleepTask;
impl Task for SleepTask {
type Input = ();
type Output = ();
async fn execute(
&self,
_: Transaction<'_, Postgres>,
_: Self::Input,
) -> TaskResult<Self::Output> {
tokio::time::sleep(StdDuration::from_secs(5)).await;
Ok(())
}
fn heartbeat(&self) -> Span {
1.second()
}
}
let queue = Queue::builder()
.name("heartbeat_stops_after_task_completion")
.pool(pool.clone())
.build()
.await?;
let task_id = queue.enqueue(&pool, &SleepTask, &()).await?;
let queue = Arc::new(queue);
let worker = Worker::new(queue.clone(), SleepTask);
let worker_handle = tokio::spawn(async move { worker.run_every(1.second()).await });
tokio::time::sleep(StdDuration::from_secs(1)).await;
let mut last_heartbeat_at = None;
let start_time = Instant::now();
while start_time.elapsed() < StdDuration::from_secs(6) {
let task_row = sqlx::query!(
r#"
select last_heartbeat_at as "last_heartbeat_at: i64"
from underway.task
where id = $1
and task_queue_name = $2
"#,
task_id as TaskId,
queue.name
)
.fetch_one(&pool)
.await?;
let current_heartbeat = task_row
.last_heartbeat_at
.expect("A heartbeat should be set");
if let Some(prev_heartbeat) = last_heartbeat_at {
assert!(current_heartbeat > prev_heartbeat);
}
last_heartbeat_at = Some(current_heartbeat);
tokio::time::sleep(StdDuration::from_secs(1)).await;
}
worker_handle.abort();
let final_task_row = sqlx::query!(
r#"
select last_heartbeat_at as "last_heartbeat_at: i64"
from underway.task
where id = $1
"#,
task_id as TaskId
)
.fetch_one(&pool)
.await?;
tokio::time::sleep(StdDuration::from_secs(2)).await;
let post_completion_task_row = sqlx::query!(
r#"
select last_heartbeat_at as "last_heartbeat_at: i64"
from underway.task
where id = $1
"#,
task_id as TaskId
)
.fetch_one(&pool)
.await?;
assert_eq!(
final_task_row.last_heartbeat_at,
post_completion_task_row.last_heartbeat_at
);
let task_state = sqlx::query_scalar!(
r#"
select state as "state: TaskState"
from underway.task
where id = $1
"#,
task_id as TaskId
)
.fetch_one(&pool)
.await?;
assert_eq!(task_state, TaskState::Succeeded);
Ok(())
}
}