#![cfg_attr(feature = "_shuttle", doc = "```ignore")]
#![cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
#![cfg_attr(feature = "_shuttle", doc = "```ignore")]
#![cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
use std::marker::PhantomData;
#[cfg(feature = "_shuttle")]
pub(crate) use shuttle::sync::{
atomic::{AtomicPtr, Ordering},
Once,
};
#[cfg(not(feature = "_shuttle"))]
pub(crate) use std::sync::{
atomic::{AtomicPtr, Ordering},
Once,
};
#[cfg_attr(feature = "_shuttle", doc = "```ignore")]
#[cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
#[cfg_attr(feature = "_shuttle", doc = "```ignore")]
#[cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
#[derive(Debug)]
pub struct TakeOnce<T> {
once: Once,
value: AtomicPtr<T>,
_marker: PhantomData<T>,
}
impl<T> TakeOnce<T> {
#[inline]
#[must_use]
pub const fn new() -> TakeOnce<T> {
TakeOnce {
once: Once::new(),
value: AtomicPtr::new(std::ptr::null_mut()),
_marker: PhantomData,
}
}
#[cfg_attr(feature = "_shuttle", doc = "```ignore")]
#[cfg_attr(not(feature = "_shuttle"), doc = "```rust")]
#[must_use]
pub fn new_with(val: T) -> TakeOnce<T> {
let cell = TakeOnce::new();
let _ = cell.store(val);
cell
}
#[inline]
pub fn store(&self, val: T) -> Result<(), T> {
let mut val = Some(val);
self.once.call_once(|| {
let val = val.take().unwrap();
let ptr = Box::into_raw(Box::new(val));
self.value.store(ptr, Ordering::Release);
});
val.map_or(Ok(()), Err)
}
#[inline]
#[must_use]
pub fn take(&self) -> Option<T> {
if self.once.is_completed() {
let ptr = self.value.swap(std::ptr::null_mut(), Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(*unsafe { Box::from_raw(ptr) })
}
} else {
None
}
}
#[inline]
#[must_use]
pub fn is_completed(&self) -> bool {
self.once.is_completed()
}
}
impl<T> Default for TakeOnce<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Drop for TakeOnce<T> {
fn drop(&mut self) {
if self.once.is_completed() {
let ptr = self.value.swap(std::ptr::null_mut(), Ordering::Acquire);
if !ptr.is_null() {
drop(unsafe { Box::from_raw(ptr) });
}
}
}
}
unsafe impl<T: Send> Send for TakeOnce<T> {}
unsafe impl<T: Send> Sync for TakeOnce<T> {}
#[cfg(test)]
mod tests {
use super::TakeOnce;
use shuttle::sync::Arc;
use shuttle::thread;
const SHUTTLE_ITERS: usize = 10_000;
#[test]
fn concurrent_store_operations() {
shuttle::check_random(
|| {
let once_take = Arc::new(TakeOnce::new());
let num_threads = 6;
let threads: Vec<_> = (0..num_threads)
.map(|i| {
let once_take = once_take.clone();
thread::spawn(move || once_take.store(i))
})
.collect();
let results: Vec<_> = threads.into_iter().map(|t| t.join().unwrap()).collect();
assert_eq!(results.iter().filter(|r| r.is_ok()).count(), 1);
assert_eq!(
results.iter().filter(|r| r.is_err()).count(),
num_threads - 1
);
},
SHUTTLE_ITERS,
);
}
#[test]
fn concurrent_take_operations() {
shuttle::check_random(
|| {
let once_take = Arc::new(TakeOnce::new());
assert_eq!(once_take.store(42), Ok(()));
let threads: Vec<_> = (0..3)
.map(|_| {
let once_take = once_take.clone();
thread::spawn(move || once_take.take())
})
.collect();
let results: Vec<_> = threads.into_iter().map(|t| t.join().unwrap()).collect();
assert_eq!(results.iter().filter(|r| r.is_some()).count(), 1);
assert!(results.iter().any(|r| r == &Some(42)));
assert_eq!(results.iter().filter(|r| r.is_none()).count(), 2);
},
SHUTTLE_ITERS,
);
}
#[test]
fn mixed_store_take_operations() {
shuttle::check_random(
|| {
let once_take = Arc::new(TakeOnce::new());
let num_threads = 6;
let threads: Vec<_> = (0..num_threads)
.map(|i| {
let once_take = once_take.clone();
thread::spawn(move || {
if i % 2 == 0 {
once_take.store(i)
} else {
once_take.take().map_or(Err(i), |_| Ok(()))
}
})
})
.collect();
let results = threads
.into_iter()
.map(|t| t.join().unwrap())
.collect::<Vec<_>>();
assert!(results.iter().any(|r| r.is_ok()));
},
SHUTTLE_ITERS,
);
}
#[test]
fn completion_status_consistency() {
shuttle::check_random(
|| {
let once_take = Arc::new(TakeOnce::new());
let _once_take2 = once_take.clone();
assert!(!once_take.is_completed());
let t1 = thread::spawn(move || {
once_take.store(42).unwrap();
once_take.is_completed()
});
let completed_after_store = t1.join().unwrap();
assert!(completed_after_store);
},
SHUTTLE_ITERS,
);
}
#[test]
fn store_take_ordering() {
shuttle::check_random(
|| {
let once_take = Arc::new(TakeOnce::new());
let once_take2 = once_take.clone();
let once_take3 = once_take.clone();
let t1 = thread::spawn(move || {
once_take.store(42).unwrap();
true
});
let t2 = thread::spawn(move || {
if once_take2.is_completed() {
once_take2.take()
} else {
None
}
});
let t3 = thread::spawn(move || {
if once_take3.is_completed() {
assert!(once_take3.take().is_some() || once_take3.take().is_none());
}
});
assert!(t1.join().unwrap());
t2.join().unwrap();
t3.join().unwrap();
},
SHUTTLE_ITERS,
);
}
#[test]
fn drop_consistency() {
shuttle::check_random(
|| {
let once_take = Arc::new(TakeOnce::new());
let once_take2 = once_take.clone();
#[derive(Debug, PartialEq)]
struct DropTest(i32);
static DROPPED: shuttle::sync::Once = shuttle::sync::Once::new();
impl Drop for DropTest {
fn drop(&mut self) {
let mut called = false;
DROPPED.call_once(|| called = true);
assert!(called);
}
}
once_take.store(DropTest(42)).unwrap();
let t = thread::spawn(move || {
once_take2.take()
});
let taken = t.join().unwrap();
if let Some(val) = taken {
assert_eq!(val.0, 42);
drop(val);
}
assert!(DROPPED.is_completed());
},
SHUTTLE_ITERS,
);
}
}