wild_thread_pool 0.4.0

wild thread pool
Documentation
#[macro_use]
extern crate log;
extern crate thread_tryjoin;

use std::thread;
use std::time::Duration;
use std::error::Error;
use std::borrow::Cow;
use std::sync::Arc;
use std::sync::atomic::{Ordering, AtomicBool};
use thread_tryjoin::TryJoinHandle;

pub trait ErrorExt: Error + Send + 'static {
    fn new_err(s: Cow<'static, str>) -> Self;
}

pub trait NewWorker {
    type Worker: Worker;

    fn new_worker(&self, worker_id: u64) -> Self::Worker;
}

pub trait Worker: Send + Clone + 'static {
    type Shutdown: Shutdown;
    type Err: ErrorExt;

    fn worker_id(&self) -> u64;

    fn on_tick(&self) -> Result<(), Self::Err>;

    fn new_shutdown(&self) -> Self::Shutdown;

    fn on_start(&self) -> Result<(), Self::Err> {
        info!("worker-{} started", self.worker_id());

        Ok(())
    }

    fn on_tick_err(&self, err: &Self::Err) {
        error!("worker-{} on_tick_err: {:?}", self.worker_id(), err);
    }

    fn on_stop(&self) -> Result<(), Self::Err> {
        info!("worker-{} stoped", self.worker_id());
        Ok(())
    }
}

pub trait Shutdown: Send {
    fn worker_id(&self) -> u64;

    fn on_stop_timeout(&self) {
        error!("worker-{} stop_timeout", self.worker_id());
    }
}

pub struct ThreadPool {
    workers: u64,
    closed: Arc<AtomicBool>,
    stop_timeout: Duration,
    watch_dead_thread: Duration,
}

impl ThreadPool {
    pub fn new(
        workers: u64,
        stop_timeout: Duration,
        watch_dead_thread: Duration,
        closed: Arc<AtomicBool>,
    ) -> ThreadPool {
        ThreadPool {
            workers,
            closed,
            stop_timeout,
            watch_dead_thread,
        }
    }

    pub fn run<N, W, S, E>(self, new_worker: N) -> Workers<W, S, E>
    where
        N: NewWorker<Worker = W>,
        W: Worker<Shutdown = S, Err = E>,
        S: Shutdown,
        E: ErrorExt,
    {
        let mut inner_workers = Vec::with_capacity(self.workers as usize);
        for worker_id in 0..self.workers {
            let worker = new_worker.new_worker(worker_id);
            let shutdown = worker.new_shutdown();
            let inner_worker =
                InnerWorker::new(worker, shutdown, self.closed.clone(), self.stop_timeout);
            inner_workers.push(inner_worker);
        }

        Workers {
            closed: self.closed,
            watch_dead_thread: self.watch_dead_thread,
            inner_workers: inner_workers,
        }
    }
}

pub struct Workers<W, S, E> {
    closed: Arc<AtomicBool>,
    watch_dead_thread: Duration,
    inner_workers: Vec<InnerWorker<W, S, E>>,
}

impl<W, S, E> Workers<W, S, E>
where
    W: Worker<Shutdown = S, Err = E>,
    S: Shutdown,
    E: ErrorExt,
{
    pub fn wait(mut self) -> Result<(), E> {
        // watch closed & dead_thread
        loop {
            thread::sleep(self.watch_dead_thread);

            if self.closed.load(Ordering::Relaxed) {
                break;
            } else {
                for inner_worker in &mut self.inner_workers {
                    inner_worker.check_dead_thread();
                }
            }
        }

        // wait thread with timeout
        let mut res = Ok(());
        for inner_worker in self.inner_workers {
            if let Err(err) = inner_worker.wait() {
                res = Err(err);
            }
        }

        res
    }
}

struct InnerWorker<W, S, E> {
    worker_id: u64,
    closed: Arc<AtomicBool>,
    worker: W,
    shutdown: S,
    stop_timeout: Duration,
    thread: thread::JoinHandle<Result<(), E>>,
    normal_exit: Arc<AtomicBool>,
}

