1use bumpalo::Bump;
2use std::{
3 alloc::Layout,
4 cell::Cell,
5 marker::PhantomData,
6 mem,
7 ptr::{self, NonNull},
8 slice,
9};
10
11pub struct TypedArena<T> {
12 bump: Bump,
13 len: Cell<usize>,
14 _marker: PhantomData<T>,
15}
16
17impl<T> TypedArena<T> {
18 pub fn len(&self) -> usize {
20 self.len.get()
21 }
22
23 pub fn is_empty(&self) -> bool {
24 self.len() == 0
25 }
26
27 pub fn alloc(&self, value: T) -> &mut T {
28 self.len.set(self.len() + 1);
29 self.bump.alloc(value)
30 }
31
32 pub fn clear(&mut self) {
33 self.drop_all();
34 self.bump.reset();
35 self.len.set(0);
36 }
37
38 fn drop_all(&mut self) {
39 if mem::needs_drop::<T>() {
40 if mem::size_of::<T>() > 0 {
41 let stride = Layout::new::<T>().pad_to_align().size();
42 unsafe {
43 for (ptr, len) in self.bump.iter_allocated_chunks_raw() {
44 let num_elems = len / stride;
47 let ptr = ptr.cast::<T>();
48 let slice = slice::from_raw_parts_mut(ptr, num_elems);
49 ptr::drop_in_place(slice);
50 }
51 }
52 } else {
53 let ptr = NonNull::<T>::dangling().as_ptr();
54 unsafe {
55 let slice = slice::from_raw_parts_mut(ptr, self.len());
56 ptr::drop_in_place(slice);
57 }
58 }
59 }
60 }
61}
62
63impl<T> Default for TypedArena<T> {
64 fn default() -> Self {
65 Self {
66 bump: Bump::new(),
67 len: Cell::new(0),
68 _marker: PhantomData,
69 }
70 }
71}
72
73impl<T> Drop for TypedArena<T> {
74 fn drop(&mut self) {
75 self.drop_all();
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn test_arena() {
85 test_arena_alloc();
86 test_arena_drop();
87 }
88
89 fn test_arena_alloc() {
90 const START: u32 = 0;
91 const END: u32 = 100;
92 const EXPECTED: u32 = (END + START) * (END - START + 1) / 2;
93
94 let arena = TypedArena::default();
95 let mut refs = Vec::new();
96 for i in START..=END {
97 let ref_ = arena.alloc(i);
98 refs.push(ref_);
99 }
100 let acc = refs.into_iter().map(|ref_| *ref_).sum::<u32>();
101 assert_eq!(acc, EXPECTED);
102 }
103
104 fn test_arena_drop() {
105 macro_rules! test {
106 ($arr_len:literal, $align:literal) => {{
107 thread_local! {
108 static SUM: Cell<u32> = Cell::new(0);
109 static CNT: Cell<u32> = Cell::new(0);
110 }
111
112 #[repr(align($align))]
113 struct A([u8; $arr_len]);
114
115 const _: () = const { assert!($arr_len < 256) };
117
118 impl A {
119 fn new() -> Self {
120 Self(std::array::from_fn(|i| i as u8))
121 }
122
123 fn sum() -> u32 {
124 ($arr_len - 1) * $arr_len / 2
125 }
126 }
127
128 impl Drop for A {
129 fn drop(&mut self) {
130 let sum = self.0.iter().map(|n| *n as u32).sum::<u32>();
131 SUM.set(SUM.get() + sum);
132 CNT.set(CNT.get() + 1);
133 }
134 }
135
136 struct Zst;
137
138 impl Drop for Zst {
139 fn drop(&mut self) {
140 CNT.set(CNT.get() + 1);
141 }
142 }
143
144 const REPEAT: u32 = 10;
145
146 let arena = TypedArena::default();
149 for _ in 0..REPEAT {
150 arena.alloc(A::new());
151 }
152 drop(arena);
153
154 assert_eq!(SUM.get(), A::sum() * REPEAT);
155 assert_eq!(CNT.get(), REPEAT);
156 SUM.set(0);
157 CNT.set(0);
158
159 let arena = TypedArena::default();
162 for _ in 0..REPEAT {
163 arena.alloc(Zst);
164 }
165 drop(arena);
166
167 assert_eq!(CNT.get(), REPEAT);
168 }};
169 }
170
171 test!(1, 1);
173 test!(1, 2);
174 test!(1, 4);
175 test!(1, 8);
176 test!(1, 16);
177 test!(1, 32);
178 test!(1, 64);
179 test!(1, 128);
180 test!(1, 256);
181
182 test!(100, 1);
183 test!(100, 2);
184 test!(100, 4);
185 test!(100, 8);
186 test!(100, 16);
187 test!(100, 32);
188 test!(100, 64);
189 test!(100, 128);
190 test!(100, 256);
191 }
192}