use std::{ptr::NonNull, sync::atomic::Ordering};
use crate::{heaped_object::GCHeapedObject, traceable::GCTraceable};
#[allow(dead_code)]
pub trait GCRef {
fn strong_ref(&self) -> usize;
fn weak_ref(&self) -> usize;
fn inc_ref(&self);
fn dec_ref(&self);
fn inc_weak_ref(&self);
fn dec_weak_ref(&self);
}
pub struct GCArc<T: GCTraceable + 'static> {
obj: NonNull<GCHeapedObject<T>>,
}
#[allow(dead_code)]
impl<T> GCArc<T>
where
T: GCTraceable + 'static,
{
pub fn new(obj: T) -> Self {
let heaped_obj = Box::new(GCHeapedObject::new(obj));
let obj_ptr = Box::into_raw(heaped_obj);
Self {
obj: NonNull::new(obj_ptr).expect("Unable to create GCArc"),
}
}
pub unsafe fn inc_ref(&self) {
unsafe {
self.obj
.as_ref()
.strong_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
pub unsafe fn dec_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.strong_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to decrement a GCArc with 0 strong references");
}
if self
.obj
.as_ref()
.strong_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
pub fn as_weak(&self) -> GCArcWeak<T> {
unsafe {
self.obj
.as_ref()
.weak_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
GCArcWeak { obj: self.obj }
}
pub fn is_marked(&self) -> bool {
unsafe { self.obj.as_ref().is_marked() }
}
pub fn mark_and_visit(&self) {
unsafe {
if self.obj.as_ref().is_marked() {
return;
}
self.obj.as_ref().mark();
}
self.visit();
}
pub fn unmark(&self) {
unsafe {
self.obj.as_ref().unmark();
}
}
pub fn as_ref(&self) -> &T {
unsafe { self.obj.as_ref().as_ref() }
}
pub fn get_mut(&mut self) -> &mut T {
self.try_as_mut().expect(
"Cannot get mutable reference: GCArc is not unique. \
Strong count > 1 or weak references exist. \
Consider using interior mutability (RefCell, Mutex, etc.) instead.",
)
}
pub fn try_as_mut(&mut self) -> Option<&mut T> {
let strong_count = unsafe { self.obj.as_ref().strong_rc.load(Ordering::SeqCst) };
let weak_count = unsafe { self.obj.as_ref().weak_rc.load(Ordering::SeqCst) };
if strong_count == 1 && weak_count == 0 {
Some(unsafe { self.obj.as_mut().as_mut() })
} else {
None
}
}
fn visit(&self) {
unsafe {
self.obj.as_ref().as_ref().visit();
}
}
pub(crate) fn ptr_eq(a: &GCArc<T>, b: &GCArc<T>) -> bool {
unsafe { std::ptr::eq(a.obj.as_ref(), b.obj.as_ref()) }
}
}
impl<T> Clone for GCArc<T>
where
T: GCTraceable + 'static,
{
fn clone(&self) -> Self {
unsafe {
self.obj
.as_ref()
.strong_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.obj
.as_ref()
.weak_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Self { obj: self.obj }
}
}
impl<T> GCRef for GCArc<T>
where
T: GCTraceable + 'static,
{
fn strong_ref(&self) -> usize {
unsafe { self.obj.as_ref().strong_ref() }
}
fn weak_ref(&self) -> usize {
unsafe { self.obj.as_ref().weak_ref() }
}
fn inc_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.strong_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to increment a GCArc with 0 strong references");
}
self.obj
.as_ref()
.strong_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
fn dec_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.strong_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to decrement a GCArc with 0 strong references");
}
if self
.obj
.as_ref()
.strong_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
fn inc_weak_ref(&self) {
unsafe {
self.obj
.as_ref()
.weak_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
fn dec_weak_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.weak_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to decrement a GCArc with 0 weak references");
}
if self
.obj
.as_ref()
.weak_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
}
impl<T> Drop for GCArc<T>
where
T: GCTraceable + 'static,
{
fn drop(&mut self) {
unsafe {
if self
.obj
.as_ref()
.strong_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to drop a GCArc with 0 strong references");
}
if self
.obj
.as_mut()
.strong_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
self.obj.as_mut().drop_value();
}
if self
.obj
.as_ref()
.weak_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to drop a GCArc with 0 weak references");
}
if self
.obj
.as_ref()
.weak_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
}
unsafe impl<T> Send for GCArc<T> where T: GCTraceable + 'static {}
unsafe impl<T> Sync for GCArc<T> where T: GCTraceable + 'static {}
pub struct GCArcWeak<T: GCTraceable + 'static> {
obj: NonNull<GCHeapedObject<T>>,
}
#[allow(dead_code)]
impl<T> GCArcWeak<T>
where
T: GCTraceable + 'static,
{
pub unsafe fn from_raw(obj: NonNull<GCHeapedObject<T>>) -> Self {
Self { obj }
}
pub fn upgrade(&self) -> Option<GCArc<T>> {
#[inline]
fn checked_increment(n: usize) -> Option<usize> {
if n == 0 {
return None;
}
if n >= usize::MAX / 2 {
panic!("Reference count overflow");
}
Some(n + 1)
}
unsafe {
if self.obj.as_ref().is_dropped() {
return None;
}
if self
.obj
.as_ref()
.strong_rc
.fetch_update(Ordering::SeqCst, Ordering::Relaxed, checked_increment)
.is_ok()
{
if self.obj.as_ref().is_dropped() {
self.obj.as_ref().strong_rc.fetch_sub(1, Ordering::SeqCst);
return None;
}
self.obj
.as_ref()
.weak_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Some(GCArc { obj: self.obj })
} else {
None
}
}
}
pub fn is_valid(&self) -> bool {
unsafe { self.obj.as_ref().strong_ref() > 0 }
}
}
impl<T> Clone for GCArcWeak<T>
where
T: GCTraceable + 'static,
{
fn clone(&self) -> Self {
unsafe {
self.obj
.as_ref()
.weak_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Self { obj: self.obj }
}
}
impl<T> GCRef for GCArcWeak<T>
where
T: GCTraceable + 'static,
{
fn strong_ref(&self) -> usize {
unsafe { self.obj.as_ref().strong_ref() }
}
fn weak_ref(&self) -> usize {
unsafe { self.obj.as_ref().weak_ref() }
}
fn inc_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.strong_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to increment a GCArcWeak with 0 strong references");
}
self.obj
.as_ref()
.strong_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
fn dec_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.strong_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to decrement a GCArcWeak with 0 strong references");
}
if self
.obj
.as_ref()
.strong_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
fn inc_weak_ref(&self) {
unsafe {
self.obj
.as_ref()
.weak_rc
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
fn dec_weak_ref(&self) {
unsafe {
if self
.obj
.as_ref()
.weak_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to decrement a GCArcWeak with 0 weak references");
}
if self
.obj
.as_ref()
.weak_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
}
impl<T> Drop for GCArcWeak<T>
where
T: GCTraceable + 'static,
{
fn drop(&mut self) {
unsafe {
if self
.obj
.as_ref()
.weak_rc
.load(std::sync::atomic::Ordering::SeqCst)
== 0
{
panic!("Attempted to drop a GCArcWeak with 0 weak references");
}
if self
.obj
.as_ref()
.weak_rc
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
{
drop(Box::from_raw(self.obj.as_ptr()));
}
}
}
}
unsafe impl<T> Send for GCArcWeak<T> where T: GCTraceable + 'static {}
unsafe impl<T> Sync for GCArcWeak<T> where T: GCTraceable + 'static {}
#[cfg(test)]
mod tests {
use super::*;
use crate::traceable::GCTraceable;
#[derive(Debug, PartialEq)]
struct TestValue {
value: i32,
}
impl GCTraceable for TestValue {
fn visit(&self) {
}
}
#[test]
fn test_no_mutable_reference_with_multiple_strong_refs() {
let mut a = GCArc::new(TestValue { value: 1 });
let _a_clone = a.clone();
assert!(a.try_as_mut().is_none());
}
#[test]
fn test_no_mutable_reference_with_weak_refs() {
let mut a = GCArc::new(TestValue { value: 1 });
let _weak = a.as_weak();
assert!(a.try_as_mut().is_none());
}
#[test]
fn test_mutable_reference_when_unique() {
let mut a = GCArc::new(TestValue { value: 1 });
let mutable_ref = a.try_as_mut();
assert!(mutable_ref.is_some());
if let Some(val) = mutable_ref {
val.value = 42;
assert_eq!(val.value, 42);
}
}
#[test]
#[should_panic(expected = "Cannot get mutable reference: GCArc is not unique")]
fn test_get_mut_panics_when_not_unique() {
let mut a = GCArc::new(TestValue { value: 1 });
let _a_clone = a.clone();
let _mutable_ref = a.get_mut();
}
#[test]
fn test_mutable_reference_after_clone_dropped() {
let mut a = GCArc::new(TestValue { value: 1 });
{
let _a_clone = a.clone();
assert!(a.try_as_mut().is_none());
}
let mutable_ref = a.try_as_mut();
assert!(mutable_ref.is_some());
}
#[test]
fn test_mutable_reference_after_weak_dropped() {
let mut a = GCArc::new(TestValue { value: 1 });
{
let _weak = a.as_weak();
assert!(a.try_as_mut().is_none());
}
let mutable_ref = a.try_as_mut();
assert!(mutable_ref.is_some());
}
#[test]
fn test_original_ub_scenario_now_safe() {
let mut a = GCArc::new(TestValue { value: 1 });
let a_clone = a.clone();
let mutable_ref = a.try_as_mut();
assert!(
mutable_ref.is_none(),
"Should not be able to get mutable reference when multiple strong refs exist"
);
let immutable_ref = a_clone.as_ref();
assert_eq!(immutable_ref.value, 1);
}
#[test]
fn test_reference_counts() {
let a = GCArc::new(TestValue { value: 1 });
assert_eq!(a.strong_ref(), 1);
assert_eq!(a.weak_ref(), 0);
let a_clone = a.clone();
assert_eq!(a.strong_ref(), 2);
assert_eq!(a.weak_ref(), 0);
let weak = a.as_weak();
assert_eq!(a.strong_ref(), 2);
assert_eq!(a.weak_ref(), 1);
drop(a_clone);
assert_eq!(a.strong_ref(), 1);
assert_eq!(a.weak_ref(), 1);
drop(weak);
assert_eq!(a.strong_ref(), 1);
assert_eq!(a.weak_ref(), 0);
}
#[test]
fn test_complex_reference_scenario() {
let mut a = GCArc::new(TestValue { value: 1 });
let a2 = a.clone();
let a3 = a.clone();
assert!(a.try_as_mut().is_none());
drop(a2);
assert!(a.try_as_mut().is_none());
drop(a3);
assert!(a.try_as_mut().is_some());
let weak1 = a.as_weak();
let weak2 = a.as_weak();
assert!(a.try_as_mut().is_none());
drop(weak1);
assert!(a.try_as_mut().is_none());
drop(weak2);
assert!(a.try_as_mut().is_some()); }
#[test]
fn test_demonstrate_original_ub_prevention() {
let mut a = GCArc::new(TestValue { value: 1 });
let mut b = a.clone();
assert!(a.try_as_mut().is_none()); assert!(b.try_as_mut().is_none());
drop(b);
assert!(a.try_as_mut().is_some()); }
#[test]
fn test_weak_upgrade_race_condition_fix() {
let a = GCArc::new(TestValue { value: 1 });
let weak = a.as_weak();
let upgraded = weak.upgrade();
assert!(upgraded.is_some());
drop(a);
drop(upgraded);
let upgraded = weak.upgrade();
assert!(upgraded.is_none());
}
}