use std::sync::Arc;
use serde::{de::DeserializeOwned, Serialize};
use tokio::task::JoinHandle;
use tokio_retry::{strategy::FixedInterval, RetryIf};
use tracing_futures::Instrument;
use crate::{queue::Id, runner::JobRunnerHandle, AbortOnDropHandle, Error, JobFunctionType};
type JobRunnerHandler = Arc<dyn JobRunnerHandle + Send + Sync>;
#[derive(Debug)]
pub struct CurrentJob {
pub(crate) id: Id,
pub(crate) name: &'static str,
pub(crate) db: JobRunnerHandler,
pub(crate) payload_json: Option<serde_json::Value>,
pub(crate) payload_bytes: Option<Vec<u8>>,
pub(crate) keep_alive: Option<AbortOnDropHandle<Result<(), Error>>>,
}
impl CurrentJob {
#[must_use]
pub fn id(&self) -> Id {
self.id
}
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
#[must_use]
pub fn payload_json<D: DeserializeOwned>(&self) -> Option<Result<D, serde_json::Error>> {
self.payload_json.as_ref().map(|payload| serde_json::from_value(payload.clone()))
}
#[must_use]
pub fn payload_bytes(&self) -> Option<&Vec<u8>> {
self.payload_bytes.as_ref()
}
#[must_use]
pub fn context<C: Clone + Send + Sync + 'static>(&self) -> Option<C> {
self.db.context().get::<C>().cloned()
}
pub async fn complete(&mut self) -> Result<(), Error> {
RetryIf::spawn(
FixedInterval::from_millis(10).take(2),
|| self.db.complete(self.id),
Error::should_retry,
)
.await?;
if let Some(keep_alive) = self.keep_alive.take() {
keep_alive.abort();
};
Ok(())
}
#[must_use]
pub fn checkpoint(&mut self) -> Checkpoint<'_> {
Checkpoint::new(self)
}
pub(crate) fn keep_alive(db: JobRunnerHandler, id: Id) -> JoinHandle<Result<(), Error>> {
let span = tracing::debug_span!("job-keep-alive");
tokio::task::spawn(
async move {
loop {
let duration = RetryIf::spawn(
FixedInterval::from_millis(10).take(2),
|| db.keep_alive(id),
Error::should_retry,
)
.await?;
tokio::time::sleep(duration.div_f32(2.0)).await;
}
}
.instrument(span),
)
}
pub(crate) fn run(mut self, mut function: JobFunctionType) -> JoinHandle<Result<(), Error>> {
self.keep_alive = Some(Self::keep_alive(self.db.clone(), self.id).into());
let span = tracing::debug_span!("job-run");
tokio::task::spawn(
async move {
let id = self.id;
let db = self.db.clone();
tracing::trace!("Starting job with ID {id}.");
let res = function(self).await;
if let Err(err) = res {
db.handle_job_error(err);
}
db.notify().await?;
tracing::trace!("Job with ID {id} finished execution.");
Ok(())
}
.instrument(span),
)
}
}
#[derive(Debug)]
pub struct Checkpoint<'a> {
job: &'a mut CurrentJob,
payload_json: Option<serde_json::Value>,
payload_bytes: Option<Vec<u8>>,
}
impl<'a> Checkpoint<'a> {
fn new(job: &'a mut CurrentJob) -> Self {
let payload_json = job.payload_json.clone();
let payload_bytes = job.payload_bytes.clone();
Self { job, payload_json, payload_bytes }
}
pub fn payload_json<S: Serialize>(
mut self,
payload: impl Into<Option<S>>,
) -> Result<Self, Error> {
let payload_json = payload.into().map(|s| serde_json::to_value(s)).transpose()?;
self.payload_json = payload_json;
Ok(self)
}
#[must_use]
pub fn payload_bytes(mut self, payload: impl Into<Option<Vec<u8>>>) -> Self {
self.payload_bytes = payload.into();
self
}
pub async fn set(self) -> Result<(), Error> {
RetryIf::spawn(
FixedInterval::from_millis(10).take(2),
|| {
self.job.db.checkpoint(
self.job.id,
self.payload_json.clone(),
self.payload_bytes.clone(),
)
},
Error::should_retry,
)
.await?;
self.job.payload_json = self.payload_json;
self.job.payload_bytes = self.payload_bytes;
Ok(())
}
}