1use std::ops::Drop;
36use std::mem;
37use std::thread::{self, Thread, JoinHandle, Result};
38use std::sync::atomic::{AtomicBool, Ordering};
39use std::sync::{Arc, Weak};
40
41pub struct SimpleAtomicBool(AtomicBool);
43
44impl SimpleAtomicBool {
45 pub fn new(v: bool) -> SimpleAtomicBool {
47 SimpleAtomicBool(AtomicBool::new(v))
48 }
49
50 pub fn get(&self) -> bool {
52 self.0.load(Ordering::SeqCst)
53 }
54
55 pub fn set(&self, v: bool) {
57 self.0.store(v, Ordering::SeqCst)
58 }
59}
60
61pub struct StoppableHandle<T> {
66 join_handle: JoinHandle<T>,
67 stopped: Weak<SimpleAtomicBool>,
68}
69
70impl<T> StoppableHandle<T> {
71 pub fn thread(&self) -> &Thread {
72 self.join_handle.thread()
73 }
74
75 pub fn join(self) -> Result<T> {
76 self.join_handle.join()
77 }
78
79 pub fn stop(self) -> JoinHandle<T> {
86 if let Some(v) = self.stopped.upgrade() {
87 v.set(true)
88 }
89
90 self.join_handle
91 }
92}
93
94pub fn spawn<F, T>(f: F) -> StoppableHandle<T> where
99 F: FnOnce(&SimpleAtomicBool) -> T,
100 F: Send + 'static, T: Send + 'static {
101 let stopped = Arc::new(SimpleAtomicBool::new(false));
102 let stopped_w = Arc::downgrade(&stopped);
103
104 StoppableHandle{
105 join_handle: thread::spawn(move || f(&*stopped)),
106 stopped: stopped_w,
107 }
108}
109
110pub fn spawn_with_builder<F, T>(thread_builder: thread::Builder, f: F) -> std::io::Result<StoppableHandle<T>> where
111 F: FnOnce(&SimpleAtomicBool) -> T,
112 F: Send + 'static, T: Send + 'static {
113 let stopped = Arc::new(SimpleAtomicBool::new(false));
114 let stopped_w = Arc::downgrade(&stopped);
115
116 let handle = thread_builder.spawn(move || f(&*stopped))?;
117
118 return Ok(StoppableHandle{
119 join_handle: handle,
120 stopped: stopped_w,
121 });
122}
123
124pub struct Stopping<T> {
132 handle: Option<StoppableHandle<T>>
133}
134
135impl<T> Stopping<T> {
136 pub fn new(handle: StoppableHandle<T>) -> Stopping<T> {
137 Stopping{
138 handle: Some(handle)
139 }
140 }
141}
142
143impl<T> Drop for Stopping<T> {
144 fn drop(&mut self) {
145 let handle = mem::replace(&mut self.handle, None);
146
147 if let Some(h) = handle {
148 h.stop();
149 };
150 }
151}
152
153pub struct Joining<T> {
158 handle: Option<StoppableHandle<T>>
159}
160
161impl<T> Joining<T> {
162 pub fn new(handle: StoppableHandle<T>) -> Joining<T> {
163 Joining{
164 handle: Some(handle)
165 }
166 }
167}
168
169impl<T> Drop for Joining<T> {
170 fn drop(&mut self) {
171 let handle = mem::replace(&mut self.handle, None);
172
173 if let Some(h) = handle {
174 h.stop().join().ok();
175 };
176 }
177}
178
179
180#[cfg(test)]
181#[test]
182fn test_stoppable_thead() {
183 use std::thread::sleep;
184 use std::time::Duration;
185
186 let work_work = spawn(|stopped| {
187 let mut count: u64 = 0;
188 while !stopped.get() {
189 count += 1;
190 sleep(Duration::from_millis(10));
191 }
192 count
193 });
194
195 sleep(Duration::from_millis(100));
197
198 let join_handle = work_work.stop();
199 let result = join_handle.join().unwrap();
200
201 assert!(result > 1);
203}
204
205#[cfg(test)]
206#[test]
207fn test_guard() {
208 use std::thread::sleep;
209 use std::time::Duration;
210 use std::sync;
211
212 let stopping_count = sync::Arc::new(sync::Mutex::new(0));
213 let joining_count = sync::Arc::new(sync::Mutex::new(0));
214
215 fn count_upwards(stopped: &SimpleAtomicBool,
216 var: sync::Arc<sync::Mutex<u64>>) {
217 while !stopped.get() {
220 let mut guard = var.lock().unwrap();
221
222 *guard += 1;
223
224 if *guard > 500 {
225 break
226 }
227
228 sleep(Duration::from_millis(10))
229 }
230 }
231
232 {
233 let scount = stopping_count.clone();
235 let stopping = Stopping::new(spawn(move |stopped|
236 count_upwards(stopped, scount)));
237
238 let jcount = joining_count.clone();
239 let joining = Joining::new(spawn(move |stopped|
240 count_upwards(stopped, jcount)));
241 sleep(Duration::from_millis(1))
242 }
243
244 sleep(Duration::from_millis(100));
246
247 let sc = stopping_count.lock().unwrap();
248 assert!(*sc > 1 && *sc < 5);
249 let jc = joining_count.lock().unwrap();
250 assert!(*sc > 1 && *jc < 5);
251}
252
253
254#[cfg(test)]
255#[test]
256fn test_stoppable_thead_builder_with_name() {
257 use std::thread::sleep;
258 use std::time::Duration;
259
260 let thread_name = "test_builder";
261 let thread_builder = thread::Builder::new().name(String::from(thread_name));
262
263 let spawn_result = spawn_with_builder(thread_builder, |stopped| {
264 let mut count: u64 = 0;
265 while !stopped.get() {
266 count += 1;
267 sleep(Duration::from_millis(10));
268 }
269 count
270 });
271
272 sleep(Duration::from_millis(100));
274
275 let stoppable_handle = spawn_result.unwrap();
276 assert!(stoppable_handle.thread().name().unwrap() == thread_name);
277
278 let join_handle = stoppable_handle.stop();
279 let result = join_handle.join().unwrap();
280
281 assert!(result > 1);
283}