Skip to main content

ferrotorch_gpu/
buffer.rs

1//! GPU memory buffer with pool-aware Drop.
2//!
3//! [`CudaBuffer`] owns a region of device memory via `cudarc::driver::CudaSlice`
4//! and tracks its length and originating device ordinal. When dropped, pooled
5//! buffers are returned to the global GPU memory pool for reuse instead of
6//! being freed back to the CUDA driver.
7
8#[cfg(feature = "cuda")]
9use cudarc::driver::CudaSlice;
10
11/// Type-erased function pointer that returns a `CudaSlice<T>` to the pool.
12/// Stored as `Option` — `None` means "don't pool, just drop normally."
13#[cfg(feature = "cuda")]
14type PoolReturnFn<T> = Option<fn(usize, usize, CudaSlice<T>)>;
15
16/// Return a `CudaSlice<f32>` to the global pool.
17#[cfg(feature = "cuda")]
18fn return_f32(device: usize, len: usize, slice: CudaSlice<f32>) {
19    crate::pool::pool_return::<CudaSlice<f32>>(device, len, 4, slice);
20}
21
22/// Return a `CudaSlice<f64>` to the global pool.
23#[cfg(feature = "cuda")]
24fn return_f64(device: usize, len: usize, slice: CudaSlice<f64>) {
25    crate::pool::pool_return::<CudaSlice<f64>>(device, len, 8, slice);
26}
27
28/// Owned GPU memory buffer holding `len` elements of type `T`.
29///
30/// When `pool_fn` is `Some`, dropping returns the inner `CudaSlice` to the
31/// global pool ([`crate::pool`]) instead of freeing GPU memory.
32///
33/// `alloc_len` is the rounded allocation size used as the pool key.
34/// `len` is the logical element count visible to callers.
35#[cfg(feature = "cuda")]
36pub struct CudaBuffer<T> {
37    /// The underlying CUDA device memory. Wrapped in `Option` so
38    /// `Drop` can `take()` it without double-free.
39    pub(crate) data: Option<CudaSlice<T>>,
40    pub(crate) len: usize,
41    /// Rounded allocation length — used as the pool key so that
42    /// buffers are always findable on pool lookup.
43    pub(crate) alloc_len: usize,
44    pub(crate) device_ordinal: usize,
45    /// If `Some`, this function is called in Drop to return the slice
46    /// to the pool. If `None`, CudaSlice::Drop frees normally.
47    pub(crate) pool_fn: PoolReturnFn<T>,
48}
49
50/// Helper to create a pooled f32 buffer.
51#[cfg(feature = "cuda")]
52impl CudaBuffer<f32> {
53    /// Create a pooled f32 buffer that returns to the global pool on drop.
54    ///
55    /// `alloc_len` is the rounded allocation size used as the pool key.
56    /// `len` is the logical element count visible to callers.
57    pub(crate) fn new_pooled(
58        slice: CudaSlice<f32>,
59        len: usize,
60        alloc_len: usize,
61        device: usize,
62    ) -> Self {
63        Self {
64            data: Some(slice),
65            len,
66            alloc_len,
67            device_ordinal: device,
68            pool_fn: Some(return_f32),
69        }
70    }
71}
72
73#[cfg(feature = "cuda")]
74impl CudaBuffer<f64> {
75    /// Create a pooled f64 buffer that returns to the global pool on drop.
76    ///
77    /// `alloc_len` is the rounded allocation size used as the pool key.
78    /// `len` is the logical element count visible to callers.
79    pub(crate) fn new_pooled(
80        slice: CudaSlice<f64>,
81        len: usize,
82        alloc_len: usize,
83        device: usize,
84    ) -> Self {
85        Self {
86            data: Some(slice),
87            len,
88            alloc_len,
89            device_ordinal: device,
90            pool_fn: Some(return_f64),
91        }
92    }
93}
94
95#[cfg(feature = "cuda")]
96impl<T> Drop for CudaBuffer<T> {
97    fn drop(&mut self) {
98        if let Some(slice) = self.data.take() {
99            if let Some(return_fn) = self.pool_fn {
100                // Use alloc_len (rounded) as the pool key so the buffer
101                // is findable on the next pool_take with the same rounded len.
102                return_fn(self.device_ordinal, self.alloc_len, slice);
103            }
104            // else: CudaSlice::Drop fires naturally (cuMemFreeAsync)
105        }
106    }
107}
108
109#[cfg(feature = "cuda")]
110impl<T> CudaBuffer<T> {
111    /// Number of logical elements in this buffer.
112    #[inline]
113    pub fn len(&self) -> usize {
114        self.len
115    }
116
117    /// Rounded allocation length used as the pool key.
118    ///
119    /// For pooled buffers, this is `round_len(len)`. For non-pooled
120    /// buffers, this equals `len`. Stats (hits, misses, returns) use
121    /// `len` consistently within the allocator for user-facing reporting;
122    /// `alloc_len` is an internal detail for pool key stability.
123    #[inline]
124    pub fn alloc_len(&self) -> usize {
125        self.alloc_len
126    }
127
128    /// Whether the buffer is empty.
129    #[inline]
130    pub fn is_empty(&self) -> bool {
131        self.len == 0
132    }
133
134    /// The ordinal of the device that owns this memory.
135    #[inline]
136    pub fn device_ordinal(&self) -> usize {
137        self.device_ordinal
138    }
139
140    /// Borrow the underlying `CudaSlice` for use with cudarc APIs.
141    #[inline]
142    pub fn inner(&self) -> &CudaSlice<T> {
143        self.data
144            .as_ref()
145            .expect("CudaBuffer: inner slice already taken")
146    }
147
148    /// Mutably borrow the underlying `CudaSlice`.
149    #[inline]
150    pub fn inner_mut(&mut self) -> &mut CudaSlice<T> {
151        self.data
152            .as_mut()
153            .expect("CudaBuffer: inner slice already taken")
154    }
155}
156
157#[cfg(feature = "cuda")]
158impl<T> std::fmt::Debug for CudaBuffer<T> {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("CudaBuffer")
161            .field("len", &self.len)
162            .field("device_ordinal", &self.device_ordinal)
163            .field("pooled", &self.pool_fn.is_some())
164            .finish_non_exhaustive()
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Stub when `cuda` feature is disabled
170// ---------------------------------------------------------------------------
171
172/// Stub `CudaBuffer` when the `cuda` feature is not enabled.
173#[cfg(not(feature = "cuda"))]
174#[derive(Debug)]
175pub struct CudaBuffer<T> {
176    pub(crate) _phantom: std::marker::PhantomData<T>,
177    pub(crate) len: usize,
178    pub(crate) device_ordinal: usize,
179}
180
181#[cfg(not(feature = "cuda"))]
182impl<T> CudaBuffer<T> {
183    /// Number of elements in this buffer.
184    #[inline]
185    pub fn len(&self) -> usize {
186        self.len
187    }
188
189    /// Whether the buffer is empty.
190    #[inline]
191    pub fn is_empty(&self) -> bool {
192        self.len == 0
193    }
194
195    /// The ordinal of the device that owns this memory.
196    #[inline]
197    pub fn device_ordinal(&self) -> usize {
198        self.device_ordinal
199    }
200}