oxanus 1.0.1

A simple & fast job queue system.
Documentation
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinSet;

use crate::config::Config;
use crate::context::ContextValue;
use crate::error::OxanusError;
use crate::executor::ExecutionError;
use crate::job_envelope::JobEnvelope;
use crate::queue::{QueueConfig, QueueKind};
use crate::result_collector::{JobResult, JobResultKind};
use crate::semaphores_map::SemaphoresMap;
use crate::worker_event::WorkerJob;
use crate::{
    dispatcher, executor,
    result_collector::{self, Stats},
};

pub async fn run<DT, ET>(
    config: Arc<Config<DT, ET>>,
    stats: Arc<Mutex<Stats>>,
    ctx: ContextValue<DT>,
    queue_config: QueueConfig,
) -> Result<(), OxanusError>
where
    DT: Send + Sync + Clone + 'static,
    ET: std::error::Error + Send + Sync + 'static,
{
    let concurrency = queue_config.concurrency;
    let (result_tx, result_rx) = mpsc::channel::<JobResult>(concurrency);
    let (job_tx, mut job_rx) = mpsc::channel::<WorkerJob>(concurrency);
    let semaphores = Arc::new(SemaphoresMap::new(concurrency));
    let mut joinset = JoinSet::new();

    joinset.spawn(result_collector::run(
        result_rx,
        Arc::clone(&config),
        Arc::clone(&stats),
    ));
    joinset.spawn(run_queue_watcher(
        Arc::clone(&config),
        queue_config.clone(),
        job_tx.clone(),
        Arc::clone(&semaphores),
    ));

    loop {
        tokio::select! {
            job = job_rx.recv() => {
                if let Some(job) = job {
                    joinset.spawn(process_job(
                        Arc::clone(&config),
                        ctx.clone(),
                        result_tx.clone(),
                        job,
                    ));
                }
            }
            Some(task_result) = joinset.join_next() => {
                task_result??;
            }
            _ = config.cancel_token.cancelled() => {
                break;
            }
        }
    }

    wait_for_workers_to_finish(config, Arc::clone(&semaphores)).await;

    Ok(())
}

async fn process_job<DT, ET>(
    config: Arc<Config<DT, ET>>,
    ctx: ContextValue<DT>,
    result_tx: mpsc::Sender<JobResult>,
    job_event: WorkerJob,
) -> Result<(), OxanusError>
where
    DT: Send + Sync + Clone + 'static,
    ET: std::error::Error + Send + Sync + 'static,
{
    tracing::trace!("Processing job: {:?}", job_event);

    let mut envelope: JobEnvelope = match config.storage.internal.get_job(&job_event.job_id).await {
        Ok(Some(envelope)) => envelope,
        Ok(None) => {
            tracing::warn!("Job {} not found", job_event.job_id);
            if let Err(e) = config.storage.internal.delete_job(&job_event.job_id).await {
                #[cfg(feature = "sentry")]
                sentry_core::capture_error(&e);
                tracing::error!("Failed to delete job: {}", e);
            }
            return Ok(());
        }
        Err(e) => {
            #[cfg(feature = "sentry")]
            sentry_core::capture_error(&e);
            tracing::error!("Failed to get job envelope: {}", e);
            return Ok(());
        }
    };

    tracing::debug!(
        job_id = envelope.id,
        latency_ms = envelope.meta.latency_millis(),
        envelope = %serde_json::to_value(&envelope)?,
        "Received envelope"
    );
    let job = match config
        .registry
        .build(&envelope.job.name, envelope.job.args.clone(), &ctx.0)
    {
        Ok(job) => job,
        Err(e) => {
            let err_msg = format!("Invalid job: {} - {}", &envelope.job.name, e);
            tracing::error!("{}", err_msg);
            if let Err(e) = config.storage.internal.kill(&envelope, err_msg).await {
                #[cfg(feature = "sentry")]
                sentry_core::capture_error(&e);
                tracing::error!("Failed to kill job: {}", e);
            }
            return Ok(());
        }
    };

    let result = executor::run(config, job, &mut envelope).await?;
    drop(job_event.permit);

    process_result(result_tx, result, envelope).await;

    Ok(())
}

async fn process_result<ET>(
    result_tx: mpsc::Sender<JobResult>,
    result: Result<(), ExecutionError<ET>>,
    envelope: JobEnvelope,
) where
    ET: std::error::Error + Send + Sync + 'static,
{
    let kind = match result {
        Ok(()) => JobResultKind::Success,
        Err(e) => match e {
            ExecutionError::NotPanic(_) => JobResultKind::Failed,
            ExecutionError::Panic() => JobResultKind::Panicked,
        },
    };

    result_tx.send(JobResult { envelope, kind }).await.ok();
}

async fn run_queue_watcher<DT, ET>(
    config: Arc<Config<DT, ET>>,
    queue_config: QueueConfig,
    job_tx: mpsc::Sender<WorkerJob>,
    semaphores: Arc<SemaphoresMap>,
) -> Result<(), OxanusError>
where
    DT: Send + Sync + Clone + 'static,
    ET: std::error::Error + Send + Sync + 'static,
{
    let mut tracked_queues = HashSet::new();

    loop {
        let all_queues: HashSet<String> = match &queue_config.kind {
            QueueKind::Static { key } => HashSet::from([key.clone()]),
            QueueKind::Dynamic { prefix, .. } => config
                .storage
                .internal
                .track_redis_result(
                    config
                        .storage
                        .internal
                        .queue_keys(&format!("{prefix}*"))
                        .await,
                )?
                .unwrap_or_default(),
        };
        let new_queues: HashSet<String> = all_queues.difference(&tracked_queues).cloned().collect();

        for queue in new_queues {
            tracing::info!(
                queue = queue,
                config.throttle = format!("{:?}", queue_config.throttle),
                "Tracking queue"
            );

            let dispatcher_config = Arc::clone(&config);
            let dispatcher_queue_config = queue_config.clone();
            let dispatcher_job_tx = job_tx.clone();
            let dispatcher_semaphores = Arc::clone(&semaphores);
            let dispatcher_queue = queue.clone();
            tokio::spawn(async move {
                if let Err(e) = dispatcher::run(
                    dispatcher_config,
                    dispatcher_queue_config,
                    dispatcher_queue,
                    dispatcher_job_tx,
                    dispatcher_semaphores,
                )
                .await
                {
                    tracing::error!(error = %e, "Dispatcher exited with error");
                }
            });

            tracked_queues.insert(queue);
        }

        if config.cancel_token.is_cancelled() {
            return Ok(());
        } else if let QueueKind::Dynamic { sleep_period, .. } = queue_config.kind {
            tokio::time::sleep(sleep_period).await;
        } else {
            return Ok(());
        }
    }
}

async fn wait_for_workers_to_finish<DT, ET>(
    config: Arc<Config<DT, ET>>,
    semaphores: Arc<SemaphoresMap>,
) where
    DT: Send + Sync + Clone + 'static,
    ET: std::error::Error + Send + Sync + 'static,
{
    let t_start = std::time::Instant::now();
    let mut ticks = 0;

    loop {
        ticks += 1;

        let busy_count = semaphores.busy_count().await;
        if busy_count == 0 {
            break;
        }

        if ticks % 200 == 0 {
            tracing::info!("Waiting for {} workers to finish...", busy_count);
        }

        if t_start.elapsed() > config.shutdown_timeout {
            tracing::error!("Shutdown timeout reached");
            break;
        }

        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
    }
}