use crate::basic::clocks::{check_clock, me};
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use shuttle::future::{self, batch_semaphore::*};
use shuttle::{check_dfs, check_random, current, thread};
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use test_log::test;
#[test]
fn batch_semaphore_basic() {
check_dfs(
|| {
let s = BatchSemaphore::new(3, Fairness::StrictlyFair);
future::spawn(async move {
s.acquire(2).await.unwrap();
s.acquire(1).await.unwrap();
let r = s.try_acquire(1);
assert_eq!(r, Err(TryAcquireError::NoPermits));
s.release(1);
s.acquire(1).await.unwrap();
});
},
None,
);
}
#[test]
fn batch_semaphore_unfair() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_random(
move || {
let semaphore = Arc::new(BatchSemaphore::new(0, Fairness::Unfair));
let order1 = Arc::new(std::sync::Mutex::new(vec![]));
let order2 = Arc::new(std::sync::Mutex::new(vec![]));
let threads = (0..3)
.map(|tid| {
let semaphore = semaphore.clone();
let order1 = order1.clone();
let order2 = order2.clone();
thread::spawn(move || {
order1.lock().unwrap().push(tid); let val = [2, 1, 1][tid];
semaphore.acquire_blocking(val).unwrap();
order2.lock().unwrap().push((tid, val)); })
})
.collect::<Vec<_>>();
while order1.lock().unwrap().len() < 3 {
thread::yield_now();
}
let order1_after_enqueued = order1.lock().unwrap().clone();
semaphore.release(2);
while order2.lock().unwrap().iter().map(|(_tid, val)| val).sum::<usize>() < 2 {
thread::yield_now();
}
let order2_after_release = order2.lock().unwrap().clone();
semaphore.release(2);
for thread in threads {
thread.join().unwrap();
}
observed_values_clone
.lock()
.unwrap()
.insert((order1_after_enqueued, order2_after_release));
},
1000, );
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(
observed_values,
HashSet::from([
(vec![0, 1, 2], vec![(0, 2)]),
(vec![0, 1, 2], vec![(1, 1), (2, 1)]),
(vec![0, 1, 2], vec![(2, 1), (1, 1)]),
(vec![0, 2, 1], vec![(0, 2)]),
(vec![0, 2, 1], vec![(1, 1), (2, 1)]),
(vec![0, 2, 1], vec![(2, 1), (1, 1)]),
(vec![1, 0, 2], vec![(0, 2)]),
(vec![1, 0, 2], vec![(1, 1), (2, 1)]),
(vec![1, 0, 2], vec![(2, 1), (1, 1)]),
(vec![1, 2, 0], vec![(0, 2)]),
(vec![1, 2, 0], vec![(1, 1), (2, 1)]),
(vec![1, 2, 0], vec![(2, 1), (1, 1)]),
(vec![2, 1, 0], vec![(0, 2)]),
(vec![2, 1, 0], vec![(1, 1), (2, 1)]),
(vec![2, 1, 0], vec![(2, 1), (1, 1)]),
(vec![2, 0, 1], vec![(0, 2)]),
(vec![2, 0, 1], vec![(1, 1), (2, 1)]),
(vec![2, 0, 1], vec![(2, 1), (1, 1)]),
])
);
}
#[test]
fn batch_semaphore_clock_1() {
for fairness in [Fairness::StrictlyFair, Fairness::Unfair] {
check_dfs(
move || {
let s = Arc::new(BatchSemaphore::new(0, fairness));
let s2 = s.clone();
thread::spawn(move || {
assert_eq!(me(), 1);
s2.release(1);
});
thread::spawn(move || {
assert_eq!(me(), 2);
check_clock(|i, c| (i != 1) || (c == 0));
s.acquire_blocking(1).unwrap();
check_clock(|i, c| (i != 1) || (c > 0));
});
},
None,
);
}
}
#[test]
fn batch_semaphore_clock_2() {
for fairness in [Fairness::StrictlyFair, Fairness::Unfair] {
check_dfs(
move || {
let s = Arc::new(BatchSemaphore::new(0, fairness));
for i in 1..=2 {
let s2 = s.clone();
thread::spawn(move || {
assert_eq!(me(), i);
s2.release(1);
});
}
thread::spawn(move || {
assert_eq!(me(), 3);
check_clock(|i, c| (c > 0) == (i == 0));
s.acquire_blocking(2).unwrap();
check_clock(|i, c| (i == 3) || (c > 0));
});
},
None,
);
}
}
#[test]
fn batch_semaphore_clock_3() {
for fairness in [Fairness::StrictlyFair, Fairness::Unfair] {
check_dfs(
move || {
let s = Arc::new(BatchSemaphore::new(0, fairness));
for i in 1..=2 {
let s2 = s.clone();
thread::spawn(move || {
assert_eq!(me(), i);
s2.release(1);
});
}
thread::spawn(move || {
assert_eq!(me(), 3);
check_clock(|i, c| (c > 0) == (i == 0));
s.acquire_blocking(1).unwrap();
let clock = current::clock();
assert!((clock[1] > 0 && clock[2] == 0) || (clock[1] == 0 && clock[2] > 0));
});
},
None,
);
}
}
#[test]
fn batch_semaphore_clock_4() {
for fairness in [Fairness::StrictlyFair, Fairness::Unfair] {
check_dfs(
move || {
let s = Arc::new(BatchSemaphore::new(1, fairness));
for tid in 1..=2 {
let other_tid = 2 - tid;
let s2 = s.clone();
thread::spawn(move || {
assert_eq!(me(), tid);
match s2.try_acquire(1) {
Ok(()) => {
check_clock(|i, c| (c > 0) == (i == 0 || i == tid));
}
Err(TryAcquireError::NoPermits) => {
check_clock(|i, c| !(i == 0 || i == other_tid) || (c > 0));
}
Err(TryAcquireError::Closed) => unreachable!(),
}
});
}
},
None,
);
}
}
#[test]
#[should_panic(expected = "doesn't satisfy predicate")]
fn batch_semaphore_clock_imprecise() {
check_dfs(
move || {
let s = Arc::new(BatchSemaphore::new(2, Fairness::StrictlyFair));
for tid in 1..=2 {
let s2 = s.clone();
thread::spawn(move || {
assert_eq!(me(), tid);
for _ in 0..2 {
s2.try_acquire(1).unwrap();
s2.release(1);
}
check_clock(|i, c| (c > 0) == (i == 0 || i == tid));
});
}
},
None,
);
}
async fn semtest(num_permits: usize, counts: Vec<usize>, states: &Arc<Mutex<HashSet<(usize, usize)>>>, mode: Fairness) {
let s = Arc::new(BatchSemaphore::new(num_permits, mode));
let r = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for (i, &c) in counts.iter().enumerate() {
let s = s.clone();
let r = r.clone();
let states = states.clone();
let val = 1usize << i;
handles.push(future::spawn(async move {
s.acquire(c).await.unwrap();
let v = r.fetch_add(val, Ordering::SeqCst);
future::yield_now().await;
let _ = r.fetch_sub(val, Ordering::SeqCst);
states.lock().unwrap().insert((i, v));
s.release(c);
}));
}
for h in handles {
h.await.unwrap();
}
}
#[test]
fn batch_semaphore_test_1() {
let states = Arc::new(Mutex::new(HashSet::new()));
let states2 = states.clone();
check_dfs(
move || {
let states2 = states2.clone();
future::block_on(async move {
semtest(5, vec![3, 3, 3], &states2, Fairness::StrictlyFair).await;
});
},
None,
);
let states = Arc::try_unwrap(states).unwrap().into_inner().unwrap();
assert_eq!(states, HashSet::from([(0, 0), (1, 0), (2, 0)]));
}
#[test]
fn batch_semaphore_test_2() {
let states = Arc::new(Mutex::new(HashSet::new()));
let states2 = states.clone();
check_dfs(
move || {
let states2 = states2.clone();
future::block_on(async move {
semtest(5, vec![3, 3, 2], &states2, Fairness::StrictlyFair).await;
});
},
None,
);
let states = Arc::try_unwrap(states).unwrap().into_inner().unwrap();
assert_eq!(
states,
HashSet::from([(0, 0), (1, 0), (2, 0), (0, 4), (1, 4), (2, 1), (2, 2)])
);
}
#[test]
fn batch_semaphore_signal() {
check_dfs(
move || {
let sem = Arc::new(BatchSemaphore::new(0, Fairness::StrictlyFair));
let sem2 = sem.clone();
let r = Arc::new(AtomicUsize::new(0));
let r2 = r.clone();
future::spawn(async move {
sem.acquire(1).await.unwrap();
let v = r2.load(Ordering::SeqCst);
assert!(v > 0);
sem.acquire(1).await.unwrap();
let v = r2.load(Ordering::SeqCst);
assert!(v > 1);
});
r.store(1, Ordering::SeqCst);
sem2.release(1);
r.store(2, Ordering::SeqCst);
sem2.release(1);
},
None,
);
}
#[test]
fn batch_semaphore_close_acquire() {
check_dfs(
|| {
future::block_on(async {
let tx = Arc::new(BatchSemaphore::new(1, Fairness::StrictlyFair));
let rx = Arc::new(BatchSemaphore::new(0, Fairness::StrictlyFair));
let tx2 = tx.clone();
let rx2 = rx.clone();
let h = future::spawn(async move {
tx2.acquire(1).await.unwrap();
rx2.release(1);
let s = tx2.acquire(1).await;
assert!(s.is_err());
assert!(matches!(tx2.try_acquire(1), Err(TryAcquireError::Closed)));
});
rx.acquire(1).await.unwrap();
tx.close();
h.await.unwrap();
});
},
None,
);
}
#[test]
fn batch_semaphore_drop_sender() {
struct Sender {
sem: Arc<BatchSemaphore>,
}
impl Drop for Sender {
fn drop(&mut self) {
self.sem.close();
}
}
check_dfs(
|| {
future::block_on(async {
let sem = Arc::new(BatchSemaphore::new(0, Fairness::StrictlyFair));
let sender = Sender { sem: sem.clone() };
future::spawn(async move {
let r = sem.acquire(2).await;
assert!(r.is_err());
});
future::spawn(async move {
sender.sem.release(1);
});
});
},
None,
);
}
#[test]
fn bugged_cleanup_would_cause_deadlock() {
struct Guard {
sem: Arc<BatchSemaphore>,
}
async fn lock(sem: &Arc<BatchSemaphore>) -> Guard {
let _ = sem.acquire(1).await;
Guard { sem: sem.clone() }
}
impl Drop for Guard {
fn drop(&mut self) {
self.sem.release(1);
}
}
check_dfs(
|| {
let sem = Arc::new(BatchSemaphore::new(1, Fairness::StrictlyFair));
let sem2 = sem.clone();
future::block_on(async move {
let handle = future::spawn(async move {
let mut futunord = FuturesUnordered::new();
let lock_future1 = async {
lock(&sem2).await;
}
.boxed();
let lock_future2 = async {
lock(&sem2).await;
}
.boxed();
let empty_future = async {}.boxed();
futunord.push(lock_future1);
futunord.push(lock_future2);
futunord.push(empty_future);
futunord.next().await.unwrap();
});
let guard = lock(&sem).await;
handle.await.unwrap();
drop(guard);
lock(&sem).await;
});
},
None,
)
}
mod early_acquire_drop_tests {
use super::*;
use futures::{
future::join_all,
task::{Context, Poll, Waker},
Future,
};
use pin_project::pin_project;
use proptest::prelude::*;
use proptest_derive::Arbitrary;
use shuttle::{
check_random,
sync::mpsc::{channel, Sender},
};
use std::pin::Pin;
#[derive(Arbitrary, Clone, Copy, Debug)]
enum Behavior {
EarlyDrop, Release, Hold, }
#[pin_project]
struct Task {
poll_count: usize, behavior: Behavior, requested_permits: usize, tx: Sender<Waker>, #[pin]
acquire: Acquire<'static>,
}
impl Task {
fn new(behavior: Behavior, requested_permits: usize, tx: Sender<Waker>, sem: &'static BatchSemaphore) -> Self {
Self {
poll_count: 0,
behavior,
requested_permits,
tx,
acquire: sem.acquire(requested_permits),
}
}
}
impl Future for Task {
type Output = usize;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut this = self.project();
if *this.poll_count == 0 {
let s: Poll<Result<(), AcquireError>> = this.acquire.as_mut().poll(cx);
assert!(s.is_pending());
this.tx.send(cx.waker().clone()).unwrap(); *this.poll_count += 1;
Poll::Pending
} else if matches!(*this.behavior, Behavior::EarlyDrop) {
Poll::Ready(0)
} else {
this.acquire.as_mut().poll(cx).map(|_| *this.requested_permits)
}
}
}
fn dropped_acquire_must_release(sem: &'static BatchSemaphore, task_config: Vec<(Behavior, usize)>) {
future::block_on(async move {
let mut wakers = vec![];
let mut handles = vec![];
let mut total_held = 0usize;
let mut max_requested = 0usize;
for (behavior, requested_permits) in task_config {
let (tx, rx) = channel();
match behavior {
Behavior::Hold => total_held += requested_permits,
_ => max_requested = std::cmp::max(max_requested, requested_permits),
}
handles.push(future::spawn(async move {
let task: Task = Task::new(behavior, requested_permits, tx, sem);
let p = task.await;
if matches!(behavior, Behavior::Release) {
sem.release(p);
}
}));
wakers.push(rx.recv().unwrap());
}
sem.release(total_held + max_requested);
for w in wakers.into_iter() {
w.wake();
}
join_all(handles).await;
});
}
macro_rules! sem_tests {
($mod_name:ident, $fairness:expr) => {
mod $mod_name {
use super::*;
#[test_log::test]
fn dropped_acquire_must_release_exhaustive() {
shuttle::lazy_static! {
static ref SEM: BatchSemaphore = BatchSemaphore::new(0, $fairness);
}
check_dfs(
|| dropped_acquire_must_release(&SEM, vec![(Behavior::EarlyDrop, 1), (Behavior::Release, 1)]),
None,
);
}
#[test_log::test]
fn dropped_acquire_must_release_deadlock() {
shuttle::lazy_static! {
static ref SEM: BatchSemaphore = BatchSemaphore::new(0, $fairness);
}
check_dfs(
|| dropped_acquire_must_release(&SEM, vec![(Behavior::Hold, 1), (Behavior::EarlyDrop, 2), (Behavior::Release, 1)]),
None,
);
}
const MAX_REQUESTED_PERMITS: usize = 3;
const MAX_TASKS: usize = 7;
proptest! {
#[test_log::test]
fn dropped_acquire_must_release_random(behavior in proptest::collection::vec((proptest::arbitrary::any::<Behavior>(), 1..=MAX_REQUESTED_PERMITS), 1..=MAX_TASKS)) {
check_random(
move || {
shuttle::lazy_static! {
static ref SEM: BatchSemaphore = BatchSemaphore::new(0, $fairness);
}
dropped_acquire_must_release(&SEM, behavior.clone())
},
1000,
);
}
}
}
}
}
sem_tests!(unfair, Fairness::Unfair);
sem_tests!(fair, Fairness::StrictlyFair);
}