apalis_libsql/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{fmt, marker::PhantomData, pin::Pin};
4
5use apalis_core::{
6    backend::{Backend, BackendExt, codec::Codec},
7    error::BoxDynError,
8    layers::Stack,
9    task::Task,
10    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
11};
12pub use apalis_sql::context::SqlContext;
13use futures::{FutureExt, Stream, StreamExt, stream::BoxStream};
14use libsql::Database;
15use pin_project::pin_project;
16use ulid::Ulid;
17
18pub mod ack;
19/// Configuration for the libSQL storage backend
20pub mod config;
21/// Fetcher implementation for polling tasks
22pub mod fetcher;
23/// Row mapping from database rows to task structs
24pub mod row;
25/// Sink implementation for pushing tasks
26pub mod sink;
27
28pub use ack::{LibsqlAck, LockTaskLayer, LockTaskService};
29pub use config::Config;
30pub use fetcher::LibsqlPollFetcher;
31pub use sink::LibsqlSink;
32
33/// Type alias for a task stored in libsql backend
34pub type LibsqlTask<Args> = Task<Args, SqlContext, Ulid>;
35
36/// CompactType is the type used for compact serialization in libsql backend
37pub type CompactType = Vec<u8>;
38
39/// Error type for libSQL storage operations
40#[derive(Debug, thiserror::Error)]
41pub enum LibsqlError {
42    /// Database error from libsql
43    #[error("Database error: {0}")]
44    Database(#[from] libsql::Error),
45    /// Other errors
46    #[error("Other error: {0}")]
47    Other(String),
48}
49
50/// SQL query to register a worker
51const REGISTER_WORKER_SQL: &str = r#"
52INSERT OR REPLACE INTO Workers (id, worker_type, storage_name, layers, last_seen)
53VALUES (?1, ?2, 'LibsqlStorage', '', strftime('%s', 'now'))
54"#;
55
56/// SQL query to update worker heartbeat
57const KEEP_ALIVE_SQL: &str = r#"
58UPDATE Workers SET last_seen = strftime('%s', 'now') WHERE id = ?1
59"#;
60
61/// SQL query to re-enqueue orphaned tasks
62const REENQUEUE_ORPHANED_SQL: &str = r#"
63UPDATE Jobs
64SET status = 'Pending', lock_by = NULL, lock_at = NULL
65WHERE status = 'Running' AND lock_by IN (
66    SELECT id FROM Workers WHERE last_seen < strftime('%s', 'now') - ?1
67) AND job_type = ?2
68"#;
69
70/// LibsqlStorage is a storage backend for apalis using libsql as the database.
71#[pin_project]
72pub struct LibsqlStorage<T, C> {
73    db: &'static Database,
74    config: Config,
75    job_type: PhantomData<T>,
76    codec: PhantomData<C>,
77    #[pin]
78    sink: LibsqlSink<T, C>,
79}
80
81impl<T, C> fmt::Debug for LibsqlStorage<T, C> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("LibsqlStorage")
84            .field("db", &"Database")
85            .field("config", &self.config)
86            .field("job_type", &std::any::type_name::<T>())
87            .field("codec", &std::any::type_name::<C>())
88            .finish()
89    }
90}
91
92impl<T, C> Clone for LibsqlStorage<T, C> {
93    fn clone(&self) -> Self {
94        Self {
95            db: self.db,
96            config: self.config.clone(),
97            job_type: PhantomData,
98            codec: PhantomData,
99            sink: self.sink.clone(),
100        }
101    }
102}
103
104impl<T> LibsqlStorage<T, ()> {
105    /// Create a new LibsqlStorage with default JSON codec
106    #[must_use]
107    pub fn new(
108        db: &'static Database,
109    ) -> LibsqlStorage<T, apalis_core::backend::codec::json::JsonCodec<CompactType>> {
110        let config = Config::new(std::any::type_name::<T>());
111        LibsqlStorage {
112            db,
113            config: config.clone(),
114            job_type: PhantomData,
115            codec: PhantomData,
116            sink: LibsqlSink::new(db, &config),
117        }
118    }
119
120    /// Create a new LibsqlStorage with custom config
121    #[must_use]
122    #[allow(clippy::needless_pass_by_value)]
123    pub fn new_with_config(
124        db: &'static Database,
125        config: Config,
126    ) -> LibsqlStorage<T, apalis_core::backend::codec::json::JsonCodec<CompactType>> {
127        LibsqlStorage {
128            db,
129            config: config.clone(),
130            job_type: PhantomData,
131            codec: PhantomData,
132            sink: LibsqlSink::new(db, &config),
133        }
134    }
135}
136
137impl<T, C> LibsqlStorage<T, C> {
138    /// Get the database reference
139    #[must_use]
140    pub fn db(&self) -> &'static Database {
141        self.db
142    }
143
144    /// Get the config
145    #[must_use]
146    pub fn config(&self) -> &Config {
147        &self.config
148    }
149
150    /// Setup the database schema by running migrations
151    pub async fn setup(&self) -> Result<(), LibsqlError> {
152        let conn = self.db.connect()?;
153
154        // Read and execute the migration SQL
155        let migration_sql = include_str!("../migrations/001_initial.sql");
156
157        // Execute the migration as a batch
158        conn.execute_batch(migration_sql)
159            .await
160            .map_err(LibsqlError::Database)?;
161
162        Ok(())
163    }
164
165    /// Change the codec used for serialization/deserialization
166    #[must_use]
167    pub fn with_codec<D>(self) -> LibsqlStorage<T, D> {
168        LibsqlStorage {
169            db: self.db,
170            config: self.config.clone(),
171            job_type: PhantomData,
172            codec: PhantomData,
173            sink: LibsqlSink::new(self.db, &self.config),
174        }
175    }
176}
177
178/// Register a worker in the database
179async fn register_worker(
180    db: &'static Database,
181    worker_id: &str,
182    worker_type: &str,
183) -> Result<(), LibsqlError> {
184    let conn = db.connect()?;
185    conn.execute(REGISTER_WORKER_SQL, libsql::params![worker_id, worker_type])
186        .await
187        .map_err(LibsqlError::Database)?;
188    Ok(())
189}
190
191/// Update worker heartbeat
192async fn keep_alive(db: &'static Database, worker_id: &str) -> Result<(), LibsqlError> {
193    let conn = db.connect()?;
194    conn.execute(KEEP_ALIVE_SQL, libsql::params![worker_id])
195        .await
196        .map_err(LibsqlError::Database)?;
197    Ok(())
198}
199
200/// Re-enqueue orphaned tasks from dead workers
201pub async fn reenqueue_orphaned(
202    db: &'static Database,
203    config: &Config,
204) -> Result<u64, LibsqlError> {
205    let conn = db.connect()?;
206    let dead_for = config.reenqueue_orphaned_after().as_secs() as i64;
207    let queue = config.queue().to_string();
208
209    let rows = conn
210        .execute(REENQUEUE_ORPHANED_SQL, libsql::params![dead_for, queue])
211        .await
212        .map_err(LibsqlError::Database)?;
213
214    if rows > 0 {
215        log::info!("Re-enqueued {} orphaned tasks", rows);
216    }
217
218    Ok(rows)
219}
220
221/// Initial heartbeat: register worker and re-enqueue orphaned tasks
222#[allow(clippy::needless_pass_by_value)]
223async fn initial_heartbeat(
224    db: &'static Database,
225    config: Config,
226    worker: WorkerContext,
227) -> Result<(), LibsqlError> {
228    let worker_id = worker.name().to_string();
229    let worker_type = config.queue().to_string();
230
231    // Re-enqueue orphaned tasks first
232    reenqueue_orphaned(db, &config).await?;
233
234    // Register worker
235    register_worker(db, &worker_id, &worker_type).await?;
236
237    Ok(())
238}
239
240/// Create a heartbeat stream that periodically updates worker status
241#[allow(clippy::needless_pass_by_value)]
242fn heartbeat_stream(
243    db: &'static Database,
244    config: Config,
245    worker: WorkerContext,
246) -> impl Stream<Item = Result<(), LibsqlError>> + Send + 'static {
247    let worker_id = worker.name().to_string();
248    let keep_alive_interval = config.keep_alive();
249
250    futures::stream::unfold((), move |_| {
251        let db = db;
252        let worker_id = worker_id.clone();
253        let interval = keep_alive_interval;
254        let config = config.clone();
255
256        async move {
257            // Wait for the keep-alive interval
258            tokio::time::sleep(interval).await;
259
260            // Update heartbeat
261            if let Err(e) = keep_alive(db, &worker_id).await {
262                return Some((Err(e), ()));
263            }
264
265            // Re-enqueue orphaned tasks periodically
266            if let Err(e) = reenqueue_orphaned(db, &config).await {
267                return Some((Err(e), ()));
268            }
269
270            Some((Ok(()), ()))
271        }
272    })
273}
274
275impl<Args, Decode> Backend for LibsqlStorage<Args, Decode>
276where
277    Args: Send + 'static + Unpin,
278    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
279    Decode::Error: std::error::Error + Send + Sync + 'static,
280{
281    type Args = Args;
282    type IdType = Ulid;
283    type Context = SqlContext;
284    type Error = LibsqlError;
285    type Stream = apalis_core::backend::TaskStream<LibsqlTask<Args>, LibsqlError>;
286    type Beat = BoxStream<'static, Result<(), LibsqlError>>;
287    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<LibsqlAck>>;
288
289    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
290        let db = self.db;
291        let config = self.config.clone();
292        let worker = worker.clone();
293
294        // Start heartbeat stream
295        heartbeat_stream(db, config, worker).boxed()
296    }
297
298    fn middleware(&self) -> Self::Layer {
299        let lock = LockTaskLayer::new(self.db);
300        let ack = AcknowledgeLayer::new(LibsqlAck::new(self.db));
301        Stack::new(lock, ack)
302    }
303
304    fn poll(self, worker: &WorkerContext) -> Self::Stream {
305        let db = self.db;
306        let config = self.config.clone();
307        let worker = worker.clone();
308
309        // Initial registration - create a stream that owns the data
310        let register = futures::stream::once(
311            initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
312        );
313
314        // Polling stream - we need to use a concrete type for the fetcher
315        // Since we're in the Backend impl, we can use the Decode type parameter
316        let fetcher = LibsqlPollFetcher::<Decode>::new(db, &config, &worker);
317
318        // Chain registration with polling, and decode tasks
319        register
320            .chain(fetcher)
321            .map(move |result| match result {
322                Ok(Some(task)) => {
323                    let decoded = task
324                        .try_map(|t| Decode::decode(&t))
325                        .map_err(|e| LibsqlError::Other(e.to_string()))?;
326                    Ok(Some(decoded))
327                }
328                Ok(None) => Ok(None),
329                Err(e) => Err(e),
330            })
331            .boxed()
332    }
333}
334
335impl<Args, Decode> BackendExt for LibsqlStorage<Args, Decode>
336where
337    Args: Send + 'static + Unpin,
338    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
339    Decode::Error: std::error::Error + Send + Sync + 'static,
340{
341    type Codec = Decode;
342    type Compact = CompactType;
343    type CompactStream = apalis_core::backend::TaskStream<LibsqlTask<CompactType>, LibsqlError>;
344
345    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
346        let db = self.db;
347        let config = self.config.clone();
348        let worker = worker.clone();
349
350        // Initial registration
351        let register = futures::stream::once(
352            initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
353        );
354
355        // Polling stream (compact tasks) - use the Decode type parameter
356        let fetcher = LibsqlPollFetcher::<Decode>::new(db, &config, &worker);
357
358        register.chain(fetcher).boxed()
359    }
360}
361
362impl<Args, Decode> LibsqlStorage<Args, Decode>
363where
364    Args: Send + 'static + Unpin,
365    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
366    Decode::Error: std::error::Error + Send + Sync + 'static,
367{
368    /// Poll for tasks using the default polling strategy
369    pub fn poll_default(
370        self,
371        worker: &WorkerContext,
372    ) -> impl Stream<Item = Result<Option<LibsqlTask<CompactType>>, LibsqlError>> + Send + 'static
373    {
374        let db = self.db;
375        let config = self.config.clone();
376        let worker = worker.clone();
377
378        // Initial registration
379        let register = futures::stream::once(
380            initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
381        );
382
383        // Polling stream (compact tasks) - use () as the codec since we want compact tasks
384        let fetcher = LibsqlPollFetcher::<()>::new(db, &config, &worker);
385
386        register.chain(fetcher).boxed()
387    }
388
389    /// Acknowledge a task completion
390    pub async fn ack<Res>(
391        &mut self,
392        task_id: &Ulid,
393        result: Result<Res, BoxDynError>,
394    ) -> Result<(), LibsqlError>
395    where
396        Res: serde::Serialize + Send,
397    {
398        use apalis_core::task::status::Status;
399
400        let task_id_str = task_id.to_string();
401        let response = serde_json::to_string(&result.as_ref().map_err(|e| e.to_string()))
402            .map_err(|e| LibsqlError::Other(e.to_string()))?;
403
404        // First, get the current task information to find the lock_by, attempts, and max_attempts
405        let conn = self.db.connect()?;
406        let mut rows = conn
407            .query(
408                "SELECT lock_by, attempts, max_attempts FROM Jobs WHERE id = ?1",
409                libsql::params![task_id_str.clone()],
410            )
411            .await
412            .map_err(LibsqlError::Database)?;
413
414        let (lock_by, current_attempts, max_attempts) =
415            match rows.next().await.map_err(LibsqlError::Database)? {
416                Some(row) => {
417                    let lock_by: Option<String> = row.get(0).map_err(LibsqlError::Database)?;
418                    let attempts: i64 = row.get(1).map_err(LibsqlError::Database)?;
419                    let max_attempts: i64 = row.get(2).map_err(LibsqlError::Database)?;
420                    (lock_by, attempts as i32, max_attempts as i32)
421                }
422                None => return Err(LibsqlError::Other("Task not found".into())),
423            };
424
425        let status = match &result {
426            Ok(_) => Status::Done,
427            Err(_) => {
428                // Implement proper retry logic based on attempt count
429                // If we've reached max_attempts, mark as Killed, otherwise Failed for retry
430                if current_attempts + 1 >= max_attempts {
431                    Status::Killed
432                } else {
433                    Status::Failed
434                }
435            }
436        };
437        let status_str = status.to_string();
438
439        let worker_id =
440            lock_by.ok_or_else(|| LibsqlError::Other("Task is not locked by any worker".into()))?;
441        let new_attempts = match &result {
442            Ok(_) => current_attempts,      // Don't increment on success
443            Err(_) => current_attempts + 1, // Only increment on failure
444        };
445
446        let rows_affected = conn
447            .execute(
448                "UPDATE Jobs SET status = ?1, attempts = ?2, last_error = ?3, done_at = strftime('%s', 'now') WHERE id = ?4 AND lock_by = ?5",
449                libsql::params![status_str, new_attempts, response, task_id_str, worker_id],
450            )
451            .await
452            .map_err(LibsqlError::Database)?;
453
454        if rows_affected == 0 {
455            return Err(LibsqlError::Other("Task not found or already acked".into()));
456        }
457
458        Ok(())
459    }
460}
461
462/// Enable WAL mode for better concurrency
463///
464/// This should be called after setup() to enable Write-Ahead Logging which
465/// improves performance for concurrent database access.
466pub async fn enable_wal_mode(db: &'static Database) -> Result<(), LibsqlError> {
467    let conn = db.connect()?;
468    conn.query("PRAGMA journal_mode=WAL", libsql::params![])
469        .await
470        .map_err(LibsqlError::Database)?;
471    Ok(())
472}
473
474/// Implementation of Sink for LibsqlStorage to push tasks
475impl<Args, Codec> futures::Sink<LibsqlTask<CompactType>> for LibsqlStorage<Args, Codec>
476where
477    Args: Send + Sync + 'static,
478{
479    type Error = LibsqlError;
480
481    fn poll_ready(
482        self: Pin<&mut Self>,
483        cx: &mut std::task::Context<'_>,
484    ) -> std::task::Poll<Result<(), Self::Error>> {
485        self.project().sink.poll_ready(cx)
486    }
487
488    fn start_send(self: Pin<&mut Self>, item: LibsqlTask<CompactType>) -> Result<(), Self::Error> {
489        self.project().sink.start_send(item)
490    }
491
492    fn poll_flush(
493        self: Pin<&mut Self>,
494        cx: &mut std::task::Context<'_>,
495    ) -> std::task::Poll<Result<(), Self::Error>> {
496        self.project().sink.poll_flush(cx)
497    }
498
499    fn poll_close(
500        self: Pin<&mut Self>,
501        cx: &mut std::task::Context<'_>,
502    ) -> std::task::Poll<Result<(), Self::Error>> {
503        self.project().sink.poll_close(cx)
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510    use tempfile::TempDir;
511
512    #[tokio::test]
513    async fn test_basic_connectivity() -> Result<(), Box<dyn std::error::Error>> {
514        // Setup file-based database for testing (in-memory might have issues)
515        let temp_dir = TempDir::new()?;
516        let db_path = temp_dir.path().join("test.db");
517
518        let db = libsql::Builder::new_local(db_path.to_str().unwrap())
519            .build()
520            .await?;
521        let db_static: &'static Database = Box::leak(Box::new(db));
522
523        // Create storage
524        let storage = LibsqlStorage::<(), ()>::new(db_static);
525
526        // Verify we can access the database
527        let conn = db_static.connect()?;
528        let mut rows = conn.query("SELECT 1", libsql::params![]).await?;
529        let row = rows.next().await?.unwrap();
530        let result: i32 = row.get(0)?;
531        assert_eq!(result, 1);
532
533        // Setup schema
534        storage.setup().await?;
535
536        // Enable WAL mode for better concurrency
537        enable_wal_mode(db_static).await?;
538
539        // Verify tables exist using the same connection
540        let mut rows = conn
541            .query(
542                "SELECT name FROM sqlite_master WHERE type='table' AND name='Jobs'",
543                libsql::params![],
544            )
545            .await?;
546
547        if let Some(row) = rows.next().await? {
548            let name: String = row.get(0)?;
549            assert_eq!(name, "Jobs");
550        } else {
551            panic!("Jobs table should exist after setup");
552        }
553
554        // Clean up
555        drop(conn);
556
557        Ok(())
558    }
559}