use std::cell::UnsafeCell;
use std::pin::{Pin, pin};
use std::ptr;
use std::sync::atomic::Ordering;
use saa::lock::Mode;
use saa::{Lock, Pager};
use sdd::{AtomicShared, Guard};
#[derive(Debug, Default)]
pub(crate) struct AsyncGuard {
guard: UnsafeCell<Option<Guard>>,
}
#[derive(Debug)]
pub(crate) struct AsyncWait {
pager: Pager<'static, Lock>,
}
pub(crate) trait TryWait {
fn try_wait(&mut self, lock: &Lock);
}
impl AsyncGuard {
#[inline]
pub(crate) fn has_guard(&self) -> bool {
unsafe { (*self.guard.get()).is_some() }
}
#[inline]
pub(crate) fn guard(&self) -> &Guard {
unsafe { (*self.guard.get()).get_or_insert_with(Guard::new) }
}
#[inline]
pub(crate) fn reset(&self) {
unsafe {
*self.guard.get() = None;
}
}
#[inline]
pub(crate) fn load<T>(&self, atomic_ptr: &AtomicShared<T>, mo: Ordering) -> Option<&T> {
atomic_ptr.load(mo, self.guard()).as_ref()
}
#[inline]
pub(crate) fn check_ref<T>(&self, atomic_ptr: &AtomicShared<T>, r: &T, mo: Ordering) -> bool {
atomic_ptr
.load(mo, self.guard())
.as_ref()
.is_some_and(|s| ptr::eq(s, r))
}
}
unsafe impl Send for AsyncGuard {}
unsafe impl Sync for AsyncGuard {}
impl AsyncWait {
#[inline]
pub async fn wait(self: &mut Pin<&mut Self>) {
let this = unsafe { ptr::read(self) };
let mut pinned_pager = unsafe { Pin::new_unchecked(&mut this.get_unchecked_mut().pager) };
let _result = pinned_pager.poll_async().await;
}
}
impl Default for AsyncWait {
#[inline]
fn default() -> Self {
Self {
pager: unsafe {
std::mem::transmute::<Pager<'_, Lock>, Pager<'static, Lock>>(Pager::default())
},
}
}
}
impl TryWait for Pin<&mut AsyncWait> {
#[inline]
fn try_wait(&mut self, lock: &Lock) {
let this = unsafe { ptr::read(self) };
let mut pinned_pager = unsafe {
let pager_ref = std::mem::transmute::<&mut Pager<'static, Lock>, &mut Pager<Lock>>(
&mut this.get_unchecked_mut().pager,
);
Pin::new_unchecked(pager_ref)
};
lock.register_pager(&mut pinned_pager, Mode::WaitExclusive, false);
}
}
impl TryWait for () {
#[inline]
fn try_wait(&mut self, lock: &Lock) {
let mut pinned_pager = pin!(Pager::default());
lock.register_pager(&mut pinned_pager, Mode::WaitExclusive, true);
let _: Result<_, _> = pinned_pager.poll_sync();
}
}