cubecl_common/arena.rs
1use std::{
2 cell::UnsafeCell,
3 marker::PhantomData,
4 sync::{
5 Arc,
6 atomic::{AtomicU32, Ordering},
7 },
8 vec::Vec,
9};
10
11/// The raw storage for an item, potentially uninitialized.
12///
13/// Aligned to 64 bytes (typical cache-line size) to prevent false sharing
14/// when different threads access adjacent slots concurrently.
15#[repr(C, align(64))]
16pub struct Bytes<const MAX_ITEM_SIZE: usize> {
17 bytes: [u8; MAX_ITEM_SIZE],
18}
19
20/// A circular, allocation-free arena for reusable memory blocks.
21///
22/// `Arena` manages a fixed-capacity pool of [`Bytes`] buffers, each up to
23/// `MAX_ITEM_SIZE` bytes. After the pool is lazily initialized, subsequent
24/// allocations scan from an internal cursor to find a free slot, avoiding
25/// further heap allocation.
26///
27/// # Const Parameters
28///
29/// - `MAX_ITEM_COUNT` — maximum number of buffers in the pool.
30/// - `MAX_ITEM_SIZE` — capacity of each individual buffer in bytes.
31///
32/// # How It Works
33///
34/// The arena maintains a vector of reference-counted buffer slots. When a
35/// caller requests memory, the arena advances its cursor through the pool
36/// looking for a slot whose reference count is zero, then hands back a
37/// [`Bytes`] handle to that slot. The cursor wraps around, giving the
38/// allocation pattern its circular behavior.
39///
40/// Because a `Bytes` handle can outlive the `Arena` itself (e.g. when the
41/// owning thread exits but the handle was sent elsewhere), each slot is
42/// wrapped in an `Arc` to keep the underlying storage alive. A separate
43/// `AtomicU32` reference count tracks logical ownership independently of
44/// the `Arc` strong count, so the arena can reliably detect which slots
45/// are free.
46///
47/// # Use Case
48///
49/// This is useful as a replacement for repeated `Arc<dyn Trait>` allocations.
50pub struct Arena<const MAX_ITEM_COUNT: usize, const MAX_ITEM_SIZE: usize> {
51 /// Backing storage for each slot. Wrapped in `Arc` so that a [`Bytes`]
52 /// handle remains valid even after the `Arena` (and its owning thread)
53 /// is dropped.
54 buffer: Vec<Arc<UnsafeCell<Bytes<MAX_ITEM_SIZE>>>>,
55 /// Logical reference counts, one per slot. Tracked separately from the
56 /// `Arc` strong count because the arena may be dropped while outstanding
57 /// `Bytes` handles still exist — the `Arc` keeps memory alive, but this
58 /// counter tells the arena whether a slot can be reclaimed.
59 ref_counts: Vec<Arc<AtomicU32>>,
60 /// Current scan position in the circular pool. Advanced on each
61 /// allocation attempt and wraps at `MAX_ITEM_COUNT`.
62 cursor: usize,
63}
64
65/// An initialized, immutable handle to a slot in the arena.
66///
67/// This type is `Send + Sync` and can be cheaply cloned. Each clone
68/// increments a logical reference count; when the last clone is dropped,
69/// the stored object's destructor runs and the slot becomes available for
70/// reuse by the arena.
71pub struct ReservedMemory<const MAX_ITEM_SIZE: usize> {
72 data: Arc<UnsafeCell<Bytes<MAX_ITEM_SIZE>>>,
73 ref_count: Arc<AtomicU32>,
74 drop_fn: fn(&mut Bytes<MAX_ITEM_SIZE>),
75}
76
77/// An uninitialized handle to a reserved arena slot.
78///
79/// Obtained from [`Arena::reserve`]. Must be initialized via [`init`](Self::init)
80/// to produce a usable [`ReservedMemory`].
81///
82/// This type is intentionally `!Send` and `!Sync` — it must be initialized on
83/// the same thread that reserved it.
84pub struct UninitReservedMemory<const MAX_ITEM_SIZE: usize> {
85 data: Arc<UnsafeCell<Bytes<MAX_ITEM_SIZE>>>,
86 ref_count: Arc<AtomicU32>,
87 /// Used to assert the position in the arena.
88 #[cfg(test)]
89 index: usize,
90 // Add this type to make sure the object is `!Sync`.
91 not_sync: PhantomData<*const ()>,
92}
93
94impl<const MAX_ITEM_SIZE: usize> UninitReservedMemory<MAX_ITEM_SIZE> {
95 /// Initialize the reserved memory.
96 ///
97 /// # Panics
98 ///
99 /// If the given object isn't safe to store in this arena.
100 pub fn init<O>(self, obj: O) -> ReservedMemory<MAX_ITEM_SIZE> {
101 assert!(
102 accept_obj::<O, MAX_ITEM_SIZE>(),
103 "Object isn't safe to store in this arena"
104 );
105
106 self.init_with_func(
107 |bytes| {
108 let ptr = core::ptr::from_mut(bytes);
109 unsafe {
110 core::ptr::write(ptr as *mut O, obj);
111 };
112 },
113 |bytes| {
114 let ptr = core::ptr::from_mut(bytes);
115 unsafe {
116 core::ptr::drop_in_place(ptr as *mut O);
117 }
118 },
119 )
120 }
121
122 /// Writes to the reserved slot using `init_data` and attaches `drop_fn`
123 /// as the destructor to run when the last [`ReservedMemory`] clone is dropped.
124 fn init_with_func<F>(
125 self,
126 init_data: F,
127 drop_fn: fn(&mut Bytes<MAX_ITEM_SIZE>),
128 ) -> ReservedMemory<MAX_ITEM_SIZE>
129 where
130 F: FnOnce(&mut Bytes<MAX_ITEM_SIZE>),
131 {
132 // SAFETY: We access the `UnsafeCell` contents mutably. This is sound
133 // because strong_count == 2 means only two owners exist: the arena's
134 // buffer slot and this `UninitReservedMemory`. The arena never reads
135 // through the `UnsafeCell` — only the holder of `UninitReservedMemory`
136 // writes, so there is no data race.
137 assert_eq!(
138 Arc::strong_count(&self.data),
139 2,
140 "Slot must be held by exactly two owners (the arena and this \
141 UninitReservedMemory) to guarantee exclusive write access."
142 );
143
144 let bytes_mut = unsafe { self.data.as_ref().get().as_mut().unwrap() };
145 init_data(bytes_mut);
146
147 ReservedMemory {
148 data: self.data,
149 ref_count: self.ref_count,
150 drop_fn,
151 }
152 }
153}
154
155impl<const MAX_ITEM_SIZE: usize> core::fmt::Debug for ReservedMemory<MAX_ITEM_SIZE> {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("ReservedMemory")
158 .field("data", &self.data)
159 .field("drop_fn", &self.drop_fn)
160 .finish()
161 }
162}
163
164impl<const MAX_ITEM_SIZE: usize> Clone for ReservedMemory<MAX_ITEM_SIZE> {
165 fn clone(&self) -> Self {
166 self.ref_count.fetch_add(1, Ordering::Release);
167
168 Self {
169 data: self.data.clone(),
170 ref_count: self.ref_count.clone(),
171 drop_fn: self.drop_fn,
172 }
173 }
174}
175
176impl<const MAX_ITEM_SIZE: usize> Drop for ReservedMemory<MAX_ITEM_SIZE> {
177 fn drop(&mut self) {
178 // Ref-count lifecycle:
179 // reserve() → stores 1 (arena holds the slot)
180 // init() → consumes UninitReservedMemory, inherits count 1
181 // clone() → fetch_add (count grows with each clone: 2, 3, …)
182 // drop() → fetch_sub (count shrinks)
183 //
184 // When `fetch_sub` returns 2 it means the count just moved from 2→1,
185 // and we are the last `ReservedMemory` clone — the remaining "1" is
186 // the arena's own ref-count baseline. At this point no other clone
187 // can access the data, so we can safely run the destructor.
188 //
189 // If the arena has already been dropped, the same logic applies: the
190 // last clone still sees previous == 2 because the arena never
191 // decrements the logical ref_count — only `ReservedMemory::drop` does.
192 let drop_fn = || {
193 // SAFETY: We are the last user of this slot. The data pointer is valid,
194 // initialized, and no other `ReservedMemory` clone exists.
195 let bytes_mut = unsafe { self.data.get().as_mut().unwrap() };
196 (self.drop_fn)(bytes_mut);
197 };
198
199 let previous = self.ref_count.fetch_sub(1, Ordering::Release);
200
201 if previous == 2 {
202 drop_fn();
203 }
204 }
205}
206
207// SAFETY: After initialization, the data behind `ReservedMemory` is immutable
208// (no `&mut` access is possible while any clone exists). The logical ref_count
209// is an `AtomicU32` with proper ordering, and the backing `Arc` guarantees the
210// storage outlives all handles. These together satisfy the `Send` and `Sync`
211// contracts.
212unsafe impl<const MAX_ITEM_SIZE: usize> Send for ReservedMemory<MAX_ITEM_SIZE> {}
213unsafe impl<const MAX_ITEM_SIZE: usize> Sync for ReservedMemory<MAX_ITEM_SIZE> {}
214
215impl<const MAX_ITEM_SIZE: usize> std::convert::AsRef<Bytes<MAX_ITEM_SIZE>>
216 for ReservedMemory<MAX_ITEM_SIZE>
217{
218 /// Gets the reserved bytes.
219 fn as_ref(&self) -> &Bytes<MAX_ITEM_SIZE> {
220 // The pointer is valid and the data is readonly.
221 unsafe { self.data.as_ref().get().as_ref().unwrap() }
222 }
223}
224
225impl<const MAX_ITEM_COUNT: usize, const MAX_ITEM_SIZE: usize> Default
226 for Arena<MAX_ITEM_COUNT, MAX_ITEM_SIZE>
227{
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl<const MAX_ITEM_COUNT: usize, const MAX_ITEM_SIZE: usize> Arena<MAX_ITEM_COUNT, MAX_ITEM_SIZE> {
234 /// Creates a new, empty `Arena`.
235 ///
236 /// The internal buffer is not allocated until the first call to [`reserve`](Self::reserve).
237 pub const fn new() -> Self {
238 Self {
239 buffer: Vec::new(),
240 ref_counts: Vec::new(),
241 cursor: 0,
242 }
243 }
244
245 /// Returns `true` if an object of type `O` fits within a single slot.
246 ///
247 /// Checks that both the size and alignment of `O` are compatible with
248 /// [`Bytes<MAX_ITEM_SIZE>`].
249 pub const fn accept<O>() -> bool {
250 accept_obj::<O, MAX_ITEM_SIZE>()
251 }
252
253 /// Attempts to reserve an uninitialized slot in the arena.
254 ///
255 /// On the first call, the internal buffer is lazily allocated to
256 /// `MAX_ITEM_COUNT` slots. Subsequent calls scan from the current cursor
257 /// position, wrapping around circularly, looking for a slot whose backing
258 /// `Arc` has a strong count of 1 (meaning no outstanding
259 /// [`ReservedMemory`] handles reference it).
260 ///
261 /// # Returns
262 ///
263 /// - `Some(UninitReservedMemory)` — a handle to the reserved slot, ready
264 /// to be initialized via [`UninitReservedMemory::init`].
265 /// - `None` — all slots are currently in use.
266 pub fn reserve(&mut self) -> Option<UninitReservedMemory<MAX_ITEM_SIZE>> {
267 if self.buffer.is_empty() {
268 for _ in 0..MAX_ITEM_COUNT {
269 self.ref_counts.push(Arc::new(AtomicU32::new(0)));
270
271 // Here we need to disable the clippy warning since we manually ensure the type is
272 // send sync and we need to wrap it in an Arc because the bytes might outlive the
273 // current arena.
274 #[allow(clippy::arc_with_non_send_sync)]
275 self.buffer.push(Arc::new(UnsafeCell::new(Bytes {
276 bytes: [0; MAX_ITEM_SIZE],
277 })));
278 }
279 }
280
281 for i in 0..MAX_ITEM_COUNT {
282 let i = (i + self.cursor) % MAX_ITEM_COUNT;
283 let item = &self.buffer[i];
284
285 // SAFETY: `Arc::strong_count` is not synchronized, but this is safe
286 // because `reserve` takes `&mut self`, guaranteeing single-threaded
287 // access to the arena side. The only concurrent mutation is a
288 // `ReservedMemory` being dropped on another thread, which performs a
289 // `Release`-ordered `Arc::drop` before the strong count decrements.
290 // A stale (too-high) read here is harmless — we simply skip a slot
291 // that is actually free, and will find it on the next call.
292 if Arc::strong_count(item) == 1 {
293 self.cursor = (i + 1) % MAX_ITEM_COUNT;
294 let data = item.clone();
295 let ref_count = self.ref_counts[i].clone();
296 ref_count.store(1, Ordering::Release);
297
298 return Some(UninitReservedMemory {
299 data,
300 ref_count,
301 #[cfg(test)]
302 index: i,
303 not_sync: PhantomData,
304 });
305 }
306 }
307
308 None
309 }
310}
311
312const fn accept_obj<O, const MAX_ITEM_SIZE: usize>() -> bool {
313 size_of::<O>() <= size_of::<Bytes<MAX_ITEM_SIZE>>()
314 && align_of::<O>() <= align_of::<Bytes<MAX_ITEM_SIZE>>()
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 const MAX_ITEM_SIZE: usize = 2048;
322
323 #[test]
324 fn test_lazy_initialization() {
325 let mut arena = Arena::<10, MAX_ITEM_SIZE>::new();
326 assert_eq!(
327 arena.buffer.len(),
328 0,
329 "Buffer should be empty before first reservation"
330 );
331
332 arena.reserve();
333
334 assert_eq!(
335 arena.buffer.len(),
336 10,
337 "Buffer should be initialized to size"
338 );
339 }
340
341 #[test]
342 fn test_sequential_allocation_moves_cursor() {
343 let mut arena = Arena::<3, MAX_ITEM_SIZE>::new();
344
345 // First allocation
346 let _ = arena.reserve().expect("Should allocate");
347 assert_eq!(arena.cursor, 1);
348
349 // Second allocation
350 let _ = arena.reserve().expect("Should allocate");
351 assert_eq!(arena.cursor, 2);
352 }
353
354 #[test]
355 fn test_reuse_of_freed_data() {
356 let mut arena = Arena::<2, MAX_ITEM_SIZE>::new();
357
358 // Fill the arena
359 let data0 = arena.reserve().unwrap();
360 let _data1 = arena.reserve().unwrap();
361
362 // Arena is now full (counts are 2)
363 assert!(arena.reserve().is_none(), "Should be full");
364
365 // Manually "free" index 0 by setting count to 0 (simulating ManagedOperation drop)
366 let data0_index = data0.index;
367 core::mem::drop(data0);
368
369 // Should now be able to reserve again, and it should pick up index 0
370 let data2 = arena.reserve().expect("Should reuse index 0");
371 assert_eq!(data0_index, data2.index);
372 }
373
374 #[test]
375 fn test_circular_cursor_search() {
376 let mut arena = Arena::<3, MAX_ITEM_SIZE>::new();
377
378 // Fill 0, 1, 2
379 let _d0 = arena.reserve().unwrap();
380 let d1 = arena.reserve().unwrap();
381 let _d2 = arena.reserve().unwrap();
382
383 // Free index 1 (the middle)
384 core::mem::drop(d1);
385
386 // Currently cursor is at 2. The search starts at (cursor + i) % size.
387 // It should wrap around and find index 1.
388 let _ = arena.reserve().expect("Should find the hole at index 1");
389 assert_eq!(arena.cursor, 2);
390 }
391
392 #[test]
393 fn test_full_arena_returns_none() {
394 let mut arena = Arena::<5, MAX_ITEM_SIZE>::new();
395
396 let mut reserved = Vec::new();
397
398 for _ in 0..5 {
399 let item = arena.reserve();
400 assert!(item.is_some());
401 reserved.push(item);
402 }
403
404 // Next one should fail
405 assert!(arena.reserve().is_none());
406 }
407}
408
409#[cfg(test)]
410mod concurrent_tests {
411 use super::*;
412 use std::sync::{Arc, Barrier, Mutex};
413 use std::{thread, vec};
414
415 const MAX_ITEM_SIZE: usize = 2048;
416
417 /// Wraps an arena in a Mutex for shared cross-thread access.
418 fn shared_arena<const N: usize>() -> Arc<Mutex<Arena<N, MAX_ITEM_SIZE>>> {
419 #[allow(clippy::arc_with_non_send_sync)]
420 Arc::new(Mutex::new(Arena::<N, MAX_ITEM_SIZE>::new()))
421 }
422
423 /// Verifies that `drop_fn` is called exactly once even when multiple threads
424 /// hold clones and release them concurrently.
425 #[test]
426 fn test_drop_called_exactly_once_under_contention() {
427 let drop_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
428 let arena = shared_arena::<4>();
429
430 let uninit = arena.lock().unwrap().reserve().unwrap();
431
432 struct Probe(Arc<std::sync::atomic::AtomicUsize>);
433 impl Drop for Probe {
434 fn drop(&mut self) {
435 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
436 }
437 }
438
439 let reserved = uninit.init(Probe(drop_count.clone()));
440
441 // Spawn 32 threads, each clones and drops ReservedMemory concurrently.
442 let barrier = Arc::new(Barrier::new(32));
443 let mut handles = vec![];
444
445 for _ in 0..32 {
446 let r = reserved.clone();
447 let b = barrier.clone();
448 handles.push(thread::spawn(move || {
449 b.wait(); // all threads drop at the same time
450 drop(r);
451 }));
452 }
453
454 drop(reserved); // drop the original too
455 for h in handles {
456 h.join().unwrap();
457 }
458
459 assert_eq!(
460 drop_count.load(std::sync::atomic::Ordering::Relaxed),
461 1,
462 "drop_fn must be called exactly once"
463 );
464 }
465
466 /// Verifies that a slot becomes available for reuse after all `ReservedMemory`
467 /// clones are dropped across threads.
468 #[test]
469 fn test_slot_reuse_after_concurrent_drop() {
470 let arena = shared_arena::<1>();
471 let uninit = arena.lock().unwrap().reserve().unwrap();
472 let reserved = uninit.init(42u64);
473
474 let barrier = Arc::new(Barrier::new(8));
475 let mut handles = vec![];
476
477 for _ in 0..8 {
478 let r = reserved.clone();
479 let b = barrier.clone();
480 handles.push(thread::spawn(move || {
481 b.wait();
482 drop(r);
483 }));
484 }
485
486 drop(reserved);
487 for h in handles {
488 h.join().unwrap();
489 }
490
491 // All clones dropped — the single slot should be free again.
492 assert!(
493 arena.lock().unwrap().reserve().is_some(),
494 "Slot should be available after all clones are dropped"
495 );
496 }
497
498 /// Verifies that `ReservedMemory` clones dropped after the arena is dropped
499 /// still correctly run `drop_fn` (the count == 1 case).
500 #[test]
501 fn test_drop_after_arena_dropped() {
502 let drop_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
503
504 struct Probe(Arc<std::sync::atomic::AtomicUsize>);
505 impl Drop for Probe {
506 fn drop(&mut self) {
507 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
508 }
509 }
510
511 let reserved = {
512 let mut arena = Arena::<4, MAX_ITEM_SIZE>::new();
513 let uninit = arena.reserve().unwrap();
514 uninit.init(Probe(drop_count.clone()))
515 // arena drops here
516 };
517
518 // Spawn threads that hold clones past the arena's lifetime.
519 let barrier = Arc::new(Barrier::new(8));
520 let mut handles = vec![];
521
522 for _ in 0..8 {
523 let r = reserved.clone();
524 let b = barrier.clone();
525 handles.push(thread::spawn(move || {
526 b.wait();
527 drop(r);
528 }));
529 }
530
531 drop(reserved);
532 for h in handles {
533 h.join().unwrap();
534 }
535
536 assert_eq!(
537 drop_count.load(std::sync::atomic::Ordering::Relaxed),
538 1,
539 "drop_fn must fire exactly once even when arena is dropped first"
540 );
541 }
542}