use crate::sync::atomic::AtomicPtr;
use crate::{record::HazPtrRecord, Domain};
use core::marker::PhantomData;
use core::mem::{ManuallyDrop, MaybeUninit};
use core::ptr::NonNull;
use core::sync::atomic::Ordering;
#[cfg(doc)]
use crate::*;
pub struct HazardPointer<'domain, F = crate::Global> {
hazard: &'domain HazPtrRecord,
pub(crate) domain: &'domain Domain<F>,
}
impl Default for HazardPointer<'static, crate::Global> {
fn default() -> Self {
Self::new()
}
}
impl HazardPointer<'static, crate::Global> {
pub fn new() -> Self {
HazardPointer::new_in_domain(Domain::global())
}
pub fn many<const N: usize>() -> HazardPointerArray<'static, crate::Global, N> {
HazardPointer::many_in_domain(Domain::global())
}
}
impl<'domain, F> HazardPointer<'domain, F> {
pub fn new_in_domain(domain: &'domain Domain<F>) -> Self {
Self {
hazard: domain.acquire(),
domain,
}
}
pub fn many_in_domain<const N: usize>(
domain: &'domain Domain<F>,
) -> HazardPointerArray<'domain, F, N> {
let haz_ptrs = domain
.acquire_many::<N>()
.map(|hazard| ManuallyDrop::new(HazardPointer { hazard, domain }));
HazardPointerArray { haz_ptrs }
}
pub unsafe fn protect<'l, T>(&'l mut self, src: &'_ AtomicPtr<T>) -> Option<&'l T>
where
T: Sync,
F: 'static,
{
let (ptr, _proof): (_, PhantomData<&'l T>) = self.protect_ptr(src)?;
Some(unsafe { ptr.as_ref() })
}
pub fn protect_ptr<'l, T>(
&'l mut self,
src: &'_ AtomicPtr<T>,
) -> Option<(NonNull<T>, PhantomData<&'l T>)>
where
F: 'static,
{
let mut ptr = src.load(Ordering::Relaxed);
loop {
match self.try_protect_ptr(ptr, src) {
Ok(None) => break None,
Ok(Some((ptr, _h))) => {
break Some((ptr, PhantomData));
}
Err(ptr2) => {
ptr = ptr2;
}
}
}
}
pub unsafe fn try_protect<'l, T>(
&'l mut self,
ptr: *mut T,
src: &'_ AtomicPtr<T>,
) -> Result<Option<&'l T>, *mut T>
where
T: Sync,
F: 'static,
{
if ptr.is_null() {
return Ok(None);
}
let ptr: Option<(_, PhantomData<&'l T>)> = self.try_protect_ptr(ptr, src)?;
let (ptr, _) = ptr.expect("ptr was not null, but try_protect_ptr returned null");
Ok(Some(unsafe { ptr.as_ref() }))
}
#[allow(clippy::type_complexity)]
pub fn try_protect_ptr<'l, T>(
&'l mut self,
ptr: *mut T,
src: &'_ AtomicPtr<T>,
) -> Result<Option<(NonNull<T>, PhantomData<&'l T>)>, *mut T>
where
F: 'static,
{
self.hazard.protect(ptr as *mut u8);
crate::asymmetric_light_barrier();
let ptr2 = src.load(Ordering::Acquire);
if ptr != ptr2 {
self.hazard.reset();
Err(ptr2)
} else {
Ok(core::ptr::NonNull::new(ptr).map(|ptr| (ptr, PhantomData)))
}
}
pub fn reset_protection(&mut self) {
self.hazard.reset();
}
pub fn protect_raw<T>(&mut self, ptr: *mut T)
where
F: 'static,
{
self.hazard.protect(ptr as *mut u8);
}
}
impl<F> Drop for HazardPointer<'_, F> {
fn drop(&mut self) {
self.hazard.reset();
self.domain.release(self.hazard);
}
}
pub struct HazardPointerArray<'domain, F, const N: usize> {
haz_ptrs: [ManuallyDrop<HazardPointer<'domain, F>>; N],
}
impl<const N: usize> Default for HazardPointerArray<'static, crate::Global, N> {
fn default() -> Self {
HazardPointer::many::<N>()
}
}
impl<'domain, F, const N: usize> HazardPointerArray<'domain, F, N> {
pub fn as_refs<'array>(&'array mut self) -> [&'array mut HazardPointer<'domain, F>; N] {
let mut out: [MaybeUninit<&'array mut HazardPointer<'domain, F>>; N] =
[(); N].map(|_| MaybeUninit::uninit());
for (i, hazptr) in self.haz_ptrs.iter_mut().enumerate() {
out[i].write(hazptr);
}
out.map(|maybe_uninit| unsafe { maybe_uninit.assume_init() })
}
pub unsafe fn protect_all<'l, T>(
&'l mut self,
mut sources: [&'_ AtomicPtr<T>; N],
) -> [Option<&'l T>; N]
where
T: Sync,
F: 'static,
{
let mut out = [None; N];
for (i, (hazptr, src)) in self.haz_ptrs.iter_mut().zip(&mut sources).enumerate() {
out[i] = unsafe { hazptr.protect(src) };
}
out
}
pub fn reset_protection(&mut self) {
for hazptr in self.haz_ptrs.iter_mut() {
hazptr.reset_protection();
}
}
}
impl<'domain, F, const N: usize> Drop for HazardPointerArray<'domain, F, N> {
fn drop(&mut self) {
self.reset_protection();
let domain = self.haz_ptrs[0].domain;
let records = self.as_refs().map(|hazptr| hazptr.hazard);
domain.release_many(records);
}
}