Skip to main content

any_intern/
typed.rs

1use bumpalo::Bump;
2use std::{alloc::Layout, cell::Cell, marker::PhantomData, mem, ptr, slice};
3
4pub struct TypedArena<T> {
5    bump: Bump,
6    len: Cell<usize>,
7    _marker: PhantomData<T>,
8}
9
10impl<T> TypedArena<T> {
11    /// Returns number of elements in this arena.
12    pub const fn len(&self) -> usize {
13        self.len.get()
14    }
15
16    pub const fn is_empty(&self) -> bool {
17        self.len() == 0
18    }
19
20    pub fn alloc(&self, value: T) -> &mut T {
21        self.len.set(self.len() + 1);
22        self.bump.alloc(value)
23    }
24
25    pub fn clear(&mut self) {
26        self.drop_all();
27        self.bump.reset();
28        self.len.set(0);
29    }
30
31    fn drop_all(&mut self) {
32        if mem::needs_drop::<T>() {
33            if size_of::<T>() > 0 {
34                let stride = Layout::new::<T>().pad_to_align().size();
35                for chunk in self.bump.iter_allocated_chunks() {
36                    // Chunk would not be divisible by the `stride` especially when the stride is
37                    // greater than 16. In that case, we should ignore the remainder.
38                    let num_elems = chunk.len() / stride;
39                    let ptr = chunk.as_ptr().cast::<T>().cast_mut();
40                    unsafe {
41                        let slice = slice::from_raw_parts_mut(ptr, num_elems);
42                        ptr::drop_in_place(slice);
43                    }
44                }
45            } else {
46                let ptr = ptr::dangling_mut::<T>();
47                unsafe {
48                    let slice = slice::from_raw_parts_mut(ptr, self.len());
49                    ptr::drop_in_place(slice);
50                }
51            }
52        }
53    }
54}
55
56impl<T> Default for TypedArena<T> {
57    fn default() -> Self {
58        Self {
59            bump: Bump::new(),
60            len: Cell::new(0),
61            _marker: PhantomData,
62        }
63    }
64}
65
66impl<T> Drop for TypedArena<T> {
67    fn drop(&mut self) {
68        self.drop_all();
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn test_arena() {
78        test_arena_alloc();
79        test_arena_drop();
80    }
81
82    fn test_arena_alloc() {
83        const START: u32 = 0;
84        const END: u32 = 100;
85        const EXPECTED: u32 = (END + START) * (END - START + 1) / 2;
86
87        let arena = TypedArena::default();
88        let mut refs = Vec::new();
89        for i in START..=END {
90            let ref_ = arena.alloc(i);
91            refs.push(ref_);
92        }
93        let acc = refs.into_iter().map(|ref_| *ref_).sum::<u32>();
94        assert_eq!(acc, EXPECTED);
95    }
96
97    fn test_arena_drop() {
98        macro_rules! test {
99            ($arr_len:literal, $align:literal) => {{
100                thread_local! {
101                    static SUM: Cell<u32> = Cell::new(0);
102                    static CNT: Cell<u32> = Cell::new(0);
103                }
104
105                #[repr(align($align))]
106                struct A([u8; $arr_len]);
107
108                // Restricted by `u8` and `A::new()`.
109                const _: () = const { assert!($arr_len < 256) };
110
111                impl A {
112                    fn new() -> Self {
113                        Self(std::array::from_fn(|i| i as u8))
114                    }
115
116                    fn sum() -> u32 {
117                        ($arr_len - 1) * $arr_len / 2
118                    }
119                }
120
121                impl Drop for A {
122                    fn drop(&mut self) {
123                        let sum = self.0.iter().map(|n| *n as u32).sum::<u32>();
124                        SUM.set(SUM.get() + sum);
125                        CNT.set(CNT.get() + 1);
126                    }
127                }
128
129                struct Zst;
130
131                impl Drop for Zst {
132                    fn drop(&mut self) {
133                        CNT.set(CNT.get() + 1);
134                    }
135                }
136
137                const REPEAT: u32 = 10;
138
139                // === Non-ZST type ===
140
141                let arena = TypedArena::default();
142                for _ in 0..REPEAT {
143                    arena.alloc(A::new());
144                }
145                drop(arena);
146
147                assert_eq!(SUM.get(), A::sum() * REPEAT);
148                assert_eq!(CNT.get(), REPEAT);
149                SUM.set(0);
150                CNT.set(0);
151
152                // === ZST type ===
153
154                let arena = TypedArena::default();
155                for _ in 0..REPEAT {
156                    arena.alloc(Zst);
157                }
158                drop(arena);
159
160                assert_eq!(CNT.get(), REPEAT);
161            }};
162        }
163
164        // Array len, align
165        test!(1, 1);
166        test!(1, 2);
167        test!(1, 4);
168        test!(1, 8);
169        test!(1, 16);
170        test!(1, 32);
171        test!(1, 64);
172        test!(1, 128);
173        test!(1, 256);
174
175        test!(100, 1);
176        test!(100, 2);
177        test!(100, 4);
178        test!(100, 8);
179        test!(100, 16);
180        test!(100, 32);
181        test!(100, 64);
182        test!(100, 128);
183        test!(100, 256);
184    }
185}