use std::marker::PhantomData;
use std::ptr::{eq, null, null_mut, NonNull};
use std::sync::atomic::AtomicPtr;
use std::sync::atomic::Ordering::SeqCst;
use crate::smart_ptrs::{find_inner_ptr, ArcInner, Guard, CTX};
use crate::Arc;
#[derive(Default)]
pub struct AtomicArc<T: 'static> {
ptr: AtomicPtr<ArcInner<T>>,
phantom: PhantomData<ArcInner<T>>,
}
impl<T: 'static> AtomicArc<T> {
pub fn new<D: Into<Option<T>>>(data: D) -> Self {
let ptr = data.into().map_or(null_mut(), ArcInner::new);
Self {
ptr: AtomicPtr::new(ptr),
phantom: PhantomData,
}
}
pub fn load(&self) -> Option<Guard<'static, T>> {
let guard = CTX.with_borrow(|ctx| ctx.load(&self.ptr, 1))?;
Some(Guard { guard })
}
pub fn swap<N: Into<NonNull<T>>>(&self, new: Option<N>) -> Option<Arc<T>> {
unsafe {
let n = new.map_or(null_mut(), |n| find_inner_ptr(n.into().as_ptr()).cast_mut());
if !n.is_null() {
ArcInner::increment(n);
}
let before = NonNull::new(self.ptr.swap(n, SeqCst))?;
Some(Arc {
ptr: before,
phantom: PhantomData,
})
}
}
pub fn store<N: Into<NonNull<T>>>(&self, new: Option<N>) {
_ = self.swap(new)
}
}
pub trait CompareExchange<T, N> {
fn compare_exchange<C: Into<NonNull<T>>>(
&self,
current: Option<C>,
new: Option<N>,
) -> Result<(), Option<Guard<'static, T>>>;
}
impl<T: 'static> CompareExchange<T, &Guard<'static, T>> for AtomicArc<T> {
fn compare_exchange<C: Into<NonNull<T>>>(
&self,
current: Option<C>,
new: Option<&Guard<'static, T>>,
) -> Result<(), Option<Guard<'static, T>>> {
unsafe {
let c = current.map_or(null_mut(), |c| find_inner_ptr(c.into().as_ptr()).cast_mut());
let n = new.map_or(null(), Guard::inner_ptr).cast_mut();
match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
Ok(before) => {
if !eq(before, n) {
if !n.is_null() {
ArcInner::increment(n);
}
if !before.is_null() {
ArcInner::delayed_decrement(before);
}
}
Ok(())
}
Err(actual) => {
if let Some(ptr) = NonNull::new(actual) {
let mut opt = None;
let loaded = CTX.with_borrow(|ctx| ctx.protect(&self.ptr, ptr, 1));
if let Some(guard) = loaded {
opt = Some(Guard { guard })
}
Err(opt)
} else {
Err(None)
}
}
}
}
}
}
impl<T: 'static> CompareExchange<T, &Arc<T>> for AtomicArc<T> {
fn compare_exchange<C: Into<NonNull<T>>>(
&self,
current: Option<C>,
new: Option<&Arc<T>>,
) -> Result<(), Option<Guard<'static, T>>> {
let g = new.map(Guard::from);
CompareExchange::compare_exchange(self, current, g.as_ref())
}
}
impl<T: 'static> Clone for AtomicArc<T> {
fn clone(&self) -> Self {
let ptr = if let Some(guard) = self.load() {
unsafe {
let ptr = guard.guard.as_ptr();
_ = (*ptr).ref_count.fetch_add(1, SeqCst);
ptr
}
} else {
null_mut()
};
Self {
ptr: AtomicPtr::new(ptr.cast_mut()),
phantom: PhantomData,
}
}
}
impl<T: 'static> Drop for AtomicArc<T> {
fn drop(&mut self) {
if let Some(ptr) = NonNull::new(self.ptr.load(SeqCst)) {
unsafe {
ArcInner::delayed_decrement(ptr.as_ptr());
}
}
}
}
unsafe impl<T: 'static + Send + Sync> Send for AtomicArc<T> {}
unsafe impl<T: 'static + Send + Sync> Sync for AtomicArc<T> {}
impl<T: 'static, P: Into<NonNull<T>>> From<P> for AtomicArc<T> {
fn from(value: P) -> Self {
unsafe {
let inner_ptr = find_inner_ptr(value.into().as_ptr());
_ = (*inner_ptr).ref_count.fetch_add(1, SeqCst);
Self {
ptr: AtomicPtr::new(inner_ptr.cast_mut()),
phantom: PhantomData,
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{Arc, AtomicArc, CompareExchange};
#[test]
fn test_new_with_value() {
let atomic = AtomicArc::new(42);
let guard = atomic.load().unwrap();
assert_eq!(*guard, 42);
}
#[test]
fn test_new_with_none() {
let atomic: AtomicArc<i32> = AtomicArc::new(None);
assert!(atomic.load().is_none());
}
#[test]
fn test_swap() {
let atomic = AtomicArc::new(10);
let arc = Arc::new(20);
let old = atomic.swap(Some(&arc));
assert!(old.is_some());
assert_eq!(*old.unwrap(), 10);
let guard = atomic.load().unwrap();
assert_eq!(*guard, 20);
}
#[test]
fn test_swap_none() {
let atomic = AtomicArc::new(10);
let old = atomic.swap::<&Arc<i32>>(None);
assert!(old.is_some());
assert_eq!(*old.unwrap(), 10);
assert!(atomic.load().is_none());
}
#[test]
fn test_clone() {
let atomic = AtomicArc::new(42);
let cloned = atomic.clone();
let guard1 = atomic.load().unwrap();
let guard2 = cloned.load().unwrap();
assert_eq!(*guard1, 42);
assert_eq!(*guard2, 42);
}
#[test]
fn test_clone_none() {
let atomic: AtomicArc<i32> = AtomicArc::new(None);
let cloned = atomic.clone();
assert!(atomic.load().is_none());
assert!(cloned.load().is_none());
}
#[test]
fn test_compare_exchange_success_with_arc() {
let arc1 = Arc::new(10);
let arc2 = Arc::new(20);
let atomic = AtomicArc::new(10);
atomic.store(Some(&arc1));
let result = atomic.compare_exchange(Some(&arc1), Some(&arc2));
assert!(result.is_ok());
let guard = atomic.load().unwrap();
assert_eq!(*guard, 20);
}
#[test]
fn test_compare_exchange_failure_with_arc() {
let arc1 = Arc::new(10);
let arc2 = Arc::new(20);
let arc3 = Arc::new(30);
let atomic = AtomicArc::new(10);
atomic.store(Some(&arc1));
let result = atomic.compare_exchange(Some(&arc2), Some(&arc3));
assert!(result.is_err());
let guard = atomic.load().unwrap();
assert_eq!(*guard, 10);
}
#[test]
fn test_compare_exchange_with_guard() {
let arc1 = Arc::new(10);
let arc2 = Arc::new(20);
let atomic = AtomicArc::new(10);
atomic.store(Some(&arc1));
let guard = atomic.load().unwrap();
let result = atomic.compare_exchange(Some(&guard), Some(&arc2));
assert!(result.is_ok());
let new_guard = atomic.load().unwrap();
assert_eq!(*new_guard, 20);
}
#[test]
fn test_from_arc() {
let arc = Arc::new(42);
let atomic = AtomicArc::new(0);
atomic.store(Some(&arc));
let guard = atomic.load().unwrap();
assert_eq!(*guard, 42);
assert_eq!(*arc, 42);
}
}