Skip to main content

any_intern/
typed.rs

1use bumpalo::Bump;
2use std::{
3    alloc::Layout,
4    cell::Cell,
5    marker::PhantomData,
6    mem,
7    ptr::{self, NonNull},
8    slice,
9};
10
11pub struct TypedArena<T> {
12    bump: Bump,
13    len: Cell<usize>,
14    _marker: PhantomData<T>,
15}
16
17impl<T> TypedArena<T> {
18    /// Returns number of elements in this arena.
19    pub fn len(&self) -> usize {
20        self.len.get()
21    }
22
23    pub fn is_empty(&self) -> bool {
24        self.len() == 0
25    }
26
27    pub fn alloc(&self, value: T) -> &mut T {
28        self.len.set(self.len() + 1);
29        self.bump.alloc(value)
30    }
31
32    pub fn clear(&mut self) {
33        self.drop_all();
34        self.bump.reset();
35        self.len.set(0);
36    }
37
38    fn drop_all(&mut self) {
39        if mem::needs_drop::<T>() {
40            if mem::size_of::<T>() > 0 {
41                let stride = Layout::new::<T>().pad_to_align().size();
42                unsafe {
43                    for (ptr, len) in self.bump.iter_allocated_chunks_raw() {
44                        // Chunk would not be divisible by the `stride` especially when the stride
45                        // is greater than 16. In that case, we should ignore the remainder.
46                        let num_elems = len / stride;
47                        let ptr = ptr.cast::<T>();
48                        let slice = slice::from_raw_parts_mut(ptr, num_elems);
49                        ptr::drop_in_place(slice);
50                    }
51                }
52            } else {
53                let ptr = NonNull::<T>::dangling().as_ptr();
54                unsafe {
55                    let slice = slice::from_raw_parts_mut(ptr, self.len());
56                    ptr::drop_in_place(slice);
57                }
58            }
59        }
60    }
61}
62
63impl<T> Default for TypedArena<T> {
64    fn default() -> Self {
65        Self {
66            bump: Bump::new(),
67            len: Cell::new(0),
68            _marker: PhantomData,
69        }
70    }
71}
72
73impl<T> Drop for TypedArena<T> {
74    fn drop(&mut self) {
75        self.drop_all();
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn test_arena() {
85        test_arena_alloc();
86        test_arena_drop();
87    }
88
89    fn test_arena_alloc() {
90        const START: u32 = 0;
91        const END: u32 = 100;
92        const EXPECTED: u32 = (END + START) * (END - START + 1) / 2;
93
94        let arena = TypedArena::default();
95        let mut refs = Vec::new();
96        for i in START..=END {
97            let ref_ = arena.alloc(i);
98            refs.push(ref_);
99        }
100        let acc = refs.into_iter().map(|ref_| *ref_).sum::<u32>();
101        assert_eq!(acc, EXPECTED);
102    }
103
104    fn test_arena_drop() {
105        macro_rules! test {
106            ($arr_len:literal, $align:literal) => {{
107                thread_local! {
108                    static SUM: Cell<u32> = Cell::new(0);
109                    static CNT: Cell<u32> = Cell::new(0);
110                }
111
112                #[repr(align($align))]
113                struct A([u8; $arr_len]);
114
115                // Restricted by `u8` and `A::new()`.
116                const _: () = const { assert!($arr_len < 256) };
117
118                impl A {
119                    fn new() -> Self {
120                        Self(std::array::from_fn(|i| i as u8))
121                    }
122
123                    fn sum() -> u32 {
124                        ($arr_len - 1) * $arr_len / 2
125                    }
126                }
127
128                impl Drop for A {
129                    fn drop(&mut self) {
130                        let sum = self.0.iter().map(|n| *n as u32).sum::<u32>();
131                        SUM.set(SUM.get() + sum);
132                        CNT.set(CNT.get() + 1);
133                    }
134                }
135
136                struct Zst;
137
138                impl Drop for Zst {
139                    fn drop(&mut self) {
140                        CNT.set(CNT.get() + 1);
141                    }
142                }
143
144                const REPEAT: u32 = 10;
145
146                // === Non-ZST type ===
147
148                let arena = TypedArena::default();
149                for _ in 0..REPEAT {
150                    arena.alloc(A::new());
151                }
152                drop(arena);
153
154                assert_eq!(SUM.get(), A::sum() * REPEAT);
155                assert_eq!(CNT.get(), REPEAT);
156                SUM.set(0);
157                CNT.set(0);
158
159                // === ZST type ===
160
161                let arena = TypedArena::default();
162                for _ in 0..REPEAT {
163                    arena.alloc(Zst);
164                }
165                drop(arena);
166
167                assert_eq!(CNT.get(), REPEAT);
168            }};
169        }
170
171        // Array len, align
172        test!(1, 1);
173        test!(1, 2);
174        test!(1, 4);
175        test!(1, 8);
176        test!(1, 16);
177        test!(1, 32);
178        test!(1, 64);
179        test!(1, 128);
180        test!(1, 256);
181
182        test!(100, 1);
183        test!(100, 2);
184        test!(100, 4);
185        test!(100, 8);
186        test!(100, 16);
187        test!(100, 32);
188        test!(100, 64);
189        test!(100, 128);
190        test!(100, 256);
191    }
192}