use std::ops::Drop;
use std::mem;
use std::thread::{self, Thread, JoinHandle, Result};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
pub struct SimpleAtomicBool(AtomicBool);
impl SimpleAtomicBool {
pub fn new(v: bool) -> SimpleAtomicBool {
SimpleAtomicBool(AtomicBool::new(v))
}
pub fn get(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
pub fn set(&self, v: bool) {
self.0.store(v, Ordering::SeqCst)
}
}
pub struct StoppableHandle<T> {
join_handle: JoinHandle<T>,
stopped: Weak<SimpleAtomicBool>,
}
impl<T> StoppableHandle<T> {
pub fn thread(&self) -> &Thread {
self.join_handle.thread()
}
pub fn join(self) -> Result<T> {
self.join_handle.join()
}
pub fn stop(self) -> JoinHandle<T> {
if let Some(v) = self.stopped.upgrade() {
v.set(true)
}
self.join_handle
}
}
pub fn spawn<F, T>(f: F) -> StoppableHandle<T> where
F: FnOnce(&SimpleAtomicBool) -> T,
F: Send + 'static, T: Send + 'static {
let stopped = Arc::new(SimpleAtomicBool::new(false));
let stopped_w = Arc::downgrade(&stopped);
StoppableHandle{
join_handle: thread::spawn(move || f(&*stopped)),
stopped: stopped_w,
}
}
pub fn spawn_with_builder<F, T>(thread_builder: thread::Builder, f: F) -> std::io::Result<StoppableHandle<T>> where
F: FnOnce(&SimpleAtomicBool) -> T,
F: Send + 'static, T: Send + 'static {
let stopped = Arc::new(SimpleAtomicBool::new(false));
let stopped_w = Arc::downgrade(&stopped);
let handle = thread_builder.spawn(move || f(&*stopped))?;
return Ok(StoppableHandle{
join_handle: handle,
stopped: stopped_w,
});
}
pub struct Stopping<T> {
handle: Option<StoppableHandle<T>>
}
impl<T> Stopping<T> {
pub fn new(handle: StoppableHandle<T>) -> Stopping<T> {
Stopping{
handle: Some(handle)
}
}
}
impl<T> Drop for Stopping<T> {
fn drop(&mut self) {
let handle = mem::replace(&mut self.handle, None);
if let Some(h) = handle {
h.stop();
};
}
}
pub struct Joining<T> {
handle: Option<StoppableHandle<T>>
}
impl<T> Joining<T> {
pub fn new(handle: StoppableHandle<T>) -> Joining<T> {
Joining{
handle: Some(handle)
}
}
}
impl<T> Drop for Joining<T> {
fn drop(&mut self) {
let handle = mem::replace(&mut self.handle, None);
if let Some(h) = handle {
h.stop().join().ok();
};
}
}
#[cfg(test)]
#[test]
fn test_stoppable_thead() {
use std::thread::sleep;
use std::time::Duration;
let work_work = spawn(|stopped| {
let mut count: u64 = 0;
while !stopped.get() {
count += 1;
sleep(Duration::from_millis(10));
}
count
});
sleep(Duration::from_millis(100));
let join_handle = work_work.stop();
let result = join_handle.join().unwrap();
assert!(result > 1);
}
#[cfg(test)]
#[test]
fn test_guard() {
use std::thread::sleep;
use std::time::Duration;
use std::sync;
let stopping_count = sync::Arc::new(sync::Mutex::new(0));
let joining_count = sync::Arc::new(sync::Mutex::new(0));
fn count_upwards(stopped: &SimpleAtomicBool,
var: sync::Arc<sync::Mutex<u64>>) {
while !stopped.get() {
let mut guard = var.lock().unwrap();
*guard += 1;
if *guard > 500 {
break
}
sleep(Duration::from_millis(10))
}
}
{
let scount = stopping_count.clone();
let stopping = Stopping::new(spawn(move |stopped|
count_upwards(stopped, scount)));
let jcount = joining_count.clone();
let joining = Joining::new(spawn(move |stopped|
count_upwards(stopped, jcount)));
sleep(Duration::from_millis(1))
}
sleep(Duration::from_millis(100));
let sc = stopping_count.lock().unwrap();
assert!(*sc > 1 && *sc < 5);
let jc = joining_count.lock().unwrap();
assert!(*sc > 1 && *jc < 5);
}
#[cfg(test)]
#[test]
fn test_stoppable_thead_builder_with_name() {
use std::thread::sleep;
use std::time::Duration;
let thread_name = "test_builder";
let thread_builder = thread::Builder::new().name(String::from(thread_name));
let spawn_result = spawn_with_builder(thread_builder, |stopped| {
let mut count: u64 = 0;
while !stopped.get() {
count += 1;
sleep(Duration::from_millis(10));
}
count
});
sleep(Duration::from_millis(100));
let stoppable_handle = spawn_result.unwrap();
assert!(stoppable_handle.thread().name().unwrap() == thread_name);
let join_handle = stoppable_handle.stop();
let result = join_handle.join().unwrap();
assert!(result > 1);
}