Skip to main content

cubecl_common/
arena.rs

1use std::{
2    cell::UnsafeCell,
3    marker::PhantomData,
4    sync::{
5        Arc,
6        atomic::{AtomicU32, Ordering},
7    },
8    vec::Vec,
9};
10
11/// The raw storage for an item, potentially uninitialized.
12///
13/// Aligned to 64 bytes (typical cache-line size) to prevent false sharing
14/// when different threads access adjacent slots concurrently.
15#[repr(C, align(64))]
16pub struct Bytes<const MAX_ITEM_SIZE: usize> {
17    bytes: [u8; MAX_ITEM_SIZE],
18}
19
20/// A circular, allocation-free arena for reusable memory blocks.
21///
22/// `Arena` manages a fixed-capacity pool of [`Bytes`] buffers, each up to
23/// `MAX_ITEM_SIZE` bytes. After the pool is lazily initialized, subsequent
24/// allocations scan from an internal cursor to find a free slot, avoiding
25/// further heap allocation.
26///
27/// # Const Parameters
28///
29/// - `MAX_ITEM_COUNT` — maximum number of buffers in the pool.
30/// - `MAX_ITEM_SIZE` — capacity of each individual buffer in bytes.
31///
32/// # How It Works
33///
34/// The arena maintains a vector of reference-counted buffer slots. When a
35/// caller requests memory, the arena advances its cursor through the pool
36/// looking for a slot whose reference count is zero, then hands back a
37/// [`Bytes`] handle to that slot. The cursor wraps around, giving the
38/// allocation pattern its circular behavior.
39///
40/// Because a `Bytes` handle can outlive the `Arena` itself (e.g. when the
41/// owning thread exits but the handle was sent elsewhere), each slot is
42/// wrapped in an `Arc` to keep the underlying storage alive. A separate
43/// `AtomicU32` reference count tracks logical ownership independently of
44/// the `Arc` strong count, so the arena can reliably detect which slots
45/// are free.
46///
47/// # Use Case
48///
49/// This is useful as a replacement for repeated `Arc<dyn Trait>` allocations.
50pub struct Arena<const MAX_ITEM_COUNT: usize, const MAX_ITEM_SIZE: usize> {
51    /// Backing storage for each slot. Wrapped in `Arc` so that a [`Bytes`]
52    /// handle remains valid even after the `Arena` (and its owning thread)
53    /// is dropped.
54    buffer: Vec<Arc<UnsafeCell<Bytes<MAX_ITEM_SIZE>>>>,
55    /// Logical reference counts, one per slot. Tracked separately from the
56    /// `Arc` strong count because the arena may be dropped while outstanding
57    /// `Bytes` handles still exist — the `Arc` keeps memory alive, but this
58    /// counter tells the arena whether a slot can be reclaimed.
59    ref_counts: Vec<Arc<AtomicU32>>,
60    /// Current scan position in the circular pool. Advanced on each
61    /// allocation attempt and wraps at `MAX_ITEM_COUNT`.
62    cursor: usize,
63}
64
65/// An initialized, immutable handle to a slot in the arena.
66///
67/// This type is `Send + Sync` and can be cheaply cloned. Each clone
68/// increments a logical reference count; when the last clone is dropped,
69/// the stored object's destructor runs and the slot becomes available for
70/// reuse by the arena.
71pub struct ReservedMemory<const MAX_ITEM_SIZE: usize> {
72    data: Arc<UnsafeCell<Bytes<MAX_ITEM_SIZE>>>,
73    ref_count: Arc<AtomicU32>,
74    drop_fn: fn(&mut Bytes<MAX_ITEM_SIZE>),
75}
76
77/// An uninitialized handle to a reserved arena slot.
78///
79/// Obtained from [`Arena::reserve`]. Must be initialized via [`init`](Self::init)
80/// to produce a usable [`ReservedMemory`].
81///
82/// This type is intentionally `!Send` and `!Sync` — it must be initialized on
83/// the same thread that reserved it.
84pub struct UninitReservedMemory<const MAX_ITEM_SIZE: usize> {
85    data: Arc<UnsafeCell<Bytes<MAX_ITEM_SIZE>>>,
86    ref_count: Arc<AtomicU32>,
87    /// Used to assert the position in the arena.
88    #[cfg(test)]
89    index: usize,
90    // Add this type to make sure the object is `!Sync`.
91    not_sync: PhantomData<*const ()>,
92}
93
94impl<const MAX_ITEM_SIZE: usize> UninitReservedMemory<MAX_ITEM_SIZE> {
95    /// Initialize the reserved memory.
96    ///
97    /// # Panics
98    ///
99    /// If the given object isn't safe to store in this arena.
100    pub fn init<O>(self, obj: O) -> ReservedMemory<MAX_ITEM_SIZE> {
101        assert!(
102            accept_obj::<O, MAX_ITEM_SIZE>(),
103            "Object isn't safe to store in this arena"
104        );
105
106        self.init_with_func(
107            |bytes| {
108                let ptr = core::ptr::from_mut(bytes);
109                unsafe {
110                    core::ptr::write(ptr as *mut O, obj);
111                };
112            },
113            |bytes| {
114                let ptr = core::ptr::from_mut(bytes);
115                unsafe {
116                    core::ptr::drop_in_place(ptr as *mut O);
117                }
118            },
119        )
120    }
121
122    /// Writes to the reserved slot using `init_data` and attaches `drop_fn`
123    /// as the destructor to run when the last [`ReservedMemory`] clone is dropped.
124    fn init_with_func<F>(
125        self,
126        init_data: F,
127        drop_fn: fn(&mut Bytes<MAX_ITEM_SIZE>),
128    ) -> ReservedMemory<MAX_ITEM_SIZE>
129    where
130        F: FnOnce(&mut Bytes<MAX_ITEM_SIZE>),
131    {
132        // SAFETY: We access the `UnsafeCell` contents mutably. This is sound
133        // because strong_count == 2 means only two owners exist: the arena's
134        // buffer slot and this `UninitReservedMemory`. The arena never reads
135        // through the `UnsafeCell` — only the holder of `UninitReservedMemory`
136        // writes, so there is no data race.
137        assert_eq!(
138            Arc::strong_count(&self.data),
139            2,
140            "Slot must be held by exactly two owners (the arena and this \
141             UninitReservedMemory) to guarantee exclusive write access."
142        );
143
144        let bytes_mut = unsafe { self.data.as_ref().get().as_mut().unwrap() };
145        init_data(bytes_mut);
146
147        ReservedMemory {
148            data: self.data,
149            ref_count: self.ref_count,
150            drop_fn,
151        }
152    }
153}
154
155impl<const MAX_ITEM_SIZE: usize> core::fmt::Debug for ReservedMemory<MAX_ITEM_SIZE> {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("ReservedMemory")
158            .field("data", &self.data)
159            .field("drop_fn", &self.drop_fn)
160            .finish()
161    }
162}
163
164impl<const MAX_ITEM_SIZE: usize> Clone for ReservedMemory<MAX_ITEM_SIZE> {
165    fn clone(&self) -> Self {
166        self.ref_count.fetch_add(1, Ordering::Release);
167
168        Self {
169            data: self.data.clone(),
170            ref_count: self.ref_count.clone(),
171            drop_fn: self.drop_fn,
172        }
173    }
174}
175
176impl<const MAX_ITEM_SIZE: usize> Drop for ReservedMemory<MAX_ITEM_SIZE> {
177    fn drop(&mut self) {
178        // Ref-count lifecycle:
179        //   reserve()  → stores 1   (arena holds the slot)
180        //   init()     → consumes UninitReservedMemory, inherits count 1
181        //   clone()    → fetch_add  (count grows with each clone: 2, 3, …)
182        //   drop()     → fetch_sub  (count shrinks)
183        //
184        // When `fetch_sub` returns 2 it means the count just moved from 2→1,
185        // and we are the last `ReservedMemory` clone — the remaining "1" is
186        // the arena's own ref-count baseline. At this point no other clone
187        // can access the data, so we can safely run the destructor.
188        //
189        // If the arena has already been dropped, the same logic applies: the
190        // last clone still sees previous == 2 because the arena never
191        // decrements the logical ref_count — only `ReservedMemory::drop` does.
192        let drop_fn = || {
193            // SAFETY: We are the last user of this slot. The data pointer is valid,
194            // initialized, and no other `ReservedMemory` clone exists.
195            let bytes_mut = unsafe { self.data.get().as_mut().unwrap() };
196            (self.drop_fn)(bytes_mut);
197        };
198
199        let previous = self.ref_count.fetch_sub(1, Ordering::Release);
200
201        if previous == 2 {
202            drop_fn();
203        }
204    }
205}
206
207// SAFETY: After initialization, the data behind `ReservedMemory` is immutable
208// (no `&mut` access is possible while any clone exists). The logical ref_count
209// is an `AtomicU32` with proper ordering, and the backing `Arc` guarantees the
210// storage outlives all handles. These together satisfy the `Send` and `Sync`
211// contracts.
212unsafe impl<const MAX_ITEM_SIZE: usize> Send for ReservedMemory<MAX_ITEM_SIZE> {}
213unsafe impl<const MAX_ITEM_SIZE: usize> Sync for ReservedMemory<MAX_ITEM_SIZE> {}
214
215impl<const MAX_ITEM_SIZE: usize> std::convert::AsRef<Bytes<MAX_ITEM_SIZE>>
216    for ReservedMemory<MAX_ITEM_SIZE>
217{
218    /// Gets the reserved bytes.
219    fn as_ref(&self) -> &Bytes<MAX_ITEM_SIZE> {
220        // The pointer is valid and the data is readonly.
221        unsafe { self.data.as_ref().get().as_ref().unwrap() }
222    }
223}
224
225impl<const MAX_ITEM_COUNT: usize, const MAX_ITEM_SIZE: usize> Default
226    for Arena<MAX_ITEM_COUNT, MAX_ITEM_SIZE>
227{
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl<const MAX_ITEM_COUNT: usize, const MAX_ITEM_SIZE: usize> Arena<MAX_ITEM_COUNT, MAX_ITEM_SIZE> {
234    /// Creates a new, empty `Arena`.
235    ///
236    /// The internal buffer is not allocated until the first call to [`reserve`](Self::reserve).
237    pub const fn new() -> Self {
238        Self {
239            buffer: Vec::new(),
240            ref_counts: Vec::new(),
241            cursor: 0,
242        }
243    }
244
245    /// Returns `true` if an object of type `O` fits within a single slot.
246    ///
247    /// Checks that both the size and alignment of `O` are compatible with
248    /// [`Bytes<MAX_ITEM_SIZE>`].
249    pub const fn accept<O>() -> bool {
250        accept_obj::<O, MAX_ITEM_SIZE>()
251    }
252
253    /// Attempts to reserve an uninitialized slot in the arena.
254    ///
255    /// On the first call, the internal buffer is lazily allocated to
256    /// `MAX_ITEM_COUNT` slots. Subsequent calls scan from the current cursor
257    /// position, wrapping around circularly, looking for a slot whose backing
258    /// `Arc` has a strong count of 1 (meaning no outstanding
259    /// [`ReservedMemory`] handles reference it).
260    ///
261    /// # Returns
262    ///
263    /// - `Some(UninitReservedMemory)` — a handle to the reserved slot, ready
264    ///   to be initialized via [`UninitReservedMemory::init`].
265    /// - `None` — all slots are currently in use.
266    pub fn reserve(&mut self) -> Option<UninitReservedMemory<MAX_ITEM_SIZE>> {
267        if self.buffer.is_empty() {
268            for _ in 0..MAX_ITEM_COUNT {
269                self.ref_counts.push(Arc::new(AtomicU32::new(0)));
270
271                // Here we need to disable the clippy warning since we manually ensure the type is
272                // send sync and we need to wrap it in an Arc because the bytes might outlive the
273                // current arena.
274                #[allow(clippy::arc_with_non_send_sync)]
275                self.buffer.push(Arc::new(UnsafeCell::new(Bytes {
276                    bytes: [0; MAX_ITEM_SIZE],
277                })));
278            }
279        }
280
281        for i in 0..MAX_ITEM_COUNT {
282            let i = (i + self.cursor) % MAX_ITEM_COUNT;
283            let item = &self.buffer[i];
284
285            // SAFETY: `Arc::strong_count` is not synchronized, but this is safe
286            // because `reserve` takes `&mut self`, guaranteeing single-threaded
287            // access to the arena side. The only concurrent mutation is a
288            // `ReservedMemory` being dropped on another thread, which performs a
289            // `Release`-ordered `Arc::drop` before the strong count decrements.
290            // A stale (too-high) read here is harmless — we simply skip a slot
291            // that is actually free, and will find it on the next call.
292            if Arc::strong_count(item) == 1 {
293                self.cursor = (i + 1) % MAX_ITEM_COUNT;
294                let data = item.clone();
295                let ref_count = self.ref_counts[i].clone();
296                ref_count.store(1, Ordering::Release);
297
298                return Some(UninitReservedMemory {
299                    data,
300                    ref_count,
301                    #[cfg(test)]
302                    index: i,
303                    not_sync: PhantomData,
304                });
305            }
306        }
307
308        None
309    }
310}
311
312const fn accept_obj<O, const MAX_ITEM_SIZE: usize>() -> bool {
313    size_of::<O>() <= size_of::<Bytes<MAX_ITEM_SIZE>>()
314        && align_of::<O>() <= align_of::<Bytes<MAX_ITEM_SIZE>>()
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    const MAX_ITEM_SIZE: usize = 2048;
322
323    #[test]
324    fn test_lazy_initialization() {
325        let mut arena = Arena::<10, MAX_ITEM_SIZE>::new();
326        assert_eq!(
327            arena.buffer.len(),
328            0,
329            "Buffer should be empty before first reservation"
330        );
331
332        arena.reserve();
333
334        assert_eq!(
335            arena.buffer.len(),
336            10,
337            "Buffer should be initialized to size"
338        );
339    }
340
341    #[test]
342    fn test_sequential_allocation_moves_cursor() {
343        let mut arena = Arena::<3, MAX_ITEM_SIZE>::new();
344
345        // First allocation
346        let _ = arena.reserve().expect("Should allocate");
347        assert_eq!(arena.cursor, 1);
348
349        // Second allocation
350        let _ = arena.reserve().expect("Should allocate");
351        assert_eq!(arena.cursor, 2);
352    }
353
354    #[test]
355    fn test_reuse_of_freed_data() {
356        let mut arena = Arena::<2, MAX_ITEM_SIZE>::new();
357
358        // Fill the arena
359        let data0 = arena.reserve().unwrap();
360        let _data1 = arena.reserve().unwrap();
361
362        // Arena is now full (counts are 2)
363        assert!(arena.reserve().is_none(), "Should be full");
364
365        // Manually "free" index 0 by setting count to 0 (simulating ManagedOperation drop)
366        let data0_index = data0.index;
367        core::mem::drop(data0);
368
369        // Should now be able to reserve again, and it should pick up index 0
370        let data2 = arena.reserve().expect("Should reuse index 0");
371        assert_eq!(data0_index, data2.index);
372    }
373
374    #[test]
375    fn test_circular_cursor_search() {
376        let mut arena = Arena::<3, MAX_ITEM_SIZE>::new();
377
378        // Fill 0, 1, 2
379        let _d0 = arena.reserve().unwrap();
380        let d1 = arena.reserve().unwrap();
381        let _d2 = arena.reserve().unwrap();
382
383        // Free index 1 (the middle)
384        core::mem::drop(d1);
385
386        // Currently cursor is at 2. The search starts at (cursor + i) % size.
387        // It should wrap around and find index 1.
388        let _ = arena.reserve().expect("Should find the hole at index 1");
389        assert_eq!(arena.cursor, 2);
390    }
391
392    #[test]
393    fn test_full_arena_returns_none() {
394        let mut arena = Arena::<5, MAX_ITEM_SIZE>::new();
395
396        let mut reserved = Vec::new();
397
398        for _ in 0..5 {
399            let item = arena.reserve();
400            assert!(item.is_some());
401            reserved.push(item);
402        }
403
404        // Next one should fail
405        assert!(arena.reserve().is_none());
406    }
407}
408
409#[cfg(test)]
410mod concurrent_tests {
411    use super::*;
412    use std::sync::{Arc, Barrier, Mutex};
413    use std::{thread, vec};
414
415    const MAX_ITEM_SIZE: usize = 2048;
416
417    /// Wraps an arena in a Mutex for shared cross-thread access.
418    fn shared_arena<const N: usize>() -> Arc<Mutex<Arena<N, MAX_ITEM_SIZE>>> {
419        #[allow(clippy::arc_with_non_send_sync)]
420        Arc::new(Mutex::new(Arena::<N, MAX_ITEM_SIZE>::new()))
421    }
422
423    /// Verifies that `drop_fn` is called exactly once even when multiple threads
424    /// hold clones and release them concurrently.
425    #[test]
426    fn test_drop_called_exactly_once_under_contention() {
427        let drop_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
428        let arena = shared_arena::<4>();
429
430        let uninit = arena.lock().unwrap().reserve().unwrap();
431
432        struct Probe(Arc<std::sync::atomic::AtomicUsize>);
433        impl Drop for Probe {
434            fn drop(&mut self) {
435                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
436            }
437        }
438
439        let reserved = uninit.init(Probe(drop_count.clone()));
440
441        // Spawn 32 threads, each clones and drops ReservedMemory concurrently.
442        let barrier = Arc::new(Barrier::new(32));
443        let mut handles = vec![];
444
445        for _ in 0..32 {
446            let r = reserved.clone();
447            let b = barrier.clone();
448            handles.push(thread::spawn(move || {
449                b.wait(); // all threads drop at the same time
450                drop(r);
451            }));
452        }
453
454        drop(reserved); // drop the original too
455        for h in handles {
456            h.join().unwrap();
457        }
458
459        assert_eq!(
460            drop_count.load(std::sync::atomic::Ordering::Relaxed),
461            1,
462            "drop_fn must be called exactly once"
463        );
464    }
465
466    /// Verifies that a slot becomes available for reuse after all `ReservedMemory`
467    /// clones are dropped across threads.
468    #[test]
469    fn test_slot_reuse_after_concurrent_drop() {
470        let arena = shared_arena::<1>();
471        let uninit = arena.lock().unwrap().reserve().unwrap();
472        let reserved = uninit.init(42u64);
473
474        let barrier = Arc::new(Barrier::new(8));
475        let mut handles = vec![];
476
477        for _ in 0..8 {
478            let r = reserved.clone();
479            let b = barrier.clone();
480            handles.push(thread::spawn(move || {
481                b.wait();
482                drop(r);
483            }));
484        }
485
486        drop(reserved);
487        for h in handles {
488            h.join().unwrap();
489        }
490
491        // All clones dropped — the single slot should be free again.
492        assert!(
493            arena.lock().unwrap().reserve().is_some(),
494            "Slot should be available after all clones are dropped"
495        );
496    }
497
498    /// Verifies that `ReservedMemory` clones dropped after the arena is dropped
499    /// still correctly run `drop_fn` (the count == 1 case).
500    #[test]
501    fn test_drop_after_arena_dropped() {
502        let drop_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
503
504        struct Probe(Arc<std::sync::atomic::AtomicUsize>);
505        impl Drop for Probe {
506            fn drop(&mut self) {
507                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
508            }
509        }
510
511        let reserved = {
512            let mut arena = Arena::<4, MAX_ITEM_SIZE>::new();
513            let uninit = arena.reserve().unwrap();
514            uninit.init(Probe(drop_count.clone()))
515            // arena drops here
516        };
517
518        // Spawn threads that hold clones past the arena's lifetime.
519        let barrier = Arc::new(Barrier::new(8));
520        let mut handles = vec![];
521
522        for _ in 0..8 {
523            let r = reserved.clone();
524            let b = barrier.clone();
525            handles.push(thread::spawn(move || {
526                b.wait();
527                drop(r);
528            }));
529        }
530
531        drop(reserved);
532        for h in handles {
533            h.join().unwrap();
534        }
535
536        assert_eq!(
537            drop_count.load(std::sync::atomic::Ordering::Relaxed),
538            1,
539            "drop_fn must fire exactly once even when arena is dropped first"
540        );
541    }
542}