use std::cell::UnsafeCell;
use std::marker::PhantomPinned;
use std::mem::ManuallyDrop;
use std::pin::{Pin, pin};
use std::ptr;
use std::sync::atomic::Ordering;
use saa::lock::Mode as LockMode;
use saa::{Gate, Lock, Pager};
use sdd::{AtomicShared, Guard};
use crate::exit_guard::ExitGuard;
#[derive(Debug, Default)]
pub(crate) struct AsyncPager {
pager: UnsafeCell<ManuallyDrop<Pager<'static, Lock>>>,
_pinned: PhantomPinned,
}
#[derive(Debug)]
pub(crate) struct SendableGuard {
guard: UnsafeCell<Option<Guard>>,
pager: &'static AsyncPager,
}
#[derive(Debug, Default)]
pub(crate) struct WaitQueue {
gate: Gate,
}
#[derive(Debug)]
pub(crate) struct AsyncWait {
pager: Pager<'static, Gate>,
}
pub(crate) trait DeriveAsyncWait {
fn derive(&mut self) -> Option<&mut AsyncWait>;
}
impl SendableGuard {
pub(crate) fn new(async_pager: &AsyncPager) -> Self {
Self {
guard: UnsafeCell::new(None),
pager: unsafe { std::mem::transmute::<&AsyncPager, &'static AsyncPager>(async_pager) },
}
}
#[inline]
pub(crate) fn has_guard(&self) -> bool {
unsafe { (*self.guard.get()).is_some() }
}
pub(crate) async fn wait_acquire<'l>(&self, lock: &'l Lock, writer: bool) -> bool {
let mut unwind_guard = ExitGuard::new(true, |unwind| {
if unwind {
self.guard();
}
});
let mut pinned_pager = unsafe {
let pager_ref = std::mem::transmute::<&mut Pager<'static, Lock>, &mut Pager<'l, Lock>>(
&mut **self.pager.pager.get(),
);
Pin::new_unchecked(pager_ref)
};
let lock_mode = if writer {
LockMode::Exclusive
} else {
LockMode::Shared
};
lock.register_pager(&mut pinned_pager, lock_mode, false);
self.reset();
let result = pinned_pager.poll_async().await.ok().unwrap_or(false);
*unwind_guard = false;
result
}
#[inline]
pub(crate) fn guard(&self) -> &Guard {
unsafe { (*self.guard.get()).get_or_insert_with(Guard::new) }
}
#[inline]
pub(crate) fn reset(&self) {
unsafe {
*self.guard.get() = None;
}
}
#[inline]
pub(crate) fn load<T>(&self, atomic_ptr: &AtomicShared<T>, mo: Ordering) -> Option<&T> {
atomic_ptr.load(mo, self.guard()).as_ref()
}
#[inline]
pub(crate) fn check_ref<T>(&self, atomic_ptr: &AtomicShared<T>, r: &T, mo: Ordering) -> bool {
atomic_ptr
.load(mo, self.guard())
.as_ref()
.is_some_and(|s| ptr::eq(s, r))
}
}
impl Drop for SendableGuard {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut *self.pager.pager.get());
}
}
}
unsafe impl Send for SendableGuard {}
unsafe impl Sync for SendableGuard {}
impl WaitQueue {
#[inline]
pub(crate) fn wait_sync<T, F: FnOnce() -> Result<T, ()>>(&self, f: F) -> Result<T, ()> {
if cfg!(miri) {
return f();
}
let mut pinned_pager = pin!(Pager::default());
self.gate.register_pager(&mut pinned_pager, true);
let result = f();
if result.is_ok() {
self.signal();
}
let _: Result<_, _> = pinned_pager.poll_sync();
result
}
#[inline]
pub(crate) fn push_async_entry<'w, T, F: FnOnce() -> Result<T, ()>>(
&'w self,
async_wait: &'w mut AsyncWait,
f: F,
) -> Result<T, ()> {
let mut pinned_pager = unsafe {
let pager_ref = std::mem::transmute::<&mut Pager<'static, Gate>, &mut Pager<'w, Gate>>(
&mut async_wait.pager,
);
Pin::new_unchecked(pager_ref)
};
self.gate.register_pager(&mut pinned_pager, false);
if let Ok(result) = f() {
self.signal();
if pinned_pager.try_poll().is_ok() {
async_wait.pager = Pager::default();
return Ok(result);
}
}
Err(())
}
#[inline]
pub(crate) fn signal(&self) {
let _: Result<_, _> = self.gate.permit();
}
}
impl AsyncWait {
#[inline]
pub async fn wait(self: &mut Pin<&mut Self>) {
let this = unsafe { ptr::read(self) };
let mut pinned_pager = unsafe { Pin::new_unchecked(&mut this.get_unchecked_mut().pager) };
let _result = pinned_pager.poll_async().await;
}
}
impl Default for AsyncWait {
#[inline]
fn default() -> Self {
Self {
pager: unsafe {
std::mem::transmute::<Pager<'_, Gate>, Pager<'static, Gate>>(Pager::default())
},
}
}
}
impl DeriveAsyncWait for Pin<&mut AsyncWait> {
#[inline]
fn derive(&mut self) -> Option<&mut AsyncWait> {
let this = unsafe { ptr::read(self) };
Some(unsafe { this.get_unchecked_mut() })
}
}
impl DeriveAsyncWait for () {
#[inline]
fn derive(&mut self) -> Option<&mut AsyncWait> {
None
}
}