#![warn(missing_docs)]
use std::sync::{mpsc, mpsc::TryRecvError};
use std::{thread, thread::JoinHandle};
#[derive(Debug)]
pub struct Signal {
stop_receiver: mpsc::Receiver<()>,
}
impl Signal {
pub fn should_continue(&self) -> bool {
!self.should_stop()
}
pub fn should_stop(&self) -> bool {
Err(TryRecvError::Empty) != self.stop_receiver.try_recv()
}
}
#[derive(Debug)]
struct Controller {
stop_sender: mpsc::Sender<()>,
}
impl Controller {
pub fn stop(&self) {
self.stop_sender.send(()).ok();
}
}
#[derive(Debug)]
pub struct OwnedThread<T> {
join_handle: Option<JoinHandle<T>>,
stop_controller: Controller,
}
impl<T> OwnedThread<T> {
pub fn join(mut self) -> std::thread::Result<T> {
self.stop();
self.join_handle
.take()
.expect("joinhandle of OwnedThread does not exist")
.join()
}
pub fn stop(&self) {
self.stop_controller.stop();
}
}
impl<T> Drop for OwnedThread<T> {
fn drop(&mut self) {
self.stop();
if let Some(handle) = self.join_handle.take() {
handle.join().ok();
}
}
}
pub fn spawn_owned<T: Send + 'static, F: FnOnce(Signal) -> T + Send + 'static>(
thread_function: F,
) -> OwnedThread<T> {
let (signal_sender, receiver) = mpsc::channel();
let signal = Signal {
stop_receiver: receiver,
};
let join_handle = thread::spawn(move || thread_function(signal));
OwnedThread {
join_handle: Some(join_handle),
stop_controller: Controller {
stop_sender: signal_sender,
},
}
}
#[cfg(test)]
mod tests {
use std::{sync::mpsc, thread, time::Duration};
#[test]
fn owned_thread_terminates_when_dropped() {
let (tx, rx) = mpsc::channel::<()>();
let owned_thread = crate::spawn_owned(move |signal| {
while signal.should_continue() {
}
tx.send(())
});
thread::sleep(Duration::from_secs(1));
assert_eq!(rx.try_recv(), Err(std::sync::mpsc::TryRecvError::Empty));
drop(owned_thread);
rx.recv().unwrap();
}
#[test]
fn owned_thread_terminates_when_told_to_stop() {
let (tx, rx) = mpsc::channel::<()>();
let owned_thread = crate::spawn_owned(move |signal| {
while signal.should_continue() {
}
tx.send(())
});
thread::sleep(Duration::from_secs(1));
assert_eq!(rx.try_recv(), Err(std::sync::mpsc::TryRecvError::Empty));
owned_thread.stop();
rx.recv().unwrap();
}
}