use crate::{flags::to_flags, flags::RundownFlags, guard::RundownGuard};
use lazy_init::Lazy;
use rsevents::{Awaitable, ManualResetEvent, State};
use std::{result::Result, sync::atomic::AtomicU64, sync::atomic::Ordering};
#[derive(Debug, PartialEq)]
pub enum RundownError {
RundownInProgress,
}
#[derive(Default)]
pub struct RundownRef {
ref_count: AtomicU64,
event: Lazy<ManualResetEvent>,
}
const ORDERING_VAL: Ordering = Ordering::SeqCst;
impl RundownRef {
#[inline]
pub fn new() -> Self {
Self::default()
}
pub fn re_init(&self) {
let current = self.load_flags();
if current.is_pre_rundown() || current.is_ref_active() {
panic!("Attempt to re-init before rundown is complete");
}
if let Some(event) = self.event.get() {
event.reset();
}
self.ref_count.store(0, ORDERING_VAL);
}
pub fn try_acquire(&self) -> Result<RundownGuard<'_>, RundownError> {
let mut current = self.load_flags();
loop {
if current.is_rundown_in_progress() {
return Err(RundownError::RundownInProgress);
}
let new_bits_with_ref = current.add_ref();
match self.compare_exchange(current.bits(), new_bits_with_ref) {
Ok(_) => return Ok(RundownGuard::new(self)),
Err(new_current) => current = to_flags(new_current),
}
}
}
pub fn release(&self) {
let mut current = self.load_flags();
loop {
let bits_with_decrement = current.dec_ref();
match self.compare_exchange(current.bits(), bits_with_decrement) {
Ok(_) => {
current = to_flags(bits_with_decrement);
break;
}
Err(new_current) => current = to_flags(new_current),
}
}
if current.is_ref_zero() && current.is_rundown_in_progress() {
let event = self.event.get().expect("Must have been set");
event.set();
}
}
pub fn wait_for_rundown(&self) {
let mut current = self.load_flags();
loop {
if current.is_ref_active() {
self.event
.get_or_create(|| ManualResetEvent::new(State::Unset));
}
let bits_with_rundown = current.set_rundown_in_progress();
match self.compare_exchange(current.bits(), bits_with_rundown) {
Ok(_) => {
current = to_flags(bits_with_rundown);
break;
}
Err(new_current) => current = to_flags(new_current),
}
}
if current.is_ref_active() {
let event = self.event.get().expect("Must have been set");
event.wait();
}
}
#[inline]
fn load_flags(&self) -> RundownFlags {
to_flags(self.ref_count.load(ORDERING_VAL))
}
#[inline]
fn compare_exchange(&self, current: u64, new: u64) -> Result<u64, u64> {
self.ref_count
.compare_exchange(current, new, ORDERING_VAL, ORDERING_VAL)
}
}
#[cfg(test)]
use std::sync::Arc;
#[cfg(test)]
use std::thread;
#[test]
#[allow(clippy::result_unwrap_used)]
fn test_wait_when_protected() {
let rundown = Arc::new(RundownRef::new());
let guard = rundown.try_acquire().unwrap();
let rundown_clone = Arc::clone(&rundown);
let waiter = thread::spawn(move || {
rundown_clone.wait_for_rundown();
});
while rundown.load_flags().is_pre_rundown() {
thread::yield_now();
}
std::mem::drop(guard);
waiter.join().unwrap();
rundown.re_init();
}