#![doc = include_str!("../README.md")]
use std::{fmt, marker::PhantomData, pin::Pin};
use apalis_core::{
backend::{Backend, BackendExt, codec::Codec},
error::BoxDynError,
layers::Stack,
task::Task,
worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
};
pub use apalis_sql::context::SqlContext;
use futures::{FutureExt, Stream, StreamExt, stream::BoxStream};
use libsql::Database;
use pin_project::pin_project;
use ulid::Ulid;
pub mod ack;
pub mod config;
pub mod fetcher;
pub mod row;
pub mod sink;
pub use ack::{LibsqlAck, LockTaskLayer, LockTaskService};
pub use config::Config;
pub use fetcher::LibsqlPollFetcher;
pub use sink::LibsqlSink;
pub type LibsqlTask<Args> = Task<Args, SqlContext, Ulid>;
pub type CompactType = Vec<u8>;
#[derive(Debug, thiserror::Error)]
pub enum LibsqlError {
#[error("Database error: {0}")]
Database(#[from] libsql::Error),
#[error("Other error: {0}")]
Other(String),
}
const REGISTER_WORKER_SQL: &str = r#"
INSERT OR REPLACE INTO Workers (id, worker_type, storage_name, layers, last_seen)
VALUES (?1, ?2, 'LibsqlStorage', '', strftime('%s', 'now'))
"#;
const KEEP_ALIVE_SQL: &str = r#"
UPDATE Workers SET last_seen = strftime('%s', 'now') WHERE id = ?1
"#;
const REENQUEUE_ORPHANED_SQL: &str = r#"
UPDATE Jobs
SET status = 'Pending', lock_by = NULL, lock_at = NULL
WHERE status = 'Running' AND lock_by IN (
SELECT id FROM Workers WHERE last_seen < strftime('%s', 'now') - ?1
) AND job_type = ?2
"#;
#[pin_project]
pub struct LibsqlStorage<T, C> {
db: &'static Database,
config: Config,
job_type: PhantomData<T>,
codec: PhantomData<C>,
#[pin]
sink: LibsqlSink<T, C>,
}
impl<T, C> fmt::Debug for LibsqlStorage<T, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LibsqlStorage")
.field("db", &"Database")
.field("config", &self.config)
.field("job_type", &std::any::type_name::<T>())
.field("codec", &std::any::type_name::<C>())
.finish()
}
}
impl<T, C> Clone for LibsqlStorage<T, C> {
fn clone(&self) -> Self {
Self {
db: self.db,
config: self.config.clone(),
job_type: PhantomData,
codec: PhantomData,
sink: self.sink.clone(),
}
}
}
impl<T> LibsqlStorage<T, ()> {
#[must_use]
pub fn new(
db: &'static Database,
) -> LibsqlStorage<T, apalis_core::backend::codec::json::JsonCodec<CompactType>> {
let config = Config::new(std::any::type_name::<T>());
LibsqlStorage {
db,
config: config.clone(),
job_type: PhantomData,
codec: PhantomData,
sink: LibsqlSink::new(db, &config),
}
}
#[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn new_with_config(
db: &'static Database,
config: Config,
) -> LibsqlStorage<T, apalis_core::backend::codec::json::JsonCodec<CompactType>> {
LibsqlStorage {
db,
config: config.clone(),
job_type: PhantomData,
codec: PhantomData,
sink: LibsqlSink::new(db, &config),
}
}
}
impl<T, C> LibsqlStorage<T, C> {
#[must_use]
pub fn db(&self) -> &'static Database {
self.db
}
#[must_use]
pub fn config(&self) -> &Config {
&self.config
}
pub async fn setup(&self) -> Result<(), LibsqlError> {
let conn = self.db.connect()?;
let migration_sql = include_str!("../migrations/001_initial.sql");
conn.execute_batch(migration_sql)
.await
.map_err(LibsqlError::Database)?;
Ok(())
}
#[must_use]
pub fn with_codec<D>(self) -> LibsqlStorage<T, D> {
LibsqlStorage {
db: self.db,
config: self.config.clone(),
job_type: PhantomData,
codec: PhantomData,
sink: LibsqlSink::new(self.db, &self.config),
}
}
}
async fn register_worker(
db: &'static Database,
worker_id: &str,
worker_type: &str,
) -> Result<(), LibsqlError> {
let conn = db.connect()?;
conn.execute(REGISTER_WORKER_SQL, libsql::params![worker_id, worker_type])
.await
.map_err(LibsqlError::Database)?;
Ok(())
}
async fn keep_alive(db: &'static Database, worker_id: &str) -> Result<(), LibsqlError> {
let conn = db.connect()?;
conn.execute(KEEP_ALIVE_SQL, libsql::params![worker_id])
.await
.map_err(LibsqlError::Database)?;
Ok(())
}
pub async fn reenqueue_orphaned(
db: &'static Database,
config: &Config,
) -> Result<u64, LibsqlError> {
let conn = db.connect()?;
let dead_for = config.reenqueue_orphaned_after().as_secs() as i64;
let queue = config.queue().to_string();
let rows = conn
.execute(REENQUEUE_ORPHANED_SQL, libsql::params![dead_for, queue])
.await
.map_err(LibsqlError::Database)?;
if rows > 0 {
log::info!("Re-enqueued {} orphaned tasks", rows);
}
Ok(rows)
}
#[allow(clippy::needless_pass_by_value)]
async fn initial_heartbeat(
db: &'static Database,
config: Config,
worker: WorkerContext,
) -> Result<(), LibsqlError> {
let worker_id = worker.name().to_string();
let worker_type = config.queue().to_string();
reenqueue_orphaned(db, &config).await?;
register_worker(db, &worker_id, &worker_type).await?;
Ok(())
}
#[allow(clippy::needless_pass_by_value)]
fn heartbeat_stream(
db: &'static Database,
config: Config,
worker: WorkerContext,
) -> impl Stream<Item = Result<(), LibsqlError>> + Send + 'static {
let worker_id = worker.name().to_string();
let keep_alive_interval = config.keep_alive();
futures::stream::unfold((), move |_| {
let db = db;
let worker_id = worker_id.clone();
let interval = keep_alive_interval;
let config = config.clone();
async move {
tokio::time::sleep(interval).await;
if let Err(e) = keep_alive(db, &worker_id).await {
return Some((Err(e), ()));
}
if let Err(e) = reenqueue_orphaned(db, &config).await {
return Some((Err(e), ()));
}
Some((Ok(()), ()))
}
})
}
impl<Args, Decode> Backend for LibsqlStorage<Args, Decode>
where
Args: Send + 'static + Unpin,
Decode: Codec<Args, Compact = CompactType> + 'static + Send,
Decode::Error: std::error::Error + Send + Sync + 'static,
{
type Args = Args;
type IdType = Ulid;
type Context = SqlContext;
type Error = LibsqlError;
type Stream = apalis_core::backend::TaskStream<LibsqlTask<Args>, LibsqlError>;
type Beat = BoxStream<'static, Result<(), LibsqlError>>;
type Layer = Stack<LockTaskLayer, AcknowledgeLayer<LibsqlAck>>;
fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
let db = self.db;
let config = self.config.clone();
let worker = worker.clone();
heartbeat_stream(db, config, worker).boxed()
}
fn middleware(&self) -> Self::Layer {
let lock = LockTaskLayer::new(self.db);
let ack = AcknowledgeLayer::new(LibsqlAck::new(self.db));
Stack::new(lock, ack)
}
fn poll(self, worker: &WorkerContext) -> Self::Stream {
let db = self.db;
let config = self.config.clone();
let worker = worker.clone();
let register = futures::stream::once(
initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
);
let fetcher = LibsqlPollFetcher::<Decode>::new(db, &config, &worker);
register
.chain(fetcher)
.map(move |result| match result {
Ok(Some(task)) => {
let decoded = task
.try_map(|t| Decode::decode(&t))
.map_err(|e| LibsqlError::Other(e.to_string()))?;
Ok(Some(decoded))
}
Ok(None) => Ok(None),
Err(e) => Err(e),
})
.boxed()
}
}
impl<Args, Decode> BackendExt for LibsqlStorage<Args, Decode>
where
Args: Send + 'static + Unpin,
Decode: Codec<Args, Compact = CompactType> + 'static + Send,
Decode::Error: std::error::Error + Send + Sync + 'static,
{
type Codec = Decode;
type Compact = CompactType;
type CompactStream = apalis_core::backend::TaskStream<LibsqlTask<CompactType>, LibsqlError>;
fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
let db = self.db;
let config = self.config.clone();
let worker = worker.clone();
let register = futures::stream::once(
initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
);
let fetcher = LibsqlPollFetcher::<Decode>::new(db, &config, &worker);
register.chain(fetcher).boxed()
}
}
impl<Args, Decode> LibsqlStorage<Args, Decode>
where
Args: Send + 'static + Unpin,
Decode: Codec<Args, Compact = CompactType> + 'static + Send,
Decode::Error: std::error::Error + Send + Sync + 'static,
{
pub fn poll_default(
self,
worker: &WorkerContext,
) -> impl Stream<Item = Result<Option<LibsqlTask<CompactType>>, LibsqlError>> + Send + 'static
{
let db = self.db;
let config = self.config.clone();
let worker = worker.clone();
let register = futures::stream::once(
initial_heartbeat(db, config.clone(), worker.clone()).map(|res| res.map(|_| None)),
);
let fetcher = LibsqlPollFetcher::<()>::new(db, &config, &worker);
register.chain(fetcher).boxed()
}
pub async fn ack<Res>(
&mut self,
task_id: &Ulid,
result: Result<Res, BoxDynError>,
) -> Result<(), LibsqlError>
where
Res: serde::Serialize + Send,
{
use apalis_core::task::status::Status;
let task_id_str = task_id.to_string();
let response = serde_json::to_string(&result.as_ref().map_err(|e| e.to_string()))
.map_err(|e| LibsqlError::Other(e.to_string()))?;
let conn = self.db.connect()?;
let mut rows = conn
.query(
"SELECT lock_by, attempts, max_attempts FROM Jobs WHERE id = ?1",
libsql::params![task_id_str.clone()],
)
.await
.map_err(LibsqlError::Database)?;
let (lock_by, current_attempts, max_attempts) =
match rows.next().await.map_err(LibsqlError::Database)? {
Some(row) => {
let lock_by: Option<String> = row.get(0).map_err(LibsqlError::Database)?;
let attempts: i64 = row.get(1).map_err(LibsqlError::Database)?;
let max_attempts: i64 = row.get(2).map_err(LibsqlError::Database)?;
(lock_by, attempts as i32, max_attempts as i32)
}
None => return Err(LibsqlError::Other("Task not found".into())),
};
let status = match &result {
Ok(_) => Status::Done,
Err(_) => {
if current_attempts + 1 >= max_attempts {
Status::Killed
} else {
Status::Failed
}
}
};
let status_str = status.to_string();
let worker_id =
lock_by.ok_or_else(|| LibsqlError::Other("Task is not locked by any worker".into()))?;
let new_attempts = match &result {
Ok(_) => current_attempts, Err(_) => current_attempts + 1, };
let rows_affected = conn
.execute(
"UPDATE Jobs SET status = ?1, attempts = ?2, last_error = ?3, done_at = strftime('%s', 'now') WHERE id = ?4 AND lock_by = ?5",
libsql::params![status_str, new_attempts, response, task_id_str, worker_id],
)
.await
.map_err(LibsqlError::Database)?;
if rows_affected == 0 {
return Err(LibsqlError::Other("Task not found or already acked".into()));
}
Ok(())
}
}
pub async fn enable_wal_mode(db: &'static Database) -> Result<(), LibsqlError> {
let conn = db.connect()?;
conn.query("PRAGMA journal_mode=WAL", libsql::params![])
.await
.map_err(LibsqlError::Database)?;
Ok(())
}
impl<Args, Codec> futures::Sink<LibsqlTask<CompactType>> for LibsqlStorage<Args, Codec>
where
Args: Send + Sync + 'static,
{
type Error = LibsqlError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.project().sink.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: LibsqlTask<CompactType>) -> Result<(), Self::Error> {
self.project().sink.start_send(item)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.project().sink.poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.project().sink.poll_close(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_basic_connectivity() -> Result<(), Box<dyn std::error::Error>> {
let temp_dir = TempDir::new()?;
let db_path = temp_dir.path().join("test.db");
let db = libsql::Builder::new_local(db_path.to_str().unwrap())
.build()
.await?;
let db_static: &'static Database = Box::leak(Box::new(db));
let storage = LibsqlStorage::<(), ()>::new(db_static);
let conn = db_static.connect()?;
let mut rows = conn.query("SELECT 1", libsql::params![]).await?;
let row = rows.next().await?.unwrap();
let result: i32 = row.get(0)?;
assert_eq!(result, 1);
storage.setup().await?;
enable_wal_mode(db_static).await?;
let mut rows = conn
.query(
"SELECT name FROM sqlite_master WHERE type='table' AND name='Jobs'",
libsql::params![],
)
.await?;
if let Some(row) = rows.next().await? {
let name: String = row.get(0)?;
assert_eq!(name, "Jobs");
} else {
panic!("Jobs table should exist after setup");
}
drop(conn);
Ok(())
}
}