apalis_libsql/
ack.rs

1//! Acknowledgment implementation for libSQL backend
2
3use apalis_core::{
4    error::{AbortError, BoxDynError},
5    layers::{Layer, Service},
6    task::{Parts, status::Status},
7    worker::{context::WorkerContext, ext::ack::Acknowledge},
8};
9use apalis_sql::context::SqlContext;
10use futures::{FutureExt, future::BoxFuture};
11use libsql::Database;
12use serde::Serialize;
13use ulid::Ulid;
14
15use crate::{LibsqlError, LibsqlTask};
16
17/// SQL query to acknowledge a task completion
18const ACK_SQL: &str = r#"
19UPDATE Jobs
20SET status = ?4, attempts = ?2, last_error = ?3, done_at = strftime('%s', 'now')
21WHERE id = ?1 AND lock_by = ?5
22"#;
23
24/// SQL query to lock a task before processing  
25const LOCK_SQL: &str = r#"
26UPDATE Jobs
27SET status = 'Running', lock_by = ?2, lock_at = strftime('%s', 'now')
28WHERE id = ?1 AND (lock_by IS NULL OR lock_by = ?2)
29"#;
30
31/// Acknowledgment handler for libSQL backend
32#[derive(Clone, Debug)]
33pub struct LibsqlAck {
34    db: &'static Database,
35}
36
37impl LibsqlAck {
38    /// Create a new LibsqlAck
39    #[must_use]
40    pub fn new(db: &'static Database) -> Self {
41        Self { db }
42    }
43}
44
45impl<Res: Serialize + 'static> Acknowledge<Res, SqlContext, Ulid> for LibsqlAck {
46    type Error = LibsqlError;
47    type Future = BoxFuture<'static, Result<(), Self::Error>>;
48
49    fn ack(
50        &mut self,
51        res: &Result<Res, BoxDynError>,
52        parts: &Parts<SqlContext, Ulid>,
53    ) -> Self::Future {
54        let task_id = parts.task_id;
55        let worker_id = parts.ctx.lock_by().clone();
56        let db = self.db;
57
58        // Serialize response for storage
59        let response = serde_json::to_string(&res.as_ref().map_err(|e| e.to_string()));
60
61        let status = calculate_status(parts, res);
62        parts.status.store(status.clone());
63        let attempt_result = parts.attempt.current().try_into();
64        let status_str = status.to_string();
65
66        async move {
67            let task_id = task_id
68                .ok_or_else(|| LibsqlError::Other("Missing task_id for ack".into()))?
69                .to_string();
70            let worker_id = worker_id
71                .ok_or_else(|| LibsqlError::Other("Missing worker_id (lock_by)".into()))?;
72            let res_str = response.map_err(|e| LibsqlError::Other(e.to_string()))?;
73
74            let attempt: i32 = attempt_result
75                .map_err(|e| LibsqlError::Other(format!("Attempt count overflow: {}", e)))?;
76
77            let conn = db.connect()?;
78            let rows_affected = conn
79                .execute(
80                    ACK_SQL,
81                    libsql::params![task_id, attempt, res_str, status_str, worker_id],
82                )
83                .await
84                .map_err(LibsqlError::Database)?;
85
86            if rows_affected == 0 {
87                return Err(LibsqlError::Other("Task not found or already acked".into()));
88            }
89
90            Ok(())
91        }
92        .boxed()
93    }
94}
95
96/// Calculate the status based on the result and attempt count
97pub fn calculate_status<Res>(
98    parts: &Parts<SqlContext, Ulid>,
99    res: &Result<Res, BoxDynError>,
100) -> Status {
101    match res {
102        Ok(_) => Status::Done,
103        Err(e) => {
104            // Check if max attempts exceeded or explicitly aborted
105            #[allow(clippy::if_same_then_else)]
106            if parts.ctx.max_attempts() as usize <= parts.attempt.current() {
107                Status::Killed
108            // Check if explicitly aborted
109            } else if e.downcast_ref::<AbortError>().is_some() {
110                Status::Killed
111            } else {
112                Status::Failed
113            }
114        }
115    }
116}
117
118/// Lock a task for processing
119pub async fn lock_task(
120    db: &'static Database,
121    task_id: &Ulid,
122    worker_id: &str,
123) -> Result<(), LibsqlError> {
124    let conn = db.connect()?;
125    let task_id_str = task_id.to_string();
126
127    let rows_affected = conn
128        .execute(LOCK_SQL, libsql::params![task_id_str, worker_id])
129        .await
130        .map_err(LibsqlError::Database)?;
131
132    if rows_affected == 0 {
133        return Err(LibsqlError::Other(
134            "Task not found or already locked".into(),
135        ));
136    }
137
138    Ok(())
139}
140
141/// Layer for locking tasks before processing
142#[derive(Clone, Debug)]
143pub struct LockTaskLayer {
144    db: &'static Database,
145}
146
147impl LockTaskLayer {
148    /// Create a new LockTaskLayer
149    #[must_use]
150    pub fn new(db: &'static Database) -> Self {
151        Self { db }
152    }
153}
154
155impl<S> Layer<S> for LockTaskLayer {
156    type Service = LockTaskService<S>;
157
158    fn layer(&self, inner: S) -> Self::Service {
159        LockTaskService { inner, db: self.db }
160    }
161}
162
163/// Service that locks tasks before passing them to the inner service
164#[derive(Clone, Debug)]
165pub struct LockTaskService<S> {
166    inner: S,
167    db: &'static Database,
168}
169
170impl<S, Args> Service<LibsqlTask<Args>> for LockTaskService<S>
171where
172    S: Service<LibsqlTask<Args>> + Send + 'static + Clone,
173    S::Future: Send + 'static,
174    S::Error: Into<BoxDynError>,
175    Args: Send + 'static,
176{
177    type Response = S::Response;
178    type Error = BoxDynError;
179    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
180
181    fn poll_ready(
182        &mut self,
183        cx: &mut std::task::Context<'_>,
184    ) -> std::task::Poll<Result<(), Self::Error>> {
185        self.inner.poll_ready(cx).map_err(Into::into)
186    }
187
188    fn call(&mut self, mut req: LibsqlTask<Args>) -> Self::Future {
189        let db = self.db;
190        let worker_id = req
191            .parts
192            .data
193            .get::<WorkerContext>()
194            .map(|w| w.name().to_owned())
195            .ok_or_else(|| LibsqlError::Other("Missing WorkerContext for lock".into()));
196
197        let worker_id = match worker_id {
198            Ok(id) => id,
199            Err(e) => return async move { Err(e.into()) }.boxed(),
200        };
201
202        let task_id = match &req.parts.task_id {
203            Some(id) => *id.inner(),
204            None => {
205                return async { Err(LibsqlError::Other("Missing task_id for lock".into()).into()) }
206                    .boxed();
207            }
208        };
209
210        // Update context with lock_by
211        req.parts.ctx = req.parts.ctx.with_lock_by(Some(worker_id.clone()));
212
213        let mut inner = self.inner.clone();
214
215        async move {
216            lock_task(db, &task_id, &worker_id).await?;
217            inner.call(req).await.map_err(Into::into)
218        }
219        .boxed()
220    }
221}