Skip to main content

radiate_core/domain/sync/
cell.rs

1use std::{
2    cell::UnsafeCell,
3    fmt::{Debug, Formatter},
4    ops::Deref,
5    sync::atomic::{AtomicUsize, Ordering},
6};
7struct ArcInner<T> {
8    value: UnsafeCell<T>,
9    ref_count: AtomicUsize,
10}
11
12pub struct MutCell<T> {
13    inner: *const ArcInner<T>,
14    consumed: bool,
15}
16
17// Ensure MutCell<T> is safe to send/sync if T is
18unsafe impl<T: Send> Send for MutCell<T> {}
19unsafe impl<T: Sync> Sync for MutCell<T> {}
20
21impl<T> MutCell<T> {
22    pub fn new(value: T) -> Self {
23        Self {
24            inner: Box::into_raw(Box::new(ArcInner {
25                value: UnsafeCell::new(value),
26                ref_count: AtomicUsize::new(1),
27            })),
28            consumed: false,
29        }
30    }
31
32    pub fn is_unique(&self) -> bool {
33        // SAFETY: We're only reading the ref_count
34        unsafe { (*self.inner).ref_count.load(Ordering::Acquire) == 1 }
35    }
36
37    pub fn is_shared(&self) -> bool {
38        !self.is_unique()
39    }
40
41    pub fn strong_count(&self) -> usize {
42        // SAFETY: We're only reading the ref_count
43        unsafe { (*self.inner).ref_count.load(Ordering::Acquire) }
44    }
45
46    pub fn borrow(&self) -> &T {
47        // SAFETY: This is inherently unsafe because we don't know if there exists a mutable
48        // reference to the inner value elsewhere.
49        //
50        // We assume that the caller has ensured that there are no mutable references
51        // to the inner value when calling this method. So straight up - make sure that you don't have
52        // any mutable references to the inner value when calling this method.
53        assert!(!self.consumed, "Cannot access consumed MutCell");
54        unsafe { &*(*self.inner).value.get() }
55    }
56
57    pub fn borrow_mut(&mut self) -> &mut T {
58        assert!(self.is_unique(), "Cannot mutably borrow shared MutCell");
59        unsafe { &mut *(*self.inner).value.get() }
60    }
61
62    pub fn into_inner(mut self) -> T
63    where
64        T: Clone,
65    {
66        // SAFETY: If there is more than one reference to the
67        // inner value, we will clone it and decrement the ref count.
68        // If there is only one reference, we will consume the inner value and
69        // drop the inner box.
70        unsafe {
71            if (*self.inner).ref_count.load(Ordering::Acquire) == 1 {
72                self.consumed = true;
73                std::sync::atomic::fence(Ordering::SeqCst);
74                let boxed = Box::from_raw(self.inner as *mut ArcInner<T>);
75                boxed.value.into_inner()
76            } else {
77                let clone = (*(*self.inner).value.get()).clone();
78                (*self.inner).ref_count.fetch_sub(1, Ordering::Release);
79                clone
80            }
81        }
82    }
83}
84
85impl<T> Clone for MutCell<T> {
86    fn clone(&self) -> Self {
87        // SAFETY: We are only incrementing the ref_count here - no mutable access is done.
88        unsafe {
89            (*self.inner).ref_count.fetch_add(1, Ordering::Relaxed);
90        }
91        Self {
92            inner: self.inner,
93            consumed: false,
94        }
95    }
96}
97
98impl<T> Drop for MutCell<T> {
99    fn drop(&mut self) {
100        if self.consumed {
101            return;
102        }
103
104        // SAFETY: We are decrementing the ref_count here - no mutable access is done.
105        // If the ref_count reaches zero, we can safely drop the inner value.
106        unsafe {
107            if (*self.inner).ref_count.fetch_sub(1, Ordering::Release) == 1 {
108                std::sync::atomic::fence(Ordering::Acquire);
109                drop(Box::from_raw(self.inner as *mut ArcInner<T>));
110            }
111        }
112    }
113}
114
115impl<T> Deref for MutCell<T> {
116    type Target = T;
117    fn deref(&self) -> &Self::Target {
118        self.borrow()
119    }
120}
121
122impl<T: PartialEq> PartialEq for MutCell<T> {
123    fn eq(&self, other: &Self) -> bool {
124        self.borrow() == other.borrow()
125    }
126}
127
128impl<T: PartialOrd> PartialOrd for MutCell<T> {
129    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
130        self.borrow().partial_cmp(other.borrow())
131    }
132}
133
134impl<T> From<T> for MutCell<T> {
135    fn from(value: T) -> Self {
136        Self::new(value)
137    }
138}
139
140impl<T: Debug> Debug for MutCell<T> {
141    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
142        write!(f, "{:?}", self.borrow())
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn mutcell_basic_clone_and_mutation_updated() {
152        let mut cell = MutCell::new(5);
153        assert_eq!(*cell, 5);
154
155        // Mutate the cell (unique, so this is allowed)
156        *cell.borrow_mut() = 10;
157        assert_eq!(*cell, 10);
158
159        // Cloning happens only after mutation:
160        let cell2 = cell.clone();
161        // Now, getting a mutable reference from either cell will panic because it's no longer unique.
162        // Instead, test that both views see the update:
163        assert_eq!(*cell2, 10);
164    }
165
166    #[test]
167    fn mutcell_into_inner_unique() {
168        let cell = MutCell::new(String::from("hello"));
169        let inner = cell.into_inner();
170        assert_eq!(inner, "hello");
171    }
172
173    #[test]
174    fn mutcell_into_inner_clone_when_multiple() {
175        let cell = MutCell::new(String::from("hello"));
176        let cell2 = cell.clone();
177
178        let inner = cell.into_inner();
179        assert_eq!(inner, "hello");
180
181        // Drop cell2 to avoid leak
182        drop(cell2);
183    }
184
185    #[test]
186    fn mutcell_partial_eq_and_ord() {
187        let cell1 = MutCell::new(10);
188        let cell2 = MutCell::new(20);
189        let cell3 = MutCell::new(10);
190
191        assert!(cell1 == cell3);
192        assert!(cell1 != cell2);
193        assert!(cell1 < cell2);
194        assert!(cell2 > cell3);
195    }
196
197    #[test]
198    fn mutcell_is_unique_and_shared() {
199        let cell = MutCell::new(42);
200        assert!(cell.is_unique());
201
202        let cell2 = cell.clone();
203
204        assert!(cell.is_shared());
205        assert!(cell2.is_shared());
206        assert!(!cell.is_unique());
207        assert!(!cell2.is_unique());
208        assert_eq!(*cell, 42);
209        assert_eq!(*cell2, 42);
210        assert!(cell.borrow() == cell2.borrow());
211    }
212
213    #[test]
214    fn mut_cell_drop() {
215        let cell = MutCell::new(42);
216        {
217            let _cell2 = cell.clone();
218            assert!(cell.is_shared());
219        } // _cell2 goes out of scope, ref count should decrease
220
221        assert!(cell.is_unique());
222        drop(cell); // Should not panic
223    }
224
225    #[test]
226    fn mut_cell_deref() {
227        let mut cell = MutCell::new(42);
228        assert_eq!(*cell, 42);
229        let mut_ref = cell.borrow_mut();
230        *mut_ref = 100;
231        assert_eq!(*cell, 100);
232    }
233}