use crate::background_job::DEFAULT_QUEUE;
use crate::job_registry::JobRegistry;
use crate::worker::Worker;
use crate::{BackgroundJob, storage};
use anyhow::anyhow;
use futures_util::future::join_all;
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
use tracing::{Instrument, info, info_span, warn};
const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(1);
pub struct Runner<Context> {
connection_pool: PgPool,
queues: HashMap<String, Queue<Context>>,
context: Context,
shutdown_when_queue_empty: bool,
}
impl<Context: std::fmt::Debug> std::fmt::Debug for Runner<Context> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runner")
.field("queues", &self.queues.keys().collect::<Vec<_>>())
.field("context", &self.context)
.field("shutdown_when_queue_empty", &self.shutdown_when_queue_empty)
.finish()
}
}
impl<Context: Clone + Send + Sync + 'static> Runner<Context> {
pub fn new(connection_pool: PgPool, context: Context) -> Self {
Self {
connection_pool,
queues: HashMap::new(),
context,
shutdown_when_queue_empty: false,
}
}
pub fn register_job_type<J: BackgroundJob<Context = Context>>(mut self) -> Self {
let queue = self.queues.entry(J::QUEUE.into()).or_default();
queue.job_registry.register::<J>();
self
}
pub fn configure_default_queue<F>(self, f: F) -> Self
where
F: FnOnce(&mut Queue<Context>) -> &Queue<Context>,
{
self.configure_queue(DEFAULT_QUEUE, f)
}
pub fn configure_queue<F>(mut self, name: &str, f: F) -> Self
where
F: FnOnce(&mut Queue<Context>) -> &Queue<Context>,
{
f(self.queues.entry(name.into()).or_default());
self
}
pub fn shutdown_when_queue_empty(mut self) -> Self {
self.shutdown_when_queue_empty = true;
self
}
pub fn start(&self) -> RunHandle {
let mut handles = Vec::new();
for (queue_name, queue) in &self.queues {
for i in 1..=queue.num_workers {
let name = format!("background-worker-{queue_name}-{i}");
info!(worker.name = %name, "Starting worker…");
let worker = Worker {
connection_pool: self.connection_pool.clone(),
context: self.context.clone(),
job_registry: Arc::new(queue.job_registry.clone()),
shutdown_when_queue_empty: self.shutdown_when_queue_empty,
poll_interval: queue.poll_interval,
};
let span = info_span!("worker", worker.name = %name);
let handle = tokio::spawn(async move { worker.run().instrument(span).await });
handles.push(handle);
}
}
RunHandle { handles }
}
pub async fn check_for_failed_jobs(&self) -> anyhow::Result<()> {
let failed_jobs = storage::failed_job_count(&self.connection_pool).await?;
if failed_jobs == 0 {
Ok(())
} else {
Err(anyhow!("{failed_jobs} jobs failed"))
}
}
}
#[derive(Debug)]
pub struct RunHandle {
handles: Vec<JoinHandle<()>>,
}
impl RunHandle {
pub async fn wait_for_shutdown(self) {
join_all(self.handles).await.into_iter().for_each(|result| {
if let Err(error) = result {
warn!(%error, "Background worker task panicked");
}
});
}
}
#[derive(Debug)]
pub struct Queue<Context> {
job_registry: JobRegistry<Context>,
num_workers: usize,
poll_interval: Duration,
}
impl<Context> Default for Queue<Context> {
fn default() -> Self {
Self {
job_registry: JobRegistry::default(),
num_workers: 1,
poll_interval: DEFAULT_POLL_INTERVAL,
}
}
}
impl<Context> Queue<Context> {
pub fn num_workers(&mut self, num_workers: usize) -> &mut Self {
self.num_workers = num_workers;
self
}
pub fn poll_interval(&mut self, poll_interval: Duration) -> &mut Self {
self.poll_interval = poll_interval;
self
}
}