use apalis_core::{
error::{AbortError, BoxDynError},
layers::{Layer, Service},
task::{Parts, status::Status},
worker::{context::WorkerContext, ext::ack::Acknowledge},
};
use futures::{FutureExt, future::BoxFuture};
use serde::Serialize;
use sqlx::MySqlPool;
use ulid::Ulid;
use crate::{MySqlContext, MySqlTask};
#[derive(Clone, Debug)]
pub struct MySqlAck {
pool: MySqlPool,
}
impl MySqlAck {
pub fn new(pool: MySqlPool) -> Self {
Self { pool }
}
}
impl<Res: Serialize + 'static> Acknowledge<Res, MySqlContext, Ulid> for MySqlAck {
type Error = sqlx::Error;
type Future = BoxFuture<'static, Result<(), Self::Error>>;
fn ack(
&mut self,
res: &Result<Res, BoxDynError>,
parts: &Parts<MySqlContext, Ulid>,
) -> Self::Future {
let task_id = parts.task_id;
let worker_id = parts.ctx.lock_by().clone();
let response = serde_json::to_string(&res.as_ref().map_err(|e| e.to_string()));
let status = calculate_status(parts, res);
parts.status.store(status.clone());
let attempt = parts.attempt.current() as i32;
let pool = self.pool.clone();
let res = response.map_err(|e| sqlx::Error::Decode(e.into()));
let status = status.to_string();
async move {
let task_id = task_id
.ok_or(sqlx::Error::ColumnNotFound("TASK_ID_FOR_ACK".to_owned()))?
.to_string();
let worker_id =
worker_id.ok_or(sqlx::Error::ColumnNotFound("WORKER_ID_LOCK_BY".to_owned()))?;
let res_ok = res?;
let res = sqlx::query_file!(
"queries/task/ack.sql",
status,
attempt,
res_ok,
task_id,
worker_id
)
.execute(&pool)
.await?;
if res.rows_affected() == 0 {
return Err(sqlx::Error::RowNotFound);
}
Ok(())
}
.boxed()
}
}
pub(crate) fn calculate_status<Res>(
parts: &Parts<MySqlContext, Ulid>,
res: &Result<Res, BoxDynError>,
) -> Status {
match &res {
Ok(_) => Status::Done,
Err(e) => match e {
_ if parts.ctx.max_attempts() as usize <= parts.attempt.current() => Status::Killed,
e if e.downcast_ref::<AbortError>().is_some() => Status::Killed,
_ => Status::Failed,
},
}
}
pub(crate) async fn lock_task(
pool: &MySqlPool,
task_id: &Ulid,
worker_id: &str,
) -> Result<(), sqlx::Error> {
let task_id = task_id.to_string();
let res = sqlx::query_file!("queries/task/lock.sql", worker_id, task_id)
.execute(pool)
.await?;
if res.rows_affected() == 0 {
return Err(sqlx::Error::RowNotFound);
}
Ok(())
}
#[derive(Clone, Debug)]
pub struct LockTaskLayer {
pool: MySqlPool,
}
impl LockTaskLayer {
pub fn new(pool: MySqlPool) -> Self {
Self { pool }
}
}
impl<S> Layer<S> for LockTaskLayer {
type Service = LockTaskService<S>;
fn layer(&self, inner: S) -> Self::Service {
LockTaskService {
inner,
pool: self.pool.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct LockTaskService<S> {
inner: S,
pool: MySqlPool,
}
impl<S, Args> Service<MySqlTask<Args>> for LockTaskService<S>
where
S: Service<MySqlTask<Args>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxDynError>,
Args: Send + 'static,
{
type Response = S::Response;
type Error = BoxDynError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|e| e.into())
}
fn call(&mut self, mut req: MySqlTask<Args>) -> Self::Future {
let pool = self.pool.clone();
let worker_id = req
.parts
.data
.get::<WorkerContext>()
.map(|w| w.name().to_owned())
.unwrap();
let parts = &req.parts;
let task_id = match &parts.task_id {
Some(id) => *id.inner(),
None => {
return async {
Err(sqlx::Error::ColumnNotFound("TASK_ID_FOR_LOCK".to_owned()).into())
}
.boxed();
}
};
req.parts.ctx = req.parts.ctx.with_lock_by(Some(worker_id.clone()));
let fut = self.inner.call(req);
async move {
lock_task(&pool, &task_id, &worker_id).await?;
fut.await.map_err(|e| e.into())
}
.boxed()
}
}