1use apalis_core::{
2 error::AbortError,
3 error::BoxDynError,
4 layers::{Layer, Service},
5 task::{Parts, status::Status},
6 worker::{context::WorkerContext, ext::ack::Acknowledge},
7};
8use futures::{FutureExt, future::BoxFuture};
9use serde::Serialize;
10use sqlx::PgPool;
11use ulid::Ulid;
12
13use crate::{PgContext, PgTask};
14
15#[derive(Debug, Clone)]
16pub struct PgAck {
17 pool: PgPool,
18}
19impl PgAck {
20 pub fn new(pool: PgPool) -> Self {
21 Self { pool }
22 }
23}
24
25impl<Res: Serialize> Acknowledge<Res, PgContext, Ulid> for PgAck {
26 type Error = sqlx::Error;
27 type Future = BoxFuture<'static, Result<(), Self::Error>>;
28 fn ack(
29 &mut self,
30 res: &Result<Res, BoxDynError>,
31 parts: &Parts<PgContext, Ulid>,
32 ) -> Self::Future {
33 let task_id = parts.task_id;
34 let worker_id = parts.ctx.lock_by().clone();
35
36 let response = serde_json::to_value(res.as_ref().map_err(|e| e.to_string()));
37 let status = calculate_status(parts, res);
38 let attempt = parts.attempt.current() as i32;
39 let pool = self.pool.clone();
40 async move {
41 let res = sqlx::query_file!(
42 "queries/task/ack.sql",
43 task_id
44 .ok_or(sqlx::Error::ColumnNotFound("TASK_ID_FOR_ACK".to_owned()))?
45 .to_string(),
46 attempt,
47 &response.map_err(|e| sqlx::Error::Decode(e.into()))?,
48 status.to_string(),
49 worker_id.ok_or(sqlx::Error::ColumnNotFound("WORKER_ID_LOCK_BY".to_owned()))?
50 )
51 .execute(&pool)
52 .await?;
53
54 if res.rows_affected() == 0 {
55 return Err(sqlx::Error::RowNotFound);
56 }
57 Ok(())
58 }
59 .boxed()
60 }
61}
62
63pub fn calculate_status<Res>(
64 parts: &Parts<PgContext, Ulid>,
65 res: &Result<Res, BoxDynError>,
66) -> Status {
67 match &res {
68 Ok(_) => Status::Done,
69 Err(e) => match &e {
70 _ if parts.ctx.max_attempts() as usize <= parts.attempt.current() => Status::Killed,
72 _ => Status::Failed,
73 },
74 }
75}
76
77pub async fn lock_task(pool: &PgPool, task_id: &Ulid, worker_id: &str) -> Result<(), sqlx::Error> {
78 let task_id = vec![task_id.to_string()];
79 sqlx::query_file!("queries/task/lock_by_id.sql", &task_id, &worker_id,)
80 .fetch_one(pool)
81 .await?;
82 Ok(())
83}
84
85#[derive(Debug, Clone)]
86
87pub struct LockTaskLayer {
88 pool: PgPool,
89}
90
91impl LockTaskLayer {
92 pub fn new(pool: PgPool) -> Self {
93 Self { pool }
94 }
95}
96
97impl<S> Layer<S> for LockTaskLayer {
98 type Service = LockTaskService<S>;
99
100 fn layer(&self, inner: S) -> Self::Service {
101 LockTaskService {
102 inner,
103 pool: self.pool.clone(),
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
109pub struct LockTaskService<S> {
110 inner: S,
111 pool: PgPool,
112}
113
114impl<S, Args> Service<PgTask<Args>> for LockTaskService<S>
115where
116 S: Service<PgTask<Args>> + Send + 'static,
117 S::Future: Send + 'static,
118 S::Error: Into<BoxDynError>,
119 Args: Send + 'static,
120{
121 type Response = S::Response;
122 type Error = BoxDynError;
123 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
124
125 fn poll_ready(
126 &mut self,
127 cx: &mut std::task::Context<'_>,
128 ) -> std::task::Poll<Result<(), Self::Error>> {
129 self.inner.poll_ready(cx).map_err(|e| e.into())
130 }
131
132 fn call(&mut self, req: PgTask<Args>) -> Self::Future {
133 let pool = self.pool.clone();
134 let worker_id = req
135 .parts
136 .data
137 .get::<WorkerContext>()
138 .map(|w| w.name().to_owned())
139 .unwrap();
140 let parts = &req.parts;
141 let task_id = match &parts.task_id {
142 Some(id) => *id.inner(),
143 None => {
144 return async {
145 Err(sqlx::Error::ColumnNotFound("TASK_ID_FOR_LOCK".to_owned()).into())
146 }
147 .boxed();
148 }
149 };
150 let fut = self.inner.call(req);
151 async move {
152 lock_task(&pool, &task_id, &worker_id)
153 .await
154 .map_err(AbortError::new)?;
155 fut.await.map_err(|e| e.into())
156 }
157 .boxed()
158 }
159}