use super::spinlock::SpinRwLock;
use super::IntoOptionArc;
use std::mem;
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
pub struct AtomicOptionRef<T> {
ptr: AtomicPtr<T>,
lock: SpinRwLock,
}
impl<T> AtomicOptionRef<T> {
pub fn new() -> Self {
Self::default()
}
pub fn from(value: impl IntoOptionArc<T>) -> Self {
Self {
ptr: AtomicPtr::new(option_arc_to_ptr(value)),
lock: SpinRwLock::new(),
}
}
pub fn is_some(&self) -> bool {
self.ptr.load(Ordering::SeqCst).is_null()
}
pub fn load(&self) -> Option<Arc<T>> {
let _guard = self.lock.read();
ptr_to_option_arc(self.ptr.load(Ordering::SeqCst), true)
}
pub fn store(&self, value: impl IntoOptionArc<T>) {
self.swap(value);
}
pub fn swap(&self, value: impl IntoOptionArc<T>) -> Option<Arc<T>> {
let _guard = self.lock.write();
ptr_to_option_arc(
self.ptr.swap(option_arc_to_ptr(value), Ordering::SeqCst),
false,
)
}
}
impl<T> Default for AtomicOptionRef<T> {
fn default() -> Self {
Self::from(None)
}
}
impl<T> Drop for AtomicOptionRef<T> {
fn drop(&mut self) {
let ptr = self.ptr.swap(null_mut(), Ordering::SeqCst);
if !ptr.is_null() {
unsafe {
let _ = Arc::from_raw(ptr);
}
}
}
}
fn option_arc_to_ptr<T>(value: impl IntoOptionArc<T>) -> *mut T {
if let Some(value) = value.into_option_arc() {
Arc::into_raw(value) as *mut _
} else {
null_mut()
}
}
fn ptr_to_option_arc<T>(ptr: *mut T, increment: bool) -> Option<Arc<T>> {
if ptr.is_null() {
None
} else {
let value = unsafe { Arc::from_raw(ptr) };
if increment {
mem::forget(Arc::clone(&value));
}
Some(value)
}
}
#[cfg(test)]
mod tests {
use super::AtomicOptionRef;
#[test]
fn test_store_load() {
let m = AtomicOptionRef::<String>::new();
m.store(String::from("2"));
assert_eq!(m.load().unwrap().as_ref(), "2");
}
#[test]
fn test_overwrite() {
let m = AtomicOptionRef::<String>::new();
m.store(String::from("Hello World"));
let m0 = m.load();
m.store(String::from("Goodbye World"));
assert_eq!(m0.unwrap().as_ref(), "Hello World");
assert_eq!(m.load().unwrap().as_ref(), "Goodbye World");
}
#[test]
fn test_drop() {
use std::sync::atomic::{AtomicUsize, Ordering};
static DROPS: AtomicUsize = AtomicUsize::new(0);
struct Foo;
impl Drop for Foo {
fn drop(&mut self) {
DROPS.fetch_add(1, Ordering::SeqCst);
}
}
let m = AtomicOptionRef::<Foo>::new();
m.swap(Foo);
m.swap(Foo);
assert_eq!(DROPS.load(Ordering::SeqCst), 1);
}
#[test]
fn test_threads() {
use rand::{thread_rng, Rng};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
const THREADS: usize = 100;
const ITERATIONS: usize = 100;
static DROPS: AtomicUsize = AtomicUsize::new(0);
#[derive(Default)]
struct Foo {
dropped: AtomicUsize,
};
impl Drop for Foo {
fn drop(&mut self) {
self.dropped.fetch_add(1, Ordering::SeqCst);
DROPS.fetch_add(1, Ordering::SeqCst);
}
}
let m = Arc::new(AtomicOptionRef::<Foo>::new());
m.store(Foo::default());
let mut threads = Vec::new();
for _ in 0..THREADS {
let m0 = Arc::clone(&m);
threads.push(thread::spawn(move || {
for _ in 0..ITERATIONS {
let value = m0.load().unwrap();
assert_eq!(value.dropped.load(Ordering::SeqCst), 0);
let ms = thread_rng().gen_range(0, 10);
thread::sleep(Duration::from_millis(ms));
}
}));
let m1 = Arc::clone(&m);
threads.push(thread::spawn(move || {
for _ in 0..ITERATIONS {
m1.swap(Foo::default());
let ms = thread_rng().gen_range(0, 10);
thread::sleep(Duration::from_millis(ms));
}
}));
}
for thread in threads {
let _ = thread.join();
}
assert_eq!(DROPS.load(Ordering::SeqCst), (THREADS * ITERATIONS));
}
}