use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::usize;
use crate::registry::{Registry, WorkerThread};
use crate::sync::{Condvar, Mutex};
pub(super) trait Latch {
unsafe fn set(this: *const Self);
}
pub(super) trait AsCoreLatch {
fn as_core_latch(&self) -> &CoreLatch;
}
const UNSET: usize = 0;
const SLEEPY: usize = 1;
const SLEEPING: usize = 2;
const SET: usize = 3;
#[derive(Debug)]
pub(super) struct CoreLatch {
state: AtomicUsize,
}
impl CoreLatch {
#[inline]
fn new() -> Self {
Self {
state: AtomicUsize::new(0),
}
}
#[inline]
pub(super) fn get_sleepy(&self) -> bool {
self.state
.compare_exchange(UNSET, SLEEPY, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub(super) fn fall_asleep(&self) -> bool {
self.state
.compare_exchange(SLEEPY, SLEEPING, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub(super) fn wake_up(&self) {
if !self.probe() {
let _ =
self.state
.compare_exchange(SLEEPING, UNSET, Ordering::SeqCst, Ordering::Relaxed);
}
}
#[inline]
unsafe fn set(this: *const Self) -> bool {
let old_state = (*this).state.swap(SET, Ordering::AcqRel);
old_state == SLEEPING
}
#[inline]
pub(super) fn probe(&self) -> bool {
self.state.load(Ordering::Acquire) == SET
}
}
impl AsCoreLatch for CoreLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
self
}
}
pub(super) struct SpinLatch<'r> {
core_latch: CoreLatch,
registry: &'r Arc<Registry>,
target_worker_index: usize,
cross: bool,
}
impl<'r> SpinLatch<'r> {
#[inline]
pub(super) fn new(thread: &'r WorkerThread) -> SpinLatch<'r> {
SpinLatch {
core_latch: CoreLatch::new(),
registry: thread.registry(),
target_worker_index: thread.index(),
cross: false,
}
}
#[inline]
pub(super) fn cross(thread: &'r WorkerThread) -> SpinLatch<'r> {
SpinLatch {
cross: true,
..SpinLatch::new(thread)
}
}
#[inline]
pub(super) fn probe(&self) -> bool {
self.core_latch.probe()
}
}
impl<'r> AsCoreLatch for SpinLatch<'r> {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
&self.core_latch
}
}
impl<'r> Latch for SpinLatch<'r> {
#[inline]
unsafe fn set(this: *const Self) {
let cross_registry;
let registry: &Registry = if (*this).cross {
cross_registry = Arc::clone((*this).registry);
&cross_registry
} else {
(*this).registry
};
let target_worker_index = (*this).target_worker_index;
if CoreLatch::set(&(*this).core_latch) {
registry.notify_worker_latch_is_set(target_worker_index);
}
}
}
#[derive(Debug)]
pub(super) struct LockLatch {
m: Mutex<bool>,
v: Condvar,
}
impl LockLatch {
#[inline]
pub(super) fn new() -> LockLatch {
LockLatch {
m: Mutex::new(false),
v: Condvar::new(),
}
}
pub(super) fn wait_and_reset(&self) {
let mut guard = self.m.lock().unwrap();
while !*guard {
guard = self.v.wait(guard).unwrap();
}
*guard = false;
}
pub(super) fn wait(&self) {
let mut guard = self.m.lock().unwrap();
while !*guard {
guard = self.v.wait(guard).unwrap();
}
}
}
impl Latch for LockLatch {
#[inline]
unsafe fn set(this: *const Self) {
let mut guard = (*this).m.lock().unwrap();
*guard = true;
(*this).v.notify_all();
}
}
#[derive(Debug)]
pub(super) struct OnceLatch {
core_latch: CoreLatch,
}
impl OnceLatch {
#[inline]
pub(super) fn new() -> OnceLatch {
Self {
core_latch: CoreLatch::new(),
}
}
#[inline]
pub(super) unsafe fn set_and_tickle_one(
this: *const Self,
registry: &Registry,
target_worker_index: usize,
) {
if CoreLatch::set(&(*this).core_latch) {
registry.notify_worker_latch_is_set(target_worker_index);
}
}
}
impl AsCoreLatch for OnceLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
&self.core_latch
}
}
#[derive(Debug)]
pub(super) struct CountLatch {
counter: AtomicUsize,
kind: CountLatchKind,
}
enum CountLatchKind {
Stealing {
latch: CoreLatch,
registry: Arc<Registry>,
worker_index: usize,
},
Blocking { latch: LockLatch },
}
impl std::fmt::Debug for CountLatchKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CountLatchKind::Stealing { latch, .. } => {
f.debug_tuple("Stealing").field(latch).finish()
}
CountLatchKind::Blocking { latch, .. } => {
f.debug_tuple("Blocking").field(latch).finish()
}
}
}
}
impl CountLatch {
pub(super) fn new(owner: Option<&WorkerThread>) -> Self {
Self::with_count(1, owner)
}
pub(super) fn with_count(count: usize, owner: Option<&WorkerThread>) -> Self {
Self {
counter: AtomicUsize::new(count),
kind: match owner {
Some(owner) => CountLatchKind::Stealing {
latch: CoreLatch::new(),
registry: Arc::clone(owner.registry()),
worker_index: owner.index(),
},
None => CountLatchKind::Blocking {
latch: LockLatch::new(),
},
},
}
}
#[inline]
pub(super) fn increment(&self) {
let old_counter = self.counter.fetch_add(1, Ordering::Relaxed);
debug_assert!(old_counter != 0);
}
pub(super) fn wait(&self, owner: Option<&WorkerThread>) {
match &self.kind {
CountLatchKind::Stealing {
latch,
registry,
worker_index,
} => unsafe {
let owner = owner.expect("owner thread");
debug_assert_eq!(registry.id(), owner.registry().id());
debug_assert_eq!(*worker_index, owner.index());
owner.wait_until(latch);
},
CountLatchKind::Blocking { latch } => latch.wait(),
}
}
}
impl Latch for CountLatch {
#[inline]
unsafe fn set(this: *const Self) {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
match (*this).kind {
CountLatchKind::Stealing {
ref latch,
ref registry,
worker_index,
} => {
let registry = Arc::clone(registry);
if CoreLatch::set(latch) {
registry.notify_worker_latch_is_set(worker_index);
}
}
CountLatchKind::Blocking { ref latch } => LockLatch::set(latch),
}
}
}
}
pub(super) struct LatchRef<'a, L> {
inner: *const L,
marker: PhantomData<&'a L>,
}
impl<L> LatchRef<'_, L> {
pub(super) fn new(inner: &L) -> LatchRef<'_, L> {
LatchRef {
inner,
marker: PhantomData,
}
}
}
unsafe impl<L: Sync> Sync for LatchRef<'_, L> {}
impl<L> Deref for LatchRef<'_, L> {
type Target = L;
fn deref(&self) -> &L {
unsafe { &*self.inner }
}
}
impl<L: Latch> Latch for LatchRef<'_, L> {
#[inline]
unsafe fn set(this: *const Self) {
L::set((*this).inner);
}
}