circ/
utils.rs

1use std::cell::Cell;
2use std::sync::atomic::Ordering;
3use std::{mem::ManuallyDrop, sync::atomic::AtomicU64};
4
5use crate::ebr_impl::{cs, global_epoch, Guard, Tagged, HIGH_TAG_WIDTH};
6use crate::RcObject;
7
8/// Raw pointer to a reference counted object. Allows tagging.
9pub(crate) type Raw<T> = Tagged<RcInner<T>>;
10
11trait Deferable {
12    unsafe fn defer_with_inner<T, F>(&self, ptr: *mut RcInner<T>, f: F)
13    where
14        F: FnOnce(*mut RcInner<T>);
15}
16
17impl Deferable for Guard {
18    unsafe fn defer_with_inner<T, F>(&self, ptr: *mut RcInner<T>, f: F)
19    where
20        F: FnOnce(*mut RcInner<T>),
21    {
22        debug_assert!(!ptr.is_null());
23        self.defer_unchecked(move || f(ptr));
24    }
25}
26
27impl Deferable for Option<&Guard> {
28    unsafe fn defer_with_inner<T, F>(&self, ptr: *mut RcInner<T>, f: F)
29    where
30        F: FnOnce(*mut RcInner<T>),
31    {
32        if let Some(guard) = self {
33            guard.defer_with_inner(ptr, f)
34        } else {
35            cs().defer_with_inner(ptr, f)
36        }
37    }
38}
39
40const EPOCH_WIDTH: u32 = HIGH_TAG_WIDTH;
41const EPOCH_MASK_HEIGHT: u32 = u64::BITS - EPOCH_WIDTH;
42const EPOCH: u64 = ((1 << EPOCH_WIDTH) - 1) << EPOCH_MASK_HEIGHT;
43const DESTRUCTED: u64 = 1 << (EPOCH_MASK_HEIGHT - 1);
44const WEAKED: u64 = 1 << (EPOCH_MASK_HEIGHT - 2);
45const TOTAL_COUNT_WIDTH: u32 = u64::BITS - EPOCH_WIDTH - 2;
46const WEAK_WIDTH: u32 = TOTAL_COUNT_WIDTH / 2;
47const STRONG_WIDTH: u32 = TOTAL_COUNT_WIDTH - WEAK_WIDTH;
48const STRONG: u64 = (1 << STRONG_WIDTH) - 1;
49const WEAK: u64 = ((1 << WEAK_WIDTH) - 1) << STRONG_WIDTH;
50const COUNT: u64 = 1;
51const WEAK_COUNT: u64 = 1 << STRONG_WIDTH;
52
53thread_local! {
54    static DISPOSE_COUNTER: Cell<usize> = const { Cell::new(0) };
55}
56
57/// Effectively wraps the presence of epoch and destruction bits.
58#[derive(Clone, Copy)]
59struct State {
60    inner: u64,
61}
62
63impl State {
64    fn from_raw(inner: u64) -> Self {
65        Self { inner }
66    }
67
68    fn epoch(self) -> u32 {
69        ((self.inner & EPOCH) >> EPOCH_MASK_HEIGHT) as u32
70    }
71
72    fn strong(self) -> u32 {
73        ((self.inner & STRONG) / COUNT) as u32
74    }
75
76    fn weak(self) -> u32 {
77        ((self.inner & WEAK) / WEAK_COUNT) as u32
78    }
79
80    fn destructed(self) -> bool {
81        (self.inner & DESTRUCTED) != 0
82    }
83
84    fn weaked(&self) -> bool {
85        (self.inner & WEAKED) != 0
86    }
87
88    fn with_epoch(self, epoch: usize) -> Self {
89        Self::from_raw((self.inner & !EPOCH) | (((epoch as u64) << EPOCH_MASK_HEIGHT) & EPOCH))
90    }
91
92    fn add_strong(self, val: u32) -> Self {
93        Self::from_raw(self.inner + (val as u64) * COUNT)
94    }
95
96    fn sub_strong(self, val: u32) -> Self {
97        debug_assert!(self.strong() >= val);
98        Self::from_raw(self.inner - (val as u64) * COUNT)
99    }
100
101    fn add_weak(self, val: u32) -> Self {
102        Self::from_raw(self.inner + (val as u64) * WEAK_COUNT)
103    }
104
105    fn with_destructed(self, dest: bool) -> Self {
106        Self::from_raw((self.inner & !DESTRUCTED) | if dest { DESTRUCTED } else { 0 })
107    }
108
109    fn with_weaked(self, weaked: bool) -> Self {
110        Self::from_raw((self.inner & !WEAKED) | if weaked { WEAKED } else { 0 })
111    }
112
113    fn as_raw(self) -> u64 {
114        self.inner
115    }
116}
117
118struct Modular<const WIDTH: u32> {
119    max: isize,
120}
121
122impl<const WIDTH: u32> Modular<WIDTH> {
123    /// Creates a modular space where `max` ia the maximum.
124    pub fn new(max: isize) -> Self {
125        Self { max }
126    }
127
128    // Sends a number to a modular space.
129    fn trans(&self, val: isize) -> isize {
130        debug_assert!(val <= self.max);
131        (val - (self.max + 1)) % (1 << WIDTH)
132    }
133
134    // Receives a number from a modular space.
135    fn inver(&self, val: isize) -> isize {
136        (val + (self.max + 1)) % (1 << WIDTH)
137    }
138
139    pub fn max(&self, nums: &[isize]) -> isize {
140        self.inver(nums.iter().fold(isize::MIN, |acc, val| {
141            acc.max(self.trans(val % (1 << WIDTH)))
142        }))
143    }
144
145    // Checks if `a` is less than or equal to `b` in the modular space.
146    pub fn le(&self, a: isize, b: isize) -> bool {
147        self.trans(a) <= self.trans(b)
148    }
149}
150
151/// A reference-counted object of type `T` with an atomic reference counts.
152pub(crate) struct RcInner<T> {
153    storage: ManuallyDrop<T>,
154    state: AtomicU64,
155}
156
157impl<T> RcInner<T> {
158    #[inline(always)]
159    pub(crate) fn alloc(obj: T, init_strong: u32) -> *mut Self {
160        let obj = Self {
161            storage: ManuallyDrop::new(obj),
162            state: AtomicU64::new((init_strong as u64) * COUNT + WEAK_COUNT),
163        };
164        Box::into_raw(Box::new(obj))
165    }
166
167    /// # Safety
168    ///
169    /// The given `ptr` must not be shared across more than one thread.
170    pub(crate) unsafe fn dealloc(ptr: *mut Self) {
171        drop(Box::from_raw(ptr));
172    }
173
174    /// Returns an immutable reference to the object.
175    pub fn data(&self) -> &T {
176        &self.storage
177    }
178
179    /// Returns a mutable reference to the object.
180    pub fn data_mut(&mut self) -> &mut T {
181        &mut self.storage
182    }
183
184    #[inline]
185    pub(crate) fn increment_strong(&self) -> bool {
186        let val = State::from_raw(self.state.fetch_add(COUNT, Ordering::SeqCst));
187        if val.destructed() {
188            return false;
189        }
190        if val.strong() == 0 {
191            // The previous fetch_add created a permission to run decrement again.
192            // Now create an actual reference.
193            self.state.fetch_add(COUNT, Ordering::SeqCst);
194        }
195        true
196    }
197
198    #[inline]
199    unsafe fn try_dealloc(ptr: *mut Self) {
200        if State::from_raw((*ptr).state.load(Ordering::SeqCst)).weak() > 0 {
201            Self::decrement_weak(ptr, None);
202        } else {
203            Self::dealloc(ptr);
204        }
205    }
206
207    #[inline]
208    pub(crate) fn increment_weak(&self, count: u32) {
209        let mut old = State::from_raw(self.state.load(Ordering::SeqCst));
210        while !old.weaked() {
211            // In this case, `increment_weak` must have been called from `Rc::downgrade`,
212            // guaranteeing weak > 0, so it can’t be incremented from 0.
213            debug_assert!(old.weak() != 0);
214            match self.state.compare_exchange(
215                old.as_raw(),
216                old.with_weaked(true).add_weak(count).as_raw(),
217                Ordering::SeqCst,
218                Ordering::SeqCst,
219            ) {
220                Ok(_) => return,
221                Err(curr) => old = State::from_raw(curr),
222            }
223        }
224        if State::from_raw(
225            self.state
226                .fetch_add(count as u64 * WEAK_COUNT, Ordering::SeqCst),
227        )
228        .weak()
229            == 0
230        {
231            self.state.fetch_add(WEAK_COUNT, Ordering::SeqCst);
232        }
233    }
234
235    #[inline]
236    pub(crate) unsafe fn decrement_weak(ptr: *mut Self, guard: Option<&Guard>) {
237        debug_assert!(State::from_raw((*ptr).state.load(Ordering::SeqCst)).weak() >= 1);
238        if State::from_raw((*ptr).state.fetch_sub(WEAK_COUNT, Ordering::SeqCst)).weak() == 1 {
239            guard.defer_with_inner(ptr, |inner| Self::try_dealloc(inner));
240        }
241    }
242
243    #[inline]
244    pub(crate) fn is_not_destructed(&self) -> bool {
245        let mut old = State::from_raw(self.state.load(Ordering::SeqCst));
246        while !old.destructed() && old.strong() == 0 {
247            match self.state.compare_exchange(
248                old.as_raw(),
249                old.add_strong(1).as_raw(),
250                Ordering::SeqCst,
251                Ordering::SeqCst,
252            ) {
253                Ok(_) => return true,
254                Err(curr) => old = State::from_raw(curr),
255            }
256        }
257        !old.destructed()
258    }
259}
260
261impl<T: RcObject> RcInner<T> {
262    #[inline]
263    pub(crate) unsafe fn decrement_strong(ptr: *mut Self, count: u32, guard: Option<&Guard>) {
264        let epoch = global_epoch();
265        // Should mark the current epoch on the strong count with CAS.
266        let hit_zero = loop {
267            let curr = State::from_raw((*ptr).state.load(Ordering::SeqCst));
268            debug_assert!(curr.strong() >= count);
269            if (*ptr)
270                .state
271                .compare_exchange(
272                    curr.as_raw(),
273                    curr.with_epoch(epoch).sub_strong(count).as_raw(),
274                    Ordering::SeqCst,
275                    Ordering::SeqCst,
276                )
277                .is_ok()
278            {
279                break curr.strong() == count;
280            }
281        };
282
283        let trigger_recl = |guard: &Guard| {
284            if hit_zero {
285                guard.defer_with_inner(ptr, |inner| Self::try_destruct(inner));
286            }
287            // Periodically triggers a collection.
288            guard.incr_manual_collection();
289        };
290
291        if let Some(guard) = guard {
292            trigger_recl(guard)
293        } else {
294            trigger_recl(&cs())
295        }
296    }
297
298    #[inline]
299    unsafe fn try_destruct(ptr: *mut Self) {
300        let mut old = State::from_raw((*ptr).state.load(Ordering::SeqCst));
301        debug_assert!(!old.destructed());
302        loop {
303            if old.strong() > 0 {
304                Self::decrement_strong(ptr, 1, None);
305                return;
306            }
307            match (*ptr).state.compare_exchange(
308                old.as_raw(),
309                old.with_destructed(true).as_raw(),
310                Ordering::SeqCst,
311                Ordering::SeqCst,
312            ) {
313                // Note that `decrement_weak` will be called in `dispose`.
314                Ok(_) => return dispose(ptr),
315                Err(curr) => old = State::from_raw(curr),
316            }
317        }
318    }
319}
320
321#[inline]
322unsafe fn dispose<T: RcObject>(inner: *mut RcInner<T>) {
323    DISPOSE_COUNTER.with(|counter| {
324        let guard = &cs();
325        dispose_general_node(inner, 0, counter, guard);
326    });
327}
328
329#[inline]
330unsafe fn dispose_general_node<T: RcObject>(
331    ptr: *mut RcInner<T>,
332    depth: usize,
333    counter: &Cell<usize>,
334    guard: &Guard,
335) {
336    let rc = match ptr.as_mut() {
337        Some(rc) => rc,
338        None => return,
339    };
340
341    let count = counter.get();
342    counter.set(count + 1);
343    if count % 128 == 0 {
344        if let Some(local) = guard.local.as_ref() {
345            local.repin_without_collect();
346        }
347    }
348
349    if depth >= 1024 {
350        // Prevent a potential stack overflow.
351        guard.defer_with_inner(rc, |rc| RcInner::try_destruct(rc));
352        return;
353    }
354
355    let state = State::from_raw(rc.state.load(Ordering::SeqCst));
356    let node_epoch = state.epoch();
357    debug_assert_eq!(state.strong(), 0);
358
359    let curr_epoch = global_epoch();
360    let modu: Modular<EPOCH_WIDTH> = Modular::new(curr_epoch as isize + 1);
361    let mut outgoings = Vec::new();
362
363    // Note that checking whether it is a root is necessary, because if `node_epoch` is
364    // old enough, `modu.le` may return false.
365    if depth == 0 || modu.le(node_epoch as _, curr_epoch as isize - 3) {
366        // The current node is immediately reclaimable.
367        rc.data_mut().pop_edges(&mut outgoings);
368        unsafe {
369            ManuallyDrop::drop(&mut rc.storage);
370            if State::from_raw(rc.state.load(Ordering::SeqCst)).weaked() {
371                RcInner::decrement_weak(rc, Some(guard));
372            } else {
373                RcInner::dealloc(rc);
374            }
375        }
376        for next in outgoings.drain(..) {
377            if next.is_null() {
378                continue;
379            }
380
381            let next_ptr = next.into_raw();
382            let next_ref = next_ptr.deref();
383            let link_epoch = next_ptr.high_tag() as u32;
384
385            // Decrement next node's strong count and update its epoch.
386            let next_cnt = loop {
387                let cnt_curr = State::from_raw(next_ref.state.load(Ordering::SeqCst));
388                let next_epoch =
389                    modu.max(&[node_epoch as _, link_epoch as _, cnt_curr.epoch() as _]);
390                let cnt_next = cnt_curr.sub_strong(1).with_epoch(next_epoch as _);
391
392                if next_ref
393                    .state
394                    .compare_exchange(
395                        cnt_curr.as_raw(),
396                        cnt_next.as_raw(),
397                        Ordering::SeqCst,
398                        Ordering::SeqCst,
399                    )
400                    .is_ok()
401                {
402                    break cnt_next;
403                }
404            };
405
406            // If the reference count hit zero, try dispose it recursively.
407            if next_cnt.strong() == 0 {
408                dispose_general_node(next_ptr.as_raw(), depth + 1, counter, guard);
409            }
410        }
411    } else {
412        // It is likely to be unsafe to reclaim right now.
413        guard.defer_with_inner(rc, |rc| RcInner::try_destruct(rc));
414    }
415}