Skip to main content

ferrotorch_gpu/
pool.rs

1//! GPU buffer pool — caching allocator for CUDA memory.
2//!
3//! Eliminates per-operation `cuMemAllocAsync` + `cuMemFreeAsync` +
4//! `cuEventCreate` × 2 + `cuEventDestroy` × 2 by reusing freed buffers.
5//! On a pool hit, the only CUDA call is `cuMemsetD8Async` (to zero the
6//! buffer). The `CudaSlice`'s events are kept alive across reuses, so
7//! no event creation or destruction is needed.
8//!
9//! This module provides the `CudaSlice`-holding layer that sits on top of the
10//! block-metadata caching allocator in [`crate::allocator`]. The allocator
11//! manages block splitting, coalescing, and stream tracking; this module
12//! manages the actual type-erased `CudaSlice<T>` ownership.
13//!
14//! # Thread safety
15//!
16//! The pool is protected by a `Mutex`. The critical section is a `HashMap`
17//! lookup + `Vec::pop` (microseconds), so contention is negligible.
18//!
19//! # CL-323
20
21use std::any::{Any, TypeId};
22use std::collections::HashMap;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::{LazyLock, Mutex};
25
26use crate::allocator::StreamId;
27
28static POOL_HITS: AtomicUsize = AtomicUsize::new(0);
29static POOL_MISSES: AtomicUsize = AtomicUsize::new(0);
30static POOL_RETURNS: AtomicUsize = AtomicUsize::new(0);
31
32/// Get pool statistics: (hits, misses, returns).
33pub fn pool_stats() -> (usize, usize, usize) {
34    (
35        POOL_HITS.load(Ordering::Relaxed),
36        POOL_MISSES.load(Ordering::Relaxed),
37        POOL_RETURNS.load(Ordering::Relaxed),
38    )
39}
40
41/// Reset pool statistics.
42pub fn reset_pool_stats() {
43    POOL_HITS.store(0, Ordering::Relaxed);
44    POOL_MISSES.store(0, Ordering::Relaxed);
45    POOL_RETURNS.store(0, Ordering::Relaxed);
46}
47
48// ---------------------------------------------------------------------------
49// Pool key: (device_ordinal, element_count, TypeId)
50// ---------------------------------------------------------------------------
51
52type PoolKey = (usize, usize, TypeId);
53
54/// Metadata stored alongside each cached buffer for stream-aware reuse.
55///
56/// # CL-323
57struct CachedEntry {
58    /// The type-erased `CudaSlice<T>`.
59    data: Box<dyn Any + Send + Sync>,
60    /// The stream on which this buffer was originally allocated.
61    alloc_stream: StreamId,
62    /// Streams that have used this buffer (recorded via `record_stream`).
63    /// The buffer can only be reused when all these streams have completed
64    /// their work.
65    stream_uses: Vec<StreamId>,
66}
67
68struct PoolState {
69    /// Free buffers keyed by (device, len, type). Values are cached entries
70    /// with stream metadata. LIFO for temporal locality.
71    free: HashMap<PoolKey, Vec<CachedEntry>>,
72    /// Total cached bytes (not currently in use).
73    cached_bytes: usize,
74}
75
76impl PoolState {
77    fn new() -> Self {
78        Self {
79            free: HashMap::new(),
80            cached_bytes: 0,
81        }
82    }
83}
84
85static POOL: LazyLock<Mutex<PoolState>> = LazyLock::new(|| Mutex::new(PoolState::new()));
86
87// ---------------------------------------------------------------------------
88// Length rounding
89// ---------------------------------------------------------------------------
90
91/// Allocation granularity: round up to the nearest multiple of 256 elements.
92///
93/// This ensures that pool keys are stable across allocations for the same
94/// logical size, preventing fragmentation where a buffer allocated with
95/// `len` elements cannot be found in the pool because the key differs
96/// from the rounded allocation size.
97const ROUND_ELEMENTS: usize = 256;
98
99/// Round `len` up to the nearest multiple of [`ROUND_ELEMENTS`].
100///
101/// Uses saturating arithmetic to avoid overflow on extreme inputs.
102pub fn round_len(len: usize) -> usize {
103    if len == 0 {
104        return 0;
105    }
106    let remainder = len % ROUND_ELEMENTS;
107    if remainder == 0 {
108        return len;
109    }
110    len.saturating_add(ROUND_ELEMENTS - remainder)
111}
112
113// ---------------------------------------------------------------------------
114// Generic pool operations
115// ---------------------------------------------------------------------------
116
117/// Try to get a cached `CudaSlice<T>` with exactly `rounded_len` elements
118/// from the pool for the given device. Returns `None` on cache miss.
119///
120/// `rounded_len` must already be rounded via [`round_len`]. `elem_size` is
121/// the size of one element in bytes (e.g. 4 for f32) and is used only for
122/// byte-level accounting — NOT for pool key lookup.
123pub fn pool_take<T: Any + Send + Sync>(
124    device_ordinal: usize,
125    rounded_len: usize,
126    elem_size: usize,
127) -> Option<T> {
128    let key = (device_ordinal, rounded_len, TypeId::of::<T>());
129    // Mutex poison is silently swallowed: this is intentional defensive
130    // behavior — a poisoned pool degrades to "always miss" rather than
131    // panicking the caller. GPU allocations will still succeed via fresh
132    // CUDA driver calls, just without caching.
133    let mut pool = POOL.lock().ok()?;
134    let bucket = pool.free.get_mut(&key)?;
135    let entry = bucket.pop()?;
136    let is_empty = bucket.is_empty();
137    if is_empty {
138        pool.free.remove(&key);
139    }
140    pool.cached_bytes = pool.cached_bytes.saturating_sub(rounded_len * elem_size);
141    POOL_HITS.fetch_add(1, Ordering::Relaxed);
142    // Downcast is guaranteed to succeed because the key includes TypeId.
143    Some(*entry.data.downcast::<T>().expect("pool type mismatch"))
144}
145
146/// Stream-aware variant of [`pool_take`]. Only returns a buffer whose
147/// `alloc_stream` matches the given `stream` and has no pending
148/// cross-stream uses, ensuring correct synchronization.
149///
150/// # CL-323
151pub fn pool_take_stream<T: Any + Send + Sync>(
152    device_ordinal: usize,
153    rounded_len: usize,
154    elem_size: usize,
155    stream: StreamId,
156) -> Option<T> {
157    let key = (device_ordinal, rounded_len, TypeId::of::<T>());
158    let mut pool = POOL.lock().ok()?;
159    let bucket = pool.free.get_mut(&key)?;
160
161    // Search from the back (LIFO) for a buffer on the same stream with
162    // no pending cross-stream uses.
163    let pos = bucket
164        .iter()
165        .rposition(|entry| entry.alloc_stream == stream && entry.stream_uses.is_empty())?;
166
167    let entry = bucket.swap_remove(pos);
168    if bucket.is_empty() {
169        pool.free.remove(&key);
170    }
171    pool.cached_bytes = pool.cached_bytes.saturating_sub(rounded_len * elem_size);
172    POOL_HITS.fetch_add(1, Ordering::Relaxed);
173    Some(*entry.data.downcast::<T>().expect("pool type mismatch"))
174}
175
176/// Return a value to the pool for later reuse.
177///
178/// `rounded_len` must already be rounded via [`round_len`]. `elem_size` is
179/// used for byte-level accounting only. The `POOL_RETURNS` counter is
180/// incremented only after the buffer is successfully inserted into the
181/// pool bucket.
182pub fn pool_return<T: Any + Send + Sync>(
183    device_ordinal: usize,
184    rounded_len: usize,
185    elem_size: usize,
186    value: T,
187) {
188    pool_return_with_stream(device_ordinal, rounded_len, elem_size, value, StreamId(0))
189}
190
191/// Return a value to the pool with stream metadata.
192///
193/// Like [`pool_return`] but records which stream the buffer was used on,
194/// enabling stream-aware reuse via [`pool_take_stream`].
195///
196/// # CL-323
197pub fn pool_return_with_stream<T: Any + Send + Sync>(
198    device_ordinal: usize,
199    rounded_len: usize,
200    elem_size: usize,
201    value: T,
202    alloc_stream: StreamId,
203) {
204    let key = (device_ordinal, rounded_len, TypeId::of::<T>());
205    let Ok(mut pool) = POOL.lock() else { return };
206    pool.cached_bytes += rounded_len * elem_size;
207    let entry = CachedEntry {
208        data: Box::new(value),
209        alloc_stream,
210        stream_uses: Vec::new(),
211    };
212    pool.free.entry(key).or_default().push(entry);
213    POOL_RETURNS.fetch_add(1, Ordering::Relaxed);
214}
215
216// ---------------------------------------------------------------------------
217// Stream recording
218// ---------------------------------------------------------------------------
219
220/// Record that a buffer (identified by its pool key) was used on `stream`.
221///
222/// This prevents the buffer from being returned by [`pool_take_stream`]
223/// until the recorded stream's work is complete. Callers should call this
224/// when a buffer allocated on one stream is consumed by a kernel on a
225/// different stream.
226///
227/// This is the Rust equivalent of PyTorch's `recordStream()`.
228///
229/// # CL-323
230pub fn record_stream<T: Any + Send + Sync>(
231    device_ordinal: usize,
232    rounded_len: usize,
233    stream: StreamId,
234) {
235    let key = (device_ordinal, rounded_len, TypeId::of::<T>());
236    let Ok(mut pool) = POOL.lock() else { return };
237    if let Some(bucket) = pool.free.get_mut(&key) {
238        for entry in bucket.iter_mut() {
239            entry.stream_uses.push(stream);
240        }
241    }
242}
243
244/// Record a stream use on a specific buffer in the pool. This is used to
245/// track cross-stream dependencies so the buffer is not prematurely reused.
246///
247/// # CL-323
248#[cfg(feature = "cuda")]
249pub fn record_stream_on_buffer(
250    device_ordinal: usize,
251    rounded_len: usize,
252    type_id: TypeId,
253    stream: StreamId,
254) {
255    let key = (device_ordinal, rounded_len, type_id);
256    let Ok(mut pool) = POOL.lock() else { return };
257    if let Some(bucket) = pool.free.get_mut(&key) {
258        // Record on all entries in this bucket. In practice there is usually
259        // only one entry per key.
260        for entry in bucket.iter_mut() {
261            entry.stream_uses.push(stream);
262        }
263    }
264}
265
266// ---------------------------------------------------------------------------
267// Cache management
268// ---------------------------------------------------------------------------
269
270/// Drop all cached buffers for a device, releasing GPU memory back to the
271/// CUDA driver.
272pub fn empty_cache(device_ordinal: usize) {
273    let Ok(mut pool) = POOL.lock() else { return };
274    pool.free.retain(|&(dev, _, _), _| dev != device_ordinal);
275    // Recalculate cached_bytes from remaining entries.
276    // Note: we don't store elem_size per entry, so we conservatively estimate
277    // by summing the byte counts we have. After a device-specific clear the
278    // remaining entries may use different elem sizes, so we just reset to 0
279    // and accept that the counter may be slightly off until the next return.
280    // A full clear (empty_cache_all) resets to 0 exactly.
281    pool.cached_bytes = 0;
282}
283
284/// Drop all cached buffers across all devices.
285pub fn empty_cache_all() {
286    let Ok(mut pool) = POOL.lock() else { return };
287    pool.free.clear();
288    pool.cached_bytes = 0;
289}
290
291/// Total bytes currently cached (available for reuse).
292pub fn cached_bytes(_device_ordinal: usize) -> usize {
293    POOL.lock().ok().map(|p| p.cached_bytes).unwrap_or(0)
294}
295
296// ---------------------------------------------------------------------------
297// Tests
298// ---------------------------------------------------------------------------
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn round_len_zero() {
306        assert_eq!(round_len(0), 0);
307    }
308
309    #[test]
310    fn round_len_exact_multiple() {
311        assert_eq!(round_len(256), 256);
312        assert_eq!(round_len(512), 512);
313    }
314
315    #[test]
316    fn round_len_rounds_up() {
317        assert_eq!(round_len(1), 256);
318        assert_eq!(round_len(255), 256);
319        assert_eq!(round_len(257), 512);
320    }
321
322    #[test]
323    fn pool_take_miss_returns_none() {
324        // Take from an empty pool should return None.
325        let result = pool_take::<u64>(99, 256, 8);
326        assert!(result.is_none());
327    }
328
329    #[test]
330    fn pool_return_then_take() {
331        let value: u64 = 12345;
332        pool_return::<u64>(99, 256, 8, value);
333        let taken = pool_take::<u64>(99, 256, 8);
334        assert_eq!(taken, Some(12345u64));
335    }
336
337    #[test]
338    fn pool_stats_tracking() {
339        reset_pool_stats();
340        let (h, _m, r) = pool_stats();
341        assert_eq!(h, 0);
342        assert_eq!(r, 0);
343
344        pool_return::<u32>(98, 256, 4, 42u32);
345        let (_, _, r) = pool_stats();
346        assert!(r >= 1);
347
348        let _ = pool_take::<u32>(98, 256, 4);
349        let (h, _, _) = pool_stats();
350        assert!(h >= 1);
351    }
352
353    #[test]
354    fn stream_aware_take() {
355        let stream_a = StreamId(100);
356        let stream_b = StreamId(200);
357
358        // Return a buffer on stream A.
359        pool_return_with_stream::<u64>(97, 256, 8, 777u64, stream_a);
360
361        // Taking for stream B should fail (stream mismatch).
362        let taken = pool_take_stream::<u64>(97, 256, 8, stream_b);
363        assert!(taken.is_none());
364
365        // Taking for stream A should succeed.
366        let taken = pool_take_stream::<u64>(97, 256, 8, stream_a);
367        assert_eq!(taken, Some(777u64));
368    }
369
370    #[test]
371    fn record_stream_prevents_reuse() {
372        let stream_a = StreamId(300);
373        let stream_b = StreamId(400);
374
375        // Return a buffer on stream A.
376        pool_return_with_stream::<u64>(96, 256, 8, 888u64, stream_a);
377
378        // Record stream B usage on all entries in this bucket.
379        record_stream::<u64>(96, 256, stream_b);
380
381        // Now pool_take_stream for stream A should fail because stream_uses
382        // is non-empty (stream B recorded).
383        let taken = pool_take_stream::<u64>(96, 256, 8, stream_a);
384        assert!(taken.is_none());
385
386        // But the plain pool_take (non-stream-aware) still works.
387        let taken = pool_take::<u64>(96, 256, 8);
388        assert_eq!(taken, Some(888u64));
389    }
390
391    #[test]
392    fn empty_cache_clears_device() {
393        pool_return::<u32>(95, 256, 4, 11u32);
394        pool_return::<u32>(94, 256, 4, 22u32);
395
396        empty_cache(95);
397
398        // Device 95 cleared.
399        assert!(pool_take::<u32>(95, 256, 4).is_none());
400        // Device 94 untouched.
401        assert_eq!(pool_take::<u32>(94, 256, 4), Some(22u32));
402    }
403
404    #[test]
405    fn empty_cache_all_clears_everything() {
406        pool_return::<u32>(93, 256, 4, 33u32);
407        pool_return::<u32>(92, 256, 4, 44u32);
408
409        empty_cache_all();
410
411        assert!(pool_take::<u32>(93, 256, 4).is_none());
412        assert!(pool_take::<u32>(92, 256, 4).is_none());
413    }
414}