use std::fmt;
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex, PoisonError};
use crate::OnceArc;
pub struct InitOnceArc<T> {
inner: OnceArc<T>,
init_mutex: Mutex<()>,
}
unsafe impl<T: Send + Sync> Send for InitOnceArc<T> {}
unsafe impl<T: Send + Sync> Sync for InitOnceArc<T> {}
impl<T> InitOnceArc<T> {
pub const fn new() -> Self {
Self {
inner: OnceArc::new(),
init_mutex: Mutex::new(()),
}
}
pub fn store(&self, value: Arc<T>) -> Result<(), Result<Arc<T>, PoisonError<()>>> {
let _guard = self.init_mutex.lock().map_err(|_| Err(PoisonError::new(())))?;
self.inner.store(value, Ordering::SeqCst).map_err(Ok)
}
pub fn init(&self, f: impl FnOnce() -> Arc<T>) -> Result<bool, PoisonError<()>> {
let _guard = self.init_mutex.lock().map_err(|_| PoisonError::new(()))?;
if self.inner.is_set(Ordering::SeqCst) {
return Ok(false);
}
let arc = f();
self
.inner
.store(arc, Ordering::SeqCst)
.unwrap_or_else(|_| unreachable!("store failed while holding init mutex"));
Ok(true)
}
pub fn try_init<E>(&self, f: impl FnOnce() -> Result<Arc<T>, E>) -> Result<bool, Result<E, PoisonError<()>>> {
let _guard = self.init_mutex.lock().map_err(|_| Err(PoisonError::new(())))?;
if self.inner.is_set(Ordering::SeqCst) {
return Ok(false);
}
let value = match f() {
Ok(v) => v,
Err(e) => return Err(Ok(e)),
};
self
.inner
.store(value, Ordering::SeqCst)
.unwrap_or_else(|_| unreachable!("store failed while holding init mutex"));
Ok(true)
}
pub fn get(&self, ordering: Ordering) -> Option<&T> {
self.inner.get(ordering)
}
pub fn load(&self, ordering: Ordering) -> Option<Arc<T>> {
self.inner.load(ordering)
}
pub fn is_set(&self, ordering: Ordering) -> bool {
self.inner.is_set(ordering)
}
pub fn into_inner(self) -> Option<Arc<T>> {
self.inner.into_inner()
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.inner.get_mut()
}
}
impl<T> Default for InitOnceArc<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: fmt::Debug> fmt::Debug for InitOnceArc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InitOnceArc")
.field("value", &self.inner.get(Ordering::SeqCst))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::Ordering;
#[test]
fn once_arc_init() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
assert!(cell.init(|| Arc::new(42)).unwrap());
assert!(!cell.init(|| Arc::new(99)).unwrap());
assert_eq!(cell.get(Ordering::Acquire), Some(&42));
}
#[test]
fn once_arc_get_empty() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
assert!(cell.get(Ordering::Acquire).is_none());
}
#[test]
fn once_arc_try_init_err_then_ok() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
let err = cell.try_init(|| Err("fail")).unwrap_err().unwrap();
assert_eq!(err, "fail");
assert!(cell.get(Ordering::Acquire).is_none());
assert!(cell.try_init(|| Ok::<_, &str>(Arc::new(42))).unwrap());
assert_eq!(cell.get(Ordering::Acquire), Some(&42));
assert!(!cell.try_init(|| Ok::<_, &str>(Arc::new(99))).unwrap());
}
#[test]
fn once_arc_load_and_is_set() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
assert!(!cell.is_set(Ordering::Relaxed));
assert!(cell.load(Ordering::Acquire).is_none());
cell.init(|| Arc::new(7)).unwrap();
assert!(cell.is_set(Ordering::Relaxed));
assert_eq!(*cell.load(Ordering::Acquire).unwrap(), 7);
}
#[test]
fn once_arc_into_inner() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
cell.init(|| Arc::new(42)).unwrap();
let arc = cell.into_inner().unwrap();
assert_eq!(*arc, 42);
}
#[test]
fn once_arc_into_inner_empty() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
assert!(cell.into_inner().is_none());
}
#[test]
fn once_arc_concurrent_init() {
use std::sync::Barrier;
use std::sync::atomic::AtomicUsize;
use std::thread;
let cell = Arc::new(InitOnceArc::<i32>::new());
let init_count = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(10));
let mut handles = Vec::new();
for _ in 0..10 {
let cell = cell.clone();
let init_count = init_count.clone();
let barrier = barrier.clone();
handles.push(thread::spawn(move || {
barrier.wait();
cell
.init(|| {
init_count.fetch_add(1, Ordering::Relaxed);
Arc::new(42)
})
.unwrap()
}));
}
let results: Vec<bool> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert_eq!(results.iter().filter(|&&b| b).count(), 1);
assert_eq!(cell.get(Ordering::Acquire), Some(&42));
assert_eq!(init_count.load(Ordering::Relaxed), 1);
}
#[test]
fn once_arc_debug_fmt() {
let cell: InitOnceArc<i32> = InitOnceArc::new();
cell.init(|| Arc::new(42)).unwrap();
let dbg = format!("{:?}", cell);
assert!(dbg.contains("42"));
}
fn poisoned_cell() -> Arc<InitOnceArc<i32>> {
use std::thread;
let cell = Arc::new(InitOnceArc::<i32>::new());
let c = cell.clone();
let _ = thread::spawn(move || {
let _ = c.init(|| panic!("deliberate poison"));
})
.join();
cell
}
#[test]
fn store_returns_poison_error() {
let cell = poisoned_cell();
let err = cell.store(Arc::new(1)).unwrap_err();
assert!(err.is_err()); }
#[test]
fn init_returns_poison_error() {
let cell = poisoned_cell();
assert!(cell.init(|| Arc::new(1)).is_err());
}
#[test]
fn try_init_returns_poison_error() {
let cell = poisoned_cell();
let err = cell.try_init(|| Ok::<_, &str>(Arc::new(1))).unwrap_err();
assert!(err.is_err()); }
}