backyard-core 0.1.0

Core traits and types for the Backyard async job queue
Documentation
use crate::{
    error::Result,
    job::{JobContext, RawJob},
    queue::Queue,
    registry::build_dispatch_table,
};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};

pub type HandlerFn = fn(&[u8], JobContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send>>;

#[derive(Debug, Clone)]
pub struct WorkerConfig {
    pub queues: Vec<String>,
    pub concurrency: usize,
    pub poll_interval: std::time::Duration,
}

impl Default for WorkerConfig {
    fn default() -> Self {
        Self {
            queues: vec!["default".into()],
            concurrency: 10,
            poll_interval: std::time::Duration::from_millis(500),
        }
    }
}

pub struct WorkerPool {
    queue: Arc<dyn Queue>,
    config: WorkerConfig,
    dispatch: HashMap<&'static str, HandlerFn>,
    shutdown: CancellationToken,
}

impl WorkerPool {
    pub fn new(queue: Arc<dyn Queue>, config: WorkerConfig) -> Self {
        Self {
            queue,
            config,
            dispatch: build_dispatch_table(),
            shutdown: CancellationToken::new(),
        }
    }

    pub fn shutdown_token(&self) -> CancellationToken {
        self.shutdown.child_token()
    }

    pub async fn run(self) -> Result<()> {
        let (tx, rx): (mpsc::Sender<RawJob>, mpsc::Receiver<RawJob>) =
            mpsc::channel(self.config.concurrency * 2);
        let rx = Arc::new(tokio::sync::Mutex::new(rx));

        let mut handles = vec![];
        for worker_id in 0..self.config.concurrency {
            let rx = rx.clone();
            let queue = self.queue.clone();
            let dispatch = self.dispatch.clone();
            let shutdown = self.shutdown.clone();
            let handle = tokio::spawn(async move {
                Self::worker_loop(worker_id.to_string(), rx, queue, dispatch, shutdown).await
            });
            handles.push(handle);
        }

        let fetch_shutdown = self.shutdown.clone();
        let queue = self.queue.clone();
        let queues: Vec<String> = self.config.queues.clone();
        let poll_interval = self.config.poll_interval;

        tokio::spawn(async move {
            Self::fetch_loop(queue, queues, tx, fetch_shutdown, poll_interval).await
        });

        futures::future::join_all(handles).await;
        Ok(())
    }

    async fn fetch_loop(
        queue: Arc<dyn Queue>,
        queues: Vec<String>,
        tx: mpsc::Sender<RawJob>,
        shutdown: CancellationToken,
        poll_interval: std::time::Duration,
    ) {
        let queue_refs: Vec<&str> = queues.iter().map(|s| s.as_str()).collect();
        loop {
            tokio::select! {
                _ = shutdown.cancelled() => break,
                result = queue.pop(&queue_refs) => {
                    match result {
                        Ok(Some(job)) => {
                            let _ = tx.send(job).await;
                        }
                        Ok(None) => {
                            tokio::time::sleep(poll_interval).await;
                        }
                        Err(e) => {
                            error!("fetch error: {e}");
                            tokio::time::sleep(poll_interval).await;
                        }
                    }
                }
            }
        }
    }

    async fn worker_loop(
        worker_id: String,
        rx: Arc<tokio::sync::Mutex<mpsc::Receiver<RawJob>>>,
        queue: Arc<dyn Queue>,
        dispatch: HashMap<&'static str, HandlerFn>,
        shutdown: CancellationToken,
    ) {
        loop {
            let job: Option<RawJob> = {
                let mut rx = rx.lock().await;
                tokio::select! {
                    _ = shutdown.cancelled() => break,
                    job = rx.recv() => job
                }
            };

            let job = match job {
                Some(j) => j,
                None => break,
            };

            let ctx = JobContext {
                queue: queue.clone(),
                worker_id: worker_id.clone(),
            };

            match dispatch.get(job.job_type.as_str()) {
                None => {
                    error!(job_type = %job.job_type, "no handler registered");
                    let _ = queue.fail(job.id, "no handler registered").await;
                }
                Some(handler) => {
                    let span = tracing::info_span!(
                        "execute_job",
                        job_id = %job.id,
                        job_type = %job.job_type,
                        queue = %job.queue,
                        attempt = job.attempts,
                    );
                    let result = handler(&job.payload, ctx.clone()).instrument(span).await;
                    match result {
                        Ok(()) => {
                            info!(job_id = %job.id, "job succeeded");
                            let _ = queue.ack(job.id).await;
                        }
                        Err(e) => {
                            warn!(job_id = %job.id, error = %e, "job failed");
                            if job.attempts >= job.max_retries {
                                let _ = queue.fail(job.id, &e.to_string()).await;
                            } else {
                                let retry_at = crate::retry::next_retry_at(job.attempts);
                                let _ = queue.retry(job.id, retry_at).await;
                            }
                        }
                    }
                }
            }
        }
    }
}