use crate::context::SqlContext;
use crate::{calculate_status, Config};
use apalis_core::codec::json::JsonCodec;
use apalis_core::error::Error;
use apalis_core::layers::{Ack, AckLayer};
use apalis_core::notify::Notify;
use apalis_core::poller::controller::Controller;
use apalis_core::poller::stream::BackendStream;
use apalis_core::poller::Poller;
use apalis_core::request::{Request, RequestStream};
use apalis_core::storage::Storage;
use apalis_core::task::namespace::Namespace;
use apalis_core::task::task_id::TaskId;
use apalis_core::worker::WorkerId;
use apalis_core::{Backend, Codec};
use chrono::{DateTime, Utc};
use futures::channel::mpsc;
use futures::StreamExt;
use futures::{select, stream, SinkExt};
use log::error;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use sqlx::postgres::PgListener;
use sqlx::{Pool, Postgres, Row};
use std::any::type_name;
use std::convert::TryInto;
use std::fmt::Debug;
use std::sync::Arc;
use std::{fmt, io};
use std::{marker::PhantomData, time::Duration};
type Timestamp = i64;
pub use sqlx::postgres::PgPool;
use crate::from_row::SqlRequest;
pub struct PostgresStorage<T, C = JsonCodec<serde_json::Value>>
where
C: Codec,
{
pool: PgPool,
job_type: PhantomData<T>,
codec: PhantomData<C>,
config: Config,
controller: Controller,
ack_notify: Notify<(SqlContext, Result<C::Compact, Error>)>,
subscription: Option<PgSubscription>,
}
impl<T, C: Codec> Clone for PostgresStorage<T, C> {
fn clone(&self) -> Self {
PostgresStorage {
pool: self.pool.clone(),
job_type: PhantomData,
codec: PhantomData,
config: self.config.clone(),
controller: self.controller.clone(),
ack_notify: self.ack_notify.clone(),
subscription: self.subscription.clone(),
}
}
}
impl<T, C: Codec> fmt::Debug for PostgresStorage<T, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PostgresStorage")
.field("pool", &self.pool)
.field("job_type", &"PhantomData<T>")
.field("controller", &self.controller)
.field("config", &self.config)
.field("codec", &std::any::type_name::<C>())
.finish()
}
}
impl<T, C, Res> Backend<Request<T>, Res> for PostgresStorage<T, C>
where
T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
C: Codec<Compact = serde_json::Value> + Send + 'static,
{
type Stream = BackendStream<RequestStream<Request<T>>>;
type Layer = AckLayer<PostgresStorage<T, C>, T, Res>;
fn poll<Svc>(mut self, worker: WorkerId) -> Poller<Self::Stream, Self::Layer> {
let layer = AckLayer::new(self.clone());
let subscription = self.subscription.clone();
let config = self.config.clone();
let controller = self.controller.clone();
let (mut tx, rx) = mpsc::channel(self.config.buffer_size);
let ack_notify = self.ack_notify.clone();
let pool = self.pool.clone();
let heartbeat = async move {
let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse();
let mut ack_stream = ack_notify.clone().ready_chunks(config.buffer_size).fuse();
let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse();
let mut pg_notification = subscription
.map(|stm| stm.notify.boxed().fuse())
.unwrap_or(stream::iter(vec![]).boxed().fuse());
async fn fetch_next_batch<
T: Unpin + DeserializeOwned + Send + 'static,
C: Codec<Compact = Value>,
>(
storage: &mut PostgresStorage<T, C>,
worker: &WorkerId,
tx: &mut mpsc::Sender<Result<Option<Request<T>>, Error>>,
) -> Result<(), Error> {
let res = storage
.fetch_next(worker)
.await
.map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
for job in res {
tx.send(Ok(Some(job)))
.await
.map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
}
Ok(())
}
if let Err(e) = self
.keep_alive_at::<Self::Layer>(&worker, Utc::now().timestamp())
.await
{
error!("KeepAliveError: {}", e);
}
loop {
select! {
_ = keep_alive_stm.next() => {
if let Err(e) = self.keep_alive_at::<Self::Layer>(&worker, Utc::now().timestamp()).await {
error!("KeepAliveError: {}", e);
}
}
ids = ack_stream.next() => {
if let Some(ids) = ids {
let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(ctx, res)| {
(ctx.id().to_string(), ctx.lock_by().clone().unwrap().to_string(), serde_json::to_string(&res.as_ref().map_err(|e| e.to_string())).unwrap(), calculate_status(res).to_string(), (ctx.attempts().current() + 1) as u64 )
}).collect();
let query =
"UPDATE apalis.jobs
SET status = Q.status,
done_at = now(),
lock_by = Q.worker_id,
last_error = Q.result,
attempts = Q.attempts
FROM (
SELECT (value->>0)::text as id,
(value->>1)::text as worker_id,
(value->>2)::text as result,
(value->>3)::text as status,
(value->>4)::int as attempts
FROM json_array_elements($1::json)
) Q
WHERE apalis.jobs.id = Q.id;
";
if let Err(e) = sqlx::query(query)
.bind(serde_json::to_value(&ack_ids).unwrap())
.execute(&pool)
.await
{
panic!("AckError: {e}");
}
}
}
_ = poll_next_stm.next() => {
if let Err(e) = fetch_next_batch(&mut self, &worker, &mut tx).await {
error!("FetchNextError: {e}");
}
}
_ = pg_notification.next() => {
if let Err(e) = fetch_next_batch(&mut self, &worker, &mut tx).await {
error!("PgNotificationError: {e}");
}
}
};
}
};
Poller::new_with_layer(BackendStream::new(rx.boxed(), controller), heartbeat, layer)
}
}
impl PostgresStorage<()> {
#[cfg(feature = "migrate")]
pub fn migrations() -> sqlx::migrate::Migrator {
sqlx::migrate!("migrations/postgres")
}
#[cfg(feature = "migrate")]
pub async fn setup(pool: &Pool<Postgres>) -> Result<(), sqlx::Error> {
Self::migrations().run(pool).await?;
Ok(())
}
}
impl<T> PostgresStorage<T> {
pub fn new(pool: PgPool) -> Self {
Self::new_with_config(pool, Config::new(type_name::<T>()))
}
pub fn new_with_config(pool: PgPool, config: Config) -> Self {
Self {
pool,
job_type: PhantomData,
codec: PhantomData,
config,
controller: Controller::new(),
ack_notify: Notify::new(),
subscription: None,
}
}
pub fn pool(&self) -> &Pool<Postgres> {
&self.pool
}
pub fn config(&self) -> &Config {
&self.config
}
}
impl<T, C: Codec> PostgresStorage<T, C> {
pub fn codec(&self) -> &PhantomData<C> {
&self.codec
}
async fn keep_alive_at<Service>(
&mut self,
worker_id: &WorkerId,
last_seen: Timestamp,
) -> Result<(), sqlx::Error> {
let last_seen = DateTime::from_timestamp(last_seen, 0).ok_or(sqlx::Error::Io(
io::Error::new(io::ErrorKind::InvalidInput, "Invalid Timestamp"),
))?;
let worker_type = self.config.namespace.clone();
let storage_name = std::any::type_name::<Self>();
let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (id) DO
UPDATE SET last_seen = EXCLUDED.last_seen";
sqlx::query(query)
.bind(worker_id.to_string())
.bind(worker_type)
.bind(storage_name)
.bind(std::any::type_name::<Service>())
.bind(last_seen)
.execute(&self.pool)
.await?;
Ok(())
}
}
#[derive(Debug)]
pub struct PgListen {
listener: PgListener,
subscriptions: Vec<(String, PgSubscription)>,
}
#[derive(Debug, Clone)]
pub struct PgSubscription {
notify: Notify<()>,
}
impl PgListen {
pub async fn new(pool: PgPool) -> Result<Self, sqlx::Error> {
let listener = PgListener::connect_with(&pool).await?;
Ok(Self {
listener,
subscriptions: Vec::new(),
})
}
pub fn subscribe_with<T>(&mut self, storage: &mut PostgresStorage<T>) {
let sub = PgSubscription {
notify: Notify::new(),
};
self.subscriptions
.push((storage.config.namespace.to_owned(), sub.clone()));
storage.subscription = Some(sub)
}
pub fn subscribe(&mut self, namespace: &str) -> PgSubscription {
let sub = PgSubscription {
notify: Notify::new(),
};
self.subscriptions.push((namespace.to_owned(), sub.clone()));
sub
}
pub async fn listen(mut self) -> Result<(), sqlx::Error> {
self.listener.listen("apalis::job").await?;
let mut notification = self.listener.into_stream();
while let Some(Ok(res)) = notification.next().await {
let _: Vec<_> = self
.subscriptions
.iter()
.filter(|s| s.0 == res.payload())
.map(|s| s.1.notify.notify(()))
.collect();
}
Ok(())
}
}
impl<T, C> PostgresStorage<T, C>
where
T: DeserializeOwned + Send + Unpin + 'static,
C: Codec<Compact = serde_json::Value>,
{
async fn fetch_next(&mut self, worker_id: &WorkerId) -> Result<Vec<Request<T>>, sqlx::Error> {
let config = &self.config;
let job_type = &config.namespace;
let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);";
let jobs: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
.bind(worker_id.to_string())
.bind(job_type)
.bind(
i32::try_from(config.buffer_size)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
)
.fetch_all(&self.pool)
.await?;
let jobs: Vec<_> = jobs
.into_iter()
.map(|job| {
let (req, ctx) = job.into_tuple();
let req = C::decode(req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
.unwrap();
let req = SqlRequest::new(req, ctx);
let mut req: Request<T> = req.into();
req.insert(Namespace(self.config.namespace.clone()));
req
})
.collect();
Ok(jobs)
}
}
impl<T, C> Storage for PostgresStorage<T, C>
where
T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
C: Codec<Compact = Value> + Send + 'static,
{
type Job = T;
type Error = sqlx::Error;
type Identifier = TaskId;
async fn push(&mut self, job: Self::Job) -> Result<TaskId, sqlx::Error> {
let id = TaskId::new();
let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)";
let job = C::encode(&job)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
let job_type = self.config.namespace.clone();
sqlx::query(query)
.bind(job)
.bind(id.to_string())
.bind(&job_type)
.execute(&self.pool)
.await?;
Ok(id)
}
async fn schedule(&mut self, job: Self::Job, on: Timestamp) -> Result<TaskId, sqlx::Error> {
let query =
"INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)";
let id = TaskId::new();
let on = DateTime::from_timestamp(on, 0);
let job = C::encode(&job)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
let job_type = self.config.namespace.clone();
sqlx::query(query)
.bind(job)
.bind(id.to_string())
.bind(job_type)
.bind(on)
.execute(&self.pool)
.await?;
Ok(id)
}
async fn fetch_by_id(
&mut self,
job_id: &TaskId,
) -> Result<Option<Request<Self::Job>>, sqlx::Error> {
let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1";
let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
.bind(job_id.to_string())
.fetch_optional(&self.pool)
.await?;
match res {
None => Ok(None),
Some(job) => Ok(Some({
let (req, ctx) = job.into_tuple();
let req = C::decode(req)
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
let req = SqlRequest::new(req, ctx);
let mut req: Request<T> = req.into();
req.insert(Namespace(self.config.namespace.clone()));
req
})),
}
}
async fn len(&mut self) -> Result<i64, sqlx::Error> {
let query = "Select Count(*) as count from apalis.jobs where status='Pending'";
let record = sqlx::query(query).fetch_one(&self.pool).await?;
record.try_get("count")
}
async fn reschedule(&mut self, job: Request<T>, wait: Duration) -> Result<(), sqlx::Error> {
let ctx = job
.get::<SqlContext>()
.ok_or(sqlx::Error::Io(io::Error::new(
io::ErrorKind::InvalidData,
"Missing SqlContext",
)))?;
let job_id = ctx.id();
let on = Utc::now() + wait;
let mut tx = self.pool.acquire().await?;
let query =
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
sqlx::query(query)
.bind(job_id.to_string())
.bind(on)
.execute(&mut *tx)
.await?;
Ok(())
}
async fn update(&mut self, job: Request<Self::Job>) -> Result<(), sqlx::Error> {
let ctx = job
.get::<SqlContext>()
.ok_or(sqlx::Error::Io(io::Error::new(
io::ErrorKind::InvalidData,
"Missing SqlContext",
)))?;
let job_id = ctx.id();
let status = ctx.status().to_string();
let attempts: i32 = ctx
.attempts()
.current()
.try_into()
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
let done_at = *ctx.done_at();
let lock_by = ctx.lock_by().clone();
let lock_at = *ctx.lock_at();
let last_error = ctx.last_error().clone();
let mut tx = self.pool.acquire().await?;
let query =
"UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7";
sqlx::query(query)
.bind(status.to_owned())
.bind(attempts)
.bind(done_at)
.bind(lock_by.map(|w| w.name().to_string()))
.bind(lock_at)
.bind(last_error)
.bind(job_id.to_string())
.execute(&mut *tx)
.await?;
Ok(())
}
async fn is_empty(&mut self) -> Result<bool, sqlx::Error> {
Ok(self.len().await? == 0)
}
async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
let query = "Delete from apalis.jobs where status='Done'";
let record = sqlx::query(query).execute(&self.pool).await?;
Ok(record.rows_affected().try_into().unwrap_or_default())
}
}
impl<T, Res, C> Ack<T, Res> for PostgresStorage<T, C>
where
T: Sync + Send,
Res: Serialize + Sync,
C: Codec<Compact = Value> + Send,
{
type Context = SqlContext;
type AckError = sqlx::Error;
async fn ack(
&mut self,
ctx: &Self::Context,
res: &Result<Res, apalis_core::error::Error>,
) -> Result<(), sqlx::Error> {
self.ack_notify
.notify((
ctx.clone(),
res.as_ref()
.map(|r| {
C::encode(r)
.map_err(|e| {
sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e))
})
.unwrap()
})
.map_err(|e| e.clone()),
))
.map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))?;
Ok(())
}
}
impl<T, C: Codec> PostgresStorage<T, C> {
pub async fn kill(
&mut self,
worker_id: &WorkerId,
task_id: &TaskId,
) -> Result<(), sqlx::Error> {
let mut tx = self.pool.acquire().await?;
let query =
"UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(task_id.to_string())
.bind(worker_id.to_string())
.execute(&mut *tx)
.await?;
Ok(())
}
pub async fn retry(
&mut self,
worker_id: &WorkerId,
task_id: &TaskId,
) -> Result<(), sqlx::Error> {
let mut tx = self.pool.acquire().await?;
let query =
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(task_id.to_string())
.bind(worker_id.to_string())
.execute(&mut *tx)
.await?;
Ok(())
}
pub async fn reenqueue_orphaned(&mut self, count: i32) -> Result<(), sqlx::Error> {
let job_type = self.config.namespace.clone();
let mut tx = self.pool.acquire().await?;
let query = "Update apalis.jobs
SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned'
WHERE id in
(SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id
WHERE status= 'Running' AND workers.last_seen < (NOW() - INTERVAL '300 seconds')
AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);";
sqlx::query(query)
.bind(job_type)
.bind(count)
.execute(&mut *tx)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::context::State;
use crate::sql_storage_tests;
use super::*;
use apalis_core::test_utils::DummyService;
use chrono::Utc;
use email_service::Email;
use apalis_core::generic_storage_test;
use apalis_core::test_utils::apalis_test_service_fn;
use apalis_core::test_utils::TestWrapper;
generic_storage_test!(setup);
sql_storage_tests!(setup::<Email>, PostgresStorage<Email>, Email);
async fn setup<T: Serialize + DeserializeOwned>() -> PostgresStorage<T> {
let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
let pool = PgPool::connect(&db_url).await.unwrap();
PostgresStorage::setup(&pool).await.unwrap();
let mut storage = PostgresStorage::new(pool);
cleanup(&mut storage, &WorkerId::new("test-worker")).await;
storage
}
async fn cleanup<T>(storage: &mut PostgresStorage<T>, worker_id: &WorkerId) {
let mut tx = storage
.pool
.acquire()
.await
.expect("failed to get connection");
sqlx::query("Delete from apalis.jobs where job_type = $1 OR lock_by = $2")
.bind(storage.config.namespace())
.bind(worker_id.to_string())
.execute(&mut *tx)
.await
.expect("failed to delete jobs");
sqlx::query("Delete from apalis.workers where id = $1")
.bind(worker_id.to_string())
.execute(&mut *tx)
.await
.expect("failed to delete worker");
}
fn example_email() -> Email {
Email {
subject: "Test Subject".to_string(),
to: "example@postgres".to_string(),
text: "Some Text".to_string(),
}
}
async fn consume_one(
storage: &mut PostgresStorage<Email>,
worker_id: &WorkerId,
) -> Request<Email> {
let req = storage.fetch_next(worker_id).await;
req.unwrap()[0].clone()
}
async fn register_worker_at(
storage: &mut PostgresStorage<Email>,
last_seen: Timestamp,
) -> WorkerId {
let worker_id = WorkerId::new("test-worker");
storage
.keep_alive_at::<DummyService>(&worker_id, last_seen)
.await
.expect("failed to register worker");
worker_id
}
async fn register_worker(storage: &mut PostgresStorage<Email>) -> WorkerId {
register_worker_at(storage, Utc::now().timestamp()).await
}
async fn push_email(storage: &mut PostgresStorage<Email>, email: Email) {
storage.push(email).await.expect("failed to push a job");
}
async fn get_job(storage: &mut PostgresStorage<Email>, job_id: &TaskId) -> Request<Email> {
apalis_core::sleep(Duration::from_secs(2)).await;
storage
.fetch_by_id(job_id)
.await
.expect("failed to fetch job by id")
.expect("no job found by id")
}
#[tokio::test]
async fn test_consume_last_pushed_job() {
let mut storage = setup().await;
push_email(&mut storage, example_email()).await;
let worker_id = register_worker(&mut storage).await;
let job = consume_one(&mut storage, &worker_id).await;
let ctx = job.get::<SqlContext>().unwrap();
let job_id = ctx.id();
let job = get_job(&mut storage, job_id).await;
let ctx = job.get::<SqlContext>().unwrap();
assert_eq!(*ctx.status(), State::Running);
assert_eq!(*ctx.lock_by(), Some(worker_id.clone()));
assert!(ctx.lock_at().is_some());
}
#[tokio::test]
async fn test_kill_job() {
let mut storage = setup().await;
push_email(&mut storage, example_email()).await;
let worker_id = register_worker(&mut storage).await;
let job = consume_one(&mut storage, &worker_id).await;
let ctx = job.get::<SqlContext>().unwrap();
let job_id = ctx.id();
storage
.kill(&worker_id, job_id)
.await
.expect("failed to kill job");
let job = get_job(&mut storage, job_id).await;
let ctx = job.get::<SqlContext>().unwrap();
assert_eq!(*ctx.status(), State::Killed);
assert!(ctx.done_at().is_some());
}
#[tokio::test]
async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
let mut storage = setup().await;
push_email(&mut storage, example_email()).await;
let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
let worker_id = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
let job = consume_one(&mut storage, &worker_id).await;
storage
.reenqueue_orphaned(5)
.await
.expect("failed to heartbeat");
let ctx = job.get::<SqlContext>().unwrap();
let job_id = ctx.id();
let job = get_job(&mut storage, job_id).await;
let ctx = job.get::<SqlContext>().unwrap();
assert_eq!(*ctx.status(), State::Pending);
assert!(ctx.done_at().is_none());
assert!(ctx.lock_by().is_none());
assert!(ctx.lock_at().is_none());
assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_string()));
}
#[tokio::test]
async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
let mut storage = setup().await;
push_email(&mut storage, example_email()).await;
let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
let worker_id = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
let job = consume_one(&mut storage, &worker_id).await;
let ctx = job.get::<SqlContext>().unwrap();
assert_eq!(*ctx.status(), State::Running);
storage
.reenqueue_orphaned(5)
.await
.expect("failed to heartbeat");
let job_id = ctx.id();
let job = get_job(&mut storage, job_id).await;
let ctx = job.get::<SqlContext>().unwrap();
assert_eq!(*ctx.status(), State::Running);
assert_eq!(*ctx.lock_by(), Some(worker_id.clone()));
}
}