apalis_postgres/
ack.rs

1use apalis_core::{
2    error::AbortError,
3    error::BoxDynError,
4    layers::{Layer, Service},
5    task::{Parts, status::Status},
6    worker::{context::WorkerContext, ext::ack::Acknowledge},
7};
8use futures::{FutureExt, future::BoxFuture};
9use serde::Serialize;
10use sqlx::PgPool;
11use ulid::Ulid;
12
13use crate::{PgContext, PgTask};
14
15#[derive(Debug, Clone)]
16pub struct PgAck {
17    pool: PgPool,
18}
19impl PgAck {
20    pub fn new(pool: PgPool) -> Self {
21        Self { pool }
22    }
23}
24
25impl<Res: Serialize> Acknowledge<Res, PgContext, Ulid> for PgAck {
26    type Error = sqlx::Error;
27    type Future = BoxFuture<'static, Result<(), Self::Error>>;
28    fn ack(
29        &mut self,
30        res: &Result<Res, BoxDynError>,
31        parts: &Parts<PgContext, Ulid>,
32    ) -> Self::Future {
33        let task_id = parts.task_id;
34        let worker_id = parts.ctx.lock_by().clone();
35
36        let response = serde_json::to_value(res.as_ref().map_err(|e| e.to_string()));
37        let status = calculate_status(parts, res);
38        let attempt = parts.attempt.current() as i32;
39        let pool = self.pool.clone();
40        async move {
41            let res = sqlx::query_file!(
42                "queries/task/ack.sql",
43                task_id
44                    .ok_or(sqlx::Error::ColumnNotFound("TASK_ID_FOR_ACK".to_owned()))?
45                    .to_string(),
46                attempt,
47                &response.map_err(|e| sqlx::Error::Decode(e.into()))?,
48                status.to_string(),
49                worker_id.ok_or(sqlx::Error::ColumnNotFound("WORKER_ID_LOCK_BY".to_owned()))?
50            )
51            .execute(&pool)
52            .await?;
53
54            if res.rows_affected() == 0 {
55                return Err(sqlx::Error::RowNotFound);
56            }
57            Ok(())
58        }
59        .boxed()
60    }
61}
62
63pub fn calculate_status<Res>(
64    parts: &Parts<PgContext, Ulid>,
65    res: &Result<Res, BoxDynError>,
66) -> Status {
67    match &res {
68        Ok(_) => Status::Done,
69        Err(e) => match &e {
70            // Error::Abort(_) => State::Killed,
71            _ if parts.ctx.max_attempts() as usize <= parts.attempt.current() => Status::Killed,
72            _ => Status::Failed,
73        },
74    }
75}
76
77pub async fn lock_task(pool: &PgPool, task_id: &Ulid, worker_id: &str) -> Result<(), sqlx::Error> {
78    let task_id = vec![task_id.to_string()];
79    sqlx::query_file!("queries/task/lock_by_id.sql", &task_id, &worker_id,)
80        .fetch_one(pool)
81        .await?;
82    Ok(())
83}
84
85#[derive(Debug, Clone)]
86
87pub struct LockTaskLayer {
88    pool: PgPool,
89}
90
91impl LockTaskLayer {
92    pub fn new(pool: PgPool) -> Self {
93        Self { pool }
94    }
95}
96
97impl<S> Layer<S> for LockTaskLayer {
98    type Service = LockTaskService<S>;
99
100    fn layer(&self, inner: S) -> Self::Service {
101        LockTaskService {
102            inner,
103            pool: self.pool.clone(),
104        }
105    }
106}
107
108#[derive(Debug, Clone)]
109pub struct LockTaskService<S> {
110    inner: S,
111    pool: PgPool,
112}
113
114impl<S, Args> Service<PgTask<Args>> for LockTaskService<S>
115where
116    S: Service<PgTask<Args>> + Send + 'static,
117    S::Future: Send + 'static,
118    S::Error: Into<BoxDynError>,
119    Args: Send + 'static,
120{
121    type Response = S::Response;
122    type Error = BoxDynError;
123    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
124
125    fn poll_ready(
126        &mut self,
127        cx: &mut std::task::Context<'_>,
128    ) -> std::task::Poll<Result<(), Self::Error>> {
129        self.inner.poll_ready(cx).map_err(|e| e.into())
130    }
131
132    fn call(&mut self, req: PgTask<Args>) -> Self::Future {
133        let pool = self.pool.clone();
134        let worker_id = req
135            .parts
136            .data
137            .get::<WorkerContext>()
138            .map(|w| w.name().to_owned())
139            .unwrap();
140        let parts = &req.parts;
141        let task_id = match &parts.task_id {
142            Some(id) => *id.inner(),
143            None => {
144                return async {
145                    Err(sqlx::Error::ColumnNotFound("TASK_ID_FOR_LOCK".to_owned()).into())
146                }
147                .boxed();
148            }
149        };
150        let fut = self.inner.call(req);
151        async move {
152            lock_task(&pool, &task_id, &worker_id)
153                .await
154                .map_err(AbortError::new)?;
155            fut.await.map_err(|e| e.into())
156        }
157        .boxed()
158    }
159}