1#![doc = include_str!("../README.md")]
2
3use std::{fmt, marker::PhantomData, pin::Pin};
4
5use apalis_core::{
6 backend::{Backend, BackendExt, codec::Codec},
7 error::BoxDynError,
8 layers::Stack,
9 task::Task,
10 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
11};
12pub use apalis_sql::context::SqlContext;
13use futures::{FutureExt, Stream, StreamExt, stream::BoxStream};
14use libsql::Database;
15use pin_project::pin_project;
16use ulid::Ulid;
17
18pub mod ack;
19pub mod config;
21pub mod fetcher;
23pub mod row;
25pub mod sink;
27
28pub use ack::{LibsqlAck, LockTaskLayer, LockTaskService};
29pub use config::Config;
30pub use fetcher::LibsqlPollFetcher;
31pub use sink::LibsqlSink;
32
33pub type LibsqlTask<Args> = Task<Args, SqlContext, Ulid>;
35
36pub type CompactType = Vec<u8>;
38
39#[derive(Debug, thiserror::Error)]
41pub enum LibsqlError {
42 #[error("Database error: {0}")]
44 Database(#[from] libsql::Error),
45 #[error("Other error: {0}")]
47 Other(String),
48}
49
50const REGISTER_WORKER_SQL: &str = r#"
52INSERT OR REPLACE INTO Workers (id, worker_type, storage_name, layers, last_seen)
53VALUES (?1, ?2, 'LibsqlStorage', '', strftime('%s', 'now'))
54"#;
55
56const KEEP_ALIVE_SQL: &str = r#"
58UPDATE Workers SET last_seen = strftime('%s', 'now') WHERE id = ?1
59"#;
60
61const REENQUEUE_ORPHANED_SQL: &str = r#"
63UPDATE Jobs
64SET status = 'Pending', lock_by = NULL, lock_at = NULL
65WHERE status = 'Running' AND lock_by IN (
66 SELECT id FROM Workers WHERE last_seen < strftime('%s', 'now') - ?1
67) AND job_type = ?2
68"#;
69
70#[pin_project]
72pub struct LibsqlStorage<T, C> {
73 db: &'static Database,
74 config: Config,
75 job_type: PhantomData<T>,
76 codec: PhantomData<C>,
77 #[pin]
78 sink: LibsqlSink<T, C>,
79}
80
81impl<T, C> fmt::Debug for LibsqlStorage<T, C> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_struct("LibsqlStorage")
84 .field("db", &"Database")
85 .field("config", &self.config)
86 .field("job_type", &std::any::type_name::<T>())
87 .field("codec", &std::any::type_name::<C>())
88 .finish()
89 }
90}
91
92impl<T, C> Clone for LibsqlStorage<T, C> {
93 fn clone(&self) -> Self {
94 Self {
95 db: self.db,
96 config: self.config.clone(),
97 job_type: PhantomData,
98 codec: PhantomData,
99 sink: self.sink.clone(),
100 }
101 }
102}
103
104impl<T> LibsqlStorage<T, ()> {
105 #[must_use]
107 pub fn new(
108 db: &'static Database,
109 ) -> LibsqlStorage<T, apalis_core::backend::codec::json::JsonCodec<CompactType>> {
110 let config = Config::new(std::any::type_name::<T>());
111 LibsqlStorage {
112 db,
113 config: config.clone(),
114 job_type: PhantomData,
115 codec: PhantomData,
116 sink: LibsqlSink::new(db, &config),
117 }
118 }
119
120 #[must_use]
122 #[allow(clippy::needless_pass_by_value)]
123 pub fn new_with_config(
124 db: &'static Database,
125 config: Config,
126 ) -> LibsqlStorage<T, apalis_core::backend::codec::json::JsonCodec<CompactType>> {
127 LibsqlStorage {
128 db,
129 config: config.clone(),
130 job_type: PhantomData,
131 codec: PhantomData,
132 sink: LibsqlSink::new(db, &config),
133 }
134 }
135}
136
137impl<T, C> LibsqlStorage<T, C> {
138 #[must_use]
140 pub fn db(&self) -> &'static Database {
141 self.db
142 }
143
144 #[must_use]
146 pub fn config(&self) -> &Config {
147 &self.config
148 }
149
150 pub async fn setup(&self) -> Result<(), LibsqlError> {
152 let conn = self.db.connect()?;
153
154 let migration_sql = include_str!("../migrations/001_initial.sql");
156
157 conn.execute_batch(migration_sql)
159 .await
160 .map_err(LibsqlError::Database)?;
161
162 Ok(())
163 }
164
165 #[must_use]
167 pub fn with_codec<D>(self) -> LibsqlStorage<T, D> {
168 LibsqlStorage {
169 db: self.db,
170 config: self.config.clone(),
171 job_type: PhantomData,
172 codec: PhantomData,
173 sink: LibsqlSink::new(self.db, &self.config),
174 }
175 }
176}
177
178async fn register_worker(
180 db: &'static Database,
181 worker_id: &str,
182 worker_type: &str,
183) -> Result<(), LibsqlError> {
184 let conn = db.connect()?;
185 conn.execute(REGISTER_WORKER_SQL, libsql::params![worker_id, worker_type])
186 .await
187 .map_err(LibsqlError::Database)?;
188 Ok(())
189}
190
191async fn keep_alive(db: &'static Database, worker_id: &str) -> Result<(), LibsqlError> {
193 let conn = db.connect()?;
194 conn.execute(KEEP_ALIVE_SQL, libsql::params![worker_id])
195 .await
196 .map_err(LibsqlError::Database)?;
197 Ok(())
198}
199
200pub async fn reenqueue_orphaned(
202 db: &'static Database,
203 config: &Config,
204) -> Result<u64, LibsqlError> {
205 let conn = db.connect()?;
206 let dead_for = config.reenqueue_orphaned_after().as_secs() as i64;
207 let queue = config.queue().to_string();
208
209 let rows = conn
210 .execute(REENQUEUE_ORPHANED_SQL, libsql::params![dead_for, queue])
211 .await
212 .map_err(LibsqlError::Database)?;
213
214 if rows > 0 {
215 log::info!("Re-enqueued {} orphaned tasks", rows);
216 }
217
218 Ok(rows)
219}
220
221#[allow(clippy::needless_pass_by_value)]
223async fn initial_heartbeat(
224 db: &'static Database,
225 config: Config,
226 worker: WorkerContext,
227) -> Result<(), LibsqlError> {
228 let worker_id = worker.name().to_string();
229 let worker_type = config.queue().to_string();
230
231 reenqueue_orphaned(db, &config).await?;
233
234 register_worker(db, &worker_id, &worker_type).await?;
236
237 Ok(())
238}
239
240#[allow(clippy::needless_pass_by_value)]
242fn heartbeat_stream(
243 db: &'static Database,
244 config: Config,
245 worker: WorkerContext,
246) -> impl Stream<Item = Result<(), LibsqlError>> + Send + 'static {
247 let worker_id = worker.name().to_string();
248 let keep_alive_interval = config.keep_alive();
249
250 futures::stream::unfold((), move |_| {
251 let db = db;
252 let worker_id = worker_id.clone();
253 let interval = keep_alive_interval;
254 let config = config.clone();
255
256 async move {
257 tokio::time::sleep(interval).await;
259
260 if let Err(e) = keep_alive(db, &worker_id).await {
262 return Some((Err(e), ()));
263 }
264
265 if let Err(e) = reenqueue_orphaned(db, &config).await {
267 return Some((Err(e), ()));
268 }
269
270 Some((Ok(()), ()))
271 }
272 })
273}
274
275impl<Args, Decode> Backend for LibsqlStorage<Args, Decode>
276where
277 Args: Send + 'static + Unpin,
278 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
279 Decode::Error: std::error::Error + Send + Sync + 'static,
280{
281 type Args = Args;
282 type IdType = Ulid;
283 type Context = SqlContext;
284 type Error = LibsqlError;
285 type Stream = apalis_core::backend::TaskStream<LibsqlTask<Args>, LibsqlError>;
286 type Beat = BoxStream<'static, Result<(), LibsqlError>>;
287 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<LibsqlAck>>;
288
289 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
290 let db = self.db;
291 let config = self.config.clone();
292 let worker = worker.clone();
293
294 heartbeat_stream(db, config, worker).boxed()
296 }
297
298 fn middleware(&self) -> Self::Layer {
299 let lock = LockTaskLayer::new(self.db);
300 let ack = AcknowledgeLayer::new(LibsqlAck::new(self.db));
301 Stack::new(lock, ack)
302 }
303
304 fn poll(self, worker: &WorkerContext) -> Self::Stream {
305 let db = self.db;
306 let config = self.config.clone();
307 let worker = worker.clone();
308
309 let register = futures::stream::once(
311 initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
312 );
313
314 let fetcher = LibsqlPollFetcher::<Decode>::new(db, &config, &worker);
317
318 register
320 .chain(fetcher)
321 .map(move |result| match result {
322 Ok(Some(task)) => {
323 let decoded = task
324 .try_map(|t| Decode::decode(&t))
325 .map_err(|e| LibsqlError::Other(e.to_string()))?;
326 Ok(Some(decoded))
327 }
328 Ok(None) => Ok(None),
329 Err(e) => Err(e),
330 })
331 .boxed()
332 }
333}
334
335impl<Args, Decode> BackendExt for LibsqlStorage<Args, Decode>
336where
337 Args: Send + 'static + Unpin,
338 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
339 Decode::Error: std::error::Error + Send + Sync + 'static,
340{
341 type Codec = Decode;
342 type Compact = CompactType;
343 type CompactStream = apalis_core::backend::TaskStream<LibsqlTask<CompactType>, LibsqlError>;
344
345 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
346 let db = self.db;
347 let config = self.config.clone();
348 let worker = worker.clone();
349
350 let register = futures::stream::once(
352 initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
353 );
354
355 let fetcher = LibsqlPollFetcher::<Decode>::new(db, &config, &worker);
357
358 register.chain(fetcher).boxed()
359 }
360}
361
362impl<Args, Decode> LibsqlStorage<Args, Decode>
363where
364 Args: Send + 'static + Unpin,
365 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
366 Decode::Error: std::error::Error + Send + Sync + 'static,
367{
368 pub fn poll_default(
370 self,
371 worker: &WorkerContext,
372 ) -> impl Stream<Item = Result<Option<LibsqlTask<CompactType>>, LibsqlError>> + Send + 'static
373 {
374 let db = self.db;
375 let config = self.config.clone();
376 let worker = worker.clone();
377
378 let register = futures::stream::once(
380 initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
381 );
382
383 let fetcher = LibsqlPollFetcher::<()>::new(db, &config, &worker);
385
386 register.chain(fetcher).boxed()
387 }
388
389 pub async fn ack<Res>(
391 &mut self,
392 task_id: &Ulid,
393 result: Result<Res, BoxDynError>,
394 ) -> Result<(), LibsqlError>
395 where
396 Res: serde::Serialize + Send,
397 {
398 use apalis_core::task::status::Status;
399
400 let task_id_str = task_id.to_string();
401 let response = serde_json::to_string(&result.as_ref().map_err(|e| e.to_string()))
402 .map_err(|e| LibsqlError::Other(e.to_string()))?;
403
404 let conn = self.db.connect()?;
406 let mut rows = conn
407 .query(
408 "SELECT lock_by, attempts, max_attempts FROM Jobs WHERE id = ?1",
409 libsql::params![task_id_str.clone()],
410 )
411 .await
412 .map_err(LibsqlError::Database)?;
413
414 let (lock_by, current_attempts, max_attempts) =
415 match rows.next().await.map_err(LibsqlError::Database)? {
416 Some(row) => {
417 let lock_by: Option<String> = row.get(0).map_err(LibsqlError::Database)?;
418 let attempts: i64 = row.get(1).map_err(LibsqlError::Database)?;
419 let max_attempts: i64 = row.get(2).map_err(LibsqlError::Database)?;
420 (lock_by, attempts as i32, max_attempts as i32)
421 }
422 None => return Err(LibsqlError::Other("Task not found".into())),
423 };
424
425 let status = match &result {
426 Ok(_) => Status::Done,
427 Err(_) => {
428 if current_attempts + 1 >= max_attempts {
431 Status::Killed
432 } else {
433 Status::Failed
434 }
435 }
436 };
437 let status_str = status.to_string();
438
439 let worker_id =
440 lock_by.ok_or_else(|| LibsqlError::Other("Task is not locked by any worker".into()))?;
441 let new_attempts = match &result {
442 Ok(_) => current_attempts, Err(_) => current_attempts + 1, };
445
446 let rows_affected = conn
447 .execute(
448 "UPDATE Jobs SET status = ?1, attempts = ?2, last_error = ?3, done_at = strftime('%s', 'now') WHERE id = ?4 AND lock_by = ?5",
449 libsql::params![status_str, new_attempts, response, task_id_str, worker_id],
450 )
451 .await
452 .map_err(LibsqlError::Database)?;
453
454 if rows_affected == 0 {
455 return Err(LibsqlError::Other("Task not found or already acked".into()));
456 }
457
458 Ok(())
459 }
460}
461
462pub async fn enable_wal_mode(db: &'static Database) -> Result<(), LibsqlError> {
467 let conn = db.connect()?;
468 conn.query("PRAGMA journal_mode=WAL", libsql::params![])
469 .await
470 .map_err(LibsqlError::Database)?;
471 Ok(())
472}
473
474impl<Args, Codec> futures::Sink<LibsqlTask<CompactType>> for LibsqlStorage<Args, Codec>
476where
477 Args: Send + Sync + 'static,
478{
479 type Error = LibsqlError;
480
481 fn poll_ready(
482 self: Pin<&mut Self>,
483 cx: &mut std::task::Context<'_>,
484 ) -> std::task::Poll<Result<(), Self::Error>> {
485 self.project().sink.poll_ready(cx)
486 }
487
488 fn start_send(self: Pin<&mut Self>, item: LibsqlTask<CompactType>) -> Result<(), Self::Error> {
489 self.project().sink.start_send(item)
490 }
491
492 fn poll_flush(
493 self: Pin<&mut Self>,
494 cx: &mut std::task::Context<'_>,
495 ) -> std::task::Poll<Result<(), Self::Error>> {
496 self.project().sink.poll_flush(cx)
497 }
498
499 fn poll_close(
500 self: Pin<&mut Self>,
501 cx: &mut std::task::Context<'_>,
502 ) -> std::task::Poll<Result<(), Self::Error>> {
503 self.project().sink.poll_close(cx)
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use tempfile::TempDir;
511
512 #[tokio::test]
513 async fn test_basic_connectivity() -> Result<(), Box<dyn std::error::Error>> {
514 let temp_dir = TempDir::new()?;
516 let db_path = temp_dir.path().join("test.db");
517
518 let db = libsql::Builder::new_local(db_path.to_str().unwrap())
519 .build()
520 .await?;
521 let db_static: &'static Database = Box::leak(Box::new(db));
522
523 let storage = LibsqlStorage::<(), ()>::new(db_static);
525
526 let conn = db_static.connect()?;
528 let mut rows = conn.query("SELECT 1", libsql::params![]).await?;
529 let row = rows.next().await?.unwrap();
530 let result: i32 = row.get(0)?;
531 assert_eq!(result, 1);
532
533 storage.setup().await?;
535
536 enable_wal_mode(db_static).await?;
538
539 let mut rows = conn
541 .query(
542 "SELECT name FROM sqlite_master WHERE type='table' AND name='Jobs'",
543 libsql::params![],
544 )
545 .await?;
546
547 if let Some(row) = rows.next().await? {
548 let name: String = row.get(0)?;
549 assert_eq!(name, "Jobs");
550 } else {
551 panic!("Jobs table should exist after setup");
552 }
553
554 drop(conn);
556
557 Ok(())
558 }
559}