use std::{fmt::Debug, sync::Arc, time::Duration};
use bonsaidb::core::{
async_trait::async_trait,
connection::AsyncConnection,
document::CollectionDocument,
pubsub::{AsyncPubSub, AsyncSubscriber},
schema::{view::map::MappedDocuments, Collection, SerializedCollection},
transaction::{Operation, Transaction},
Error as BonsaiError,
};
use time::OffsetDateTime;
use crate::{
queue::{DueMessages, Id, Message, MessagePayload, Timestamp, MQ_NOTIFY},
AbortOnDropHandle, CurrentJob, Error, JobRegister,
};
type ErrorHandler = Arc<dyn Fn(Box<dyn std::error::Error + Send + Sync>) + Send + Sync>;
type Context = erased_set::ErasedSyncSet;
pub struct JobRunner<DB> {
db: DB,
error_handler: Option<ErrorHandler>,
context: Context,
}
impl<DB> JobRunner<DB>
where
DB: AsyncConnection + AsyncPubSub + Debug + 'static,
{
pub fn new(db: DB) -> Self {
Self { db, error_handler: None, context: Context::new() }
}
#[must_use]
pub fn set_error_handler<F>(mut self, handler: F) -> Self
where
F: Fn(Box<dyn std::error::Error + Send + Sync>) + Send + Sync + 'static,
{
self.error_handler = Some(Arc::new(handler));
self
}
#[must_use]
pub fn set_context<C: Clone + Send + Sync + 'static>(mut self, context: C) -> Self {
self.context.insert(context);
self
}
#[must_use]
pub fn run<REG>(self) -> AbortOnDropHandle<Result<(), Error>>
where
REG: JobRegister + Send + Sync + 'static,
{
let internal_runner = InternalJobRunner {
db: Arc::new(self.db),
error_handler: self.error_handler,
context: Arc::new(self.context),
};
tokio::task::spawn(internal_runner.job_queue::<REG>()).into()
}
}
impl<DB: Debug> Debug for JobRunner<DB> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JobRunner")
.field("db", &self.db)
.field("error_handler", &"<err handler fn>")
.field("context", &self.context)
.finish()
}
}
struct InternalJobRunner<DB> {
db: Arc<DB>,
error_handler: Option<ErrorHandler>,
context: Arc<Context>,
}
impl<DB> Clone for InternalJobRunner<DB> {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
error_handler: self.error_handler.clone(),
context: self.context.clone(),
}
}
}
impl<DB> InternalJobRunner<DB>
where
DB: AsyncConnection + AsyncPubSub + Debug + 'static,
{
async fn due_messages(
&self,
due_at: Timestamp,
) -> Result<MappedDocuments<CollectionDocument<Message>, DueMessages>, BonsaiError> {
self.db.view::<DueMessages>().with_key_range(..due_at).query_with_collection_docs().await
}
async fn next_message_due_in(&self, from: Timestamp) -> Result<Duration, BonsaiError> {
let nanos = self
.db
.view::<DueMessages>()
.with_key_range(from..)
.reduce()
.await?
.map_or(10_000_000_000, |target| target - from);
let duration = Duration::from_nanos(nanos.clamp(0, u64::MAX.into()) as u64);
Ok(duration)
}
async fn message_payloads(
&self,
id: Id,
) -> Result<(Option<serde_json::Value>, Option<Vec<u8>>), BonsaiError> {
Ok(MessagePayload::get_async(&id, self.db.as_ref())
.await?
.map_or((None, None), |payload| {
(payload.contents.payload_json, payload.contents.payload_bytes)
}))
}
#[tracing::instrument(level = "debug", skip_all, err)]
async fn job_queue<REG>(self) -> Result<(), Error>
where
REG: JobRegister + Send + Sync,
DB::Subscriber: AsyncSubscriber,
{
tracing::debug!("Running JobRunner..");
let subscriber = self.db.create_subscriber().await?;
subscriber.subscribe_to(&MQ_NOTIFY).await?;
loop {
let now = OffsetDateTime::now_utc().unix_timestamp_nanos();
let messages = self.due_messages(now).await?;
tracing::trace!("Found {} due messages.", messages.len());
for msg in &messages {
if let Some(job) = REG::from_name(&msg.document.contents.name) {
if let Some(dependency) = msg.document.contents.execute_after {
if Message::get_async(&dependency, self.db.as_ref()).await?.is_some() {
continue;
}
}
if self.job_update(msg.document.contents.id).await? {
let payloads = self.message_payloads(msg.document.contents.id).await?;
let current_job = CurrentJob {
id: msg.document.contents.id,
name: job.name(),
db: Arc::new(self.clone()),
payload_json: payloads.0,
payload_bytes: payloads.1,
keep_alive: None,
};
let _jh = current_job.run(job.function());
}
} else {
tracing::trace!(
"Job {} is not registered and will be ignored.",
msg.document.contents.name
);
}
}
let next_due_in = self.next_message_due_in(now).await?;
tokio::time::timeout(next_due_in, subscriber.receiver().receive_async())
.await
.ok() .transpose()?;
}
}
}
impl<DB: Debug> Debug for InternalJobRunner<DB> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JobRunner")
.field("db", &self.db)
.field("error_handler", &"<err handler fn>")
.field("context", &self.context)
.finish()
}
}
#[async_trait]
pub(crate) trait JobRunnerHandle: Debug {
fn context(&self) -> &Context;
fn handle_job_error(&self, err: Box<dyn std::error::Error + Send + Sync>);
async fn complete(&self, id: Id) -> Result<(), Error>;
async fn keep_alive(&self, id: Id) -> Result<Duration, Error>;
async fn job_update(&self, id: Id) -> Result<bool, Error>;
async fn notify(&self) -> Result<(), Error>;
async fn checkpoint(
&self,
id: Id,
payload_json: Option<serde_json::Value>,
payload_bytes: Option<Vec<u8>>,
) -> Result<(), Error>;
}
#[async_trait]
impl<DB> JobRunnerHandle for InternalJobRunner<DB>
where
DB: AsyncConnection + AsyncPubSub + Debug + 'static,
{
fn context(&self) -> &Context {
&self.context
}
fn handle_job_error(&self, err: Box<dyn std::error::Error + Send + Sync>) {
if let Some(err_handler) = &self.error_handler {
err_handler(err);
}
}
#[tracing::instrument(level = "debug", skip(self))]
async fn complete(&self, id: Id) -> Result<(), Error> {
tracing::trace!("Completing job {id}.");
let del_message = Message::get_async(&id, self.db.as_ref()).await?.map(|msg| msg.header);
let del_payload =
MessagePayload::get_async(&id, self.db.as_ref()).await?.map(|payload| payload.header);
let mut tx = Transaction::new();
if let Some(header) = del_message {
tx.push(Operation::delete(Message::collection_name(), header.try_into()?));
}
if let Some(header) = del_payload {
tx.push(Operation::delete(MessagePayload::collection_name(), header.try_into()?));
}
match tx.apply_async(self.db.as_ref()).await {
Err(BonsaiError::DocumentNotFound(_, _)) => {}
Err(err) => return Err(err.into()),
Ok(_) => {}
};
self.db.publish(&MQ_NOTIFY, &()).await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn keep_alive(&self, id: Id) -> Result<Duration, Error> {
if let Some(mut message) = Message::get_async(&id, self.db.as_ref()).await? {
tracing::trace!("Keeping job {id} alive.");
let duration = message.contents.retry_timing.next_duration(message.contents.executions);
let now = OffsetDateTime::now_utc().unix_timestamp_nanos();
message.contents.attempt_at = now + Timestamp::try_from(duration.as_nanos())?;
message.update_async(self.db.as_ref()).await?;
Ok(duration)
} else {
Ok(Duration::default())
}
}
#[tracing::instrument(level = "debug", skip(self))]
async fn job_update(&self, id: Id) -> Result<bool, Error> {
if let Some(mut message) = Message::get_async(&id, self.db.as_ref()).await? {
tracing::trace!("Updating job {id} for execution/retry.");
message.contents.executions += 1;
if message
.contents
.max_executions
.map_or(false, |max| message.contents.executions > max)
{
self.complete(id).await?;
return Ok(false);
}
let duration = message.contents.retry_timing.next_duration(message.contents.executions);
let now = OffsetDateTime::now_utc().unix_timestamp_nanos();
message.contents.attempt_at = now + Timestamp::try_from(duration.as_nanos())?;
message.update_async(self.db.as_ref()).await?;
Ok(true)
} else {
Ok(false)
}
}
async fn notify(&self) -> Result<(), Error> {
self.db.publish(&MQ_NOTIFY, &()).await?;
Ok(())
}
async fn checkpoint(
&self,
id: Id,
payload_json: Option<serde_json::Value>,
payload_bytes: Option<Vec<u8>>,
) -> Result<(), Error> {
if let Some(mut payloads) = MessagePayload::get_async(&id, self.db.as_ref()).await? {
payloads.contents.payload_json = payload_json;
payloads.contents.payload_bytes = payload_bytes;
payloads.update_async(self.db.as_ref()).await?;
}
Ok(())
}
}