1use apalis_core::{
4 error::{AbortError, BoxDynError},
5 layers::{Layer, Service},
6 task::{Parts, status::Status},
7 worker::{context::WorkerContext, ext::ack::Acknowledge},
8};
9use apalis_sql::context::SqlContext;
10use futures::{FutureExt, future::BoxFuture};
11use libsql::Database;
12use serde::Serialize;
13use ulid::Ulid;
14
15use crate::{LibsqlError, LibsqlTask};
16
17const ACK_SQL: &str = r#"
19UPDATE Jobs
20SET status = ?4, attempts = ?2, last_error = ?3, done_at = strftime('%s', 'now')
21WHERE id = ?1 AND lock_by = ?5
22"#;
23
24const LOCK_SQL: &str = r#"
26UPDATE Jobs
27SET status = 'Running', lock_by = ?2, lock_at = strftime('%s', 'now')
28WHERE id = ?1 AND (lock_by IS NULL OR lock_by = ?2)
29"#;
30
31#[derive(Clone, Debug)]
33pub struct LibsqlAck {
34 db: &'static Database,
35}
36
37impl LibsqlAck {
38 #[must_use]
40 pub fn new(db: &'static Database) -> Self {
41 Self { db }
42 }
43}
44
45impl<Res: Serialize + 'static> Acknowledge<Res, SqlContext, Ulid> for LibsqlAck {
46 type Error = LibsqlError;
47 type Future = BoxFuture<'static, Result<(), Self::Error>>;
48
49 fn ack(
50 &mut self,
51 res: &Result<Res, BoxDynError>,
52 parts: &Parts<SqlContext, Ulid>,
53 ) -> Self::Future {
54 let task_id = parts.task_id;
55 let worker_id = parts.ctx.lock_by().clone();
56 let db = self.db;
57
58 let response = serde_json::to_string(&res.as_ref().map_err(|e| e.to_string()));
60
61 let status = calculate_status(parts, res);
62 parts.status.store(status.clone());
63 let attempt_result = parts.attempt.current().try_into();
64 let status_str = status.to_string();
65
66 async move {
67 let task_id = task_id
68 .ok_or_else(|| LibsqlError::Other("Missing task_id for ack".into()))?
69 .to_string();
70 let worker_id = worker_id
71 .ok_or_else(|| LibsqlError::Other("Missing worker_id (lock_by)".into()))?;
72 let res_str = response.map_err(|e| LibsqlError::Other(e.to_string()))?;
73
74 let attempt: i32 = attempt_result
75 .map_err(|e| LibsqlError::Other(format!("Attempt count overflow: {}", e)))?;
76
77 let conn = db.connect()?;
78 let rows_affected = conn
79 .execute(
80 ACK_SQL,
81 libsql::params![task_id, attempt, res_str, status_str, worker_id],
82 )
83 .await
84 .map_err(LibsqlError::Database)?;
85
86 if rows_affected == 0 {
87 return Err(LibsqlError::Other("Task not found or already acked".into()));
88 }
89
90 Ok(())
91 }
92 .boxed()
93 }
94}
95
96pub fn calculate_status<Res>(
98 parts: &Parts<SqlContext, Ulid>,
99 res: &Result<Res, BoxDynError>,
100) -> Status {
101 match res {
102 Ok(_) => Status::Done,
103 Err(e) => {
104 #[allow(clippy::if_same_then_else)]
106 if parts.ctx.max_attempts() as usize <= parts.attempt.current() {
107 Status::Killed
108 } else if e.downcast_ref::<AbortError>().is_some() {
110 Status::Killed
111 } else {
112 Status::Failed
113 }
114 }
115 }
116}
117
118pub async fn lock_task(
120 db: &'static Database,
121 task_id: &Ulid,
122 worker_id: &str,
123) -> Result<(), LibsqlError> {
124 let conn = db.connect()?;
125 let task_id_str = task_id.to_string();
126
127 let rows_affected = conn
128 .execute(LOCK_SQL, libsql::params![task_id_str, worker_id])
129 .await
130 .map_err(LibsqlError::Database)?;
131
132 if rows_affected == 0 {
133 return Err(LibsqlError::Other(
134 "Task not found or already locked".into(),
135 ));
136 }
137
138 Ok(())
139}
140
141#[derive(Clone, Debug)]
143pub struct LockTaskLayer {
144 db: &'static Database,
145}
146
147impl LockTaskLayer {
148 #[must_use]
150 pub fn new(db: &'static Database) -> Self {
151 Self { db }
152 }
153}
154
155impl<S> Layer<S> for LockTaskLayer {
156 type Service = LockTaskService<S>;
157
158 fn layer(&self, inner: S) -> Self::Service {
159 LockTaskService { inner, db: self.db }
160 }
161}
162
163#[derive(Clone, Debug)]
165pub struct LockTaskService<S> {
166 inner: S,
167 db: &'static Database,
168}
169
170impl<S, Args> Service<LibsqlTask<Args>> for LockTaskService<S>
171where
172 S: Service<LibsqlTask<Args>> + Send + 'static + Clone,
173 S::Future: Send + 'static,
174 S::Error: Into<BoxDynError>,
175 Args: Send + 'static,
176{
177 type Response = S::Response;
178 type Error = BoxDynError;
179 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
180
181 fn poll_ready(
182 &mut self,
183 cx: &mut std::task::Context<'_>,
184 ) -> std::task::Poll<Result<(), Self::Error>> {
185 self.inner.poll_ready(cx).map_err(Into::into)
186 }
187
188 fn call(&mut self, mut req: LibsqlTask<Args>) -> Self::Future {
189 let db = self.db;
190 let worker_id = req
191 .parts
192 .data
193 .get::<WorkerContext>()
194 .map(|w| w.name().to_owned())
195 .ok_or_else(|| LibsqlError::Other("Missing WorkerContext for lock".into()));
196
197 let worker_id = match worker_id {
198 Ok(id) => id,
199 Err(e) => return async move { Err(e.into()) }.boxed(),
200 };
201
202 let task_id = match &req.parts.task_id {
203 Some(id) => *id.inner(),
204 None => {
205 return async { Err(LibsqlError::Other("Missing task_id for lock".into()).into()) }
206 .boxed();
207 }
208 };
209
210 req.parts.ctx = req.parts.ctx.with_lock_by(Some(worker_id.clone()));
212
213 let mut inner = self.inner.clone();
214
215 async move {
216 lock_task(db, &task_id, &worker_id).await?;
217 inner.call(req).await.map_err(Into::into)
218 }
219 .boxed()
220 }
221}