1use bumpalo::Bump;
2use std::{alloc::Layout, cell::Cell, marker::PhantomData, mem, ptr, slice};
3
4pub struct TypedArena<T> {
5 bump: Bump,
6 len: Cell<usize>,
7 _marker: PhantomData<T>,
8}
9
10impl<T> TypedArena<T> {
11 pub const fn len(&self) -> usize {
13 self.len.get()
14 }
15
16 pub const fn is_empty(&self) -> bool {
17 self.len() == 0
18 }
19
20 pub fn alloc(&self, value: T) -> &mut T {
21 self.len.set(self.len() + 1);
22 self.bump.alloc(value)
23 }
24
25 pub fn clear(&mut self) {
26 self.drop_all();
27 self.bump.reset();
28 self.len.set(0);
29 }
30
31 fn drop_all(&mut self) {
32 if mem::needs_drop::<T>() {
33 if size_of::<T>() > 0 {
34 let stride = Layout::new::<T>().pad_to_align().size();
35 for chunk in self.bump.iter_allocated_chunks() {
36 let num_elems = chunk.len() / stride;
39 let ptr = chunk.as_ptr().cast::<T>().cast_mut();
40 unsafe {
41 let slice = slice::from_raw_parts_mut(ptr, num_elems);
42 ptr::drop_in_place(slice);
43 }
44 }
45 } else {
46 let ptr = ptr::dangling_mut::<T>();
47 unsafe {
48 let slice = slice::from_raw_parts_mut(ptr, self.len());
49 ptr::drop_in_place(slice);
50 }
51 }
52 }
53 }
54}
55
56impl<T> Default for TypedArena<T> {
57 fn default() -> Self {
58 Self {
59 bump: Bump::new(),
60 len: Cell::new(0),
61 _marker: PhantomData,
62 }
63 }
64}
65
66impl<T> Drop for TypedArena<T> {
67 fn drop(&mut self) {
68 self.drop_all();
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_arena() {
78 test_arena_alloc();
79 test_arena_drop();
80 }
81
82 fn test_arena_alloc() {
83 const START: u32 = 0;
84 const END: u32 = 100;
85 const EXPECTED: u32 = (END + START) * (END - START + 1) / 2;
86
87 let arena = TypedArena::default();
88 let mut refs = Vec::new();
89 for i in START..=END {
90 let ref_ = arena.alloc(i);
91 refs.push(ref_);
92 }
93 let acc = refs.into_iter().map(|ref_| *ref_).sum::<u32>();
94 assert_eq!(acc, EXPECTED);
95 }
96
97 fn test_arena_drop() {
98 macro_rules! test {
99 ($arr_len:literal, $align:literal) => {{
100 thread_local! {
101 static SUM: Cell<u32> = Cell::new(0);
102 static CNT: Cell<u32> = Cell::new(0);
103 }
104
105 #[repr(align($align))]
106 struct A([u8; $arr_len]);
107
108 const _: () = const { assert!($arr_len < 256) };
110
111 impl A {
112 fn new() -> Self {
113 Self(std::array::from_fn(|i| i as u8))
114 }
115
116 fn sum() -> u32 {
117 ($arr_len - 1) * $arr_len / 2
118 }
119 }
120
121 impl Drop for A {
122 fn drop(&mut self) {
123 let sum = self.0.iter().map(|n| *n as u32).sum::<u32>();
124 SUM.set(SUM.get() + sum);
125 CNT.set(CNT.get() + 1);
126 }
127 }
128
129 struct Zst;
130
131 impl Drop for Zst {
132 fn drop(&mut self) {
133 CNT.set(CNT.get() + 1);
134 }
135 }
136
137 const REPEAT: u32 = 10;
138
139 let arena = TypedArena::default();
142 for _ in 0..REPEAT {
143 arena.alloc(A::new());
144 }
145 drop(arena);
146
147 assert_eq!(SUM.get(), A::sum() * REPEAT);
148 assert_eq!(CNT.get(), REPEAT);
149 SUM.set(0);
150 CNT.set(0);
151
152 let arena = TypedArena::default();
155 for _ in 0..REPEAT {
156 arena.alloc(Zst);
157 }
158 drop(arena);
159
160 assert_eq!(CNT.get(), REPEAT);
161 }};
162 }
163
164 test!(1, 1);
166 test!(1, 2);
167 test!(1, 4);
168 test!(1, 8);
169 test!(1, 16);
170 test!(1, 32);
171 test!(1, 64);
172 test!(1, 128);
173 test!(1, 256);
174
175 test!(100, 1);
176 test!(100, 2);
177 test!(100, 4);
178 test!(100, 8);
179 test!(100, 16);
180 test!(100, 32);
181 test!(100, 64);
182 test!(100, 128);
183 test!(100, 256);
184 }
185}