1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

//! A barrier for blocking a main thread until the completion of work which has been offloaded to worker threads,
//! without blocking the worker threads.
//!
//! This barrier allows blocking on `wait()` until `n` `Checkpoints` have been cleared using `check_in()` or `drop()`.
//! Threads which call check_in() do not block, in contrast to `std::sync::Barrier`
//! which blocks all threads and potentially deadlocks when used with an over-utilised threadpool.
//!
//! To use and reuse the `Barrier` an `ActiveBarrier` must be generated using `activate()`, which can then be used to generate checkpoints using `checkpoint()`.
//! An ActiveBarrier cannot be dropped without blocking until all checkpoints are cleared.
//! Generating more than `n` `Checkpoints` results in a panic. Generating less than `n` `Checkpoints` will result in an error being returned from `wait()`.
//! If a Checkpoint is passed by a panicking thread, `wait()` will return an error.
//!
//! # Example
//! ```
//! use pool_barrier::{Barrier, ActiveBarrier};
//!
//! const THREADS: usize = 5;
//!
//! let mut barrier = Barrier::new(THREADS);
//! run(barrier.activate());
//! run(barrier.activate());                            // a barrier can be reused once checkpoints are cleared
//!
//! fn run(mut barrier: ActiveBarrier){
//! 	for i in 0..THREADS{
//! 		let mut checkpoint = barrier.checkpoint();
//! 		std::thread::spawn(move||{
//! 			println!("thread_id: {}", i);           // all of these occur in arbitrary order
//! 			checkpoint.check_in();                  // this does not block the spawned thread
//! 		});
//! 	}
//! 	barrier.wait().unwrap();                        // main thread blocks here until all checkpoints are cleared
//! 	println!("main thread");                        // this occurs last 
//! }
//!
//! ```

use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
use std::sync::{Condvar, Mutex};
use std::ptr;

/// A stack allocated synchronisation barrier. See crate doc for use.
pub struct Barrier{
	n: usize,
	cvar: Condvar,
	finished: Mutex<bool>,
	checkpoints_created: usize,
	checkpoints_remaining: AtomicUsize,
	checkpoint_panicked: AtomicBool,
}

impl Barrier{
	/// Create a new barrier
	///
	/// - `n` : the exact number of checkpoints to be generated, all of which must be cleared before `wait()` unblocks
	pub fn new(n: usize) -> Barrier{
		Barrier{
			n: n,
			cvar: Condvar::new(),
			finished: Mutex::new(false),
			checkpoints_created: 0,
			checkpoints_remaining: AtomicUsize::new(n),
			checkpoint_panicked: AtomicBool::new(false),
		}
	}
	
	/// Change the number of checkpoints that have to be cleared on the next barrier activation.
	pub fn set_n(&mut self, n: usize){
		self.n = n;
	}

	/// Activate the barrier producing an ActiveBarrier. The returned ActiveBarrier can then produce checkpoints which may be passed to worker threads, and will block on wait() or drop() until checkpoints are cleared.
	pub fn activate<'a>(&'a mut self) -> ActiveBarrier<'a>{
		self.reset();
		ActiveBarrier{barrier: self}
	}

	/// The number of `Checkpoint`s that must be generated and cleared each time the barrier is activated.
	pub fn n(&self) -> usize{
		self.n
	}

	fn reset(&mut self){
		*self.finished.lock().unwrap() = false;
		self.checkpoints_created = 0;
		self.checkpoints_remaining.store(self.n, Ordering::Release);
		self.checkpoint_panicked.store(false, Ordering::Release);
	}

	fn check_in_x(&self, x: usize){
		
		let result = self.checkpoints_remaining.fetch_sub(x, Ordering::AcqRel);
		debug_assert!(result >= x); // assert that fetch_sub didnt just underflow
		debug_assert!(result <= self.n); // assert that underflow hasn't already occured
		if result == x {
			let mut finished = self.finished.lock().unwrap();
			*finished = true;
			self.cvar.notify_all();
			// Cannot use &self after this point as mutex guard drops and barrier might be dropped.
		}
	}
}

/// An ActiveBarrier can be used to generate checkpoints which must be cleared (usually by worker threads) before `wait()` and `drop()` unblock.
pub struct ActiveBarrier<'a>{
	barrier: &'a mut Barrier,
}

impl<'a> ActiveBarrier<'a>{

	/// Generate a new `Checkpoint` to be cleared.
	///
	/// # Panics
	/// This function will panics if called more than `n` times.
	pub fn checkpoint(&mut self) -> Checkpoint{
		if self.barrier.checkpoints_created >= self.barrier.n{
			panic!("More than n checkpoints generated.");
		} else {
			self.barrier.checkpoints_created +=1 ;
			Checkpoint{barrier: self.barrier as *const Barrier}
		}
	}
	
	/// Returns true if all checkpoints have been cleared and any calls to `wait()` or `drop` will not block.
	pub fn finished(&self) -> bool {
		*self.barrier.finished.lock().unwrap()
	}

