use std::{
fmt::{self, Display},
future::Future,
ops::Deref,
result::Result as StdResult,
};
use jiff::{SignedDuration, Span, ToSpan};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sqlx::{Postgres, Transaction};
use ulid::Ulid;
use uuid::Uuid;
pub(crate) use self::retry_policy::RetryCount;
pub use self::retry_policy::RetryPolicy;
mod retry_policy;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Hash, Eq, PartialEq, sqlx::Type)]
#[sqlx(transparent)]
pub struct TaskId(Uuid);
impl TaskId {
pub(crate) fn new() -> Self {
Self(Ulid::new().into())
}
}
impl Deref for TaskId {
type Target = Uuid;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Display for TaskId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
pub type Result<T> = StdResult<T, Error>;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error("Task timed out after {0} during execution")]
TimedOut(SignedDuration),
#[error("{0}")]
Fatal(String),
#[error("{0}")]
Retryable(String),
}
pub trait ToTaskResult<T> {
fn retryable(self) -> StdResult<T, Error>;
fn fatal(self) -> StdResult<T, Error>;
}
impl<T, E: std::fmt::Display> ToTaskResult<T> for StdResult<T, E> {
fn retryable(self) -> StdResult<T, Error> {
self.map_err(|err| Error::Retryable(err.to_string()))
}
fn fatal(self) -> StdResult<T, Error> {
self.map_err(|err| Error::Fatal(err.to_string()))
}
}
pub trait Task: Send + 'static {
type Input: DeserializeOwned + Serialize + Send + 'static;
type Output: Serialize + Send + 'static;
fn execute(
&self,
tx: Transaction<'_, Postgres>,
input: Self::Input,
) -> impl Future<Output = Result<Self::Output>> + Send;
fn retry_policy(&self) -> RetryPolicy {
RetryPolicy::default()
}
fn timeout(&self) -> Span {
15.minutes()
}
fn ttl(&self) -> Span {
14.days()
}
fn delay(&self) -> Span {
Span::new()
}
fn heartbeat(&self) -> Span {
30.seconds()
}
fn concurrency_key(&self) -> Option<String> {
None
}
fn priority(&self) -> i32 {
0
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, sqlx::Type)]
#[sqlx(type_name = "underway.task_state", rename_all = "snake_case")]
pub enum State {
Pending,
InProgress,
Succeeded,
Cancelled,
Failed,
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use super::*;
#[derive(Debug, Deserialize, Serialize)]
struct TestTaskInput {
message: String,
}
struct TestTask;
impl Task for TestTask {
type Input = TestTaskInput;
type Output = ();
async fn execute(
&self,
_tx: Transaction<'_, Postgres>,
input: Self::Input,
) -> Result<Self::Output> {
println!("Executing task with message: {}", input.message);
if input.message == "fail" {
return Err(Error::Retryable("Task failed".to_string()));
}
Ok(())
}
}
#[sqlx::test]
async fn task_execution_success(pool: PgPool) {
let task = TestTask;
let input = TestTaskInput {
message: "Hello, World!".to_string(),
};
let tx = pool.begin().await.unwrap();
let result = task.execute(tx, input).await;
assert!(result.is_ok())
}
#[sqlx::test]
async fn task_execution_failure(pool: PgPool) {
let task = TestTask;
let input = TestTaskInput {
message: "fail".to_string(),
};
let tx = pool.begin().await.unwrap();
let result = task.execute(tx, input).await;
assert!(result.is_err())
}
#[test]
fn retry_policy_defaults() {
let default_policy = RetryPolicy::default();
assert_eq!(default_policy.max_attempts, 5);
assert_eq!(default_policy.initial_interval_ms, 1_000);
assert_eq!(default_policy.max_interval_ms, 60_000);
assert_eq!(default_policy.backoff_coefficient, 2.0);
}
#[test]
fn retry_policy_custom() {
let retry_policy = RetryPolicy::builder()
.max_attempts(3)
.initial_interval_ms(500)
.max_interval_ms(5_000)
.backoff_coefficient(1.5)
.build();
assert_eq!(retry_policy.max_attempts, 3);
assert_eq!(retry_policy.initial_interval_ms, 500);
assert_eq!(retry_policy.max_interval_ms, 5_000);
assert_eq!(retry_policy.backoff_coefficient, 1.5);
}
}