apalis_postgres/
ack.rs

1use apalis_core::{
2    error::BoxDynError,
3    layers::{Layer, Service},
4    task::{Parts, status::Status},
5    worker::{context::WorkerContext, ext::ack::Acknowledge},
6};
7use futures::{FutureExt, future::BoxFuture};
8use serde::Serialize;
9use sqlx::PgPool;
10use ulid::Ulid;
11
12use crate::{PgContext, PgTask};
13
14#[derive(Debug, Clone)]
15pub struct PgAck {
16    pool: PgPool,
17}
18impl PgAck {
19    pub fn new(pool: PgPool) -> Self {
20        Self { pool }
21    }
22}
23
24impl<Res: Serialize> Acknowledge<Res, PgContext, Ulid> for PgAck {
25    type Error = sqlx::Error;
26    type Future = BoxFuture<'static, Result<(), Self::Error>>;
27    fn ack(
28        &mut self,
29        res: &Result<Res, BoxDynError>,
30        parts: &Parts<PgContext, Ulid>,
31    ) -> Self::Future {
32        let task_id = parts.task_id;
33        let worker_id = parts.ctx.lock_by().clone();
34
35        let response = serde_json::to_value(res.as_ref().map_err(|e| e.to_string()));
36        let status = calculate_status(parts, res);
37        let attempt = parts.attempt.current() as i32;
38        let pool = self.pool.clone();
39        async move {
40            let res = sqlx::query_file!(
41                "queries/task/ack.sql",
42                task_id
43                    .ok_or(sqlx::Error::ColumnNotFound("TASK_ID_FOR_ACK".to_owned()))?
44                    .to_string(),
45                attempt,
46                &response.map_err(|e| sqlx::Error::Decode(e.into()))?,
47                status.to_string(),
48                worker_id.ok_or(sqlx::Error::ColumnNotFound("WORKER_ID_LOCK_BY".to_owned()))?
49            )
50            .execute(&pool)
51            .await?;
52
53            if res.rows_affected() == 0 {
54                return Err(sqlx::Error::RowNotFound);
55            }
56            Ok(())
57        }
58        .boxed()
59    }
60}
61
62pub fn calculate_status<Res>(
63    parts: &Parts<PgContext, Ulid>,
64    res: &Result<Res, BoxDynError>,
65) -> Status {
66    match &res {
67        Ok(_) => Status::Done,
68        Err(e) => match &e {
69            // Error::Abort(_) => State::Killed,
70            _ if parts.ctx.max_attempts() as usize <= parts.attempt.current() => Status::Killed,
71            _ => Status::Failed,
72        },
73    }
74}
75
76pub async fn lock_task(pool: &PgPool, task_id: &Ulid, worker_id: &str) -> Result<(), sqlx::Error> {
77    let task_id = vec![task_id.to_string()];
78    sqlx::query_file!("queries/task/lock_by_id.sql", &task_id, &worker_id,)
79        .fetch_one(pool)
80        .await?;
81    Ok(())
82}
83
84#[derive(Debug, Clone)]
85
86pub struct LockTaskLayer {
87    pool: PgPool,
88}
89
90impl LockTaskLayer {
91    pub fn new(pool: PgPool) -> Self {
92        Self { pool }
93    }
94}
95
96impl<S> Layer<S> for LockTaskLayer {
97    type Service = LockTaskService<S>;
98
99    fn layer(&self, inner: S) -> Self::Service {
100        LockTaskService {
101            inner,
102            pool: self.pool.clone(),
103        }
104    }
105}
106
107#[derive(Debug, Clone)]
108pub struct LockTaskService<S> {
109    inner: S,
110    pool: PgPool,
111}
112
113impl<S, Args> Service<PgTask<Args>> for LockTaskService<S>
114where
115    S: Service<PgTask<Args>> + Send + 'static,
116    S::Future: Send + 'static,
117    S::Error: Into<BoxDynError>,
118    Args: Send + 'static,
119{
120    type Response = S::Response;
121    type Error = BoxDynError;
122    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
123
124    fn poll_ready(
125        &mut self,
126        cx: &mut std::task::Context<'_>,
127    ) -> std::task::Poll<Result<(), Self::Error>> {
128        self.inner.poll_ready(cx).map_err(|e| e.into())
129    }
130
131    fn call(&mut self, req: PgTask<Args>) -> Self::Future {
132        let pool = self.pool.clone();
133        let worker_id = req
134            .parts
135            .data
136            .get::<WorkerContext>()
137            .map(|w| w.name().to_owned())
138            .unwrap();
139        let parts = &req.parts;
140        let task_id = match &parts.task_id {
141            Some(id) => *id.inner(),
142            None => {
143                return async {
144                    Err(sqlx::Error::ColumnNotFound("TASK_ID_FOR_LOCK".to_owned()).into())
145                }
146                .boxed();
147            }
148        };
149        let fut = self.inner.call(req);
150        async move {
151            lock_task(&pool, &task_id, &worker_id).await.unwrap();
152            fut.await.map_err(|e| e.into())
153        }
154        .boxed()
155    }
156}