koi_gc/
lib.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::{HashMap, HashSet},
4    marker::PhantomData,
5    ops::Deref,
6    ptr::NonNull,
7};
8
9mod std_impls;
10
11/// The data size type used by is the Gc. This is the hard limit to the number of objects in the Gc and the
12/// the number of clones of each object. Thus, the total number of objects is this value squared.
13/// This is set by the following feature flags:
14///
15/// - default (unset): u16 (65,535)
16/// - small: u8 (255)
17/// - medium: u16 (65,535)
18/// - large: usize (32 bit: 4,294,967,295 , 64 bit: 18,446,744,073,709,551,615)
19#[cfg(all(
20    not(feature = "large"),
21    not(feature = "medium"),
22    not(feature = "small")
23))]
24pub type GcSize = u16; // Default
25#[cfg(all(not(feature = "large"), not(feature = "medium"), feature = "small"))]
26pub type GcSize = u8; // Small
27#[cfg(all(not(feature = "large"), feature = "medium"))]
28pub type GcSize = u16; // Medium
29#[cfg(all(feature = "large"))]
30pub type GcSize = usize; // Large
31
32// todo replace RefCell with UnsafeCell in release mode
33thread_local! {
34    static CYCLE_CHECKS: Cell<bool> = Cell::new(true);
35    static GC_REGISTRY: RefCell<HashMap<GcSize, NonNull<dyn GcInterface + 'static>>> = RefCell::new(HashMap::with_capacity(AVAILABLE_ID_INITIAL_CAPACITY));
36    static AVAILABLE_IDS: RefCell<Vec<GcSize>> = RefCell::new(Vec::with_capacity(AVAILABLE_ID_INITIAL_CAPACITY));
37    static GC_ID_COUNTER: Cell<GcSize> = Cell::new(0);
38    // Track allocations since last collection
39    static ALLOCATIONS_SINCE_LAST_CYCLE_CHECK: Cell<usize> = Cell::new(0);
40    static REMAINING_ALLOCATIONS_AFTER_LAST_CYCLE_CHECK: Cell<usize> = Cell::new(0);
41}
42
43#[cfg(all(
44    not(feature = "large"),
45    not(feature = "medium"),
46    not(feature = "small")
47))]
48const AVAILABLE_ID_INITIAL_CAPACITY: usize = 256; // Default
49#[cfg(all(not(feature = "large"), not(feature = "medium"), feature = "small"))]
50const AVAILABLE_ID_INITIAL_CAPACITY: usize = 32; // Small
51#[cfg(all(not(feature = "large"), feature = "medium"))]
52const AVAILABLE_ID_INITIAL_CAPACITY: usize = 256; // Medium
53#[cfg(all(feature = "large"))]
54const AVAILABLE_ID_INITIAL_CAPACITY: usize = 1024; // Large
55
56/// The threshold number of allocations to consider performing a cycle collection
57const COLLECTION_SOFT_THRESHOLD: usize = 100;
58/// The threshold number of allocations to perform a cycle collection
59const COLLECTION_HARD_THRESHOLD: usize = 10000;
60
61pub struct Gc<T>
62where
63    T: VisitGc,
64{
65    ptr: NonNull<GcInner<T>>,
66    phantom: PhantomData<GcInner<T>>,
67}
68
69struct GcInner<T>
70where
71    T: VisitGc,
72{
73    id: GcSize,
74    strong: Cell<GcSize>,
75    data: T,
76}
77
78struct GcInnerCheck {
79    /// Temporary ref count used during cycle detection
80    gc_refs: GcSize,
81    /// Mark for reachability during cycle detection
82    marked: bool,
83}
84
85impl<T> Gc<T>
86where
87    T: VisitGc,
88{
89    pub fn new(data: T) -> Self {
90        if CYCLE_CHECKS.get() {
91            Gc::maybe_collect_cycles();
92        }
93        ALLOCATIONS_SINCE_LAST_CYCLE_CHECK.with(|e| e.update(|e| e + 1));
94        let id = AVAILABLE_IDS.with(|available_ids| {
95            let mut available_ids = available_ids.borrow_mut();
96            let next_id = if let Some(free_id) = available_ids.pop() {
97                free_id
98            } else {
99                #[cfg(all(
100                    not(feature = "large"),
101                    not(feature = "medium"),
102                    not(feature = "small")
103                ))]
104                const EXPECT_MSG: &str = "More than GcSize::MAX objects created and alive. Try setting feature flag 'large'"; // Default
105                #[cfg(all(not(feature = "large"), not(feature = "medium"), feature = "small"))]
106                const EXPECT_MSG: &str = "More than GcSize::MAX objects created and alive. Try setting feature flag 'medium' or 'large'"; // Small
107                #[cfg(all(not(feature = "large"), feature = "medium"))]
108                const EXPECT_MSG: &str = "More than GcSize::MAX objects created and alive. Try setting feature flag 'large'"; // Medium
109                #[cfg(all(feature = "large"))]
110                const EXPECT_MSG: &str = "More than GcSize::MAX objects created and alive."; // Large
111                let next_id = GC_ID_COUNTER.get().checked_add(1).expect(EXPECT_MSG);
112                GC_ID_COUNTER.set(next_id);
113                next_id
114            };
115            next_id
116        });
117        let inner = Box::leak(Box::new(GcInner {
118            id,
119            strong: Cell::new(1),
120            data,
121        }));
122        let ptr = NonNull::from(inner);
123        GC_REGISTRY.with(|registry| {
124            let mut registry = registry.borrow_mut();
125            let current = registry.insert(id, ptr);
126            debug_assert!(current.is_none());
127        });
128        Gc {
129            ptr,
130            phantom: PhantomData,
131        }
132    }
133
134    fn inner(&self) -> &GcInner<T> {
135        unsafe { self.ptr.as_ref() }
136    }
137}
138
139impl Gc<()> {
140    pub fn collect_cycles() -> usize {
141        let mut ids_to_reclaim = Vec::new();
142        GC_REGISTRY.with(|registry| {
143            let registry = registry.borrow_mut();
144            let mut gc_cycle_check = HashMap::with_capacity(registry.len());
145            // Copy reference counts to gc_refs and clear marks
146            for obj_ptr in registry.values() {
147                let obj = unsafe { obj_ptr.as_ref() };
148                let id = obj.id();
149                let gc_refs = obj.strong();
150                let obj_check = GcInnerCheck {
151                    gc_refs,
152                    marked: false,
153                };
154                gc_cycle_check.insert(id, obj_check);
155            }
156
157            for obj_ptr in registry.values() {
158                let obj = unsafe { obj_ptr.as_ref() };
159                obj.visit_edges(&mut |edge| {
160                    let id = edge.id();
161                    let obj_check_edge = gc_cycle_check.get_mut(&id).unwrap();
162                    let current = obj_check_edge.gc_refs;
163                    if current > 0 {
164                        obj_check_edge.gc_refs = current - 1;
165                    }
166                });
167            }
168
169            let unaccounted_refs = &mut ids_to_reclaim; // reuse Vec allocation
170            for obj_ptr in registry.values() {
171                let obj = unsafe { obj_ptr.as_ref() };
172                let id = obj.id();
173                let obj_check = gc_cycle_check.get_mut(&id).unwrap();
174                if obj_check.gc_refs > 0 {
175                    unaccounted_refs.push(id);
176                }
177            }
178
179            let to_visit = unaccounted_refs;
180            let mut visited = HashSet::new();
181
182            while let Some(id) = to_visit.pop() {
183                if !visited.insert(id) {
184                    continue; // Already visited
185                }
186
187                if let Some(obj_ptr) = registry.get(&id) {
188                    let obj = unsafe { obj_ptr.as_ref() };
189                    let id = obj.id();
190                    let obj_check = gc_cycle_check.get_mut(&id).unwrap();
191                    obj_check.marked = true;
192
193                    // Add all edges to visit queue
194                    obj.visit_edges(&mut |edge| {
195                        let id = edge.id();
196                        if !visited.contains(&id) {
197                            to_visit.push(id);
198                        }
199                    });
200                }
201            }
202
203            ids_to_reclaim.clear(); // reuse Vec allocation
204            for (_, obj_ptr) in registry.iter() {
205                let obj = unsafe { obj_ptr.as_ref() };
206                let id = obj.id();
207                let obj_check = gc_cycle_check.get_mut(&id).unwrap();
208                if obj_check.marked {
209                    obj_check.marked = false;
210                    continue;
211                }
212                ids_to_reclaim.push(id);
213            }
214        });
215
216        let collected = ids_to_reclaim.len();
217        let num_of_allocations_remaining = GC_REGISTRY.with(|registry| {
218            AVAILABLE_IDS.with(|ids| {
219                for id in ids_to_reclaim {
220                    let removed = registry.borrow_mut().remove(&id);
221                    if let Some(removed) = removed {
222                        ids.borrow_mut().push(id);
223                        let temp = unsafe { removed.as_ref() };
224                        // Set mark flag so does not try to remove from registry for within drop
225                        temp.mark_gc_deleted();
226                        unsafe {
227                            drop(Box::from_raw(removed.as_ptr()));
228                        }
229                    }
230                }
231            });
232            registry.borrow().len()
233        });
234
235        // Reset allocation counter
236        ALLOCATIONS_SINCE_LAST_CYCLE_CHECK.set(0);
237        REMAINING_ALLOCATIONS_AFTER_LAST_CYCLE_CHECK.set(num_of_allocations_remaining);
238
239        collected
240    }
241
242    /// Check if [`Gc::collect_cycles`] should run based on the internal heuristic and if so runs it.
243    /// This should only really ever be called externally if [`Gc::enable_acc`] was just called.
244    fn maybe_collect_cycles() -> Option<usize> {
245        let allocations_since_last_collection = ALLOCATIONS_SINCE_LAST_CYCLE_CHECK.get();
246        if allocations_since_last_collection > COLLECTION_HARD_THRESHOLD {
247            return Some(Gc::collect_cycles());
248        } else if allocations_since_last_collection > COLLECTION_SOFT_THRESHOLD {
249            // If allocations are exploding, do a cycle collection, otherwise things are likely being reclaimed as normal
250            let current_allocations = GC_REGISTRY.with(|e| e.borrow().len());
251            let has_allocations_quadrupled =
252                current_allocations > 4 * REMAINING_ALLOCATIONS_AFTER_LAST_CYCLE_CHECK.get();
253            if has_allocations_quadrupled {
254                return Some(Gc::collect_cycles());
255            }
256        }
257        None
258    }
259
260    /// Cycle collection is disabled by default. If called, enables automatic cycle collection -
261    /// [`Gc::collect_cycles`] is automatically called based the
262    /// internal allocation heuristic.
263    /// Consider pairing this with [`Gc::maybe_collect_cycles`].
264    pub fn enable_acc(&self) {
265        CYCLE_CHECKS.set(true);
266    }
267
268    /// Cycle collection is disabled by default. If called, disables automatic cycle collection -
269    /// and [`Gc::collect_cycles`] needs to be called manually
270    pub fn disable_acc(&self) {
271        CYCLE_CHECKS.set(false);
272    }
273}
274
275impl<T> Clone for Gc<T>
276where
277    T: VisitGc,
278{
279    fn clone(&self) -> Self {
280        let inner = self.inner();
281        let current = inner.strong.get();
282        inner.strong.set(current + 1);
283
284        Gc {
285            ptr: self.ptr,
286            phantom: PhantomData,
287        }
288    }
289}
290
291impl<T> Drop for Gc<T>
292where
293    T: VisitGc,
294{
295    fn drop(&mut self) {
296        let inner = self.inner();
297        let current = inner.strong.get();
298        // Deleted by the gc and already handled. Initial value of this type is always 1 so this is a special flag.
299        if current == 0 {
300            return;
301        }
302        if current == 1 {
303            GC_REGISTRY.with(|registry| {
304                let mut registry = registry.borrow_mut();
305                let removed = registry.remove(&inner.id);
306                debug_assert!(removed.is_some());
307            });
308            AVAILABLE_IDS.with(|ids| {
309                ids.borrow_mut().push(inner.id);
310            });
311            // Deallocate the inner box
312            unsafe {
313                drop(Box::from_raw(self.ptr.as_ptr()));
314            }
315        } else {
316            inner.strong.set(current - 1);
317        }
318    }
319}
320
321impl<T> Deref for Gc<T>
322where
323    T: VisitGc,
324{
325    type Target = T;
326
327    fn deref(&self) -> &Self::Target {
328        &self.inner().data
329    }
330}
331
332/// Dyn interface for Gc. Marked as unsafe since should not be implemented.
333pub unsafe trait GcInterface {
334    fn id(&self) -> GcSize;
335    fn strong(&self) -> GcSize;
336    fn mark_gc_deleted(&self);
337    fn visit_edges(&self, visitor: &mut dyn FnMut(&dyn GcInterface));
338}
339
340unsafe impl<T> GcInterface for GcInner<T>
341where
342    T: VisitGc,
343{
344    fn id(&self) -> GcSize {
345        self.id
346    }
347
348    fn strong(&self) -> GcSize {
349        self.strong.get()
350    }
351
352    fn mark_gc_deleted(&self) {
353        self.strong.set(0);
354    }
355
356    fn visit_edges(&self, visitor: &mut dyn FnMut(&dyn GcInterface)) {
357        self.data.visit(visitor);
358    }
359}
360
361// todo verify this
362/// SAFETY: Implementing this trait is marked as `unsafe` since it is possible incorrectly implement this
363/// trait can cause the garbage collector to free data that is being pointed at by accessible
364/// data
365pub unsafe trait VisitGc: 'static {
366    fn visit(&self, visitor: &mut dyn FnMut(&dyn GcInterface));
367}
368
369unsafe impl<T> VisitGc for Gc<T>
370where
371    T: VisitGc,
372{
373    fn visit(&self, visitor: &mut dyn FnMut(&dyn GcInterface)) {
374        visitor(self.inner());
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use std::rc::Rc;
381
382    use super::*;
383
384    #[derive(Clone)]
385    struct GcNode {
386        value: i32,
387        next: Option<Gc<GcNode>>,
388    }
389
390    unsafe impl VisitGc for GcNode {
391        fn visit(&self, visitor: &mut dyn FnMut(&dyn GcInterface)) {
392            self.value.visit(visitor);
393            self.next.visit(visitor);
394        }
395    }
396
397    #[test]
398    fn test_cycle_collection() {
399        // Create a cycle: A -> B -> A
400        let a = Gc::new(GcNode {
401            value: 1,
402            next: None,
403        });
404
405        let b = Gc::new(GcNode {
406            value: 2,
407            next: Some(a.clone()),
408        });
409
410        // complete the cycle
411        unsafe {
412            let a_inner = a.ptr.as_ref();
413            let data_ptr = &a_inner.data as *const GcNode as *mut GcNode;
414            (*data_ptr).next = Some(b.clone());
415        }
416
417        // no cycle
418        let c = Gc::new(GcNode {
419            value: 3,
420            next: None,
421        });
422
423        let d = Gc::new(GcNode {
424            value: 4,
425            next: Some(c.clone()),
426        });
427
428        drop(a);
429        drop(b);
430        drop(c);
431        drop(d);
432
433        // Run collection - should collect the cycle
434        let collected = Gc::collect_cycles();
435        assert_eq!(collected, 2, "Should collect the 2 nodes in the cycle");
436    }
437
438    struct Leak<T> {
439        cycle: RefCell<Option<Rc<Rc<Leak<T>>>>>,
440        data: T,
441    }
442
443    fn rc_cycle_forget<T>(data: T) {
444        let e = Rc::new(Leak {
445            cycle: RefCell::new(None),
446            data: data,
447        });
448        *e.cycle.borrow_mut() = Some(Rc::new(e.clone()));
449    }
450
451    struct GcLeak<T>
452    where
453        T: VisitGc,
454    {
455        cycle: RefCell<Option<Gc<Gc<GcLeak<T>>>>>,
456        data: T,
457    }
458
459    unsafe impl<T> VisitGc for GcLeak<T>
460    where
461        T: VisitGc,
462    {
463        fn visit(&self, visitor: &mut dyn FnMut(&dyn GcInterface)) {
464            self.cycle.visit(visitor);
465            self.data.visit(visitor);
466        }
467    }
468
469    fn gc_cycle_forget<T: VisitGc>(data: T) {
470        let e = Gc::new(GcLeak {
471            cycle: RefCell::new(None),
472            data: data,
473        });
474        *e.cycle.borrow_mut() = Some(Gc::new(e.clone()));
475    }
476
477    #[test]
478    fn not_leaked() {
479        thread_local! {
480            static INCREMENT_ON_DROP: Cell<usize> = Cell::new(0);
481        }
482
483        struct TestData {}
484
485        unsafe impl VisitGc for TestData {
486            fn visit(&self, visitor: &mut dyn FnMut(&dyn GcInterface)) {}
487        }
488
489        impl Drop for TestData {
490            fn drop(&mut self) {
491                INCREMENT_ON_DROP.set(INCREMENT_ON_DROP.get() + 1);
492            }
493        }
494
495        let temp = Gc::new(TestData {});
496        drop(temp);
497        assert_eq!(INCREMENT_ON_DROP.get(), 1);
498        let temp = Gc::new(TestData {});
499        rc_cycle_forget(temp);
500        Gc::collect_cycles();
501        assert_eq!(
502            INCREMENT_ON_DROP.get(),
503            1,
504            "Gc::collect_cycles should not nothing since it does not know about rc cycles"
505        );
506        let temp = Gc::new(TestData {});
507        gc_cycle_forget(temp);
508        let temp = Gc::new(TestData {});
509        assert_eq!(
510            INCREMENT_ON_DROP.get(),
511            1,
512            "Gc::collect_cycles has not ran yet"
513        );
514        Gc::collect_cycles();
515        assert_eq!(
516            INCREMENT_ON_DROP.get(),
517            2,
518            "Gc::collect_cycles should nothing detect the cycle and only remove one"
519        );
520        drop(temp);
521        assert_eq!(INCREMENT_ON_DROP.get(), 3);
522        let temp = Gc::new(TestData {});
523        std::mem::forget(temp);
524        assert_eq!(INCREMENT_ON_DROP.get(), 3);
525        Gc::collect_cycles();
526        assert_eq!(
527            INCREMENT_ON_DROP.get(),
528            3,
529            "Gc::collect_cycles only works for cycles. mem::forget is unfortunately marked as safe"
530        );
531    }
532}