archimedes 0.4.0

High performance Rust/PostgreSQL job queue (also suitable for getting jobs generated by PostgreSQL triggers/functions out into a different work queue)
Documentation
use crate::runner::WorkerFn;
use crate::sql::task_identifiers::get_tasks_details;
use crate::utils::escape_identifier;
use crate::{Worker, WorkerContext};
use archimedes_crontab_parser::{parse_crontab, CrontabParseError};
use archimedes_crontab_types::Crontab;
use archimedes_migrations::migrate;
use archimedes_shutdown_signal::shutdown_signal;
use archimedes_task_handler::{TaskDefinition, TaskHandler};
use futures::FutureExt;
use rand::RngCore;
use serde::Deserialize;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::collections::HashMap;
use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;

#[derive(Default)]
pub struct WorkerOptions {
    concurrency: Option<usize>,
    poll_interval: Option<Duration>,
    jobs: HashMap<String, WorkerFn>,
    pg_pool: Option<PgPool>,
    database_url: Option<String>,
    max_pg_conn: Option<u32>,
    schema: Option<String>,
    forbidden_flags: Vec<String>,
    crontabs: Option<Vec<Crontab>>,
    use_local_time: bool,
}

#[derive(Error, Debug)]
pub enum WorkerBuildError {
    #[error("Error occured while connecting to the postgres database : {0}")]
    ConnectError(#[from] sqlx::Error),
    #[error("Error occured while querying : {0}")]
    QueryError(#[from] crate::errors::ArchimedesError),
    #[error("Missing database_url config")]
    MissingDatabaseUrl,
}

impl WorkerOptions {
    pub async fn init(self) -> Result<Worker, WorkerBuildError> {
        let pg_pool = match self.pg_pool {
            Some(pg_pool) => pg_pool,
            None => {
                let db_url = self
                    .database_url
                    .ok_or(WorkerBuildError::MissingDatabaseUrl)?;

                PgPoolOptions::new()
                    .max_connections(self.max_pg_conn.unwrap_or(20))
                    .connect(&db_url)
                    .await?
            }
        };

        let schema = self
            .schema
            .unwrap_or_else(|| String::from("archimedes_worker"));
        let escaped_schema = escape_identifier(&pg_pool, &schema).await?;

        migrate(&pg_pool, &escaped_schema).await?;

        let task_details = get_tasks_details(
            &pg_pool,
            &escaped_schema,
            self.jobs.keys().cloned().collect(),
        )
        .await?;

        let mut random_bytes = [0u8; 9];
        rand::thread_rng().fill_bytes(&mut random_bytes);

        let worker = Worker {
            worker_id: format!("archimedes_worker_{}", hex::encode(random_bytes)),
            concurrency: self.concurrency.unwrap_or_else(num_cpus::get),
            poll_interval: self.poll_interval.unwrap_or(Duration::from_millis(1000)),
            jobs: self.jobs,
            pg_pool,
            escaped_schema,
            task_details,
            forbidden_flags: self.forbidden_flags,
            crontabs: self.crontabs.unwrap_or_default(),
            use_local_time: self.use_local_time,
            shutdown_signal: shutdown_signal(),
        };

        Ok(worker)
    }

    pub fn schema(mut self, value: &str) -> Self {
        self.schema = Some(value.into());
        self
    }

    pub fn concurrency(mut self, value: usize) -> Self {
        self.concurrency = Some(value);
        self
    }

    pub fn poll_interval(mut self, value: Duration) -> Self {
        self.poll_interval = Some(value);
        self
    }

    pub fn pg_pool(mut self, value: PgPool) -> Self {
        self.pg_pool = Some(value);
        self
    }

    pub fn database_url(mut self, value: &str) -> Self {
        self.database_url = Some(value.into());
        self
    }

    pub fn max_pg_conn(mut self, value: u32) -> Self {
        self.max_pg_conn = Some(value);
        self
    }

    pub fn define_raw_job<T, E, Fut, F>(mut self, identifier: &str, job_fn: F) -> Self
    where
        T: for<'de> Deserialize<'de> + Send,
        E: Debug,
        Fut: Future<Output = Result<(), E>> + Send,
        F: Fn(WorkerContext, T) -> Fut + Send + Sync + 'static,
    {
        let job_fn = Arc::new(job_fn);
        let worker_fn = move |ctx: WorkerContext, payload: String| {
            let job_fn = job_fn.clone();
            async move {
                let de_payload = serde_json::from_str(&payload);

                match de_payload {
                    Err(e) => Err(format!("{e:?}")),
                    Ok(p) => {
                        let job_result = job_fn(ctx, p).await;
                        match job_result {
                            Err(e) => Err(format!("{e:?}")),
                            Ok(v) => Ok(v),
                        }
                    }
                }
            }
            .boxed()
        };

        self.jobs
            .insert(identifier.to_string(), Box::new(worker_fn));
        self
    }

    pub fn define_job<T>(mut self, task: T) -> Self
    where
        T: TaskDefinition<WorkerContext>,
    {
        let task_runner = task.get_task_runner();

        let identifier = T::identifier();

        let worker_fn = move |ctx: WorkerContext, payload: String| {
            let payload = serde_json::from_str(&payload);
            let task_runner = task_runner.clone();

            async move {
                match payload {
                    Err(e) => Err(format!("{e:?}")),
                    Ok(p) => {
                        let job_result = task_runner.run(p, ctx).await;
                        match job_result {
                            Err(e) => Err(format!("{e:?}")),
                            Ok(v) => Ok(v),
                        }
                    }
                }
            }
            .boxed()
        };

        self.jobs
            .insert(identifier.to_string(), Box::new(worker_fn));
        self
    }

    pub fn add_forbidden_flag(mut self, flag: &str) -> Self {
        self.forbidden_flags.push(flag.into());
        self
    }

    pub fn with_crontab(mut self, input: &str) -> Result<Self, CrontabParseError> {
        let mut crontabs = parse_crontab(input)?;
        match self.crontabs.as_mut() {
            Some(c) => c.append(&mut crontabs),
            None => {
                self.crontabs = Some(crontabs);
            }
        }
        Ok(self)
    }

    pub fn use_local_time(mut self, value: bool) -> Self {
        self.use_local_time = value;
        self
    }
}