use apalis_core::{
error::BoxDynError,
layers::{Layer, Service},
task::Parts,
worker::{context::WorkerContext, ext::ack::Acknowledge},
};
use futures::{FutureExt, future::BoxFuture};
use serde::Serialize;
use sqlx::SqlitePool;
use ulid::Ulid;
use crate::{
SqliteContext, SqliteTask,
queries::{
ack_task::{ack_task, calculate_status},
lock_task::lock_task,
},
};
#[derive(Clone, Debug)]
pub struct SqliteAck {
pool: SqlitePool,
}
impl SqliteAck {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
impl<Res: Serialize + 'static> Acknowledge<Res, SqliteContext, Ulid> for SqliteAck {
type Error = sqlx::Error;
type Future = BoxFuture<'static, Result<(), Self::Error>>;
fn ack(
&mut self,
res: &Result<Res, BoxDynError>,
parts: &Parts<SqliteContext, Ulid>,
) -> Self::Future {
let task_id = parts.task_id;
let worker_id = parts.ctx.lock_by().clone();
let response = serde_json::to_string(&res.as_ref().map_err(|e| e.to_string()));
let status = calculate_status(parts, res);
parts.status.store(status.clone());
let attempt = parts.attempt.current() as i32;
let pool = self.pool.clone();
let res = response.map_err(|e| sqlx::Error::Decode(e.into()));
async move {
let task_id = task_id
.ok_or(sqlx::Error::ColumnNotFound("TASK_ID_FOR_ACK".to_owned()))?
.to_string();
let worker_id =
worker_id.ok_or(sqlx::Error::ColumnNotFound("WORKER_ID_LOCK_BY".to_owned()))?;
let res_ok = res?;
ack_task(&pool, &task_id, &worker_id, &res_ok, &status, attempt).await?;
Ok(())
}
.boxed()
}
}
#[derive(Clone, Debug)]
pub struct LockTaskLayer {
pool: SqlitePool,
}
impl LockTaskLayer {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
impl<S> Layer<S> for LockTaskLayer {
type Service = LockTaskService<S>;
fn layer(&self, inner: S) -> Self::Service {
LockTaskService {
inner,
pool: self.pool.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct LockTaskService<S> {
inner: S,
pool: SqlitePool,
}
impl<S, Args> Service<SqliteTask<Args>> for LockTaskService<S>
where
S: Service<SqliteTask<Args>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxDynError>,
Args: Send + 'static,
{
type Response = S::Response;
type Error = BoxDynError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|e| e.into())
}
fn call(&mut self, mut req: SqliteTask<Args>) -> Self::Future {
let pool = self.pool.clone();
let worker_id = req
.parts
.data
.get::<WorkerContext>()
.map(|w| w.name().to_owned())
.unwrap();
let parts = &req.parts;
let task_id = match &parts.task_id {
Some(id) => *id.inner(),
None => {
return async {
Err(sqlx::Error::ColumnNotFound("TASK_ID_FOR_LOCK".to_owned()).into())
}
.boxed();
}
};
req.parts.ctx = req.parts.ctx.with_lock_by(Some(worker_id.clone()));
let fut = self.inner.call(req);
async move {
lock_task(&pool, &task_id.to_string(), &worker_id).await?;
fut.await.map_err(|e| e.into())
}
.boxed()
}
}