use arr_macro::arr;
use std::{
collections::BTreeMap,
fmt,
sync::{
atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering},
Arc,
Mutex,
},
thread,
thread::Thread,
};
pub fn small() -> StaticParker<SmallThreadData> {
StaticParker::new(SmallThreadData::new())
}
pub fn large() -> StaticParker<LargeThreadData> {
StaticParker::new(LargeThreadData::new())
}
pub fn dynamic() -> StaticParker<DynamicThreadData> {
StaticParker::new(DynamicThreadData::new())
}
#[derive(Clone, Debug, Copy)]
pub enum ParkResult {
Retry,
Abort,
Woken,
}
pub trait Parker: Sync + Send + Clone {
fn max_threads(&self) -> Option<usize>;
fn init(&self, thread_id: usize) -> ();
fn prepare_park(&self, thread_id: usize) -> ();
fn abort_park(&self, thread_id: usize) -> ();
fn park(&self, thread_id: usize) -> ParkResult;
fn unpark_one(&self) -> ();
fn unpark_all(&self) -> ();
}
#[derive(Clone, Debug)]
pub struct DynParker {
inner: Arc<dyn ThreadData + Sync + Send>,
}
impl DynParker {
fn new<T: 'static + ThreadData + Sync + Send>(data: T) -> DynParker {
DynParker {
inner: Arc::new(data),
}
}
}
impl Parker for DynParker {
#[inline(always)]
fn max_threads(&self) -> Option<usize> {
self.inner.max_threads()
}
#[inline(always)]
fn init(&self, thread_id: usize) -> () {
self.inner.init(thread_id);
}
#[inline(always)]
fn prepare_park(&self, thread_id: usize) -> () {
self.inner.prepare_park(thread_id);
}
#[inline(always)]
fn abort_park(&self, thread_id: usize) -> () {
self.inner.abort_park(thread_id);
}
#[inline(always)]
fn park(&self, thread_id: usize) -> ParkResult {
self.inner.park(thread_id)
}
#[inline(always)]
fn unpark_one(&self) -> () {
if (!self.inner.all_awake()) {
self.inner.unpark_one();
}
}
#[inline(always)]
fn unpark_all(&self) -> () {
if (!self.inner.all_awake()) {
self.inner.unpark_all();
}
}
}
#[derive(Debug)]
pub struct StaticParker<T>
where
T: ThreadData + Sync + Send + 'static,
{
inner: Arc<T>,
}
impl<T> StaticParker<T>
where
T: ThreadData + Sync + Send + 'static,
{
fn new(data: T) -> StaticParker<T> {
StaticParker {
inner: Arc::new(data),
}
}
pub fn dynamic(self) -> Result<DynParker, Self> {
Arc::try_unwrap(self.inner)
.map(DynParker::new)
.map_err(|arc_data| StaticParker { inner: arc_data })
}
}
impl<T> Clone for StaticParker<T>
where
T: ThreadData + Sync + Send + 'static,
{
fn clone(&self) -> Self {
StaticParker {
inner: self.inner.clone(),
}
}
}
impl<T> Parker for StaticParker<T>
where
T: ThreadData + Send + Sync + 'static,
{
#[inline(always)]
fn max_threads(&self) -> Option<usize> {
self.inner.max_threads()
}
#[inline(always)]
fn init(&self, thread_id: usize) -> () {
self.inner.init(thread_id);
}
#[inline(always)]
fn prepare_park(&self, thread_id: usize) -> () {
self.inner.prepare_park(thread_id);
}
#[inline(always)]
fn abort_park(&self, thread_id: usize) -> () {
self.inner.abort_park(thread_id);
}
#[inline(always)]
fn park(&self, thread_id: usize) -> ParkResult {
self.inner.park(thread_id)
}
#[inline(always)]
fn unpark_one(&self) -> () {
if (!self.inner.all_awake()) {
self.inner.unpark_one();
}
}
#[inline(always)]
fn unpark_all(&self) -> () {
if (!self.inner.all_awake()) {
self.inner.unpark_all();
}
}
}
pub trait ThreadData: std::fmt::Debug {
fn max_threads(&self) -> Option<usize>;
fn init(&self, thread_id: usize) -> ();
fn prepare_park(&self, thread_id: usize) -> ();
fn abort_park(&self, thread_id: usize) -> ();
fn park(&self, thread_id: usize) -> ParkResult;
fn all_awake(&self) -> bool;
fn unpark_one(&self) -> ();
fn unpark_all(&self) -> ();
}
#[derive(Debug)]
enum ParkState {
Awake,
Asleep(Thread),
NoSleep,
Waking,
}
#[derive(Debug)]
pub struct SmallThreadData {
sleep_set: AtomicU32,
sleeping: Mutex<[ParkState; 32]>,
}
impl SmallThreadData {
pub const MAX_THREADS: usize = 32;
fn new() -> SmallThreadData {
SmallThreadData {
sleep_set: AtomicU32::new(0),
sleeping: Mutex::new(arr![ParkState::Awake; 32]),
}
}
}
impl ThreadData for SmallThreadData {
fn max_threads(&self) -> Option<usize> {
Some(SmallThreadData::MAX_THREADS)
}
fn init(&self, thread_id: usize) -> () {
assert!(thread_id < 32);
match self.sleeping.lock() {
Ok(mut guard) => {
guard[thread_id] = ParkState::Awake;
}
_ => {
panic!("Mutex is poisoned!");
}
}
}
fn prepare_park(&self, thread_id: usize) -> () {
assert!(thread_id < 32);
self.sleep_set.set_at(thread_id);
}
fn abort_park(&self, thread_id: usize) -> () {
assert!(thread_id < 32);
self.sleep_set.unset_at(thread_id);
}
fn park(&self, thread_id: usize) -> ParkResult {
assert!(thread_id < 32);
match self.sleeping.try_lock() {
Ok(mut guard) => {
match guard[thread_id] {
ParkState::Awake => {
guard[thread_id] = ParkState::Asleep(std::thread::current());
}
ParkState::Asleep(_) => unreachable!("Threads must clean up after waking up!"),
ParkState::NoSleep => {
self.sleep_set.unset_at(thread_id);
guard[thread_id] = ParkState::Awake;
return ParkResult::Abort;
}
ParkState::Waking => unreachable!("Threads must clean up after waking up!"),
}
}
_ => {
return ParkResult::Retry; }
}
thread::park();
match self.sleeping.lock() {
Ok(mut guard) => {
self.sleep_set.unset_at(thread_id);
match guard[thread_id] {
ParkState::Awake => unreachable!("Threads must be asleep to wake from park!"),
ParkState::Waking | ParkState::Asleep(_) => {
guard[thread_id] = ParkState::Awake;
ParkResult::Woken
}
ParkState::NoSleep => {
unreachable!("Threads must be awake to be prevented from sleeping!");
}
}
}
_ => {
panic!("Mutex is poisoned!");
}
}
}
#[inline(always)]
fn all_awake(&self) -> bool {
self.sleep_set.load(Ordering::SeqCst) == 0u32
}
fn unpark_one(&self) -> () {
match self.sleeping.lock() {
Ok(mut guard) => {
if let Ok(index) = self.sleep_set.get_lowest() {
match guard[index] {
ParkState::Awake => {
guard[index] = ParkState::NoSleep;
}
ParkState::Asleep(ref t) => {
t.unpark();
guard[index] = ParkState::Waking;
}
ParkState::Waking | ParkState::NoSleep => {
for index in 0..32 {
if let ParkState::Asleep(ref t) = guard[index] {
t.unpark();
guard[index] = ParkState::Waking;
return;
}
}
}
}
} }
_ => {
panic!("Mutex is poisoned!");
}
}
}
fn unpark_all(&self) -> () {
match self.sleeping.lock() {
Ok(mut guard) => {
if !self.all_awake() {
for index in 0..32 {
match guard[index] {
ParkState::Awake => {
guard[index] = ParkState::NoSleep;
}
ParkState::Asleep(ref t) => {
t.unpark();
guard[index] = ParkState::Waking;
}
ParkState::NoSleep | ParkState::Waking => (),
}
}
} }
_ => {
panic!("Mutex is poisoned!");
}
}
}
}
pub struct LargeThreadData {
sleep_set: AtomicU64,
sleeping: Mutex<[ParkState; 64]>,
}
impl LargeThreadData {
pub const MAX_THREADS: usize = 64;
fn new() -> LargeThreadData {
LargeThreadData {
sleep_set: AtomicU64::new(0),
sleeping: Mutex::new(arr![ParkState::Awake; 64]),
}
}
}
impl fmt::Debug for LargeThreadData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sset = self.sleep_set.load(Ordering::SeqCst);
write!(f, "LargeThreadData {{ sleep_set: {:032b}, sleeping: Mutex([...;64]) [array contents elided for brevity] }}", sset)
}
}
impl ThreadData for LargeThreadData {
fn max_threads(&self) -> Option<usize> {
Some(LargeThreadData::MAX_THREADS)
}
fn init(&self, thread_id: usize) -> () {
assert!(thread_id < 64);
match self.sleeping.lock() {
Ok(mut guard) => {
guard[thread_id] = ParkState::Awake;
}
_ => {
panic!("Mutex is poisoned!");
}
}
}
fn prepare_park(&self, thread_id: usize) -> () {
assert!(thread_id < 64);
self.sleep_set.set_at(thread_id);
}
fn abort_park(&self, thread_id: usize) -> () {
assert!(thread_id < 64);
self.sleep_set.unset_at(thread_id);
}
fn park(&self, thread_id: usize) -> ParkResult {
assert!(thread_id < 64);
match self.sleeping.try_lock() {
Ok(mut guard) => {
match guard[thread_id] {
ParkState::Awake => {
guard[thread_id] = ParkState::Asleep(std::thread::current());
}
ParkState::Asleep(_) => unreachable!("Threads must clean up after waking up!"),
ParkState::NoSleep => {
self.sleep_set.unset_at(thread_id);
guard[thread_id] = ParkState::Awake;
return ParkResult::Abort;
}
ParkState::Waking => unreachable!("Threads must clean up after waking up!"),
}
}
_ => {
return ParkResult::Retry; }
}
thread::park();
match self.sleeping.lock() {
Ok(mut guard) => {
self.sleep_set.unset_at(thread_id);
match guard[thread_id] {
ParkState::Awake => unreachable!("Threads must be asleep to wake from park!"),
ParkState::Waking | ParkState::Asleep(_) => {
guard[thread_id] = ParkState::Awake;
ParkResult::Woken
}
ParkState::NoSleep => {
unreachable!("Threads must be awake to be prevented from sleeping!");
}
}
}
_ => {
panic!("Mutex is poisoned!");
}
}
}
#[inline(always)]
fn all_awake(&self) -> bool {
self.sleep_set.load(Ordering::SeqCst) == 0u64
}
fn unpark_one(&self) -> () {
match self.sleeping.lock() {
Ok(mut guard) => {
if let Ok(index) = self.sleep_set.get_lowest() {
match guard[index] {
ParkState::Awake => {
guard[index] = ParkState::NoSleep;
}
ParkState::Asleep(ref t) => {
t.unpark();
guard[index] = ParkState::Waking;
}
ParkState::Waking | ParkState::NoSleep => {
for index in 0..64 {
if let ParkState::Asleep(ref t) = guard[index] {
t.unpark();
guard[index] = ParkState::Waking;
return;
}
}
}
}
} }
_ => {
panic!("Mutex is poisoned!");
}
}
}
fn unpark_all(&self) -> () {
match self.sleeping.lock() {
Ok(mut guard) => {
if !self.all_awake() {
for index in 0..64 {
match guard[index] {
ParkState::Awake => {
guard[index] = ParkState::NoSleep;
}
ParkState::Asleep(ref t) => {
t.unpark();
guard[index] = ParkState::Waking;
}
ParkState::NoSleep | ParkState::Waking => (),
}
}
} }
_ => {
panic!("Mutex is poisoned!");
}
}
}
}
#[derive(Debug)]
struct InnerData {
sleeping: BTreeMap<usize, ParkState>,
no_sleep: usize,
}
impl InnerData {
fn new() -> InnerData {
InnerData {
sleeping: BTreeMap::new(),
no_sleep: 0,
}
}
}
#[derive(Debug)]
pub struct DynamicThreadData {
sleep_count: AtomicUsize,
data: Mutex<InnerData>,
}
impl DynamicThreadData {
fn new() -> DynamicThreadData {
DynamicThreadData {
sleep_count: AtomicUsize::new(0),
data: Mutex::new(InnerData::new()),
}
}
}
impl ThreadData for DynamicThreadData {
fn max_threads(&self) -> Option<usize> {
None
}
fn init(&self, thread_id: usize) -> () {
match self.data.lock() {
Ok(mut guard) => {
let _ = guard.sleeping.remove(&thread_id);
}
_ => {
panic!("Mutex is poisoned!");
}
}
}
fn prepare_park(&self, _thread_id: usize) -> () {
self.sleep_count.fetch_add(1usize, Ordering::SeqCst);
}
fn abort_park(&self, _thread_id: usize) -> () {
self.sleep_count.fetch_sub(1usize, Ordering::SeqCst);
}
fn park(&self, thread_id: usize) -> ParkResult {
match self.data.try_lock() {
Ok(mut guard) => {
if guard.no_sleep == 0 {
let old = guard
.sleeping
.insert(thread_id, ParkState::Asleep(std::thread::current()));
if old.is_some() {
unreachable!("Inconsistent sleeping map (before park)!");
}
} else {
guard.no_sleep -= 1; self.sleep_count.fetch_sub(1usize, Ordering::SeqCst);
return ParkResult::Abort;
}
}
_ => {
return ParkResult::Retry; }
}
thread::park();
match self.data.lock() {
Ok(mut guard) => {
self.sleep_count.fetch_sub(1usize, Ordering::SeqCst);
guard
.sleeping
.remove(&thread_id)
.expect("Inconsistent sleeping map (after park)!");
ParkResult::Woken
}
_ => {
panic!("Mutex is poisoned!");
}
}
}
#[inline(always)]
fn all_awake(&self) -> bool {
self.sleep_count.load(Ordering::SeqCst) == 0usize
}
fn unpark_one(&self) -> () {
match self.data.lock() {
Ok(mut guard) => {
if self.sleep_count.load(Ordering::SeqCst) > 0usize {
for state in guard.sleeping.values_mut() {
match state {
ParkState::Asleep(t) => {
t.unpark();
*state = ParkState::Waking;
return;
}
ParkState::Waking => (), ParkState::Awake | ParkState::NoSleep => {
unreachable!("These should not be in the map at all!");
}
}
}
guard.no_sleep = guard.no_sleep.saturating_add(1); } }
_ => {
panic!("Mutex is poisoned!");
}
}
}
fn unpark_all(&self) -> () {
match self.data.lock() {
Ok(mut guard) => {
if self.sleep_count.load(Ordering::SeqCst) > 0usize {
for state in guard.sleeping.values_mut() {
match state {
ParkState::Asleep(t) => {
t.unpark();
*state = ParkState::Waking;
}
ParkState::Waking => (), ParkState::Awake | ParkState::NoSleep => {
unreachable!("These should not be in the map at all!");
}
}
}
guard.no_sleep = usize::MAX; } }
_ => {
panic!("Mutex is poisoned!");
}
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
struct BitSetEmpty;
const BIT_SET_EMPTY: BitSetEmpty = BitSetEmpty {};
trait AtomicBitSet {
fn set_at(&self, at: usize) -> ();
fn unset_at(&self, at: usize) -> ();
fn get_lowest(&self) -> Result<usize, BitSetEmpty>;
}
impl AtomicBitSet for AtomicU32 {
fn set_at(&self, at: usize) -> () {
assert!(at < 32);
let mask = 1u32 << at;
self.fetch_or(mask, Ordering::SeqCst);
}
fn unset_at(&self, at: usize) -> () {
assert!(at < 32);
let mask = !(1u32 << at);
self.fetch_and(mask, Ordering::SeqCst);
}
fn get_lowest(&self) -> Result<usize, BitSetEmpty> {
let cur = self.load(Ordering::SeqCst);
if (cur == 0u32) {
Err(BIT_SET_EMPTY)
} else {
let mut mask = 1u32;
let mut index = 0;
while index < 32 {
if (mask & cur) != 0u32 {
return Ok(index);
} else {
index += 1;
mask <<= 1;
}
}
unreachable!("Bitset was empty despite empty check!");
}
}
}
impl AtomicBitSet for AtomicU64 {
fn set_at(&self, at: usize) -> () {
assert!(at < 64);
let mask = 1u64 << at;
self.fetch_or(mask, Ordering::SeqCst);
}
fn unset_at(&self, at: usize) -> () {
assert!(at < 64);
let mask = !(1u64 << at);
self.fetch_and(mask, Ordering::SeqCst);
}
fn get_lowest(&self) -> Result<usize, BitSetEmpty> {
let cur = self.load(Ordering::SeqCst);
if (cur == 0u64) {
Err(BIT_SET_EMPTY)
} else {
let mut mask = 1u64;
let mut index = 0;
while index < 64 {
if (mask & cur) != 0u64 {
return Ok(index);
} else {
index += 1;
mask <<= 1;
}
}
unreachable!("Bitset was empty despite empty check!");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parker_printing() {
{
let p = small();
println!("Small Parker: {:?}", p);
let dyn_p = p.dynamic();
println!("Small Parker (Dynamic): {:?}", dyn_p);
}
{
let p = large();
println!("Large Parker: {:?}", p);
let dyn_p = p.dynamic();
println!("Large Parker (Dynamic): {:?}", dyn_p);
}
{
let p = dynamic();
println!("Dynamic Parker: {:?}", p);
let dyn_p = p.dynamic();
println!("Dynamic Parker (Dynamic): {:?}", dyn_p);
}
}
#[allow(clippy::unnecessary_wraps)]
fn res_ok(v: usize) -> Result<usize, BitSetEmpty> {
Ok(v)
}
fn res_err() -> Result<usize, BitSetEmpty> {
Err(BIT_SET_EMPTY)
}
#[test]
fn test_bitset() {
let data = AtomicU32::new(0);
let bs: &dyn AtomicBitSet = &data;
assert_eq!(res_err(), bs.get_lowest());
bs.set_at(1);
assert_eq!(res_ok(1), bs.get_lowest());
bs.set_at(5);
assert_eq!(res_ok(1), bs.get_lowest());
bs.unset_at(1);
assert_eq!(res_ok(5), bs.get_lowest());
bs.unset_at(5);
assert_eq!(res_err(), bs.get_lowest());
}
}