use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use super::routing::Router;
use crate::storage::Storage;
use crate::storage::error::Result;
use crate::storage::types::{EnqueueOutcome, EnqueueRequest, JobId};
#[async_trait]
pub trait JobHandler: Send + Sync + 'static {
fn kind(&self) -> &'static str;
async fn run(&self, ctx: JobCtx<'_>, payload: serde_json::Value) -> JobOutcome;
}
#[derive(Debug)]
#[non_exhaustive]
pub enum JobOutcome {
Done,
Throttled { retry_after: Duration },
Failed(String),
Dead(String),
}
pub struct JobCtx<'a> {
pub storage: &'a Storage,
pub router: &'a (dyn Router + Send + Sync),
pub rate_limit: &'a super::RateLimiter,
pub job_id: JobId,
pub process_id: &'a str,
pub host_id: &'a str,
pub cancel: CancellationToken,
}
impl std::fmt::Debug for JobCtx<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JobCtx")
.field("job_id", &self.job_id)
.field("process_id", &self.process_id)
.field("host_id", &self.host_id)
.finish_non_exhaustive()
}
}
const ENQUEUE_RETRY_DELAYS: &[Duration] = &[
Duration::from_millis(100),
Duration::from_millis(300),
Duration::from_secs(1),
];
impl JobCtx<'_> {
pub async fn enqueue(&self, req: EnqueueRequest) -> Result<EnqueueOutcome> {
let mut req = req;
if req.queue_name.is_none() {
req.queue_name = Some(Cow::Borrowed(self.router.route(req.kind.as_ref())));
}
let mut attempt = 0usize;
loop {
match self.storage.jobs.enqueue(req.clone()).await {
Ok(v) => return Ok(v),
Err(e) if e.is_transient_conflict() && attempt < ENQUEUE_RETRY_DELAYS.len() => {
tracing::warn!(
kind = %req.kind,
attempt,
delay_ms = ENQUEUE_RETRY_DELAYS[attempt].as_millis(),
err = %e,
"ctx.enqueue: transient conflict; retrying"
);
tokio::time::sleep(ENQUEUE_RETRY_DELAYS[attempt]).await;
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
pub async fn enqueue_bulk(&self, reqs: Vec<EnqueueRequest>) -> Result<Vec<EnqueueOutcome>> {
let routed: Vec<EnqueueRequest> = reqs
.into_iter()
.map(|mut req| {
if req.queue_name.is_none() {
req.queue_name = Some(Cow::Borrowed(self.router.route(req.kind.as_ref())));
}
req
})
.collect();
let mut attempt = 0usize;
loop {
match self.storage.jobs.enqueue_bulk(routed.clone()).await {
Ok(v) => return Ok(v),
Err(e) if e.is_transient_conflict() && attempt < ENQUEUE_RETRY_DELAYS.len() => {
tracing::warn!(
batch_size = routed.len(),
attempt,
delay_ms = ENQUEUE_RETRY_DELAYS[attempt].as_millis(),
err = %e,
"ctx.enqueue_bulk: transient conflict; retrying"
);
tokio::time::sleep(ENQUEUE_RETRY_DELAYS[attempt]).await;
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
}
#[derive(Default)]
pub struct HandlerRegistry {
handlers: HashMap<&'static str, Arc<dyn JobHandler>>,
}
impl HandlerRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<H: JobHandler>(&mut self, handler: H) {
self.handlers.insert(handler.kind(), Arc::new(handler));
}
#[must_use]
pub fn get(&self, kind: &str) -> Option<Arc<dyn JobHandler>> {
self.handlers.get(kind).cloned()
}
#[must_use]
pub fn kinds(&self) -> Vec<&'static str> {
self.handlers.keys().copied().collect()
}
}
impl std::fmt::Debug for HandlerRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HandlerRegistry")
.field("kinds", &self.kinds())
.finish()
}
}