use std::{
future::Future,
marker::PhantomData,
mem,
ops::Deref,
pin::Pin,
result::Result as StdResult,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::Poll,
};
use builder_states::{Initial, PoolSet, QueueNameSet, QueueSet, StateSet, StepSet};
use jiff::Span;
use sealed::JobState;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sqlx::{PgExecutor, PgPool, Postgres, Transaction};
use tokio::task::{JoinError, JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
queue::{Error as QueueError, InProgressTask, Queue},
scheduler::{Error as SchedulerError, Scheduler, ZonedSchedule},
task::{
Error as TaskError, Result as TaskResult, RetryPolicy, State as TaskState, Task, TaskId,
},
worker::{Error as WorkerError, Worker},
};
type Result<T = ()> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Queue(#[from] QueueError),
#[error(transparent)]
Task(#[from] TaskError),
#[error(transparent)]
Worker(#[from] WorkerError),
#[error(transparent)]
Scheduler(#[from] SchedulerError),
#[error(transparent)]
Join(#[from] tokio::task::JoinError),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Database(#[from] sqlx::Error),
}
type JobQueue<T, S> = Queue<Job<T, S>>;
pub struct Context<S> {
pub state: S,
pub tx: Transaction<'static, Postgres>,
pub step_index: usize,
pub step_count: usize,
pub job_id: JobId,
pub queue_name: String,
}
type StepConfig<S> = (Box<dyn StepExecutor<S>>, RetryPolicy);
mod sealed {
use serde::{Deserialize, Serialize};
use super::JobId;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct JobState {
pub step_index: usize,
pub step_input: serde_json::Value,
pub(crate) job_id: JobId,
} }
pub struct JobHandle {
workers: JoinSet<StdResult<Result<()>, JoinError>>,
shutdown_token: CancellationToken,
}
impl JobHandle {
pub async fn shutdown(mut self) -> Result<()> {
self.shutdown_token.cancel();
while let Some(result) = self.workers.join_next().await {
match result? {
Ok(Ok(())) => {}
Ok(Err(err)) => return Err(err),
Err(join_err) => return Err(join_err.into()),
}
}
Ok(())
}
}
impl Unpin for JobHandle {}
impl Future for JobHandle {
type Output = StdResult<Result<()>, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
while let Poll::Ready(Some(result)) = self.workers.poll_join_next(cx) {
match result? {
Ok(Ok(())) => continue,
Ok(Err(err)) => return Poll::Ready(Ok(Err(err))),
Err(join_err) => return Poll::Ready(Err(join_err)),
}
}
if self.workers.is_empty() {
Poll::Ready(Ok(Ok(())))
} else {
Poll::Pending
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub struct JobId(Uuid);
impl JobId {
fn new() -> Self {
Self(Ulid::new().into())
}
}
impl Deref for JobId {
type Target = Uuid;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct EnqueuedJob<T: Task> {
id: JobId,
queue: Arc<Queue<T>>,
}
impl<T: Task> EnqueuedJob<T> {
pub async fn cancel(&self) -> Result<bool> {
let mut tx = self.queue.pool.begin().await?;
let in_progress_tasks = sqlx::query_as!(
InProgressTask,
r#"
select
id as "id: TaskId",
task_queue_name as "queue_name",
input,
retry_policy as "retry_policy: RetryPolicy",
timeout,
heartbeat,
concurrency_key
from underway.task
where input->>'job_id' = $1
and state = $2
for update skip locked
"#,
self.id.to_string(),
TaskState::Pending as TaskState
)
.fetch_all(&mut *tx)
.await?;
let mut cancelled = false;
for in_progress_task in in_progress_tasks {
if in_progress_task.mark_cancelled(&mut tx).await? {
cancelled = true;
}
}
tx.commit().await?;
Ok(cancelled)
}
}
pub struct Job<I, S>
where
I: Sync + Send + 'static,
S: Clone + Sync + Send + 'static,
{
queue: Arc<JobQueue<I, S>>,
steps: Arc<Vec<StepConfig<S>>>,
state: S,
current_index: Arc<AtomicUsize>,
_marker: PhantomData<I>,
}
impl<I, S> Job<I, S>
where
I: Serialize + Sync + Send + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn builder() -> Builder<I, I, S, Initial> {
Builder::new()
}
pub async fn enqueue(&self, input: &I) -> Result<EnqueuedJob<Self>> {
let mut conn = self.queue.pool.acquire().await?;
self.enqueue_using(&mut *conn, input).await
}
pub async fn enqueue_using<'a, E>(&self, executor: E, input: &I) -> Result<EnqueuedJob<Self>>
where
E: PgExecutor<'a>,
{
self.enqueue_after_using(executor, input, Span::new()).await
}
pub async fn enqueue_after(&self, input: &I, delay: Span) -> Result<EnqueuedJob<Self>> {
let mut conn = self.queue.pool.acquire().await?;
self.enqueue_after_using(&mut *conn, input, delay).await
}
pub async fn enqueue_after_using<'a, E>(
&self,
executor: E,
input: &I,
delay: Span,
) -> Result<EnqueuedJob<Self>>
where
E: PgExecutor<'a>,
{
let job_input = self.first_job_input(input)?;
self.queue
.enqueue_after(executor, self, &job_input, delay)
.await?;
let enqueue = EnqueuedJob {
id: job_input.job_id,
queue: self.queue.clone(),
};
Ok(enqueue)
}
pub async fn schedule(&self, zoned_schedule: &ZonedSchedule, input: &I) -> Result {
let mut conn = self.queue.pool.acquire().await?;
self.schedule_using(&mut *conn, zoned_schedule, input).await
}
pub async fn schedule_using<'a, E>(
&self,
executor: E,
zoned_schedule: &ZonedSchedule,
input: &I,
) -> Result
where
E: PgExecutor<'a>,
{
let job_input = self.first_job_input(input)?;
self.queue
.schedule(executor, zoned_schedule, &job_input)
.await?;
Ok(())
}
pub async fn unschedule(&self) -> Result {
let mut conn = self.queue.pool.acquire().await?;
self.unschedule_using(&mut *conn).await
}
pub async fn unschedule_using<'a, E>(&self, executor: E) -> Result
where
E: PgExecutor<'a>,
{
self.queue.unschedule(executor).await?;
Ok(())
}
pub fn queue(&self) -> Arc<Queue<Self>> {
Arc::clone(&self.queue)
}
pub fn worker(&self) -> Worker<Self> {
Worker::new(self.queue(), self.clone())
}
pub fn scheduler(&self) -> Scheduler<Self> {
Scheduler::new(self.queue(), self.clone())
}
pub async fn run(&self) -> Result {
let worker = self.worker();
let scheduler = self.scheduler();
let mut workers = JoinSet::new();
workers.spawn(async move { worker.run().await.map_err(Error::from) });
workers.spawn(async move { scheduler.run().await.map_err(Error::from) });
while let Some(ret) = workers.join_next().await {
match ret {
Ok(Err(err)) => return Err(err),
Err(err) => return Err(Error::from(err)),
_ => continue,
}
}
Ok(())
}
pub fn start(&self) -> JobHandle {
let shutdown_token = CancellationToken::new();
let mut workers = JoinSet::new();
let mut worker = self.worker();
worker.set_shutdown_token(shutdown_token.clone());
let mut scheduler = self.scheduler();
scheduler.set_shutdown_token(shutdown_token.clone());
let worker_handle = tokio::spawn(async move { worker.run().await.map_err(Error::from) });
let scheduler_handle =
tokio::spawn(async move { scheduler.run().await.map_err(Error::from) });
workers.spawn(worker_handle);
workers.spawn(scheduler_handle);
JobHandle {
workers,
shutdown_token,
}
}
fn first_job_input(&self, input: &I) -> Result<JobState> {
let step_input = serde_json::to_value(input)?;
let step_index = self.current_index.load(Ordering::SeqCst);
let job_id = JobId::new();
Ok(JobState {
step_input,
step_index,
job_id,
})
}
}
impl<I, S> Task for Job<I, S>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
type Input = JobState;
type Output = ();
#[instrument(
skip_all,
fields(
job.id = %input.job_id.as_hyphenated(),
step = input.step_index + 1,
steps = self.steps.len()
),
err
)]
async fn execute(
&self,
mut tx: Transaction<'_, Postgres>,
input: Self::Input,
) -> TaskResult<Self::Output> {
let JobState {
step_index,
step_input,
job_id,
} = input;
if step_index >= self.steps.len() {
return Err(TaskError::Fatal("Invalid step index.".into()));
}
let (step, _) = &self.steps[step_index];
let step_tx: Transaction<'static, Postgres> = unsafe { mem::transmute_copy(&tx) };
let cx = Context {
state: self.state.clone(),
tx: step_tx,
step_index,
job_id,
step_count: self.steps.len(),
queue_name: self.queue.name.clone(),
};
let step_result = match step.execute_step(cx, step_input).await {
Ok(result) => result,
Err(err) => {
tx.commit().await?;
return Err(err);
}
};
if let Some((next_input, delay)) = step_result {
let next_index = step_index + 1;
self.current_index.store(next_index, Ordering::SeqCst);
let next_job_input = JobState {
step_input: next_input,
step_index: next_index,
job_id,
};
self.queue
.enqueue_after(&mut *tx, self, &next_job_input, delay)
.await
.map_err(|err| TaskError::Retryable(err.to_string()))?;
}
tx.commit().await?;
Ok(())
}
fn retry_policy(&self) -> RetryPolicy {
let current_index = self.current_index.load(Ordering::SeqCst);
let (_, retry_policy) = self.steps[current_index];
retry_policy
}
}
impl<I, S> Clone for Job<I, S>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
state: self.state.clone(),
steps: self.steps.clone(),
current_index: self.current_index.clone(),
_marker: PhantomData,
}
}
}
#[derive(Deserialize, Serialize)]
pub enum To<N> {
Next(N),
Delay {
next: N,
delay: Span,
},
Done,
}
impl<S> To<S> {
pub fn next(step: S) -> TaskResult<Self> {
Ok(Self::Next(step))
}
pub fn delay_for(step: S, delay: Span) -> TaskResult<Self> {
Ok(Self::Delay { next: step, delay })
}
}
impl To<()> {
pub fn done() -> TaskResult<To<()>> {
Ok(To::Done)
}
}
struct StepFn<I, O, S, F>
where
F: Fn(Context<S>, I) -> Pin<Box<dyn Future<Output = TaskResult<To<O>>> + Send>>
+ Send
+ Sync
+ 'static,
{
func: Arc<F>,
_marker: PhantomData<(I, O, S)>,
}
impl<I, O, S, F> StepFn<I, O, S, F>
where
F: Fn(Context<S>, I) -> Pin<Box<dyn Future<Output = TaskResult<To<O>>> + Send>>
+ Send
+ Sync
+ 'static,
{
fn new(func: F) -> Self {
Self {
func: Arc::new(func),
_marker: PhantomData,
}
}
}
type StepResult = TaskResult<Option<(serde_json::Value, Span)>>;
trait StepExecutor<S>: Send + Sync {
fn execute_step(
&self,
cx: Context<S>,
input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = StepResult> + Send>>;
}
impl<I, O, S, F> StepExecutor<S> for StepFn<I, O, S, F>
where
I: DeserializeOwned + Serialize + Send + Sync + 'static,
O: Serialize + Send + Sync + 'static,
S: Send + Sync + 'static,
F: Fn(Context<S>, I) -> Pin<Box<dyn Future<Output = TaskResult<To<O>>> + Send>>
+ Send
+ Sync
+ 'static,
{
fn execute_step(
&self,
cx: Context<S>,
input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = StepResult> + Send>> {
let deserialized_input: I = match serde_json::from_value(input) {
Ok(val) => val,
Err(e) => return Box::pin(async move { Err(TaskError::Fatal(e.to_string())) }),
};
let fut = (self.func)(cx, deserialized_input);
Box::pin(async move {
match fut.await {
Ok(To::Next(output)) => {
let serialized_output = serde_json::to_value(output)
.map_err(|e| TaskError::Fatal(e.to_string()))?;
Ok(Some((serialized_output, Span::new())))
}
Ok(To::Delay {
next: output,
delay,
}) => {
let serialized_output = serde_json::to_value(output)
.map_err(|e| TaskError::Fatal(e.to_string()))?;
Ok(Some((serialized_output, delay)))
}
Ok(To::Done) => Ok(None),
Err(e) => Err(e),
}
})
}
}
mod builder_states {
use std::marker::PhantomData;
use sqlx::PgPool;
use super::JobQueue;
pub struct Initial;
pub struct StateSet<S> {
pub state: S,
}
pub struct StepSet<Current, S> {
pub state: S,
pub _marker: PhantomData<Current>,
}
pub struct QueueSet<I, S>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub state: S,
pub queue: JobQueue<I, S>,
}
pub struct QueueNameSet<S> {
pub state: S,
pub queue_name: String,
}
pub struct PoolSet<S> {
pub state: S,
pub queue_name: String,
pub pool: PgPool,
}
}
pub struct Builder<I, O, S, B> {
builder_state: B,
steps: Vec<(Box<dyn StepExecutor<S>>, RetryPolicy)>,
_marker: PhantomData<(I, O, S)>,
}
impl<I, S> Default for Builder<I, I, S, Initial> {
fn default() -> Self {
Self::new()
}
}
impl<I, S> Builder<I, I, S, Initial> {
pub fn new() -> Builder<I, I, S, Initial> {
Builder::<I, I, S, _> {
builder_state: Initial,
steps: Vec::new(),
_marker: PhantomData,
}
}
pub fn state(self, state: S) -> Builder<I, I, S, StateSet<S>> {
Builder {
builder_state: StateSet { state },
steps: self.steps,
_marker: PhantomData,
}
}
pub fn step<F, O, Fut>(mut self, func: F) -> Builder<I, O, S, StepSet<O, ()>>
where
I: DeserializeOwned + Serialize + Send + Sync + 'static,
O: Serialize + Send + Sync + 'static,
S: Send + Sync + 'static,
F: Fn(Context<S>, I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = TaskResult<To<O>>> + Send + 'static,
{
let step_fn = StepFn::new(move |cx, input| Box::pin(func(cx, input)));
self.steps.push((Box::new(step_fn), RetryPolicy::default()));
Builder {
builder_state: StepSet {
state: (),
_marker: PhantomData,
},
steps: self.steps,
_marker: PhantomData,
}
}
}
impl<I, S> Builder<I, I, S, StateSet<S>> {
pub fn step<F, O, Fut>(mut self, func: F) -> Builder<I, O, S, StepSet<O, S>>
where
I: DeserializeOwned + Serialize + Send + Sync + 'static,
O: Serialize + Send + Sync + 'static,
S: Send + Sync + 'static,
F: Fn(Context<S>, I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = TaskResult<To<O>>> + Send + 'static,
{
let step_fn = StepFn::new(move |cx, input| Box::pin(func(cx, input)));
self.steps.push((Box::new(step_fn), RetryPolicy::default()));
Builder {
builder_state: StepSet {
state: self.builder_state.state,
_marker: PhantomData,
},
steps: self.steps,
_marker: PhantomData,
}
}
}
impl<I, Current, S> Builder<I, Current, S, StepSet<Current, S>> {
pub fn step<F, New, Fut>(mut self, func: F) -> Builder<I, New, S, StepSet<New, S>>
where
Current: DeserializeOwned + Serialize + Send + Sync + 'static,
New: Serialize + Send + Sync + 'static,
S: Send + Sync + 'static,
F: Fn(Context<S>, Current) -> Fut + Send + Sync + 'static,
Fut: Future<Output = TaskResult<To<New>>> + Send + 'static,
{
let step_fn = StepFn::new(move |cx, input| Box::pin(func(cx, input)));
self.steps.push((Box::new(step_fn), RetryPolicy::default()));
Builder {
builder_state: StepSet {
state: self.builder_state.state,
_marker: PhantomData,
},
steps: self.steps,
_marker: PhantomData,
}
}
pub fn retry_policy(
mut self,
retry_policy: RetryPolicy,
) -> Builder<I, Current, S, StepSet<Current, S>> {
let (_, default_policy) = self.steps.last_mut().expect("Steps should not be empty");
*default_policy = retry_policy;
Builder {
builder_state: StepSet {
state: self.builder_state.state,
_marker: PhantomData,
},
steps: self.steps,
_marker: PhantomData,
}
}
}
impl<I, S> Builder<I, (), S, StepSet<(), S>>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn name(self, name: impl Into<String>) -> Builder<I, (), S, QueueNameSet<S>> {
Builder {
builder_state: QueueNameSet {
state: self.builder_state.state,
queue_name: name.into(),
},
steps: self.steps,
_marker: PhantomData,
}
}
}
impl<I, S> Builder<I, (), S, QueueNameSet<S>>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn pool(self, pool: PgPool) -> Builder<I, (), S, PoolSet<S>> {
let QueueNameSet { queue_name, state } = self.builder_state;
Builder {
builder_state: PoolSet {
state,
queue_name,
pool,
},
steps: self.steps,
_marker: PhantomData,
}
}
}
impl<I, S> Builder<I, (), S, PoolSet<S>>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub async fn build(self) -> Result<Job<I, S>> {
let PoolSet {
state,
queue_name,
pool,
} = self.builder_state;
let queue = Queue::builder().name(queue_name).pool(pool).build().await?;
Ok(Job {
queue: Arc::new(queue),
steps: Arc::new(self.steps),
state,
current_index: Arc::new(AtomicUsize::new(0)),
_marker: PhantomData,
})
}
}
impl<I, S> Builder<I, (), S, StepSet<(), S>>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn queue(self, queue: JobQueue<I, S>) -> Builder<I, (), S, QueueSet<I, S>> {
Builder {
builder_state: QueueSet {
state: self.builder_state.state,
queue,
},
steps: self.steps,
_marker: PhantomData,
}
}
}
impl<I, S> Builder<I, (), S, QueueSet<I, S>>
where
I: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn build(self) -> Job<I, S> {
let QueueSet { state, queue } = self.builder_state;
Job {
queue: Arc::new(queue),
steps: Arc::new(self.steps),
state,
current_index: Arc::new(AtomicUsize::new(0)),
_marker: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use super::*;
use crate::queue::graceful_shutdown;
#[sqlx::test]
async fn one_step(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
let queue = Queue::builder()
.name("one_step")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Input { message }| async move {
println!("Executing job with message: {message}");
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let input = Input {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 0);
assert_eq!(job_state.step_input, serde_json::to_value(&input)?);
Ok(())
}
#[sqlx::test]
async fn one_step_named(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
async fn step(_cx: Context<()>, Input { message }: Input) -> TaskResult<To<()>> {
println!("Executing job with message: {message}");
To::done()
}
let queue = Queue::builder()
.name("one_step_named")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder().step(step).queue(queue.clone()).build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let input = Input {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 0);
assert_eq!(job_state.step_input, serde_json::to_value(&input)?);
Ok(())
}
#[sqlx::test]
async fn one_step_with_state(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Clone)]
struct State {
data: String,
}
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
let queue = Queue::builder()
.name("one_step_with_state")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.state(State {
data: "data".to_string(),
})
.step(|cx, Input { message }| async move {
println!(
"Executing job with message: {message} and state: {state}",
state = cx.state.data
);
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let input = Input {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 0);
assert_eq!(job_state.step_input, serde_json::to_value(&input)?);
Ok(())
}
#[sqlx::test]
async fn one_step_with_mutable_state(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Clone)]
struct State {
data: Arc<Mutex<String>>,
}
let state = State {
data: Arc::new(Mutex::new("foo".to_string())),
};
let job = Job::builder()
.state(state.clone())
.step(|cx, _| async move {
let mut data = cx.state.data.lock().expect("Mutex should not be poisoned");
*data = "bar".to_string();
To::done()
})
.name("one_step_with_mutable_state")
.pool(pool.clone())
.build()
.await?;
job.enqueue(&()).await?;
job.start();
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
assert_eq!(
*state.data.lock().expect("Mutex should not be poisoned"),
"bar".to_string()
);
tokio::spawn(async move { graceful_shutdown(&pool).await });
tokio::time::sleep(std::time::Duration::from_millis(250)).await;
Ok(())
}
#[sqlx::test]
async fn one_step_with_state_named(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Clone)]
struct State {
data: String,
}
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
async fn step(cx: Context<State>, Input { message }: Input) -> TaskResult<To<()>> {
println!(
"Executing job with message: {message} and state: {data}",
data = cx.state.data
);
To::done()
}
let state = State {
data: "data".to_string(),
};
let queue = Queue::builder()
.name("one_step_named")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.state(state)
.step(step)
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let input = Input {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 0);
assert_eq!(job_state.step_input, serde_json::to_value(&input)?);
Ok(())
}
#[sqlx::test]
async fn one_step_enqueue(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct Input {
message: String,
}
let queue = Queue::builder()
.name("one_step_enqueue")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Input { message }| async move {
println!("Executing job with message: {message}");
To::done()
})
.queue(queue.clone())
.build();
let input = Input {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 0);
assert_eq!(job_state.step_input, serde_json::to_value(&input)?);
Ok(())
}
#[sqlx::test]
async fn one_step_schedule(pool: PgPool) -> sqlx::Result<(), Error> {
let queue = Queue::builder()
.name("one_step_schedule")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_, _| async { To::done() })
.queue(queue.clone())
.build();
let monthly = "@monthly[America/Los_Angeles]"
.parse()
.expect("A valid zoned scheduled should be provided");
job.schedule(&monthly, &()).await?;
let (schedule, _) = queue
.task_schedule(&pool)
.await?
.expect("A schedule should be set");
assert_eq!(
schedule,
"@monthly[America/Los_Angeles]"
.parse()
.expect("A valid zoned scheduled should be provided")
);
Ok(())
}
#[sqlx::test]
async fn one_step_context_attributes(pool: PgPool) -> sqlx::Result<(), Error> {
let job = Job::builder()
.step(|ctx, _| async move {
assert_eq!(ctx.step_index, 0);
assert_eq!(ctx.step_count, 1);
assert_eq!(ctx.queue_name.as_str(), "one_step_context_attributes");
To::done()
})
.name("one_step_context_attributes")
.pool(pool.clone())
.build()
.await?;
job.enqueue(&()).await?;
let task_id = job.worker().process_next_task().await?;
assert!(task_id.is_some());
Ok(())
}
#[sqlx::test]
async fn multi_step(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Step1 {
message: String,
}
#[derive(Serialize, Deserialize)]
struct Step2 {
data: Vec<u8>,
}
let queue = Queue::builder()
.name("multi_step")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Step1 { message }| async move {
println!("Executing job with message: {message}");
To::next(Step2 {
data: message.as_bytes().into(),
})
})
.step(|_cx, Step2 { data }| async move {
println!("Executing job with data: {data:?}");
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let input = Step1 {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
job.worker().process_next_task().await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 1);
assert_eq!(
job_state.step_input,
serde_json::to_value(&Step2 {
data: "Hello, world!".as_bytes().to_vec()
})?
);
Ok(())
}
#[sqlx::test]
async fn multi_step_retry_policy(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Step1 {
message: String,
}
let step1_policy = RetryPolicy::builder().max_attempts(1).build();
#[derive(Serialize, Deserialize)]
struct Step2 {
data: Vec<u8>,
}
let step2_policy = RetryPolicy::builder().max_attempts(15).build();
let queue = Queue::builder()
.name("multi_step_retry_policy")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Step1 { message }| async move {
println!("Executing job with message: {message}");
To::next(Step2 {
data: message.as_bytes().into(),
})
})
.retry_policy(step1_policy)
.step(|_cx, Step2 { data }| async move {
println!("Executing job with data: {data:?}");
To::done()
})
.retry_policy(step2_policy)
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), step1_policy);
let input = Step1 {
message: "Hello, world!".to_string(),
};
let enqueued_job = job.enqueue(&input).await?;
let Some(dequeued_task) = queue.dequeue().await? else {
panic!("Task should exist");
};
assert_eq!(
enqueued_job.id,
dequeued_task
.input
.get("job_id")
.cloned()
.map(serde_json::from_value)
.expect("Failed to deserialize 'job_id'")?
);
assert_eq!(dequeued_task.retry_policy.max_attempts, 1);
sqlx::query!(
"update underway.task set state = $2 where id = $1",
dequeued_task.id as _,
TaskState::Pending as _
)
.execute(&pool)
.await?;
job.worker().process_next_task().await?;
let Some(dequeued_task) = queue.dequeue().await? else {
panic!("Next task should exist");
};
assert_eq!(dequeued_task.retry_policy.max_attempts, 15);
Ok(())
}
#[sqlx::test]
async fn multi_step_with_state(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Clone)]
struct State {
data: String,
}
#[derive(Serialize, Deserialize)]
struct Step1 {
message: String,
}
#[derive(Serialize, Deserialize)]
struct Step2 {
data: Vec<u8>,
}
let queue = Queue::builder()
.name("multi_step_with_state")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.state(State {
data: "data".to_string(),
})
.step(|cx, Step1 { message }| async move {
println!(
"Executing job with message: {message} and state: {state}",
state = cx.state.data
);
To::next(Step2 {
data: message.as_bytes().into(),
})
})
.step(|cx, Step2 { data }| async move {
println!(
"Executing job with data: {data:?} and state: {state}",
state = cx.state.data
);
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let input = Step1 {
message: "Hello, world!".to_string(),
};
job.enqueue(&input).await?;
job.worker().process_next_task().await?;
let pending_task = queue
.dequeue()
.await?
.expect("There should be an enqueued task");
let job_state: JobState = serde_json::from_value(pending_task.input)?;
assert_eq!(job_state.step_index, 1);
assert_eq!(
job_state.step_input,
serde_json::to_value(&Step2 {
data: "Hello, world!".as_bytes().to_vec()
})?
);
Ok(())
}
#[sqlx::test]
async fn multi_step_enqueue(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct Step1 {
message: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Step2 {
data: Vec<u8>,
}
let queue = Queue::builder()
.name("multi_step_enqueue")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Step1 { message }| async move {
println!("Executing job with message: {message}",);
To::next(Step2 {
data: message.as_bytes().into(),
})
})
.step(|_cx, Step2 { data }| async move {
println!("Executing job with data: {data:?}");
To::done()
})
.queue(queue.clone())
.build();
let input = Step1 {
message: "Hello, world!".to_string(),
};
let enqueued_job = job.enqueue(&input).await?;
let Some(dequeued_task) = queue.dequeue().await? else {
panic!("Task should exist");
};
assert_eq!(
enqueued_job.id,
dequeued_task
.input
.get("job_id")
.cloned()
.map(serde_json::from_value)
.expect("Failed to deserialize 'job_id'")?
);
let job_state: JobState = serde_json::from_value(dequeued_task.input).unwrap();
assert_eq!(
JobState {
step_index: 0,
step_input: serde_json::to_value(input).unwrap(),
job_id: job_state.job_id
},
job_state
);
sqlx::query!(
"update underway.task set state = $2 where id = $1",
dequeued_task.id as _,
TaskState::Pending as _
)
.execute(&pool)
.await?;
job.worker().process_next_task().await?;
let Some(dequeued_task) = queue.dequeue().await? else {
panic!("Next task should exist");
};
let step2_input = Step2 {
data: "Hello, world!".to_string().as_bytes().into(),
};
let job_state: JobState = serde_json::from_value(dequeued_task.input).unwrap();
assert_eq!(
JobState {
step_index: 1,
step_input: serde_json::to_value(step2_input).unwrap(),
job_id: job_state.job_id
},
job_state
);
Ok(())
}
#[sqlx::test]
async fn schedule(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
let queue = Queue::builder()
.name("schedule")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Input { message }| async move {
println!("Executing job with message: {message}");
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let daily = "@daily[America/Los_Angeles]"
.parse()
.expect("Schedule should parse");
let input = Input {
message: "Hello, world!".to_string(),
};
job.schedule(&daily, &input).await?;
let (zoned_schedule, schedule_input) = queue
.task_schedule(&pool)
.await?
.expect("Schedule should be set");
assert_eq!(zoned_schedule, daily);
assert_eq!(schedule_input.step_index, 0);
assert_eq!(schedule_input.step_input, serde_json::to_value(input)?);
Ok(())
}
#[sqlx::test]
async fn unschedule(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
let queue = Queue::builder()
.name("unschedule")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Input { message }| async move {
println!("Executing job with message: {message}");
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
let daily = "@daily[America/Los_Angeles]"
.parse()
.expect("Schedule should parse");
let input = Input {
message: "Hello, world!".to_string(),
};
job.schedule(&daily, &input).await?;
job.unschedule().await?;
assert!(queue.task_schedule(&pool).await?.is_none());
Ok(())
}
#[sqlx::test]
async fn unschedule_without_schedule(pool: PgPool) -> sqlx::Result<(), Error> {
#[derive(Serialize, Deserialize)]
struct Input {
message: String,
}
let queue = Queue::builder()
.name("unschedule_without_schedule")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, Input { message }| async move {
println!("Executing job with message: {message}");
To::done()
})
.queue(queue.clone())
.build();
assert_eq!(job.retry_policy(), RetryPolicy::default());
assert!(job.unschedule().await.is_ok());
assert!(queue.task_schedule(&pool).await?.is_none());
Ok(())
}
#[sqlx::test]
async fn enqueued_job_cancel(pool: PgPool) -> sqlx::Result<(), Error> {
let queue = Queue::builder()
.name("enqueued_job_cancel")
.pool(pool.clone())
.build()
.await?;
let job = Job::builder()
.step(|_cx, _| async move { To::done() })
.queue(queue.clone())
.build();
let enqueued_job = job.enqueue(&()).await?;
assert!(enqueued_job.cancel().await?);
let task = sqlx::query!(
r#"
select state as "state: TaskState"
from underway.task
where input->>'job_id' = $1
"#,
enqueued_job.id.to_string()
)
.fetch_one(&pool)
.await?;
assert_eq!(task.state, TaskState::Cancelled);
assert!(!enqueued_job.cancel().await?);
Ok(())
}
}