ferrotorch_gpu/buffer.rs
1//! GPU memory buffer with pool-aware Drop.
2//!
3//! [`CudaBuffer`] owns a region of device memory via `cudarc::driver::CudaSlice`
4//! and tracks its length and originating device ordinal. When dropped, pooled
5//! buffers are returned to the global GPU memory pool for reuse instead of
6//! being freed back to the CUDA driver.
7
8#[cfg(feature = "cuda")]
9use cudarc::driver::CudaSlice;
10
11/// Type-erased function pointer that returns a `CudaSlice<T>` to the pool.
12/// Stored as `Option` — `None` means "don't pool, just drop normally."
13#[cfg(feature = "cuda")]
14type PoolReturnFn<T> = Option<fn(usize, usize, CudaSlice<T>)>;
15
16/// Return a `CudaSlice<f32>` to the global pool.
17#[cfg(feature = "cuda")]
18fn return_f32(device: usize, len: usize, slice: CudaSlice<f32>) {
19 crate::pool::pool_return::<CudaSlice<f32>>(device, len, 4, slice);
20}
21
22/// Return a `CudaSlice<f64>` to the global pool.
23#[cfg(feature = "cuda")]
24fn return_f64(device: usize, len: usize, slice: CudaSlice<f64>) {
25 crate::pool::pool_return::<CudaSlice<f64>>(device, len, 8, slice);
26}
27
28/// Owned GPU memory buffer holding `len` elements of type `T`.
29///
30/// When `pool_fn` is `Some`, dropping returns the inner `CudaSlice` to the
31/// global pool ([`crate::pool`]) instead of freeing GPU memory.
32///
33/// `alloc_len` is the rounded allocation size used as the pool key.
34/// `len` is the logical element count visible to callers.
35#[cfg(feature = "cuda")]
36pub struct CudaBuffer<T> {
37 /// The underlying CUDA device memory. Wrapped in `Option` so
38 /// `Drop` can `take()` it without double-free.
39 pub(crate) data: Option<CudaSlice<T>>,
40 pub(crate) len: usize,
41 /// Rounded allocation length — used as the pool key so that
42 /// buffers are always findable on pool lookup.
43 pub(crate) alloc_len: usize,
44 pub(crate) device_ordinal: usize,
45 /// If `Some`, this function is called in Drop to return the slice
46 /// to the pool. If `None`, CudaSlice::Drop frees normally.
47 pub(crate) pool_fn: PoolReturnFn<T>,
48}
49
50/// Helper to create a pooled f32 buffer.
51#[cfg(feature = "cuda")]
52impl CudaBuffer<f32> {
53 /// Create a pooled f32 buffer that returns to the global pool on drop.
54 ///
55 /// `alloc_len` is the rounded allocation size used as the pool key.
56 /// `len` is the logical element count visible to callers.
57 pub(crate) fn new_pooled(
58 slice: CudaSlice<f32>,
59 len: usize,
60 alloc_len: usize,
61 device: usize,
62 ) -> Self {
63 Self {
64 data: Some(slice),
65 len,
66 alloc_len,
67 device_ordinal: device,
68 pool_fn: Some(return_f32),
69 }
70 }
71}
72
73#[cfg(feature = "cuda")]
74impl CudaBuffer<f64> {
75 /// Create a pooled f64 buffer that returns to the global pool on drop.
76 ///
77 /// `alloc_len` is the rounded allocation size used as the pool key.
78 /// `len` is the logical element count visible to callers.
79 pub(crate) fn new_pooled(
80 slice: CudaSlice<f64>,
81 len: usize,
82 alloc_len: usize,
83 device: usize,
84 ) -> Self {
85 Self {
86 data: Some(slice),
87 len,
88 alloc_len,
89 device_ordinal: device,
90 pool_fn: Some(return_f64),
91 }
92 }
93}
94
95#[cfg(feature = "cuda")]
96impl<T> Drop for CudaBuffer<T> {
97 fn drop(&mut self) {
98 if let Some(slice) = self.data.take() {
99 if let Some(return_fn) = self.pool_fn {
100 // Use alloc_len (rounded) as the pool key so the buffer
101 // is findable on the next pool_take with the same rounded len.
102 return_fn(self.device_ordinal, self.alloc_len, slice);
103 }
104 // else: CudaSlice::Drop fires naturally (cuMemFreeAsync)
105 }
106 }
107}
108
109#[cfg(feature = "cuda")]
110impl<T> CudaBuffer<T> {
111 /// Number of logical elements in this buffer.
112 #[inline]
113 pub fn len(&self) -> usize {
114 self.len
115 }
116
117 /// Rounded allocation length used as the pool key.
118 ///
119 /// For pooled buffers, this is `round_len(len)`. For non-pooled
120 /// buffers, this equals `len`. Stats (hits, misses, returns) use
121 /// `len` consistently within the allocator for user-facing reporting;
122 /// `alloc_len` is an internal detail for pool key stability.
123 #[inline]
124 pub fn alloc_len(&self) -> usize {
125 self.alloc_len
126 }
127
128 /// Whether the buffer is empty.
129 #[inline]
130 pub fn is_empty(&self) -> bool {
131 self.len == 0
132 }
133
134 /// The ordinal of the device that owns this memory.
135 #[inline]
136 pub fn device_ordinal(&self) -> usize {
137 self.device_ordinal
138 }
139
140 /// Borrow the underlying `CudaSlice` for use with cudarc APIs.
141 #[inline]
142 pub fn inner(&self) -> &CudaSlice<T> {
143 self.data
144 .as_ref()
145 .expect("CudaBuffer: inner slice already taken")
146 }
147
148 /// Mutably borrow the underlying `CudaSlice`.
149 #[inline]
150 pub fn inner_mut(&mut self) -> &mut CudaSlice<T> {
151 self.data
152 .as_mut()
153 .expect("CudaBuffer: inner slice already taken")
154 }
155}
156
157#[cfg(feature = "cuda")]
158impl<T> std::fmt::Debug for CudaBuffer<T> {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("CudaBuffer")
161 .field("len", &self.len)
162 .field("device_ordinal", &self.device_ordinal)
163 .field("pooled", &self.pool_fn.is_some())
164 .finish_non_exhaustive()
165 }
166}
167
168// ---------------------------------------------------------------------------
169// Stub when `cuda` feature is disabled
170// ---------------------------------------------------------------------------
171
172/// Stub `CudaBuffer` when the `cuda` feature is not enabled.
173#[cfg(not(feature = "cuda"))]
174#[derive(Debug)]
175pub struct CudaBuffer<T> {
176 pub(crate) _phantom: std::marker::PhantomData<T>,
177 pub(crate) len: usize,
178 pub(crate) device_ordinal: usize,
179}
180
181#[cfg(not(feature = "cuda"))]
182impl<T> CudaBuffer<T> {
183 /// Number of elements in this buffer.
184 #[inline]
185 pub fn len(&self) -> usize {
186 self.len
187 }
188
189 /// Whether the buffer is empty.
190 #[inline]
191 pub fn is_empty(&self) -> bool {
192 self.len == 0
193 }
194
195 /// The ordinal of the device that owns this memory.
196 #[inline]
197 pub fn device_ordinal(&self) -> usize {
198 self.device_ordinal
199 }
200}