use shuttle::scheduler::PctScheduler;
use shuttle::sync::Mutex;
use shuttle::{check_dfs, check_random, thread, Runner};
use std::collections::HashSet;
use std::sync::{Arc, TryLockError};
use test_log::test;
#[test]
fn basic_lock_test() {
check_dfs(
move || {
let lock = Arc::new(Mutex::new(0usize));
let lock_clone = Arc::clone(&lock);
thread::spawn(move || {
let mut counter = lock_clone.lock().unwrap();
*counter += 1;
});
let mut counter = lock.lock().unwrap();
*counter += 1;
},
None,
);
}
fn deadlock() {
let lock1 = Arc::new(Mutex::new(0usize));
let lock2 = Arc::new(Mutex::new(0usize));
let lock1_clone = Arc::clone(&lock1);
let lock2_clone = Arc::clone(&lock2);
thread::spawn(move || {
let _l1 = lock1_clone.lock().unwrap();
let _l2 = lock2_clone.lock().unwrap();
});
let _l2 = lock2.lock().unwrap();
let _l1 = lock1.lock().unwrap();
}
#[test]
#[should_panic(expected = "deadlock")]
fn deadlock_default() {
check_dfs(deadlock, None);
}
#[test]
#[should_panic(expected = "deadlock")]
fn deadlock_random() {
check_random(deadlock, 200);
}
#[test]
#[should_panic(expected = "deadlock")]
fn deadlock_pct() {
let scheduler = PctScheduler::new(2, 100);
let runner = Runner::new(scheduler, Default::default());
runner.run(deadlock);
}
#[test]
#[should_panic(expected = "racing increments")]
fn concurrent_increment_buggy() {
let scheduler = PctScheduler::new(2, 100);
let runner = Runner::new(scheduler, Default::default());
runner.run(|| {
let lock = Arc::new(Mutex::new(0usize));
let threads = (0..2)
.map(|_| {
let lock = Arc::clone(&lock);
thread::spawn(move || {
let curr = *lock.lock().unwrap();
*lock.lock().unwrap() = curr + 1;
})
})
.collect::<Vec<_>>();
for thd in threads {
thd.join().unwrap();
}
let counter = *lock.lock().unwrap();
assert_eq!(counter, 2, "racing increments");
});
}
#[test]
fn concurrent_increment() {
let scheduler = PctScheduler::new(2, 100);
let runner = Runner::new(scheduler, Default::default());
runner.run(|| {
let lock = Arc::new(Mutex::new(0usize));
let threads = (0..2)
.map(|_| {
let lock = Arc::clone(&lock);
thread::spawn(move || {
*lock.lock().unwrap() += 1;
})
})
.collect::<Vec<_>>();
for thd in threads {
thd.join().unwrap();
}
assert_eq!(*lock.lock().unwrap(), 2);
});
}
#[test]
fn unlock_yields() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let lock = Arc::new(Mutex::new(0usize));
let add_thread = {
let lock = Arc::clone(&lock);
thread::spawn(move || {
*lock.lock().unwrap() += 1;
*lock.lock().unwrap() += 1;
})
};
let mul_thread = {
let lock = Arc::clone(&lock);
thread::spawn(move || {
*lock.lock().unwrap() *= 2;
})
};
add_thread.join().unwrap();
mul_thread.join().unwrap();
let value = Arc::try_unwrap(lock).unwrap().into_inner().unwrap();
observed_values_clone.lock().unwrap().insert(value);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values, HashSet::from([2, 3, 4]));
}
#[test]
fn mutex_rwlock_interaction() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let lock = Arc::new(Mutex::new(()));
let rwlock = Arc::new(shuttle::sync::RwLock::new(()));
let value = Arc::new(std::sync::Mutex::new(0usize));
let add_thread = {
let lock = Arc::clone(&lock);
let value = Arc::clone(&value);
thread::spawn(move || {
{
let _guard = lock.lock().unwrap();
*value.lock().unwrap() += 1;
}
{
let _guard = rwlock.write().unwrap();
if let Ok(_g) = lock.try_lock() {
*value.lock().unwrap() += 2;
} else {
*value.lock().unwrap() += 6;
}
}
{
let _guard = lock.lock().unwrap();
*value.lock().unwrap() += 3;
}
})
};
let mul_thread = {
let lock = Arc::clone(&lock);
let log = Arc::clone(&value);
thread::spawn(move || {
let _guard = lock.lock().unwrap();
*log.lock().unwrap() *= 2;
})
};
add_thread.join().unwrap();
mul_thread.join().unwrap();
let value = Arc::try_unwrap(value).unwrap().into_inner().unwrap();
observed_values_clone.lock().unwrap().insert(value);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values, HashSet::from([6, 7, 9, 12, 11, 17]));
}
#[test]
fn concurrent_try_increment() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let lock = Arc::new(Mutex::new(0usize));
let threads = (0..2)
.map(|_| {
let lock = Arc::clone(&lock);
thread::spawn(move || {
match lock.try_lock() {
Ok(mut guard) => {
*guard += 1;
}
Err(TryLockError::WouldBlock) => (),
Err(_) => panic!("unexpected TryLockError"),
};
})
})
.collect::<Vec<_>>();
for thd in threads {
thd.join().unwrap();
}
let value = Arc::try_unwrap(lock).unwrap().into_inner().unwrap();
observed_values_clone.lock().unwrap().insert(value);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values, HashSet::from([1, 2]));
}
#[test]
fn concurrent_lock_try_lock() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let lock = Arc::new(Mutex::new(0usize));
let lock_thread = {
let lock = Arc::clone(&lock);
thread::spawn(move || {
*lock.lock().unwrap() += 1;
})
};
let try_lock_thread = {
let lock = Arc::clone(&lock);
thread::spawn(move || {
for n in 1..3 {
match lock.try_lock() {
Ok(mut guard) => {
*guard += n;
}
Err(TryLockError::WouldBlock) => (),
Err(_) => panic!("unexpected TryLockError"),
};
}
})
};
lock_thread.join().unwrap();
try_lock_thread.join().unwrap();
let value = Arc::try_unwrap(lock).unwrap().into_inner().unwrap();
observed_values_clone.lock().unwrap().insert(value);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values, HashSet::from([1, 2, 3, 4]));
}
#[test]
#[should_panic(expected = "tried to acquire a Mutex it already holds")]
fn double_lock() {
check_dfs(
|| {
let mutex = Mutex::new(());
let _guard_1 = mutex.lock().unwrap();
let _guard_2 = mutex.lock();
},
None,
)
}
#[test]
fn double_try_lock() {
check_dfs(
|| {
let mutex = Mutex::new(());
let _guard_1 = mutex.try_lock().unwrap();
assert!(matches!(mutex.try_lock(), Err(TryLockError::WouldBlock)));
},
None,
)
}
#[test]
#[should_panic(expected = "expected panic")]
fn panic_drop() {
check_dfs(
|| {
let lock = Mutex::new(0);
let _l = lock.lock().unwrap();
panic!("expected panic");
},
None,
)
}
#[test]
fn mutex_into_inner() {
check_dfs(
|| {
let lock = Arc::new(Mutex::new(0u64));
let threads = (0..2)
.map(|_| {
let lock = lock.clone();
thread::spawn(move || {
*lock.lock().unwrap() += 1;
})
})
.collect::<Vec<_>>();
for thread in threads {
thread.join().unwrap();
}
let lock = Arc::try_unwrap(lock).unwrap();
assert_eq!(lock.into_inner().unwrap(), 2);
},
None,
)
}