	/// Block thread until all checkpoints are cleared.
	/// Returns a CheckpointPanic Err if a checkpoint is passed by a panicking thread.
	/// Returns an InsufficientCheckpoints Err if less than `n` `Checkpoint`s were generated.
	pub fn wait(&self) -> WaitResult{
		
		// Guard against deadlock if not enough checkpoints were created by falsely checking in n checkpoints.
		// This should only occur on the first call to wait(), as on subsequent calls checkpoints_remaining should be zero.
		let missing = self.barrier.n - self.barrier.checkpoints_created;
		if self.barrier.checkpoints_remaining.load(Ordering::Acquire) != 0 && missing != 0{
			self.barrier.check_in_x(missing);
		}

		// wait until all checkpoints have been passed.
		let mut finished = self.barrier.finished.lock().unwrap();
		while !*finished {
			finished = self.barrier.cvar.wait(finished).unwrap();
		}
		debug_assert_eq!(0, self.barrier.checkpoints_remaining.load(Ordering::Acquire));

		if self.barrier.checkpoint_panicked.load(Ordering::Acquire) {
			Err(WaitError::CheckpointPanic)
		} else if missing != 0 {
			Err(WaitError::InsufficientCheckpoints)
		} else {
			Ok(())
		}
	}

	/// The number of `Checkpoint`s that must be generated and cleared each time the barrier is activated.
	pub fn n(&self) -> usize{
		self.barrier.n
	}
}

impl<'a> Drop for ActiveBarrier<'a>{
	fn drop(&mut self){
		self.wait().ok(); // wait for checkpoints to avoid segfault, but discard result.
	}
}

#[derive(Debug, PartialEq)]
pub enum WaitError {
	CheckpointPanic,
	InsufficientCheckpoints,
}

pub type WaitResult = Result<(), WaitError>;

/// A checkpoint which must be cleared, by calling `check_in()`.
/// All checkpoints must be cleared before `wait()` on the parent ActiveBarrier unblocks.
/// Can be sent to other threads. Automatically calls `check_in()` when dropped.
pub struct Checkpoint{
	barrier: *const Barrier,
}

unsafe impl Send for Checkpoint{}

impl Checkpoint{

	/// Clears the checkpoint. Calling multiple times does nothing.
	pub fn check_in(&mut self){
		if !self.barrier.is_null() {
			let barrier = unsafe{&*self.barrier};
			if std::thread::panicking() {
				barrier.checkpoint_panicked.store(true, Ordering::Release);
			}
			barrier.check_in_x(1);
			self.barrier = ptr::null();
		}
	}
}

impl Drop for Checkpoint{
	fn drop(&mut self){
		self.check_in();
	}
}



/// Run tests with `cargo test -- --nocapture` to see that main thread unblocks after worker threads finish
#[cfg(test)]
mod tests{
	extern crate rand;
	use super::*;
	use tests::rand::Rng;
	const THREADS: usize = 5;

	fn threaded_run(barrier: &mut ActiveBarrier, n_threads: usize) -> WaitResult{
		for i in 0..n_threads{
			let mut checkpoint = barrier.checkpoint();
			std::thread::spawn(move||{
				std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
				println!("thread_id: {}", i);         // all of these occur in arbitrary order
				checkpoint.check_in();                // this does not block the spawned thread
			});      
		}
		std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
		let result = barrier.wait();                  // main thread blocks here until checkpoints are cleared
		println!("main thread");                      // this occurs last 
		result
	}

	fn panic_run(barrier: &mut ActiveBarrier){
		for i in 0..THREADS{
			let mut checkpoint = barrier.checkpoint();
			std::thread::spawn(move||{
				std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
				if i%2 == 1 {panic!("Deliberate panic")};
				println!("thread_id: {}", i);
				checkpoint.check_in();
			});      
		}
		std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
		let result = barrier.wait();
		assert_eq!(result, Err(WaitError::CheckpointPanic)); // detect panic on worker thread with error
		println!("main thread");
	}

	#[test]
	fn same_thread() {
		
		fn run(mut barrier: ActiveBarrier){
			for i in 0..THREADS{
				let mut checkpoint = barrier.checkpoint();
				println!("thread_id: {}", i);
				checkpoint.check_in();
			}
			barrier.wait().unwrap();
			println!("main thread");
		}

		let mut barrier = Barrier::new(THREADS);
		run(barrier.activate());
	}

	#[test]
	fn single_use() {
		let mut barrier = Barrier::new(THREADS);
		threaded_run(&mut barrier.activate(), THREADS).unwrap();
	}

	#[test]
	fn reuse() {
		let mut barrier = Barrier::new(THREADS);
		threaded_run(&mut barrier.activate(), THREADS).unwrap();
		threaded_run(&mut barrier.activate(), THREADS).unwrap();
		threaded_run(&mut barrier.activate(), THREADS).unwrap();
		threaded_run(&mut barrier.activate(), THREADS).unwrap();
		threaded_run(&mut barrier.activate(), THREADS).unwrap();
	}

	#[test]
	fn test_checkpoint_panic_detection() {
		let mut barrier = Barrier::new(THREADS);
		panic_run(&mut barrier.activate());
	}

	#[test]
	fn not_enough_checkpoints() {
		let mut barrier = Barrier::new(THREADS);
		assert_eq!(threaded_run(&mut barrier.activate(), THREADS-1), Err(WaitError::InsufficientCheckpoints));
	}

	#[test]
	#[should_panic]
	fn too_many_checkpoints() {
		let mut barrier = Barrier::new(THREADS);
		threaded_run(&mut barrier.activate(), THREADS+1).unwrap();
	}

	#[test]
	fn test_finished_true() {
		let mut barrier = Barrier::new(THREADS);
		let mut active_barrier = barrier.activate();
		threaded_run(&mut active_barrier, THREADS).unwrap();
		assert_eq!(true, active_barrier.finished());
	}

	#[test]
	fn test_finished_false() {
		let mut barrier = Barrier::new(THREADS);
		let mut active_barrier = barrier.activate();
		assert_eq!(false, active_barrier.finished());
		threaded_run(&mut active_barrier, THREADS).unwrap();
		
	}
}