1#![cfg_attr(feature = "allocator_api", feature(allocator_api))]
2
3extern crate alloc;
46
47use std::{
48 cell::UnsafeCell,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53};
54
55use thread_local::ThreadLocal;
56
57mod error;
58pub use error::ResetError;
59
60#[cfg(any(feature = "allocator_api", feature = "allocator-api2"))]
61mod alloc_api;
62
63#[cfg(any(feature = "allocator_api", feature = "allocator-api2"))]
64pub use alloc_api::Allocator;
65
66struct ThreadGuard {
67 alive: Arc<AtomicBool>,
68}
69
70impl ThreadGuard {
71 fn new() -> Self {
72 Self {
73 alive: Arc::new(AtomicBool::new(true)),
74 }
75 }
76}
77
78impl Drop for ThreadGuard {
79 fn drop(&mut self) {
80 self.alive.store(false, Ordering::Release);
81 }
82}
83
84thread_local! {
85 static THREAD_GUARD: ThreadGuard = ThreadGuard::new();
86}
87
88#[derive(Default, Clone)]
92pub struct Bump {
93 inner: Arc<BumpInner>,
94}
95
96impl Bump {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn builder() -> BumpBuilder {
115 BumpBuilder::new()
116 }
117
118 #[inline]
126 pub fn local(&self) -> &BumpLocal {
127 self.inner.local()
128 }
129
130 #[inline]
139 pub fn reset_all(&mut self) -> Result<(), ResetError> {
140 match Arc::get_mut(&mut self.inner) {
141 Some(inner) => {
142 inner.reset_all();
143 Ok(())
144 }
145 None => Err(ResetError),
146 }
147 }
148}
149
150#[derive(Default)]
152pub struct BumpBuilder {
153 threads_capacity: Option<usize>,
154 bump_alloc_limit: Option<usize>,
155 bump_capacity: usize,
156}
157
158impl BumpBuilder {
159 pub fn new() -> Self {
161 Self::default()
162 }
163
164 pub fn threads_capacity(mut self, capacity: usize) -> Self {
169 self.threads_capacity = Some(capacity);
170 self
171 }
172
173 pub fn bump_allocation_limit(mut self, limit: usize) -> Self {
177 self.bump_alloc_limit = Some(limit);
178 self
179 }
180
181 pub fn bump_capacity(mut self, capacity: usize) -> Self {
186 self.bump_capacity = capacity;
187 self
188 }
189
190 pub fn build(self) -> Bump {
192 Bump {
193 inner: Arc::new(BumpInner {
194 locals: match self.threads_capacity {
195 Some(cap) => ThreadLocal::with_capacity(cap),
196 None => ThreadLocal::new(),
197 },
198 capacity: self.bump_capacity,
199 alloc_limit: self.bump_alloc_limit,
200 }),
201 }
202 }
203}
204
205pub struct BumpLocal {
207 inner: UnsafeCell<Option<BumpLocalInner>>,
208}
209
210impl BumpLocal {
211 fn new(capacity: usize, limit: Option<usize>, thread_alive: Arc<AtomicBool>) -> Self {
212 let bump = bumpalo::Bump::with_capacity(capacity);
213 bump.set_allocation_limit(limit);
214
215 Self {
216 inner: UnsafeCell::new(Some(BumpLocalInner {
217 inner: bump,
218 thread_alive,
219 })),
220 }
221 }
222
223 #[inline]
227 pub fn as_inner(&self) -> &bumpalo::Bump {
228 unsafe { &(*self.inner.get()).as_ref().unwrap().inner }
234 }
235
236 #[inline]
244 pub fn reset(&self) {
245 unsafe {
247 (*self.inner.get()).as_mut().unwrap().inner.reset();
248 }
249 }
250
251 #[inline]
252 fn needs_init(&self) -> bool {
253 unsafe { (*self.inner.get()).is_none() }
255 }
256
257 #[cold]
258 fn init(&self, capacity: usize, limit: Option<usize>, thread_alive: Arc<AtomicBool>) {
259 let bump = bumpalo::Bump::with_capacity(capacity);
260 bump.set_allocation_limit(limit);
261
262 unsafe {
264 *self.inner.get() = Some(BumpLocalInner {
265 inner: bump,
266 thread_alive,
267 })
268 }
269 }
270
271 #[cold]
272 fn clear(&mut self) {
273 #[cold]
274 fn drop_inner(bump: &mut BumpLocal) {
275 unsafe {
277 let _ = (*bump.inner.get()).take();
278 }
279 }
280
281 let inner = unsafe { &*self.inner.get() };
283 let Some(inner) = inner.as_ref() else {
284 return;
285 };
286
287 if inner.thread_alive.load(Ordering::Acquire) {
288 self.reset();
289 } else {
290 drop_inner(self);
291 }
292 }
293}
294
295struct BumpLocalInner {
296 inner: bumpalo::Bump,
297 thread_alive: Arc<AtomicBool>,
298}
299
300#[derive(Default)]
302struct BumpInner {
303 locals: ThreadLocal<BumpLocal>,
304 capacity: usize,
305 alloc_limit: Option<usize>,
306}
307
308impl BumpInner {
309 #[inline]
310 fn local(&self) -> &BumpLocal {
311 let bump = self.locals.get_or(|| {
312 let thread_alive = THREAD_GUARD.with(|guard| guard.alive.clone());
313 BumpLocal::new(self.capacity, self.alloc_limit, thread_alive)
314 });
315
316 if bump.needs_init() {
317 self.reinit_local(bump);
318 }
319
320 bump
321 }
322
323 #[cold]
324 fn reinit_local(&self, bump: &BumpLocal) {
325 let thread_alive = THREAD_GUARD.with(|guard| guard.alive.clone());
326 bump.init(self.capacity, self.alloc_limit, thread_alive);
327 }
328
329 #[inline]
330 fn reset_all(&mut self) {
331 for local in self.locals.iter_mut() {
332 local.clear();
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use std::thread;
340
341 use super::*;
342
343 #[test]
344 fn thread_guard_sets_alive_false_on_drop() {
345 let handle = thread::spawn(move || THREAD_GUARD.with(|g| g.alive.clone()));
346
347 let alive = handle.join().unwrap();
348 assert!(!alive.load(Ordering::Acquire));
349 }
350
351 #[test]
352 fn reset_resets_alive_thread() {
353 let mut bump = Bump::builder().bump_capacity(100).build();
354
355 let (tx, rx) = std::sync::mpsc::channel();
356 let handle = {
357 let bump = bump.clone();
358 thread::spawn(move || {
359 let _ = bump.local().as_inner().alloc(1_u8);
360 let capacity_before = bump.local().as_inner().chunk_capacity();
361 drop(bump);
362
363 tx.send(capacity_before).unwrap();
364 thread::park();
365 })
366 };
367
368 let capacity_before = rx.recv().unwrap();
369
370 bump.reset_all().unwrap();
372
373 let inner = Arc::get_mut(&mut bump.inner).unwrap();
375 let locals: Vec<_> = inner.locals.iter_mut().collect();
376 assert_eq!(locals.len(), 1);
377 let local = locals.first().unwrap();
378 assert!(!local.needs_init());
379 assert!(local.as_inner().chunk_capacity() > capacity_before);
380
381 handle.thread().unpark();
382 handle.join().unwrap();
383 }
384
385 #[test]
386 fn reset_drops_dead_thread_bump() {
387 let mut bump = Bump::builder().bump_capacity(100).build();
388
389 let handle = {
390 let bump = bump.clone();
391 thread::spawn(move || {
392 let _ = bump.local().as_inner().alloc(1_u8);
393 })
394 };
395
396 handle.join().unwrap();
397
398 bump.reset_all().unwrap();
400
401 let inner = Arc::get_mut(&mut bump.inner).unwrap();
402 let locals: Vec<_> = inner.locals.iter_mut().collect();
403 assert_eq!(locals.len(), 1);
404 let local = locals.first().unwrap();
405 assert!(local.needs_init());
406 }
407}