use std::cell::RefCell;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use crossbeam_channel::Receiver;
use crossbeam_channel::RecvTimeoutError;
use crossbeam_channel::Select;
use crossbeam_channel::SelectedOperation;
use super::super::ErrorKind;
use super::super::Result;
type MapThreadFn<T> = Box<dyn FnMut() -> Result<T>>;
pub struct MapThread<T: Send + 'static> {
join: RefCell<Option<MapThreadFn<T>>>,
join_check: Receiver<()>,
shutdown: Arc<AtomicBool>,
}
impl<T: Send + 'static> MapThread<T> {
pub(crate) fn new<F>(
join: F,
join_check: Receiver<()>,
shutdown: Arc<AtomicBool>,
) -> MapThread<T>
where
F: FnMut() -> Result<T> + 'static,
{
let join: MapThreadFn<T> = Box::new(join);
let join = RefCell::new(Some(join));
MapThread {
join,
join_check,
shutdown,
}
}
pub fn join(&self) -> Result<T> {
let handle = self
.join
.try_borrow_mut()
.map_err(|_| ErrorKind::JoinedAlready)?
.take();
let mut handle = match handle {
None => return Err(ErrorKind::JoinedAlready.into()),
Some(handle) => handle,
};
handle()
}
pub fn join_timeout(&self, timeout: Duration) -> Result<T> {
match self.join_check.recv_timeout(timeout) {
Err(RecvTimeoutError::Timeout) => Err(ErrorKind::JoinTimeout.into()),
_ => self.join(),
}
}
pub fn request_shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
pub fn select_add<'a>(&'a self, select: &mut Select<'a>) -> usize {
select.recv(&self.join_check)
}
pub fn select_join(&self, operation: SelectedOperation) -> Result<T> {
let _ = operation.recv(&self.join_check);
self.join()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crossbeam_channel::Select;
use super::super::super::Builder;
#[test]
fn spawn_and_join() {
let flag: bool = Builder::new("spawn_and_join")
.spawn(|_| {})
.expect("failed to spawn thread")
.map(|_| true)
.join()
.expect("failed to join thread");
assert_eq!(true, flag);
}
#[test]
fn request_shutdown() {
let thread = Builder::new("request_shutdown")
.spawn(|scope| loop {
::std::thread::sleep(Duration::from_millis(10));
if scope.should_shutdown() {
break;
}
})
.expect("to spawn test thread")
.map(|_| true);
thread.request_shutdown();
let flag = thread.join().expect("the thread to stop");
assert_eq!(true, flag);
}
#[test]
fn select_interface() {
let thread = Builder::new("select_interface")
.spawn(|_| {
::std::thread::sleep(Duration::from_millis(10));
})
.expect("to spawn test thread")
.map(|_| true);
let mut set = Select::new();
let idx = thread.select_add(&mut set);
let op = set.select_timeout(Duration::from_millis(30)).unwrap();
thread.select_join(op).unwrap();
assert_eq!(0, idx);
}
#[test]
fn select_multiple_threads() {
let thread1 = Builder::new("select_multiple_threads_1")
.spawn(|_| {
::std::thread::sleep(Duration::from_millis(50));
})
.expect("to spawn test thread")
.map(|_| true);
let thread2 = Builder::new("select_multiple_threads_2")
.spawn(|_| {
::std::thread::sleep(Duration::from_millis(10));
})
.expect("to spawn test thread")
.map(|_| true);
let mut set = Select::new();
thread1.select_add(&mut set);
thread2.select_add(&mut set);
let op = set.select_timeout(Duration::from_millis(30)).unwrap();
let idx = op.index();
thread2.select_join(op).unwrap();
assert_eq!(1, idx);
}
#[test]
fn select_panic() {
let thread = Builder::new("select_panic")
.spawn(|_| {
::std::thread::sleep(Duration::from_millis(10));
panic!("this panic is expected");
})
.expect("to spawn test thread")
.map(|_| true);
let mut set = Select::new();
thread.select_add(&mut set);
let op = set.select_timeout(Duration::from_millis(30)).unwrap();
let idx = op.index();
let result = thread.select_join(op);
assert_eq!(0, idx);
assert_eq!(true, result.is_err());
}
#[test]
fn select_ready_interface() {
let thread = Builder::new("select_panic")
.spawn(|_| {
::std::thread::sleep(Duration::from_millis(10));
})
.expect("to spawn test thread")
.map(|_| true);
let mut set = Select::new();
thread.select_add(&mut set);
let idx = set.ready_timeout(Duration::from_millis(30)).unwrap();
assert_eq!(0, idx);
thread.join_timeout(Duration::from_millis(10)).unwrap();
}
}