use core::marker::PhantomData;
use core::ptr::NonNull;
use allocator_api2::alloc::{Allocator, Global};
use widestring::Utf16Str;
use crate::internal::chunk_ref::ChunkRef;
use crate::strings::utf16_str_common::impl_utf16_str_common;
pub struct ArcUtf16Str<A: Allocator + Clone = Global> {
ptr: NonNull<u16>,
_phantom: PhantomData<(*const Utf16Str, A)>,
}
unsafe impl<A: Allocator + Clone + Send + Sync> Send for ArcUtf16Str<A> {}
unsafe impl<A: Allocator + Clone + Send + Sync> Sync for ArcUtf16Str<A> {}
impl<A: Allocator + Clone> ArcUtf16Str<A> {
#[inline]
pub(crate) unsafe fn from_raw(ptr: NonNull<u16>) -> Self {
Self {
ptr,
_phantom: PhantomData,
}
}
}
impl_utf16_str_common!(ArcUtf16Str);
impl<A: Allocator + Clone> Clone for ArcUtf16Str<A> {
#[inline]
fn clone(&self) -> Self {
let strong = unsafe { crate::internal::thin_dst::strong_ref::<[u16]>(self.ptr.cast::<u8>(), core::mem::align_of::<u16>()) };
let prev = strong.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
if prev > (u32::MAX >> 1) {
crate::internal::constants::refcount_overflow_abort();
}
Self {
ptr: self.ptr,
_phantom: PhantomData,
}
}
}
impl<A: Allocator + Clone> Drop for ArcUtf16Str<A> {
#[inline]
fn drop(&mut self) {
let strong = unsafe { crate::internal::thin_dst::strong_ref::<[u16]>(self.ptr.cast::<u8>(), core::mem::align_of::<u16>()) };
if strong.fetch_sub(1, core::sync::atomic::Ordering::Release) != 1 {
return;
}
core::sync::atomic::fence(core::sync::atomic::Ordering::Acquire);
unsafe {
let _ref: ChunkRef<A> = ChunkRef::from_value_ptr(self.ptr);
}
}
}
impl<A: Allocator + Clone> From<ArcUtf16Str<A>> for crate::Arc<[u16], A> {
#[inline]
fn from(s: ArcUtf16Str<A>) -> Self {
use core::mem::ManuallyDrop;
let me = ManuallyDrop::new(s);
unsafe { Self::from_raw(me.ptr.cast::<u8>()) }
}
}
#[cfg(test)]
mod tests {
use core::sync::atomic::{AtomicU32, Ordering};
use super::*;
use crate::Arena;
use crate::internal::thin_dst::strong_ref;
fn strong_of<A: Allocator + Clone>(s: &ArcUtf16Str<A>) -> &AtomicU32 {
unsafe { strong_ref::<[u16]>(s.ptr.cast::<u8>(), core::mem::align_of::<u16>()) }
}
#[test]
fn drop_decrements_strong_count() {
let arena = Arena::new();
let s = arena.alloc_utf16_str_arc_from_str("hi");
let strong = strong_of(&s);
let base = strong.load(Ordering::Relaxed);
let s2 = s.clone();
assert_eq!(strong.load(Ordering::Relaxed), base + 1, "clone must bump the strong count");
drop(s2);
assert_eq!(strong.load(Ordering::Relaxed), base, "drop must decrement the strong count");
}
#[test]
fn clone_at_max_refcount_threshold_does_not_abort() {
let arena = Arena::new();
let s = arena.alloc_utf16_str_arc_from_str("hi");
let strong = strong_of(&s);
strong.store(u32::MAX >> 1, Ordering::Relaxed);
let clone = s.clone();
strong.store(2, Ordering::Relaxed);
drop(clone);
}
#[test]
#[should_panic(expected = "refcount overflow")]
fn clone_above_max_refcount_threshold_aborts() {
let arena = Arena::new();
let s = arena.alloc_utf16_str_arc_from_str("hi");
let strong = strong_of(&s);
strong.store((u32::MAX >> 1) + 1, Ordering::Relaxed);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _c = s.clone();
}));
strong.store(1, Ordering::Relaxed);
std::panic::resume_unwind(result.expect_err("clone past the threshold must panic"));
}
}