use crate::monitor::event::EventType;
use crate::monitor::MonitorConfig;
use crate::thread::recovery::{PanicMarker, RecoveryThread};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{Builder, JoinHandle};
use std::time::Instant;
const OVERLOAD_THRESHOLD: u128 = 100;
pub struct ThreadPool {
thread_count: usize,
started: bool,
threads: Arc<Mutex<Vec<Thread>>>,
recovery_thread: Option<RecoveryThread>,
tx: Sender<Message>,
monitor: Option<MonitorConfig>,
}
pub struct Thread {
#[allow(dead_code)]
pub id: usize,
pub os_thread: Option<JoinHandle<()>>,
}
pub enum Message {
Function(Task, Instant),
Shutdown,
}
trait CallableTask {
fn call_box(self: Box<Self>);
}
impl<F: FnOnce()> CallableTask for F {
fn call_box(self: Box<Self>) {
(*self)();
}
}
pub type Task = Box<dyn FnOnce() + Send + 'static>;
impl ThreadPool {
pub fn new(thread_count: usize) -> Self {
assert!(thread_count > 0);
Self {
thread_count,
started: false,
threads: Arc::new(Mutex::new(Vec::new())),
recovery_thread: None,
tx: channel().0,
monitor: None,
}
}
pub fn start(&mut self) {
let (tx, rx): (Sender<Message>, Receiver<Message>) = channel();
let rx = Arc::new(Mutex::new(rx));
let mut threads = Vec::with_capacity(self.thread_count);
let (recovery_tx, recovery_rx): (Sender<usize>, Receiver<usize>) = channel();
for id in 0..self.thread_count {
threads.push(Thread::new(
id,
rx.clone(),
recovery_tx.clone(),
self.monitor.clone(),
))
}
self.threads = Arc::new(Mutex::new(threads));
self.tx = tx;
let recovery_thread = RecoveryThread::new(
recovery_rx,
recovery_tx,
rx,
self.threads.clone(),
self.monitor.clone(),
);
self.recovery_thread = Some(recovery_thread);
self.started = true;
}
pub fn register_monitor(&mut self, monitor: MonitorConfig) {
self.monitor = Some(monitor);
}
pub fn execute<F>(&self, task: F)
where
F: FnOnce() + Send + 'static,
{
assert!(self.started);
let boxed_task = Box::new(task);
let time_into_pool = Instant::now();
self.tx
.send(Message::Function(boxed_task, time_into_pool))
.unwrap();
}
pub fn thread_count(&self) -> usize {
self.thread_count
}
}
impl Thread {
pub fn new(
id: usize,
rx: Arc<Mutex<Receiver<Message>>>,
panic_tx: Sender<usize>,
monitor: Option<MonitorConfig>,
) -> Self {
let thread = Builder::new()
.name(format!("{}", id))
.spawn(move || {
let panic_marker = PanicMarker(id, panic_tx);
loop {
let task = { rx.lock().unwrap().recv().unwrap() };
match task {
Message::Function(f, t) => {
if let Some(monitor) = &monitor {
let time_in_pool = t.elapsed().as_millis();
if time_in_pool > OVERLOAD_THRESHOLD {
monitor.send(EventType::ThreadPoolOverload);
}
}
(f)()
}
Message::Shutdown => break,
}
}
drop(panic_marker);
})
.expect("Thread could not be spawned");
Self {
id,
os_thread: Some(thread),
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
if let Some(mut recovery_thread) = self.recovery_thread.take() {
if let Some(thread) = recovery_thread.0.take() {
thread.join().unwrap();
}
}
for thread in &mut *self.threads.lock().unwrap() {
if let Some(thread) = thread.os_thread.take() {
thread.join().unwrap();
}
}
}
}