use std::fmt;
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, Ordering};
pub struct OnceArc<T> {
ptr: AtomicPtr<T>,
}
unsafe impl<T: Send + Sync> Send for OnceArc<T> {}
unsafe impl<T: Send + Sync> Sync for OnceArc<T> {}
impl<T> OnceArc<T> {
pub const fn new() -> Self {
Self {
ptr: AtomicPtr::new(ptr::null_mut()),
}
}
pub fn store(&self, value: Arc<T>, ordering: Ordering) -> Result<(), Arc<T>> {
let raw = Arc::into_raw(value) as *mut T;
match self
.ptr
.compare_exchange(ptr::null_mut(), raw, ordering, Ordering::Relaxed)
{
Ok(_) => Ok(()),
Err(_) => {
let value = unsafe { Arc::from_raw(raw) };
Err(value)
}
}
}
pub fn get(&self, ordering: Ordering) -> Option<&T> {
let ptr = self.ptr.load(ordering);
if ptr.is_null() {
None
} else {
Some(unsafe { &*ptr })
}
}
pub fn load(&self, ordering: Ordering) -> Option<Arc<T>> {
let ptr = self.ptr.load(ordering);
if ptr.is_null() {
None
} else {
unsafe { Arc::increment_strong_count(ptr) };
Some(unsafe { Arc::from_raw(ptr) })
}
}
pub fn is_set(&self, ordering: Ordering) -> bool {
!self.ptr.load(ordering).is_null()
}
pub fn into_inner(mut self) -> Option<Arc<T>> {
let ptr = *self.ptr.get_mut();
std::mem::forget(self); if ptr.is_null() {
None
} else {
Some(unsafe { Arc::from_raw(ptr) })
}
}
pub fn get_mut(&mut self) -> Option<&mut T> {
let ptr = *self.ptr.get_mut();
if ptr.is_null() {
None
} else {
Some(unsafe { &mut *ptr })
}
}
}
impl<T> Default for OnceArc<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: fmt::Debug> fmt::Debug for OnceArc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnceArc")
.field("value", &self.get(Ordering::SeqCst))
.finish()
}
}
impl<T> Drop for OnceArc<T> {
fn drop(&mut self) {
let ptr = *self.ptr.get_mut();
if !ptr.is_null() {
unsafe { drop(Arc::from_raw(ptr)) };
}
}
}
impl<T> From<Arc<T>> for OnceArc<T> {
fn from(value: Arc<T>) -> Self {
Self {
ptr: AtomicPtr::new(Arc::into_raw(value) as *mut T),
}
}
}
impl<T> From<Option<Arc<T>>> for OnceArc<T> {
fn from(value: Option<Arc<T>>) -> Self {
match value {
Some(arc) => Self::from(arc),
None => Self::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::Ordering;
#[test]
fn empty_loads_none() {
let slot: OnceArc<i32> = OnceArc::new();
assert!(slot.get(Ordering::Acquire).is_none());
assert!(slot.load(Ordering::Acquire).is_none());
assert!(!slot.is_set(Ordering::Relaxed));
}
#[test]
fn set_once_and_load() {
let slot: OnceArc<i32> = OnceArc::new();
slot.store(Arc::new(42), Ordering::Release).unwrap();
assert_eq!(*slot.get(Ordering::Acquire).unwrap(), 42);
assert!(slot.is_set(Ordering::Relaxed));
}
#[test]
fn set_twice_fails() {
let slot: OnceArc<i32> = OnceArc::new();
slot.store(Arc::new(1), Ordering::Release).unwrap();
let err = slot.store(Arc::new(2), Ordering::Release).unwrap_err();
assert_eq!(*err, 2);
assert_eq!(*slot.get(Ordering::Acquire).unwrap(), 1);
}
#[test]
fn load_returns_arc() {
let slot: OnceArc<&str> = OnceArc::new();
let original = Arc::new("hello");
slot.store(original.clone(), Ordering::Release).unwrap();
let loaded = slot.load(Ordering::Acquire).unwrap();
assert!(Arc::ptr_eq(&original, &loaded));
assert_eq!(Arc::strong_count(&original), 3); }
#[test]
fn into_inner_returns_arc() {
let slot: OnceArc<i32> = OnceArc::new();
let original = Arc::new(100);
slot.store(original.clone(), Ordering::Release).unwrap();
let inner = slot.into_inner().unwrap();
assert!(Arc::ptr_eq(&original, &inner));
assert_eq!(Arc::strong_count(&original), 2); }
#[test]
fn into_inner_empty() {
let slot: OnceArc<i32> = OnceArc::new();
assert!(slot.into_inner().is_none());
}
#[test]
fn drop_decrements_refcount() {
let arc = Arc::new(42);
assert_eq!(Arc::strong_count(&arc), 1);
{
let slot: OnceArc<i32> = OnceArc::new();
slot.store(arc.clone(), Ordering::Release).unwrap();
assert_eq!(Arc::strong_count(&arc), 2);
}
assert_eq!(Arc::strong_count(&arc), 1);
}
#[test]
fn from_arc() {
let slot = OnceArc::from(Arc::new(7));
assert_eq!(*slot.get(Ordering::Acquire).unwrap(), 7);
}
#[test]
fn from_none() {
let slot = OnceArc::<i32>::from(None);
assert!(slot.get(Ordering::Acquire).is_none());
}
#[test]
fn from_some() {
let slot = OnceArc::from(Some(Arc::new(55)));
assert_eq!(*slot.get(Ordering::Acquire).unwrap(), 55);
}
#[test]
fn debug_fmt() {
let slot: OnceArc<i32> = OnceArc::new();
slot.store(Arc::new(42), Ordering::Release).unwrap();
let dbg = format!("{:?}", slot);
assert!(dbg.contains("42"));
}
#[test]
fn concurrent_set_exactly_one_wins() {
use std::sync::Barrier;
use std::thread;
let slot = Arc::new(OnceArc::<i32>::new());
let barrier = Arc::new(Barrier::new(10));
let mut handles = Vec::new();
for i in 0..10 {
let slot = slot.clone();
let barrier = barrier.clone();
handles.push(thread::spawn(move || {
barrier.wait();
slot.store(Arc::new(i), Ordering::Release).is_ok()
}));
}
let successes: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert_eq!(successes.iter().filter(|&&s| s).count(), 1);
assert!(slot.is_set(Ordering::Relaxed));
}
#[test]
fn concurrent_loads_after_set() {
use std::sync::Barrier;
use std::thread;
let slot = Arc::new(OnceArc::from(Arc::new(99)));
let barrier = Arc::new(Barrier::new(10));
let mut handles = Vec::new();
for _ in 0..10 {
let slot = slot.clone();
let barrier = barrier.clone();
handles.push(thread::spawn(move || {
barrier.wait();
*slot.get(Ordering::Acquire).unwrap()
}));
}
for h in handles {
assert_eq!(h.join().unwrap(), 99);
}
}
#[test]
fn get_mut_empty() {
let mut slot: OnceArc<i32> = OnceArc::new();
assert!(slot.get_mut().is_none());
}
#[test]
fn get_mut_modifies_value() {
let mut slot: OnceArc<i32> = OnceArc::new();
slot.store(Arc::new(10), Ordering::Release).unwrap();
*slot.get_mut().unwrap() = 20;
assert_eq!(*slot.get(Ordering::Acquire).unwrap(), 20);
}
}