use std::boxed::FnBox;
use std::collections::BinaryHeap;
use std::cmp::{PartialOrd, Ord, PartialEq, Eq, Ordering};
use std::sync::{Arc, Mutex, Condvar};
use std::thread;
use std::time::Duration;
use time;
enum JobType {
Once(Box<FnBox()+'static+Send>),
FixedRate {
f: Box<FnMut() + Send + 'static>,
rate: Duration,
},
}
struct Job {
type_: JobType,
time: u64,
}
impl PartialOrd for Job {
fn partial_cmp(&self, other: &Job) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Job {
fn cmp(&self, other: &Job) -> Ordering {
self.time.cmp(&other.time).reverse()
}
}
impl PartialEq for Job {
fn eq(&self, other: &Job) -> bool {
self.time == other.time
}
}
impl Eq for Job {}
struct InnerPool {
queue: BinaryHeap<Job>,
shutdown: bool,
}
struct SharedPool {
inner: Mutex<InnerPool>,
cvar: Condvar,
}
impl SharedPool {
fn run(&self, job: Job) {
let mut inner = self.inner.lock().unwrap();
if inner.shutdown {
return;
}
match inner.queue.peek() {
None => self.cvar.notify_all(),
Some(e) if e.time > job.time => self.cvar.notify_all(),
_ => {}
};
inner.queue.push(job);
}
}
pub struct ScheduledThreadPool {
shared: Arc<SharedPool>,
}
impl Drop for ScheduledThreadPool {
fn drop(&mut self) {
self.shared.inner.lock().unwrap().shutdown = true;
self.shared.cvar.notify_all();
}
}
impl ScheduledThreadPool {
pub fn new(size: usize) -> ScheduledThreadPool {
assert!(size > 0, "size must be positive");
let inner = InnerPool {
queue: BinaryHeap::new(),
shutdown: false,
};
let shared = SharedPool {
inner: Mutex::new(inner),
cvar: Condvar::new(),
};
let pool = ScheduledThreadPool {
shared: Arc::new(shared),
};
for _ in (0..size) {
let mut worker = Worker {
shared: pool.shared.clone(),
};
thread::spawn(move || worker.run());
}
pool
}
#[allow(dead_code)]
pub fn run<F>(&self, job: F) where F: FnOnce() + Send + 'static {
self.run_after(Duration::zero(), job)
}
pub fn run_after<F>(&self, dur: Duration, job: F) where F: FnOnce() + Send + 'static {
let job = Job {
type_: JobType::Once(Box::new(job)),
time: (time::precise_time_ns() as i64 + dur.num_nanoseconds().unwrap()) as u64,
};
self.shared.run(job)
}
#[allow(dead_code)]
pub fn run_at_fixed_rate<F>(&self, rate: Duration, f: F) where F: FnMut() + Send + 'static {
let job = Job {
type_: JobType::FixedRate { f: Box::new(f), rate: rate },
time: (time::precise_time_ns() as i64 + rate.num_nanoseconds().unwrap()) as u64,
};
self.shared.run(job)
}
pub fn clear(&self) {
self.shared.inner.lock().unwrap().queue.clear();
}
}
struct Worker {
shared: Arc<SharedPool>,
}
impl Drop for Worker {
fn drop(&mut self) {
if thread::panicking() {
let mut worker = Worker {
shared: self.shared.clone(),
};
thread::spawn(move || worker.run());
}
}
}
impl Worker {
fn run(&mut self) {
loop {
match self.get_job() {
Some(job) => self.run_job(job),
None => break,
}
}
}
fn get_job(&self) -> Option<Job> {
enum Need {
Wait,
WaitTimeout(Duration),
}
let mut inner = self.shared.inner.lock().unwrap();
loop {
let now = time::precise_time_ns();
let need = match inner.queue.peek() {
None if inner.shutdown => return None,
None => Need::Wait,
Some(e) if e.time <= now => break,
Some(e) => Need::WaitTimeout(Duration::nanoseconds(e.time as i64 - now as i64)),
};
inner = match need {
Need::Wait => self.shared.cvar.wait(inner).unwrap(),
Need::WaitTimeout(t) => {
let mut timeout = t.num_milliseconds();
if timeout < 0 {
timeout = 0;
}
self.shared.cvar.wait_timeout_ms(inner, timeout as u32).unwrap().0
},
};
}
Some(inner.queue.pop().unwrap())
}
fn run_job(&self, job: Job) {
match job.type_ {
JobType::Once(f) => f(),
JobType::FixedRate { mut f, rate } => {
f();
let new_job = Job {
type_: JobType::FixedRate { f: f, rate: rate },
time: (job.time as i64 + rate.num_nanoseconds().unwrap()) as u64,
};
self.shared.run(new_job)
}
}
}
}
#[cfg(test)]
mod test {
use std::iter::AdditiveIterator;
use std::sync::mpsc::channel;
use std::sync::{Arc, Barrier};
use std::time::Duration;
use super::ScheduledThreadPool;
const TEST_TASKS: usize = 4;
#[test]
fn test_works() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.run(move|| {
tx.send(1usize).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
}
#[test]
#[should_panic(expected = "size must be positive")]
fn test_zero_tasks_panic() {
ScheduledThreadPool::new(0);
}
#[test]
fn test_recovery_from_subtask_panic() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS as usize));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
pool.run(move || -> () {
waiter.wait();
panic!();
});
}
let (tx, rx) = channel();
let waiter = Arc::new(Barrier::new(TEST_TASKS as usize));
for _ in 0..TEST_TASKS {
let tx = tx.clone();
let waiter = waiter.clone();
pool.run(move || {
waiter.wait();
tx.send(1usize).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
}
#[test]
fn test_run_after() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
let tx1 = tx.clone();
pool.run_after(Duration::seconds(1), move || tx1.send(1usize).unwrap());
pool.run_after(Duration::milliseconds(500), move || tx.send(2usize).unwrap());
assert_eq!(2, rx.recv().unwrap());
assert_eq!(1, rx.recv().unwrap());
}
#[test]
fn test_jobs_complete_after_drop() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
let tx1 = tx.clone();
pool.run_after(Duration::seconds(1), move || tx1.send(1usize).unwrap());
pool.run_after(Duration::milliseconds(500), move || tx.send(2usize).unwrap());
drop(pool);
assert_eq!(2, rx.recv().unwrap());
assert_eq!(1, rx.recv().unwrap());
}
#[test]
fn test_fixed_delay_jobs_stop_after_drop() {
let pool = Arc::new(ScheduledThreadPool::new(TEST_TASKS));
let (tx, rx) = channel();
let (tx2, rx2) = channel();
let mut pool2 = Some(pool.clone());
let mut i = 0i32;
pool.run_at_fixed_rate(Duration::milliseconds(500), move || {
i += 1;
tx.send(i).unwrap();
rx2.recv().unwrap();
if i == 2 {
drop(pool2.take().unwrap());
}
});
drop(pool);
assert_eq!(Ok(1), rx.recv());
tx2.send(()).unwrap();
assert_eq!(Ok(2), rx.recv());
tx2.send(()).unwrap();
assert!(rx.recv().is_err());
}
}