impl<W, S, E> InnerWorker<W, S, E>
where
    W: Worker<Shutdown = S, Err = E>,
    S: Shutdown,
    E: ErrorExt,
{
    fn new(
        worker: W,
        shutdown: S,
        closed: Arc<AtomicBool>,
        stop_timeout: Duration,
    ) -> InnerWorker<W, S, E> {
        let worker_id = worker.worker_id();
        let normal_exit = Arc::new(AtomicBool::new(false));
        let worker = worker;
        let thread = InnerWorker::<W, S, E>::start_thread(
            worker.clone(),
            closed.clone(),
            normal_exit.clone(),
        );

        InnerWorker {
            worker_id,
            closed,
            worker,
            shutdown,
            stop_timeout,
            thread,
            normal_exit,
        }
    }

    fn start_thread(
        worker: W,
        closed: Arc<AtomicBool>,
        normal_exit: Arc<AtomicBool>,
    ) -> thread::JoinHandle<Result<(), E>> {
        thread::spawn(move || {
            if let Err(err) = worker.on_start() {
                closed.store(true, Ordering::Relaxed);

                return Err(err);
            }

            loop {
                if closed.load(Ordering::Relaxed) {
                    break;
                }

                if let Err(err) = worker.on_tick() {
                    worker.on_tick_err(&err);
                }
            }

            worker.on_stop()?;

            normal_exit.store(true, Ordering::Relaxed);

            Ok(())
        })
    }

    fn check_dead_thread(&mut self) {
        if self.thread.try_join().is_ok() {
            debug!("worker-{} check_dead_thread: dead", self.worker_id);

            self.worker.on_tick_err(
                &E::new_err(Cow::Borrowed("ThreadPanic")),
            );

            if let Err(err) = self.worker.on_stop() {
                error!("worker-{} on_stop err: {:?}", self.worker_id, err);
            }

            self.thread = InnerWorker::<W, S, E>::start_thread(
                self.worker.clone(),
                self.closed.clone(),
                self.normal_exit.clone(),
            );
        } else {
            debug!("worker-{} check_dead_thread: live", self.worker_id);
        }
    }

    fn wait(self) -> Result<(), E> {
        if self.stop_timeout.as_secs() == 0 {
            self.wait_no_timeout()
        } else {
            self.wait_with_timeout()
        }
    }

    fn wait_no_timeout(self) -> Result<(), E> {
        let err = match self.thread.join() {
            Ok(_) => return Ok(()),
            Err(err) => err,
        };

        let err_msg = if let Some(err) = err.downcast_ref::<String>() {
            format!(
                "worker-{} wait_for_thread panic_err: {:?}",
                self.worker_id,
                err,
            )
        } else if let Some(err) = err.downcast_ref::<&'static str>() {
            format!(
                "worker-{} wait_for_thread panic_err: {:?}",
                self.worker_id,
                err,
            )
        } else {
            format!(
                "worker-{} wait_for_thread unknown panic_err, {:?}",
                self.worker_id,
                err,
            )
        };

        let err = E::new_err(Cow::from(err_msg));
        self.worker.on_tick_err(&err);

        if let Err(err) = self.worker.on_stop() {
            error!("worker-{} on_stop err: {:?}", self.worker_id, err);
        }

        Err(err)
    }

    #[cfg(target_os = "linux")]
    fn wait_with_timeout(self) -> Result<(), E> {
        if self.thread.try_timed_join(self.stop_timeout).is_ok() {
            return self.wait_no_timeout();
        };

        if self.normal_exit.swap(true, Ordering::Relaxed) {
            return Ok(());
        }

        self.shutdown.on_stop_timeout();

        let err: E = E::new_err(Cow::from(format!("worker-{} stop_timeout", self.worker_id)));

        Err(err)
    }

    #[cfg(not(target_os = "linux"))]
    fn wait_with_timeout(self) -> Result<(), E> {
        let _ = self.shutdown;
        self.wait_no_timeout()
    }
}

pub struct DefaultShutdown(u64);

impl DefaultShutdown {
    pub fn new(worker_id: u64) -> DefaultShutdown {
        DefaultShutdown(worker_id)
    }
}

impl Shutdown for DefaultShutdown {
    fn worker_id(&self) -> u64 {
        self.0
    }
}