Skip to main content

ferrotorch_gpu/
allocator.rs

1//! Caching CUDA memory allocator.
2//!
3//! [`CudaAllocator`] wraps a [`GpuDevice`] and provides a PyTorch-style caching
4//! memory allocator with:
5//!
6//! - **Block splitting**: oversized free blocks are split, remainder returned to
7//!   the pool.
8//! - **Block coalescing**: adjacent freed blocks are merged to reduce
9//!   fragmentation.
10//! - **Stream-aware reuse**: blocks track which CUDA streams have used them;
11//!   a block is only reused when all recorded stream work is complete.
12//! - **Dual pools**: small (<1 MiB) and large (>=1 MiB) allocations are kept in
13//!   separate pools to avoid small allocations fragmenting large contiguous
14//!   regions.
15//! - **Statistics**: `memory_allocated`, `max_memory_allocated`,
16//!   `memory_reserved`, allocation/free counters.
17//!
18//! # Design
19//!
20//! This is a CPU-side data structure that manages block metadata. Actual GPU
21//! memory allocation/deallocation is delegated to the [`GpuDevice`] (cudarc).
22//! The caching layer sits between callers and the driver, intercepting frees
23//! to retain memory for reuse and serving allocs from the cache when possible.
24//!
25//! The design follows PyTorch's `CUDACachingAllocator` (c10/cuda/). Key
26//! constants match PyTorch:
27//! - `MIN_BLOCK_SIZE` = 512 bytes
28//! - `SMALL_SIZE` = 1 MiB (threshold between small/large pools)
29//! - `SMALL_BUFFER` = 2 MiB (small pool segment size)
30//! - `MIN_LARGE_ALLOC` = 10 MiB
31//! - `ROUND_LARGE` = 2 MiB (rounding for large allocations)
32//!
33//! # Thread safety
34//!
35//! `CudaAllocator` is `Send + Sync`. Internal state is protected by a `Mutex`.
36//! The critical section is short (BTreeSet lookup + pointer bookkeeping).
37//!
38//! # CL-323
39
40use std::collections::{BTreeSet, HashSet};
41use std::sync::atomic::{AtomicUsize, Ordering};
42use std::sync::{Arc, Mutex};
43
44use crate::buffer::CudaBuffer;
45use crate::device::GpuDevice;
46use crate::error::GpuResult;
47
48// ---------------------------------------------------------------------------
49// Constants (matching PyTorch's CUDACachingAllocator)
50// ---------------------------------------------------------------------------
51
52/// Minimum block size — all allocations are rounded up to at least this.
53pub const MIN_BLOCK_SIZE: usize = 512;
54
55/// Largest allocation that goes into the small pool.
56pub const SMALL_SIZE: usize = 1 << 20; // 1 MiB
57
58/// Segment size for small pool allocations from the driver.
59pub const SMALL_BUFFER: usize = 2 << 20; // 2 MiB
60
61/// Allocations between `SMALL_SIZE` and `MIN_LARGE_ALLOC` use a 20 MiB
62/// segment from the driver (to reduce the number of driver calls).
63pub const MIN_LARGE_ALLOC: usize = 10 << 20; // 10 MiB
64
65/// Large pool segment size for allocations between 1-10 MiB.
66pub const LARGE_BUFFER: usize = 20 << 20; // 20 MiB
67
68/// Round up large allocations to this granularity.
69pub const ROUND_LARGE: usize = 2 << 20; // 2 MiB
70
71// ---------------------------------------------------------------------------
72// StreamId — lightweight stream identifier for tracking cross-stream usage
73// ---------------------------------------------------------------------------
74
75/// Opaque identifier for a CUDA stream.
76///
77/// We use a `usize` derived from the stream's pointer/handle so that stream
78/// tracking works without holding `Arc<CudaStream>` references (which would
79/// prevent the stream from being dropped).
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
81pub struct StreamId(pub usize);
82
83// ---------------------------------------------------------------------------
84// Block — metadata for a cached GPU memory region
85// ---------------------------------------------------------------------------
86
87/// Unique identifier for a [`Block`]. Monotonically increasing.
88static NEXT_BLOCK_ID: AtomicUsize = AtomicUsize::new(0);
89
90/// Metadata for a contiguous region of GPU memory.
91///
92/// Blocks form a doubly-linked list within a segment (a single driver
93/// allocation). Splitting a block creates a new block for the remainder and
94/// links them. Coalescing merges adjacent free blocks by adjusting the linked
95/// list and summing sizes.
96///
97/// # CL-323
98#[derive(Debug)]
99pub struct Block {
100    /// Unique id for deterministic BTreeSet ordering when size ties.
101    pub(crate) id: usize,
102    /// Device ordinal that owns this memory.
103    pub device: usize,
104    /// Size of this block in bytes.
105    pub size: usize,
106    /// Virtual address of the block's start (byte offset within its segment).
107    /// Used for ordering during coalescing. For blocks that own real GPU memory,
108    /// this is derived from the `CudaSlice` pointer. For sub-blocks created by
109    /// splitting, this is computed as `parent.ptr + parent.size_before_split`.
110    pub ptr: usize,
111    /// The CUDA stream on which this block was originally allocated.
112    pub stream: StreamId,
113    /// Set of streams that have used this block (via `record_stream`).
114    /// A block cannot be reused until all recorded stream work is complete.
115    pub stream_uses: HashSet<StreamId>,
116    /// Whether this block is currently handed out to a caller.
117    pub allocated: bool,
118    /// Index of the previous block in the segment's linked list, or `None`.
119    pub prev: Option<usize>,
120    /// Index of the next block in the segment's linked list, or `None`.
121    pub next: Option<usize>,
122    /// Whether this block is in the small pool.
123    pub in_small_pool: bool,
124}
125
126impl Block {
127    /// Create a new block with the given parameters.
128    pub fn new(
129        device: usize,
130        size: usize,
131        ptr: usize,
132        stream: StreamId,
133        in_small_pool: bool,
134    ) -> Self {
135        Self {
136            id: NEXT_BLOCK_ID.fetch_add(1, Ordering::Relaxed),
137            device,
138            size,
139            ptr,
140            stream,
141            stream_uses: HashSet::new(),
142            allocated: false,
143            prev: None,
144            next: None,
145            in_small_pool,
146        }
147    }
148
149    /// Whether this block was created by splitting (has neighbors).
150    pub fn is_split(&self) -> bool {
151        self.prev.is_some() || self.next.is_some()
152    }
153}
154
155// ---------------------------------------------------------------------------
156// BlockKey — ordered key for BTreeSet lookups
157// ---------------------------------------------------------------------------
158
159/// Key used for ordering blocks in a [`BlockPool`]'s free set.
160///
161/// Ordered by `(stream, size, ptr, id)` so that `lower_bound` finds the
162/// smallest block >= requested size on the correct stream.
163#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
164pub(crate) struct BlockKey {
165    stream: StreamId,
166    size: usize,
167    ptr: usize,
168    id: usize,
169}
170
171impl BlockKey {
172    fn from_block(b: &Block) -> Self {
173        Self {
174            stream: b.stream,
175            size: b.size,
176            ptr: b.ptr,
177            id: b.id,
178        }
179    }
180
181    /// Create a search key: finds the smallest block >= `size` on `stream`.
182    fn search(stream: StreamId, size: usize) -> Self {
183        Self {
184            stream,
185            size,
186            ptr: 0,
187            id: 0,
188        }
189    }
190}
191
192// ---------------------------------------------------------------------------
193// BlockPool — set of free blocks for one size class
194// ---------------------------------------------------------------------------
195
196/// A pool of free [`Block`]s ordered by `(stream, size, ptr)`.
197///
198/// Two pools exist per allocator: one for small blocks (<1 MiB) and one for
199/// large blocks (>=1 MiB).
200///
201/// # CL-323
202pub(crate) struct BlockPool {
203    /// Free (non-allocated) blocks, ordered by [`BlockKey`].
204    free_blocks: BTreeSet<(BlockKey, usize)>, // (key, block_index)
205    /// Whether this is the small pool.
206    pub is_small: bool,
207}
208
209impl BlockPool {
210    /// Create a new empty block pool.
211    pub fn new(is_small: bool) -> Self {
212        Self {
213            free_blocks: BTreeSet::new(),
214            is_small,
215        }
216    }
217
218    /// Insert a block into the free set.
219    #[cfg(test)]
220    pub fn insert(&mut self, block_idx: usize, block: &Block) {
221        self.free_blocks
222            .insert((BlockKey::from_block(block), block_idx));
223    }
224
225    /// Insert a block into the free set using a precomputed key.
226    pub fn insert_key(&mut self, block_idx: usize, key: BlockKey) {
227        self.free_blocks.insert((key, block_idx));
228    }
229
230    /// Remove a block from the free set using a precomputed key.
231    pub fn remove_key(&mut self, block_idx: usize, key: BlockKey) {
232        self.free_blocks.remove(&(key, block_idx));
233    }
234
235    /// Find the smallest free block >= `size` on `stream`.
236    pub fn find_free_block(&self, stream: StreamId, size: usize) -> Option<usize> {
237        let search = (BlockKey::search(stream, size), 0);
238        if let Some(&(key, idx)) = self.free_blocks.range(search..).next() {
239            if key.stream == stream {
240                // Found a block on the same stream that is >= requested size.
241                return Some(idx);
242            }
243        }
244        None
245    }
246
247    /// Number of free blocks in this pool.
248    pub fn len(&self) -> usize {
249        self.free_blocks.len()
250    }
251
252    /// Clear all free blocks from the pool.
253    pub fn clear(&mut self) {
254        self.free_blocks.clear();
255    }
256}
257
258// ---------------------------------------------------------------------------
259// AllocatorState — the caching allocator's inner mutable state
260// ---------------------------------------------------------------------------
261
262/// All mutable state protected by the allocator's mutex.
263///
264/// # CL-323
265pub(crate) struct AllocatorState {
266    /// Arena of all blocks (allocated + free). Indexed by `usize`.
267    pub(crate) blocks: Vec<Block>,
268    /// Small pool: blocks < `SMALL_SIZE`.
269    pub(crate) small_pool: BlockPool,
270    /// Large pool: blocks >= `SMALL_SIZE`.
271    pub(crate) large_pool: BlockPool,
272    /// Total bytes reserved from the driver (cached + in-use).
273    pub(crate) reserved_bytes: usize,
274    /// Total bytes currently handed out to callers.
275    pub(crate) allocated_bytes: usize,
276    /// Peak allocated bytes.
277    pub(crate) peak_bytes: usize,
278    /// Number of successful cache hits.
279    pub(crate) hits: usize,
280    /// Number of cache misses (driver allocs).
281    pub(crate) misses: usize,
282}
283
284impl AllocatorState {
285    fn new() -> Self {
286        Self {
287            blocks: Vec::new(),
288            small_pool: BlockPool::new(true),
289            large_pool: BlockPool::new(false),
290            reserved_bytes: 0,
291            allocated_bytes: 0,
292            peak_bytes: 0,
293            hits: 0,
294            misses: 0,
295        }
296    }
297
298    /// Get the pool for a given size class.
299    pub(crate) fn get_pool_mut(&mut self, is_small: bool) -> &mut BlockPool {
300        let pool = if is_small {
301            &mut self.small_pool
302        } else {
303            &mut self.large_pool
304        };
305        debug_assert_eq!(pool.is_small, is_small, "pool size-class mismatch");
306        pool
307    }
308
309    /// Allocate a new block index in the arena.
310    pub(crate) fn add_block(&mut self, block: Block) -> usize {
311        let idx = self.blocks.len();
312        self.blocks.push(block);
313        idx
314    }
315
316    /// Whether a free block should be split for a request of `size` bytes.
317    pub(crate) fn should_split(&self, block_idx: usize, size: usize) -> bool {
318        let block = &self.blocks[block_idx];
319        let remaining = block.size - size;
320        if block.in_small_pool {
321            // Small pool: split if remainder >= MIN_BLOCK_SIZE.
322            remaining >= MIN_BLOCK_SIZE
323        } else {
324            // Large pool: split if remainder > SMALL_SIZE (avoid small
325            // fragments in the large pool).
326            remaining > SMALL_SIZE
327        }
328    }
329
330    /// Split `block_idx` to satisfy a `size`-byte request. The block at
331    /// `block_idx` is resized to `size` and a new remainder block is created
332    /// and inserted into the appropriate pool.
333    pub(crate) fn split_block(&mut self, block_idx: usize, size: usize) {
334        let remaining_size = self.blocks[block_idx].size - size;
335        let remaining_ptr = self.blocks[block_idx].ptr + size;
336        let stream = self.blocks[block_idx].stream;
337        let device = self.blocks[block_idx].device;
338        let is_small = self.blocks[block_idx].in_small_pool;
339        let old_next = self.blocks[block_idx].next;
340
341        // Create the remainder block.
342        let mut remainder = Block::new(device, remaining_size, remaining_ptr, stream, is_small);
343        remainder.prev = Some(block_idx);
344        remainder.next = old_next;
345
346        let rem_idx = self.add_block(remainder);
347
348        // Update the original block.
349        self.blocks[block_idx].size = size;
350        self.blocks[block_idx].next = Some(rem_idx);
351
352        // Update the block that was after the original, if any.
353        if let Some(old_next_idx) = old_next {
354            self.blocks[old_next_idx].prev = Some(rem_idx);
355        }
356
357        // Insert remainder into the free pool.
358        let rem_key = BlockKey::from_block(&self.blocks[rem_idx]);
359        let pool = self.get_pool_mut(is_small);
360        pool.insert_key(rem_idx, rem_key);
361    }
362
363    /// Try to merge `block_idx` with an adjacent block `neighbor_idx`.
364    /// Returns the size of the subsumed neighbor, or 0 if merge failed.
365    pub(crate) fn try_merge(&mut self, block_idx: usize, neighbor_idx: Option<usize>) -> usize {
366        let Some(nbr_idx) = neighbor_idx else {
367            return 0;
368        };
369
370        // Cannot merge if neighbor is allocated or has pending stream uses.
371        if self.blocks[nbr_idx].allocated || !self.blocks[nbr_idx].stream_uses.is_empty() {
372            return 0;
373        }
374
375        let is_small = self.blocks[nbr_idx].in_small_pool;
376        let subsumed_size = self.blocks[nbr_idx].size;
377
378        // Remove neighbor from its free pool.
379        let nbr_key = BlockKey::from_block(&self.blocks[nbr_idx]);
380        {
381            let pool = self.get_pool_mut(is_small);
382            pool.remove_key(nbr_idx, nbr_key);
383        }
384
385        // Determine merge direction.
386        if self.blocks[block_idx].prev == Some(nbr_idx) {
387            // [neighbor] [block] — neighbor is before block.
388            let nbr_prev = self.blocks[nbr_idx].prev;
389            self.blocks[block_idx].ptr = self.blocks[nbr_idx].ptr;
390            self.blocks[block_idx].size += subsumed_size;
391            self.blocks[block_idx].prev = nbr_prev;
392            if let Some(pp) = nbr_prev {
393                self.blocks[pp].next = Some(block_idx);
394            }
395        } else {
396            // [block] [neighbor] — neighbor is after block.
397            let nbr_next = self.blocks[nbr_idx].next;
398            self.blocks[block_idx].size += subsumed_size;
399            self.blocks[block_idx].next = nbr_next;
400            if let Some(nn) = nbr_next {
401                self.blocks[nn].prev = Some(block_idx);
402            }
403        }
404
405        // Mark the subsumed block as dead (size 0, no links). We do not
406        // reclaim arena slots — the Vec grows monotonically. This is fine
407        // because the number of live blocks is bounded by the number of
408        // driver allocations (typically <10k even for large models).
409        self.blocks[nbr_idx].size = 0;
410        self.blocks[nbr_idx].prev = None;
411        self.blocks[nbr_idx].next = None;
412
413        subsumed_size
414    }
415
416    /// Free a block: mark as not-allocated, try to coalesce with neighbors,
417    /// then return to the appropriate pool.
418    pub(crate) fn free_block(&mut self, block_idx: usize) {
419        self.blocks[block_idx].allocated = false;
420        self.blocks[block_idx].stream_uses.clear();
421        let size = self.blocks[block_idx].size;
422        self.allocated_bytes = self.allocated_bytes.saturating_sub(size);
423
424        // Try coalescing with prev and next.
425        let prev = self.blocks[block_idx].prev;
426        let next = self.blocks[block_idx].next;
427        self.try_merge(block_idx, prev);
428        self.try_merge(block_idx, next);
429
430        // Insert merged block into free pool.
431        let is_small = self.blocks[block_idx].in_small_pool;
432        let merged_key = BlockKey::from_block(&self.blocks[block_idx]);
433        let pool = self.get_pool_mut(is_small);
434        pool.insert_key(block_idx, merged_key);
435    }
436
437    /// Total cached bytes: reserved - allocated = free blocks in pools.
438    pub(crate) fn cached_bytes(&self) -> usize {
439        self.reserved_bytes.saturating_sub(self.allocated_bytes)
440    }
441}
442
443// ---------------------------------------------------------------------------
444// Round size — PyTorch-compatible size rounding
445// ---------------------------------------------------------------------------
446
447/// Round `size` up to an allocation-friendly boundary.
448///
449/// - Sizes below `MIN_BLOCK_SIZE` (512) are rounded up to `MIN_BLOCK_SIZE`.
450/// - Otherwise, rounded up to the next multiple of `MIN_BLOCK_SIZE`.
451pub fn round_size(size: usize) -> usize {
452    if size < MIN_BLOCK_SIZE {
453        return MIN_BLOCK_SIZE;
454    }
455    // Round up to next multiple of MIN_BLOCK_SIZE.
456    (size + MIN_BLOCK_SIZE - 1) & !(MIN_BLOCK_SIZE - 1)
457}
458
459/// Determine how many bytes to request from the driver for a given request
460/// size (after rounding). Small allocations are packed into `SMALL_BUFFER`
461/// segments; mid-range into `LARGE_BUFFER`; large are rounded to
462/// `ROUND_LARGE`.
463pub fn get_allocation_size(size: usize) -> usize {
464    if size <= SMALL_SIZE {
465        SMALL_BUFFER
466    } else if size < MIN_LARGE_ALLOC {
467        LARGE_BUFFER
468    } else {
469        // Round up to next multiple of ROUND_LARGE.
470        (size + ROUND_LARGE - 1) & !(ROUND_LARGE - 1)
471    }
472}
473
474// ---------------------------------------------------------------------------
475// CudaAllocator — the public API
476// ---------------------------------------------------------------------------
477
478/// A caching GPU memory allocator with block pools, splitting, coalescing,
479/// and stream-aware reuse.
480///
481/// Wraps a [`GpuDevice`] and maintains two block pools (small and large).
482/// Allocation requests are served from cached free blocks when possible;
483/// only on cache miss does the allocator call through to the CUDA driver.
484/// Freed blocks are returned to the pool and coalesced with neighbors to
485/// reduce fragmentation.
486///
487/// # CL-323
488pub struct CudaAllocator {
489    device: Arc<GpuDevice>,
490    pub(crate) state: Mutex<AllocatorState>,
491    /// Total bytes currently in use (atomic mirror for lock-free reads).
492    allocated_bytes_atomic: AtomicUsize,
493    /// Peak bytes ever allocated.
494    peak_bytes_atomic: AtomicUsize,
495}
496
497impl CudaAllocator {
498    /// Create a new caching allocator for the given device.
499    pub fn new(device: Arc<GpuDevice>) -> Self {
500        Self {
501            device,
502            state: Mutex::new(AllocatorState::new()),
503            allocated_bytes_atomic: AtomicUsize::new(0),
504            peak_bytes_atomic: AtomicUsize::new(0),
505        }
506    }
507
508    /// Allocate `count` zero-initialized elements of type `T` on the device.
509    ///
510    /// The returned [`CudaBuffer`] is tracked by this allocator. When you are
511    /// done with it, pass it to [`free`](CudaAllocator::free) so the
512    /// statistics stay accurate. (Dropping the buffer directly still frees
513    /// GPU memory, but the `allocated_bytes` counter will be too high.)
514    ///
515    /// # Errors
516    ///
517    /// Returns [`GpuError::Driver`] if the underlying CUDA allocation fails.
518    #[cfg(feature = "cuda")]
519    pub fn alloc_zeros<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
520    where
521        T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
522    {
523        let bytes = count.saturating_mul(std::mem::size_of::<T>());
524        let slice = self.device.stream().alloc_zeros::<T>(count)?;
525
526        // Update statistics after the allocation succeeds.
527        let prev = self
528            .allocated_bytes_atomic
529            .fetch_add(bytes, Ordering::Relaxed);
530        self.peak_bytes_atomic
531            .fetch_max(prev + bytes, Ordering::Relaxed);
532
533        Ok(CudaBuffer {
534            data: Some(slice),
535            len: count,
536            alloc_len: count,
537            device_ordinal: self.device.ordinal(),
538            pool_fn: None,
539        })
540    }
541
542    /// Copy a host slice to device memory, tracking the allocation.
543    ///
544    /// This is the allocator-aware equivalent of [`crate::transfer::cpu_to_gpu`].
545    ///
546    /// # Errors
547    ///
548    /// Returns [`GpuError::Driver`] if the CUDA memcpy or allocation fails.
549    #[cfg(feature = "cuda")]
550    pub fn alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
551    where
552        T: cudarc::driver::DeviceRepr,
553    {
554        let bytes = data.len().saturating_mul(std::mem::size_of::<T>());
555        let slice = self.device.stream().clone_htod(data)?;
556
557        let prev = self
558            .allocated_bytes_atomic
559            .fetch_add(bytes, Ordering::Relaxed);
560        self.peak_bytes_atomic
561            .fetch_max(prev + bytes, Ordering::Relaxed);
562
563        Ok(CudaBuffer {
564            data: Some(slice),
565            len: data.len(),
566            alloc_len: data.len(),
567            device_ordinal: self.device.ordinal(),
568            pool_fn: None,
569        })
570    }
571
572    /// Return a buffer to the allocator, freeing the GPU memory and updating
573    /// the statistics.
574    ///
575    /// This is preferred over simply dropping the buffer so that
576    /// [`memory_allocated`](CudaAllocator::memory_allocated) stays accurate.
577    pub fn free<T>(&self, buffer: CudaBuffer<T>) {
578        let bytes = buffer
579            .len()
580            .checked_mul(std::mem::size_of::<T>())
581            .unwrap_or(0);
582        self.allocated_bytes_atomic
583            .fetch_sub(bytes, Ordering::Relaxed);
584        drop(buffer);
585    }
586
587    // ------------------------------------------------------------------
588    // Statistics queries
589    // ------------------------------------------------------------------
590
591    /// Bytes currently allocated (live) on the device through this allocator.
592    #[inline]
593    pub fn memory_allocated(&self) -> usize {
594        self.allocated_bytes_atomic.load(Ordering::Relaxed)
595    }
596
597    /// Peak bytes ever allocated since creation or the last
598    /// [`reset_peak_stats`](CudaAllocator::reset_peak_stats).
599    #[inline]
600    pub fn max_memory_allocated(&self) -> usize {
601        self.peak_bytes_atomic.load(Ordering::Relaxed)
602    }
603
604    /// Total bytes reserved from the CUDA driver (cached + in-use).
605    pub fn memory_reserved(&self) -> usize {
606        self.state.lock().map(|s| s.reserved_bytes).unwrap_or(0)
607    }
608
609    /// Reset the peak counter to the current allocation level.
610    pub fn reset_peak_stats(&self) {
611        let current = self.allocated_bytes_atomic.load(Ordering::Relaxed);
612        self.peak_bytes_atomic.store(current, Ordering::Relaxed);
613    }
614
615    /// Release all cached (free) blocks back to the CUDA driver.
616    ///
617    /// After this call, `memory_reserved()` drops to `memory_allocated()`
618    /// (only blocks currently in use remain). This is useful when another
619    /// component needs GPU memory and the cache is holding onto freed blocks.
620    ///
621    /// # CL-323
622    pub fn empty_cache(&self) {
623        let Ok(mut state) = self.state.lock() else {
624            return;
625        };
626        // Clear both free pools. The actual GPU memory is freed when the
627        // pool.rs layer drops its CudaSlice holders.
628        state.small_pool.clear();
629        state.large_pool.clear();
630
631        // Recalculate reserved to only include allocated blocks.
632        state.reserved_bytes = state.allocated_bytes;
633    }
634
635    /// The underlying device.
636    #[inline]
637    pub fn device(&self) -> &GpuDevice {
638        &self.device
639    }
640
641    // ------------------------------------------------------------------
642    // Pool-level cache operations (used by pool.rs)
643    // ------------------------------------------------------------------
644
645    /// Record that a block was used on `stream`, preventing reuse until
646    /// work on that stream completes.
647    ///
648    /// This is the Rust equivalent of PyTorch's `recordStream()`.
649    ///
650    /// # CL-323
651    pub fn record_stream_on_block(&self, block_idx: usize, stream: StreamId) {
652        let Ok(mut state) = self.state.lock() else {
653            return;
654        };
655        if block_idx < state.blocks.len() {
656            state.blocks[block_idx].stream_uses.insert(stream);
657        }
658    }
659
660    /// Number of blocks in the arena (for debugging/testing).
661    pub fn block_count(&self) -> usize {
662        self.state.lock().map(|s| s.blocks.len()).unwrap_or(0)
663    }
664
665    /// Number of free blocks in both pools (for debugging/testing).
666    pub fn free_block_count(&self) -> usize {
667        self.state
668            .lock()
669            .map(|s| s.small_pool.len() + s.large_pool.len())
670            .unwrap_or(0)
671    }
672
673    /// (hits, misses) cache statistics.
674    pub fn cache_stats(&self) -> (usize, usize) {
675        self.state
676            .lock()
677            .map(|s| (s.hits, s.misses))
678            .unwrap_or((0, 0))
679    }
680
681    /// Total cached (free, reusable) bytes.
682    pub fn cached_bytes(&self) -> usize {
683        self.state.lock().map(|s| s.cached_bytes()).unwrap_or(0)
684    }
685
686    // ------------------------------------------------------------------
687    // Cache-aware allocation (used by pool.rs)
688    // ------------------------------------------------------------------
689
690    /// Try to find a cached block of at least `size` bytes on `stream`.
691    ///
692    /// Returns the block index and its actual size if found.
693    /// The block is marked as allocated and removed from the free pool.
694    /// If the block is significantly larger than needed, it is split.
695    ///
696    /// # CL-323
697    pub fn cache_find(&self, size: usize, stream: StreamId) -> Option<(usize, usize)> {
698        let rounded = round_size(size);
699        let is_small = rounded <= SMALL_SIZE;
700
701        let Ok(mut state) = self.state.lock() else {
702            return None;
703        };
704
705        let block_idx = {
706            let pool = state.get_pool_mut(is_small);
707            pool.find_free_block(stream, rounded)?
708        };
709
710        // Remove from free pool (key extracted before mutable pool borrow).
711        let key = BlockKey::from_block(&state.blocks[block_idx]);
712        state.get_pool_mut(is_small).remove_key(block_idx, key);
713
714        // Split if block is oversized.
715        if state.should_split(block_idx, rounded) {
716            state.split_block(block_idx, rounded);
717        }
718
719        // Mark as allocated.
720        state.blocks[block_idx].allocated = true;
721        let actual_size = state.blocks[block_idx].size;
722        state.allocated_bytes += actual_size;
723        if state.allocated_bytes > state.peak_bytes {
724            state.peak_bytes = state.allocated_bytes;
725        }
726        state.hits += 1;
727
728        Some((block_idx, actual_size))
729    }
730
731    /// Register a new block from a fresh driver allocation.
732    ///
733    /// Called when `cache_find` returns `None` and the caller has obtained
734    /// memory from the CUDA driver. The full driver allocation is registered
735    /// as a block; if it's larger than the requested size, the remainder is
736    /// split off and placed in the free pool.
737    ///
738    /// Returns `(block_idx, actual_block_size)`.
739    ///
740    /// # CL-323
741    pub fn cache_insert(
742        &self,
743        requested_size: usize,
744        driver_alloc_size: usize,
745        ptr: usize,
746        stream: StreamId,
747    ) -> (usize, usize) {
748        let rounded = round_size(requested_size);
749        let is_small = rounded <= SMALL_SIZE;
750
751        let Ok(mut state) = self.state.lock() else {
752            // Fallback: return a dummy index. Should never happen in practice.
753            return (0, driver_alloc_size);
754        };
755
756        let mut block = Block::new(
757            self.device.ordinal(),
758            driver_alloc_size,
759            ptr,
760            stream,
761            is_small,
762        );
763        block.allocated = true;
764        let block_idx = state.add_block(block);
765
766        state.reserved_bytes += driver_alloc_size;
767
768        // Split if the driver allocation is larger than requested.
769        if state.should_split(block_idx, rounded) {
770            state.split_block(block_idx, rounded);
771        }
772
773        let actual_size = state.blocks[block_idx].size;
774        state.allocated_bytes += actual_size;
775        if state.allocated_bytes > state.peak_bytes {
776            state.peak_bytes = state.allocated_bytes;
777        }
778        state.misses += 1;
779
780        (block_idx, actual_size)
781    }
782
783    /// Return a block to the cache (free it back to a pool).
784    ///
785    /// The block is coalesced with any adjacent free blocks and inserted
786    /// into the appropriate pool for future reuse.
787    ///
788    /// # CL-323
789    pub fn cache_free(&self, block_idx: usize) {
790        let Ok(mut state) = self.state.lock() else {
791            return;
792        };
793        if block_idx < state.blocks.len() && state.blocks[block_idx].allocated {
794            state.free_block(block_idx);
795        }
796    }
797
798    /// Get the driver allocation size for a given request size.
799    ///
800    /// Callers use this to know how many bytes to request from the driver
801    /// when `cache_find` misses.
802    pub fn driver_alloc_size(size: usize) -> usize {
803        get_allocation_size(round_size(size))
804    }
805}
806
807impl std::fmt::Debug for CudaAllocator {
808    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
809        f.debug_struct("CudaAllocator")
810            .field("device_ordinal", &self.device.ordinal())
811            .field(
812                "allocated_bytes",
813                &self.allocated_bytes_atomic.load(Ordering::Relaxed),
814            )
815            .field(
816                "peak_bytes",
817                &self.peak_bytes_atomic.load(Ordering::Relaxed),
818            )
819            .field("cached_bytes", &self.cached_bytes())
820            .finish()
821    }
822}
823
824// ---------------------------------------------------------------------------
825// Stub when `cuda` feature is disabled
826// ---------------------------------------------------------------------------
827
828#[cfg(not(feature = "cuda"))]
829impl CudaAllocator {
830    /// Stub — allocates nothing, only updates statistics.
831    pub fn alloc_zeros<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
832        Err(crate::error::GpuError::NoCudaFeature)
833    }
834
835    /// Stub — allocates nothing, only updates statistics.
836    pub fn alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
837        Err(crate::error::GpuError::NoCudaFeature)
838    }
839}
840
841// ---------------------------------------------------------------------------
842// Tests
843// ---------------------------------------------------------------------------
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848
849    // ------------------------------------------------------------------
850    // Unit tests for round_size
851    // ------------------------------------------------------------------
852
853    #[test]
854    fn round_size_minimum() {
855        assert_eq!(round_size(0), MIN_BLOCK_SIZE);
856        assert_eq!(round_size(1), MIN_BLOCK_SIZE);
857        assert_eq!(round_size(511), MIN_BLOCK_SIZE);
858        assert_eq!(round_size(512), MIN_BLOCK_SIZE);
859    }
860
861    #[test]
862    fn round_size_multiples() {
863        assert_eq!(round_size(513), 1024);
864        assert_eq!(round_size(1024), 1024);
865        assert_eq!(round_size(1025), 1536);
866    }
867
868    #[test]
869    fn alloc_size_small() {
870        // Anything <= 1 MiB uses a 2 MiB segment.
871        assert_eq!(get_allocation_size(512), SMALL_BUFFER);
872        assert_eq!(get_allocation_size(SMALL_SIZE), SMALL_BUFFER);
873    }
874
875    #[test]
876    fn alloc_size_mid() {
877        // Between 1 MiB and 10 MiB uses 20 MiB segment.
878        assert_eq!(get_allocation_size(SMALL_SIZE + 1), LARGE_BUFFER);
879        assert_eq!(get_allocation_size(MIN_LARGE_ALLOC - 1), LARGE_BUFFER);
880    }
881
882    #[test]
883    fn alloc_size_large() {
884        // >= 10 MiB rounds up to 2 MiB boundary.
885        assert_eq!(get_allocation_size(MIN_LARGE_ALLOC), MIN_LARGE_ALLOC);
886        assert_eq!(
887            get_allocation_size(MIN_LARGE_ALLOC + 1),
888            MIN_LARGE_ALLOC + ROUND_LARGE
889        );
890    }
891
892    // ------------------------------------------------------------------
893    // Unit tests for Block / BlockPool / AllocatorState
894    // ------------------------------------------------------------------
895
896    fn make_stream() -> StreamId {
897        StreamId(42)
898    }
899
900    #[test]
901    fn block_pool_insert_find() {
902        let mut state = AllocatorState::new();
903        let stream = make_stream();
904
905        // Create a 4096-byte free block.
906        let block = Block::new(0, 4096, 0x1000, stream, true);
907        let idx = state.add_block(block);
908        state.small_pool.insert(idx, &state.blocks[idx]);
909
910        // Should find it for a 512-byte request.
911        let found = state.small_pool.find_free_block(stream, 512);
912        assert_eq!(found, Some(idx));
913    }
914
915    #[test]
916    fn block_pool_respects_stream() {
917        let mut state = AllocatorState::new();
918        let stream_a = StreamId(1);
919        let stream_b = StreamId(2);
920
921        let block = Block::new(0, 4096, 0x1000, stream_a, true);
922        let idx = state.add_block(block);
923        state.small_pool.insert(idx, &state.blocks[idx]);
924
925        // Should NOT find it for a different stream.
926        assert!(state.small_pool.find_free_block(stream_b, 512).is_none());
927
928        // Should find it for the correct stream.
929        assert_eq!(state.small_pool.find_free_block(stream_a, 512), Some(idx));
930    }
931
932    #[test]
933    fn block_pool_finds_smallest_fit() {
934        let mut state = AllocatorState::new();
935        let stream = make_stream();
936
937        // Add blocks of sizes 4096 and 1024.
938        let b1 = Block::new(0, 4096, 0x1000, stream, true);
939        let i1 = state.add_block(b1);
940        state.small_pool.insert(i1, &state.blocks[i1]);
941
942        let b2 = Block::new(0, 1024, 0x2000, stream, true);
943        let i2 = state.add_block(b2);
944        state.small_pool.insert(i2, &state.blocks[i2]);
945
946        // Request 768 bytes — should return the 1024 block (smallest fit).
947        let found = state.small_pool.find_free_block(stream, 768);
948        assert_eq!(found, Some(i2));
949    }
950
951    #[test]
952    fn split_block_creates_remainder() {
953        let mut state = AllocatorState::new();
954        let stream = make_stream();
955
956        let block = Block::new(0, 8192, 0x1000, stream, true);
957        let idx = state.add_block(block);
958
959        // Split: request 1024 from an 8192 block.
960        state.split_block(idx, 1024);
961
962        assert_eq!(state.blocks[idx].size, 1024);
963        let rem_idx = state.blocks[idx].next.unwrap();
964        assert_eq!(state.blocks[rem_idx].size, 8192 - 1024);
965        assert_eq!(state.blocks[rem_idx].ptr, 0x1000 + 1024);
966        assert_eq!(state.blocks[rem_idx].prev, Some(idx));
967
968        // Remainder should be in the free pool.
969        let found = state.small_pool.find_free_block(stream, 1024);
970        assert_eq!(found, Some(rem_idx));
971    }
972
973    #[test]
974    fn coalesce_merges_adjacent_blocks() {
975        let mut state = AllocatorState::new();
976        let stream = make_stream();
977
978        // Simulate: a segment of 3 blocks: [A=2048] [B=2048] [C=4096]
979        let a = Block::new(0, 2048, 0x1000, stream, true);
980        let a_idx = state.add_block(a);
981
982        let b = Block::new(0, 2048, 0x1000 + 2048, stream, true);
983        let b_idx = state.add_block(b);
984
985        let c = Block::new(0, 4096, 0x1000 + 4096, stream, true);
986        let c_idx = state.add_block(c);
987
988        // Link them.
989        state.blocks[a_idx].next = Some(b_idx);
990        state.blocks[b_idx].prev = Some(a_idx);
991        state.blocks[b_idx].next = Some(c_idx);
992        state.blocks[c_idx].prev = Some(b_idx);
993
994        // A and C are free (in pool), B is allocated.
995        state.blocks[b_idx].allocated = true;
996        state.blocks[b_idx].size = 2048;
997        state.allocated_bytes = 2048;
998
999        state.small_pool.insert(a_idx, &state.blocks[a_idx]);
1000        state.small_pool.insert(c_idx, &state.blocks[c_idx]);
1001
1002        // Free B — should coalesce with A and C.
1003        state.free_block(b_idx);
1004
1005        // B should now be the merged block spanning all 8192 bytes.
1006        assert_eq!(state.blocks[b_idx].size, 2048 + 2048 + 4096);
1007        assert_eq!(state.blocks[b_idx].ptr, 0x1000);
1008        assert!(!state.blocks[b_idx].allocated);
1009    }
1010
1011    #[test]
1012    fn should_split_small_pool() {
1013        let mut state = AllocatorState::new();
1014        let stream = make_stream();
1015
1016        let block = Block::new(0, 2048, 0x1000, stream, true);
1017        let idx = state.add_block(block);
1018
1019        // Remainder = 2048 - 1024 = 1024 >= MIN_BLOCK_SIZE(512) => split.
1020        assert!(state.should_split(idx, 1024));
1021
1022        // Remainder = 2048 - 1800 = 248 < MIN_BLOCK_SIZE(512) => no split.
1023        assert!(!state.should_split(idx, 1800));
1024    }
1025
1026    #[test]
1027    fn should_split_large_pool() {
1028        let mut state = AllocatorState::new();
1029        let stream = make_stream();
1030
1031        let block = Block::new(0, 4 * 1024 * 1024, 0x1000, stream, false);
1032        let idx = state.add_block(block);
1033
1034        // Remainder = 4MB - 2MB = 2MB > SMALL_SIZE(1MB) => split.
1035        assert!(state.should_split(idx, 2 * 1024 * 1024));
1036
1037        // Remainder = 4MB - 3.5MB = 0.5MB <= SMALL_SIZE(1MB) => no split.
1038        assert!(!state.should_split(idx, 3 * 1024 * 1024 + 512 * 1024));
1039    }
1040
1041    #[test]
1042    fn stream_uses_prevent_reuse() {
1043        let stream = make_stream();
1044        let mut block = Block::new(0, 4096, 0x1000, stream, true);
1045
1046        assert!(block.stream_uses.is_empty());
1047        block.stream_uses.insert(StreamId(99));
1048
1049        // Block has pending stream uses — merge should be prevented.
1050        assert!(!block.stream_uses.is_empty());
1051    }
1052
1053    #[test]
1054    fn stream_uses_prevent_merge() {
1055        let mut state = AllocatorState::new();
1056        let stream = make_stream();
1057
1058        // Two adjacent blocks: [A=2048] [B=2048]
1059        let a = Block::new(0, 2048, 0x1000, stream, true);
1060        let a_idx = state.add_block(a);
1061
1062        let mut b = Block::new(0, 2048, 0x1000 + 2048, stream, true);
1063        b.stream_uses.insert(StreamId(99)); // pending work
1064        let b_idx = state.add_block(b);
1065
1066        // Link.
1067        state.blocks[a_idx].next = Some(b_idx);
1068        state.blocks[b_idx].prev = Some(a_idx);
1069
1070        // B is in the pool but has pending stream uses.
1071        state.small_pool.insert(b_idx, &state.blocks[b_idx]);
1072
1073        // Try to merge A with B — should fail because B has stream_uses.
1074        let merged = state.try_merge(a_idx, Some(b_idx));
1075        assert_eq!(merged, 0);
1076        assert_eq!(state.blocks[a_idx].size, 2048); // unchanged
1077    }
1078
1079    #[test]
1080    fn cache_find_and_insert_roundtrip() {
1081        let device = Arc::new(match GpuDevice::new(0) {
1082            Ok(d) => d,
1083            Err(_) => return, // no GPU, skip
1084        });
1085        let alloc = CudaAllocator::new(device);
1086        let stream = StreamId(1);
1087
1088        // Insert a 4096-byte block.
1089        let (idx, actual) = alloc.cache_insert(2048, 4096, 0x1000, stream);
1090        // The block should be split: 2048 for the request, remainder free.
1091        assert!(actual <= 4096);
1092        assert_eq!(alloc.cache_stats().1, 1); // 1 miss
1093
1094        // Free it.
1095        alloc.cache_free(idx);
1096
1097        // Find it again — should be a hit.
1098        let found = alloc.cache_find(512, stream);
1099        assert!(found.is_some());
1100        assert_eq!(alloc.cache_stats().0, 1); // 1 hit
1101    }
1102
1103    #[test]
1104    fn empty_cache_clears_pools() {
1105        let device = Arc::new(match GpuDevice::new(0) {
1106            Ok(d) => d,
1107            Err(_) => return,
1108        });
1109        let alloc = CudaAllocator::new(device);
1110        let stream = StreamId(1);
1111
1112        alloc.cache_insert(1024, 4096, 0x1000, stream);
1113        {
1114            let state = alloc.state.lock().unwrap();
1115            // Should have blocks in the arena.
1116            assert!(!state.blocks.is_empty());
1117        }
1118
1119        // Free block 0 back to pool.
1120        alloc.cache_free(0);
1121        assert!(alloc.free_block_count() > 0);
1122
1123        alloc.empty_cache();
1124        assert_eq!(alloc.free_block_count(), 0);
1125    }
1126
1127    // ------------------------------------------------------------------
1128    // CUDA integration tests
1129    // ------------------------------------------------------------------
1130
1131    #[cfg(feature = "cuda")]
1132    mod cuda_tests {
1133        use super::*;
1134
1135        fn make_allocator() -> CudaAllocator {
1136            let device = GpuDevice::new(0).expect("CUDA device 0");
1137            CudaAllocator::new(Arc::new(device))
1138        }
1139
1140        #[test]
1141        fn new_allocator_starts_at_zero() {
1142            let alloc = make_allocator();
1143            assert_eq!(alloc.memory_allocated(), 0);
1144            assert_eq!(alloc.max_memory_allocated(), 0);
1145        }
1146
1147        #[test]
1148        fn empty_cache_is_harmless() {
1149            let alloc = make_allocator();
1150            alloc.empty_cache();
1151        }
1152
1153        #[test]
1154        fn debug_impl() {
1155            let alloc = make_allocator();
1156            let s = format!("{alloc:?}");
1157            assert!(s.contains("CudaAllocator"));
1158            assert!(s.contains("allocated_bytes"));
1159        }
1160
1161        #[test]
1162        fn alloc_increases_allocated_bytes() {
1163            let alloc = make_allocator();
1164            let buf = alloc.alloc_zeros::<f32>(256).expect("alloc_zeros");
1165            assert_eq!(alloc.memory_allocated(), 256 * std::mem::size_of::<f32>());
1166            assert_eq!(
1167                alloc.max_memory_allocated(),
1168                256 * std::mem::size_of::<f32>()
1169            );
1170            alloc.free(buf);
1171        }
1172
1173        #[test]
1174        fn free_decreases_allocated_bytes() {
1175            let alloc = make_allocator();
1176            let buf = alloc.alloc_zeros::<f32>(128).expect("alloc_zeros");
1177            let expected = 128 * std::mem::size_of::<f32>();
1178            assert_eq!(alloc.memory_allocated(), expected);
1179
1180            alloc.free(buf);
1181            assert_eq!(alloc.memory_allocated(), 0);
1182        }
1183
1184        #[test]
1185        fn peak_tracks_maximum() {
1186            let alloc = make_allocator();
1187
1188            let buf1 = alloc.alloc_zeros::<f32>(100).expect("alloc 1");
1189            let buf2 = alloc.alloc_zeros::<f32>(200).expect("alloc 2");
1190            let peak_after_two = alloc.max_memory_allocated();
1191
1192            alloc.free(buf1);
1193            assert_eq!(alloc.max_memory_allocated(), peak_after_two);
1194            assert!(alloc.memory_allocated() < peak_after_two);
1195
1196            alloc.free(buf2);
1197            assert_eq!(alloc.memory_allocated(), 0);
1198            assert_eq!(alloc.max_memory_allocated(), peak_after_two);
1199        }
1200
1201        #[test]
1202        fn reset_peak_stats_lowers_peak() {
1203            let alloc = make_allocator();
1204
1205            let buf = alloc.alloc_zeros::<f32>(512).expect("alloc");
1206            let high = alloc.max_memory_allocated();
1207            alloc.free(buf);
1208
1209            assert_eq!(alloc.max_memory_allocated(), high);
1210
1211            alloc.reset_peak_stats();
1212            assert_eq!(alloc.max_memory_allocated(), 0);
1213        }
1214
1215        #[test]
1216        fn alloc_copy_tracks_bytes() {
1217            let alloc = make_allocator();
1218            let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
1219            let buf = alloc.alloc_copy(&data).expect("alloc_copy");
1220            assert_eq!(alloc.memory_allocated(), 4 * std::mem::size_of::<f64>());
1221            alloc.free(buf);
1222            assert_eq!(alloc.memory_allocated(), 0);
1223        }
1224
1225        #[test]
1226        fn zero_element_alloc() {
1227            let alloc = make_allocator();
1228            let buf = alloc.alloc_zeros::<f32>(0).expect("alloc_zeros empty");
1229            assert_eq!(alloc.memory_allocated(), 0);
1230            assert_eq!(buf.len(), 0);
1231            assert!(buf.is_empty());
1232            alloc.free(buf);
1233            assert_eq!(alloc.memory_allocated(), 0);
1234        }
1235    }
1236}