killable_thread/
lib.rs

1//! A stoppable, thin wrapper around std::Thread.
2//!
3//! Uses `std::sync::atomic::AtomicBool` and `std::thread` to create stoppable
4//! threads.
5//!
6//! The interface is very similar to that of `std::thread::Thread` (or rather
7//! `std::thread::JoinHandle`) except that every closure passed in must accept
8//! a `stopped` parameter, allowing to check whether or not a stop was
9//! requested.
10//!
11//! Since all stops must happen gracefully, i.e. by requesting the child thread
12//! to stop, partial values can be returned if needed.
13//!
14//! Example:
15//!
16//! ```
17//! use killable_thread;
18//!
19//! let handle = killable_thread::spawn(|stopped| {
20//!     let mut count: u64 = 0;
21//!
22//!     while !stopped.get() {
23//!         count += 1
24//!     }
25//!
26//!     count
27//! });
28//!
29//! // work in main thread
30//!
31//! // stop the thread. we also want to collect partial results
32//! let child_count = handle.stop().join().unwrap();
33//! ```
34
35use std::ops::Drop;
36use std::mem;
37use std::thread::{self, Thread, JoinHandle, Result};
38use std::sync::atomic::{AtomicBool, Ordering};
39use std::sync::{Arc, Weak};
40
41/// A simplified std::sync::atomic::AtomicBool
42pub struct SimpleAtomicBool(AtomicBool);
43
44impl SimpleAtomicBool {
45    /// Create a new instance
46    pub fn new(v: bool) -> SimpleAtomicBool {
47        SimpleAtomicBool(AtomicBool::new(v))
48    }
49
50    /// Return the current value
51    pub fn get(&self) -> bool {
52        self.0.load(Ordering::SeqCst)
53    }
54
55    /// Set a new value
56    pub fn set(&self, v: bool) {
57        self.0.store(v, Ordering::SeqCst)
58    }
59}
60
61/// A handle for a stoppable thread
62///
63/// The interface is similar to `std::thread::JoinHandle<T>`, supporting
64/// `thread` and `join` with the same signature.
65pub struct StoppableHandle<T> {
66    join_handle: JoinHandle<T>,
67    stopped: Weak<SimpleAtomicBool>,
68}
69
70impl<T> StoppableHandle<T> {
71    pub fn thread(&self) -> &Thread {
72        self.join_handle.thread()
73    }
74
75    pub fn join(self) -> Result<T> {
76        self.join_handle.join()
77    }
78
79    /// Stop the thread
80    ///
81    /// This will signal the thread to stop by setting the shared atomic
82    /// `stopped` variable to `True`. The function will return immediately
83    /// after, to wait for the thread to stop, use the returned `JoinHandle<T>`
84    /// and `wait()`.
85    pub fn stop(self) -> JoinHandle<T> {
86        if let Some(v) = self.stopped.upgrade() {
87            v.set(true)
88        }
89
90        self.join_handle
91    }
92}
93
94/// Spawn a stoppable thread
95///
96/// Works similar to like `std::thread::spawn`, except that a
97/// `&SimpleAtomicBool` is passed into `f`.
98pub fn spawn<F, T>(f: F) -> StoppableHandle<T> where
99    F: FnOnce(&SimpleAtomicBool) -> T,
100    F: Send + 'static, T: Send + 'static {
101    let stopped = Arc::new(SimpleAtomicBool::new(false));
102    let stopped_w = Arc::downgrade(&stopped);
103
104    StoppableHandle{
105        join_handle: thread::spawn(move || f(&*stopped)),
106        stopped: stopped_w,
107    }
108}
109
110pub fn spawn_with_builder<F, T>(thread_builder: thread::Builder, f: F) -> std::io::Result<StoppableHandle<T>> where
111  F: FnOnce(&SimpleAtomicBool) -> T,
112  F: Send + 'static, T: Send + 'static {
113    let stopped = Arc::new(SimpleAtomicBool::new(false));
114    let stopped_w = Arc::downgrade(&stopped);
115
116    let handle = thread_builder.spawn(move || f(&*stopped))?;
117
118    return Ok(StoppableHandle{
119        join_handle: handle,
120        stopped: stopped_w,
121    });
122}
123
124/// Guard a stoppable thread
125///
126/// When `Stopping` is dropped (usually by going out of scope), the contained
127/// thread will be stopped.
128///
129/// Note: This does not guarantee that `stop()` will be called (the original
130/// scoped thread was removed from stdlib for this reason).
131pub struct Stopping<T> {
132    handle: Option<StoppableHandle<T>>
133}
134
135impl<T> Stopping<T> {
136    pub fn new(handle: StoppableHandle<T>) -> Stopping<T> {
137        Stopping{
138            handle: Some(handle)
139        }
140    }
141}
142
143impl<T> Drop for Stopping<T> {
144    fn drop(&mut self) {
145        let handle = mem::replace(&mut self.handle, None);
146
147        if let Some(h) = handle {
148            h.stop();
149        };
150    }
151}
152
153/// Guard and join stoppable thread
154///
155/// Like `Stopping`, but waits for the thread to finish. See notes about
156/// guarantees on `Stopping`.
157pub struct Joining<T> {
158    handle: Option<StoppableHandle<T>>
159}
160
161impl<T> Joining<T> {
162    pub fn new(handle: StoppableHandle<T>) -> Joining<T> {
163        Joining{
164            handle: Some(handle)
165        }
166    }
167}
168
169impl<T> Drop for Joining<T> {
170    fn drop(&mut self) {
171        let handle = mem::replace(&mut self.handle, None);
172
173        if let Some(h) = handle {
174            h.stop().join().ok();
175        };
176    }
177}
178
179
180#[cfg(test)]
181#[test]
182fn test_stoppable_thead() {
183    use std::thread::sleep;
184    use std::time::Duration;
185
186    let work_work = spawn(|stopped| {
187        let mut count: u64 = 0;
188        while !stopped.get() {
189            count += 1;
190            sleep(Duration::from_millis(10));
191        }
192        count
193    });
194
195    // wait a few cycles
196    sleep(Duration::from_millis(100));
197
198    let join_handle = work_work.stop();
199    let result = join_handle.join().unwrap();
200
201    // assume some work has been done
202    assert!(result > 1);
203}
204
205#[cfg(test)]
206#[test]
207fn test_guard() {
208    use std::thread::sleep;
209    use std::time::Duration;
210    use std::sync;
211
212    let stopping_count = sync::Arc::new(sync::Mutex::new(0));
213    let joining_count = sync::Arc::new(sync::Mutex::new(0));
214
215    fn count_upwards(stopped: &SimpleAtomicBool,
216                     var: sync::Arc<sync::Mutex<u64>>) {
217        // increases a mutex-protected counter every 10 ms, exits once the
218        // value is > 500
219        while !stopped.get() {
220            let mut guard = var.lock().unwrap();
221
222            *guard += 1;
223
224            if *guard > 500 {
225                break
226            }
227
228            sleep(Duration::from_millis(10))
229        }
230    }
231
232    {
233        // seperate scope to cause early Drops
234        let scount = stopping_count.clone();
235        let stopping = Stopping::new(spawn(move |stopped|
236                                     count_upwards(stopped, scount)));
237
238        let jcount = joining_count.clone();
239        let joining = Joining::new(spawn(move |stopped|
240                                    count_upwards(stopped, jcount)));
241        sleep(Duration::from_millis(1))
242    }
243
244    // threads should not have counted far
245    sleep(Duration::from_millis(100));
246
247    let sc = stopping_count.lock().unwrap();
248    assert!(*sc > 1 && *sc < 5);
249    let jc = joining_count.lock().unwrap();
250    assert!(*sc > 1 && *jc < 5);
251}
252
253
254#[cfg(test)]
255#[test]
256fn test_stoppable_thead_builder_with_name() {
257    use std::thread::sleep;
258    use std::time::Duration;
259
260    let thread_name = "test_builder";
261    let thread_builder = thread::Builder::new().name(String::from(thread_name));
262
263    let spawn_result = spawn_with_builder(thread_builder, |stopped| {
264        let mut count: u64 = 0;
265        while !stopped.get() {
266            count += 1;
267            sleep(Duration::from_millis(10));
268        }
269        count
270    });
271
272    // wait a few cycles
273    sleep(Duration::from_millis(100));
274
275    let stoppable_handle = spawn_result.unwrap();
276    assert!(stoppable_handle.thread().name().unwrap() == thread_name);
277
278    let join_handle = stoppable_handle.stop();
279    let result = join_handle.join().unwrap();
280
281    // assume some work has been done
282    assert!(result > 1);
283}