use crate::spinwait::SpinWait;
use crate::thread_parker::{ThreadParker, ThreadParkerT, UnparkHandleT};
use core::{
cell::Cell,
mem, ptr,
sync::atomic::{fence, AtomicUsize, Ordering},
};
struct ThreadData {
parker: ThreadParker,
queue_tail: Cell<*const ThreadData>,
prev: Cell<*const ThreadData>,
next: Cell<*const ThreadData>,
}
impl ThreadData {
#[inline]
fn new() -> ThreadData {
assert!(mem::align_of::<ThreadData>() > !QUEUE_MASK);
ThreadData {
parker: ThreadParker::new(),
queue_tail: Cell::new(ptr::null()),
prev: Cell::new(ptr::null()),
next: Cell::new(ptr::null()),
}
}
}
#[inline]
fn with_thread_data<T>(f: impl FnOnce(&ThreadData) -> T) -> T {
let mut thread_data_ptr = ptr::null();
if !ThreadParker::IS_CHEAP_TO_CONSTRUCT {
thread_local!(static THREAD_DATA: ThreadData = ThreadData::new());
if let Ok(tls_thread_data) = THREAD_DATA.try_with(|x| x as *const ThreadData) {
thread_data_ptr = tls_thread_data;
}
}
let mut thread_data_storage = None;
if thread_data_ptr.is_null() {
thread_data_ptr = thread_data_storage.get_or_insert_with(ThreadData::new);
}
f(unsafe { &*thread_data_ptr })
}
const LOCKED_BIT: usize = 1;
const QUEUE_LOCKED_BIT: usize = 2;
const QUEUE_MASK: usize = !3;
pub struct WordLock {
state: AtomicUsize,
}
impl WordLock {
pub const fn new() -> Self {
WordLock {
state: AtomicUsize::new(0),
}
}
#[inline]
pub fn lock(&self) {
if self
.state
.compare_exchange_weak(0, LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return;
}
self.lock_slow();
}
#[inline]
pub unsafe fn unlock(&self) {
let state = self.state.fetch_sub(LOCKED_BIT, Ordering::Release);
if state.is_queue_locked() || state.queue_head().is_null() {
return;
}
self.unlock_slow();
}
#[cold]
fn lock_slow(&self) {
let mut spinwait = SpinWait::new();
let mut state = self.state.load(Ordering::Relaxed);
loop {
if !state.is_locked() {
match self.state.compare_exchange_weak(
state,
state | LOCKED_BIT,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(x) => state = x,
}
continue;
}
if state.queue_head().is_null() && spinwait.spin() {
state = self.state.load(Ordering::Relaxed);
continue;
}
state = with_thread_data(|thread_data| {
#[allow(unused_unsafe)]
unsafe {
thread_data.parker.prepare_park();
}
let queue_head = state.queue_head();
if queue_head.is_null() {
thread_data.queue_tail.set(thread_data);
thread_data.prev.set(ptr::null());
} else {
thread_data.queue_tail.set(ptr::null());
thread_data.prev.set(ptr::null());
thread_data.next.set(queue_head);
}
if let Err(x) = self.state.compare_exchange_weak(
state,
state.with_queue_head(thread_data),
Ordering::AcqRel,
Ordering::Relaxed,
) {
return x;
}
#[allow(unused_unsafe)]
unsafe {
thread_data.parker.park();
}
spinwait.reset();
self.state.load(Ordering::Relaxed)
});
}
}
#[cold]
fn unlock_slow(&self) {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state.is_queue_locked() || state.queue_head().is_null() {
return;
}
match self.state.compare_exchange_weak(
state,
state | QUEUE_LOCKED_BIT,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => state = x,
}
}
'outer: loop {
let queue_head = state.queue_head();
let mut queue_tail;
let mut current = queue_head;
loop {
queue_tail = unsafe { (*current).queue_tail.get() };
if !queue_tail.is_null() {
break;
}
unsafe {
let next = (*current).next.get();
(*next).prev.set(current);
current = next;
}
}
unsafe {
(*queue_head).queue_tail.set(queue_tail);
}
if state.is_locked() {
match self.state.compare_exchange_weak(
state,
state & !QUEUE_LOCKED_BIT,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(x) => state = x,
}
fence_acquire(&self.state);
continue;
}
let new_tail = unsafe { (*queue_tail).prev.get() };
if new_tail.is_null() {
loop {
match self.state.compare_exchange_weak(
state,
state & LOCKED_BIT,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => state = x,
}
if state.queue_head().is_null() {
continue;
} else {
fence_acquire(&self.state);
continue 'outer;
}
}
} else {
unsafe {
(*queue_head).queue_tail.set(new_tail);
}
self.state.fetch_and(!QUEUE_LOCKED_BIT, Ordering::Release);
}
unsafe {
(*queue_tail).parker.unpark_lock().unpark();
}
break;
}
}
}
#[inline]
fn fence_acquire(a: &AtomicUsize) {
if cfg!(tsan_enabled) {
let _ = a.load(Ordering::Acquire);
} else {
fence(Ordering::Acquire);
}
}
trait LockState {
fn is_locked(self) -> bool;
fn is_queue_locked(self) -> bool;
fn queue_head(self) -> *const ThreadData;
fn with_queue_head(self, thread_data: *const ThreadData) -> Self;
}
impl LockState for usize {
#[inline]
fn is_locked(self) -> bool {
self & LOCKED_BIT != 0
}
#[inline]
fn is_queue_locked(self) -> bool {
self & QUEUE_LOCKED_BIT != 0
}
#[inline]
fn queue_head(self) -> *const ThreadData {
(self & QUEUE_MASK) as *const ThreadData
}
#[inline]
fn with_queue_head(self, thread_data: *const ThreadData) -> Self {
(self & !QUEUE_MASK) | thread_data as *const _ as usize
}
}