use crate::error::Result;
use crate::{Job, Queue};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
use postgres::{
params::{ConnectParams, IntoConnectParams},
transaction::Transaction,
Connection, TlsMode,
};
use std::thread;
use std::time::Duration;
const RETRY_DELAY_MS: u32 = 50;
const MAX_DELAY_MS: u32 = 1_209_600_000;
pub struct JobContext<'a> {
pub job: &'a Job,
pub tx: &'a Transaction<'a>,
}
pub trait Handler {
type Error: std::error::Error;
fn handle(&self, context: JobContext) -> std::result::Result<(), Self::Error>;
}
#[derive(Clone)]
pub struct WorkerPoolConfig<H> {
queue: Queue,
num_workers: usize,
worker_poll_interval: Duration,
connect_params: ConnectParams,
handler: H,
}
pub struct WorkerPool<H> {
config: WorkerPoolConfig<H>,
shutdown_chans: Vec<Sender<()>>,
shutdown_ack_chan: Receiver<usize>,
}
struct Worker<H> {
worker_id: usize,
config: WorkerPoolConfig<H>,
shutdown_chan: Receiver<()>,
shutdown_ack_chan: Sender<usize>,
}
struct Sentinel<'a, H: Handler + Send + Clone + 'static> {
worker: &'a Worker<H>,
active: bool,
running_job_id: Option<u64>,
}
impl<H> WorkerPoolConfig<H> {
#[inline]
pub fn new<C: IntoConnectParams>(
queue: Queue,
db_connect_params: C,
handler: H,
) -> Result<WorkerPoolConfig<H>> {
let cp = db_connect_params.into_connect_params()?;
Ok(WorkerPoolConfig {
queue,
num_workers: num_cpus::get(),
worker_poll_interval: Self::default_duration(),
connect_params: cp,
handler,
})
}
#[inline]
pub fn set_worker_poll_interval(&mut self, interval: Duration) {
self.worker_poll_interval = interval;
}
#[inline]
pub fn set_num_workers(&mut self, size: usize) {
self.num_workers = size;
}
fn default_duration() -> Duration {
Duration::new(0, 200_000_000) }
}
impl<H: Handler + Send + Clone + 'static> WorkerPool<H> {
pub fn start(config: WorkerPoolConfig<H>) -> WorkerPool<H> {
let size = config.num_workers;
info!("starting worker pool with {} workers", size);
let (sa_send, sa_recv) = unbounded();
let mut shutdown_chans = Vec::with_capacity(size);
for idx in 0..size {
let (s_send, s_recv) = bounded(1);
Worker::new(idx + 1, config.clone(), s_recv, sa_send.clone()).spawn();
shutdown_chans.push(s_send);
}
WorkerPool {
config,
shutdown_chans,
shutdown_ack_chan: sa_recv,
}
}
pub fn config(&self) -> &WorkerPoolConfig<H> {
&self.config
}
pub fn join(self) {
for w in self.shutdown_chans.iter() {
let _ = w.send(());
}
for _ in 0..(self.shutdown_chans.len()) {
let _ = self.shutdown_ack_chan.recv();
}
}
}
impl<H: Handler + Send + Clone + 'static> Worker<H> {
fn new(
worker_id: usize,
config: WorkerPoolConfig<H>,
shutdown_chan: Receiver<()>,
shutdown_ack_chan: Sender<usize>,
) -> Worker<H> {
Worker {
worker_id,
config,
shutdown_chan,
shutdown_ack_chan,
}
}
fn spawn(self) {
let worker_id = self.worker_id;
let _ = thread::spawn(move || {
self.run();
});
debug!("spawned worker {}", worker_id);
}
fn run(&self) {
let mut sentinel = Sentinel::new(&self);
let mut conn = match new_conn(self.config.connect_params.clone(), self.worker_id)
{
Some(conn) => conn,
None => return,
};
let mut job_last_loop = true;
loop {
select! {
recv(self.shutdown_chan) -> _ => {
debug!("worker {} received shutdown signal", self.worker_id);
let res = self.shutdown_ack_chan.send(self.worker_id);
if let Err(err) = res {
warn!("error acking shutdown request. {}", err);
}
sentinel.cancel();
return;
}
default => {
if !job_last_loop {
thread::sleep(self.config.worker_poll_interval);
}
if conn.is_desynchronized() {
conn = match new_conn(self.config.connect_params.clone(), self.worker_id)
{
Some(conn) => conn,
None => return,
};
}
let res = self.run_job(&conn, &mut sentinel);
match res {
Ok(found_job) => {
job_last_loop = found_job;
}
Err(err) => {
error!("error in worker {}. {}", self.worker_id, err);
job_last_loop = false;
}
}
}
}
}
}
fn run_job(&self, conn: &Connection, sentinel: &mut Sentinel<H>) -> Result<bool> {
let tx = conn.transaction()?;
let job = self.config.queue.reserve(&tx)?;
let found_job = job.is_some();
if let Some(job) = job {
debug!("worker {} starting job", self.worker_id);
sentinel.set_running_job(job.id);
let sub_tx = tx.savepoint("job")?;
let job_ctx = JobContext {
job: &job,
tx: &sub_tx,
};
let job_result = self.config.handler.handle(job_ctx);
match job_result {
Ok(()) => {
debug!("worker {} completed job", self.worker_id);
sub_tx.commit()?;
self.config.queue.complete(&job, &tx)
}
Err(err) => {
sub_tx.finish()?;
if &job.error_count + 1 < job.max_attempts {
warn!(
"worker {} error in job {} with error count {}. retrying. {}",
self.worker_id,
job.id,
job.error_count + 1,
&err
);
let next_run_at = next_run_fib(job.error_count, Utc::now());
debug!("next run at {}", next_run_at);
self.config.queue.mark_error(
&job,
&format!("{}", err),
next_run_at,
&tx,
)
} else {
error!(
"worker {} job {} failed after {} errors. {}",
self.worker_id,
job.id,
job.error_count + 1,
&err
);
self.config.queue.fail(&job, &format!("{}", err), &tx)
}
}
}?;
tx.commit()?;
sentinel.clear_running_job();
}
Ok(found_job)
}
}
impl<H: Clone> Clone for Worker<H> {
fn clone(&self) -> Self {
Worker {
worker_id: self.worker_id,
config: self.config.clone(),
shutdown_chan: self.shutdown_chan.clone(),
shutdown_ack_chan: self.shutdown_ack_chan.clone(),
}
}
}
impl<'a, H: Handler + Send + Clone + 'static> Sentinel<'a, H> {
fn new(worker: &'a Worker<H>) -> Sentinel<'a, H> {
Sentinel {
worker,
active: true,
running_job_id: None,
}
}
fn set_running_job(&mut self, job_id: u64) {
self.running_job_id = Some(job_id);
}
fn clear_running_job(&mut self) {
self.running_job_id = None;
}
fn cancel(mut self) {
self.active = false;
}
}
impl<'a, H: Handler + Send + Clone + 'static> Drop for Sentinel<'a, H> {
fn drop(&mut self) {
if self.active {
debug!("running sentinel for worker {}", self.worker.worker_id);
if thread::panicking() {
error!("panic in worker {}", self.worker.worker_id);
if let Some(rj_id) = self.running_job_id {
error!(
"job id {} may have cause a panic and be in an incomplete state",
rj_id
);
}
}
debug!(
"sentinel for worker {} restarting worker",
self.worker.worker_id
);
self.worker.clone().spawn();
}
}
}
fn new_conn(connect_params: ConnectParams, worker_id: usize) -> Option<Connection> {
match Connection::connect(connect_params, TlsMode::None) {
Ok(conn) => Some(conn),
Err(err) => {
error!(
"worker {} unable to connect to the database. {}",
worker_id, err
);
thread::sleep(Duration::new(1, 0));
None
}
}
}
fn next_run_fib(error_count: u32, from: DateTime<Utc>) -> DateTime<Utc> {
let delay_factor = fib(error_count + 1);
let delay_ms = RETRY_DELAY_MS * delay_factor;
if delay_ms <= MAX_DELAY_MS {
let delay = ChronoDuration::milliseconds(i64::from(delay_ms));
let next_run_at = from.checked_add_signed(delay);
if let Some(nr) = next_run_at {
return nr;
}
}
from.checked_add_signed(ChronoDuration::milliseconds(i64::from(MAX_DELAY_MS)))
.unwrap()
}
fn fib(n: u32) -> u32 {
if n == 0 || n == 1 {
return n;
}
let mut sum = 0;
let mut last = 0;
let mut curr = 1;
for _i in 1..n {
sum = last + curr;
last = curr;
curr = sum;
}
sum
}