oxicuda_memory/buffer_view.rs
1//! Type-safe buffer reinterpretation for device memory.
2//!
3//! This module provides [`BufferView`] and [`BufferViewMut`], which allow
4//! reinterpreting a [`DeviceBuffer<T>`] as a different element type `U`
5//! without copying data. This is useful for viewing a buffer of `f32`
6//! values as `u32` (e.g., for bitwise operations in a kernel), or for
7//! interpreting raw byte buffers as structured types.
8//!
9//! # Size constraints
10//!
11//! The total byte size of the original buffer must be evenly divisible
12//! by `std::mem::size_of::<U>()`. If not, the view creation returns
13//! [`CudaError::InvalidValue`].
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! # use oxicuda_memory::DeviceBuffer;
19//! # use oxicuda_memory::buffer_view::BufferView;
20//! let buf = DeviceBuffer::<f32>::alloc(256)?;
21//! // Reinterpret as u32 (same size, different type)
22//! let view: BufferView<'_, u32> = buf.view_as::<u32>()?;
23//! assert_eq!(view.len(), 256);
24//! # Ok::<(), oxicuda_driver::error::CudaError>(())
25//! ```
26
27use std::marker::PhantomData;
28
29use oxicuda_driver::error::{CudaError, CudaResult};
30use oxicuda_driver::ffi::CUdeviceptr;
31
32use crate::device_buffer::DeviceBuffer;
33
34// ---------------------------------------------------------------------------
35// BufferView<'a, U>
36// ---------------------------------------------------------------------------
37
38/// An immutable, type-reinterpreted view into a [`DeviceBuffer`].
39///
40/// This struct borrows the underlying device allocation and exposes it
41/// as a different element type `U`. No data is copied; only the
42/// pointer arithmetic changes.
43///
44/// The view is lifetime-bound to the original buffer, preventing use
45/// after the buffer is freed.
46pub struct BufferView<'a, U: Copy> {
47 /// Device pointer to the start of the buffer.
48 ptr: CUdeviceptr,
49 /// Number of `U` elements in the reinterpreted view.
50 len: usize,
51 /// Ties the lifetime to the parent buffer.
52 _phantom: PhantomData<&'a U>,
53}
54
55impl<U: Copy> BufferView<'_, U> {
56 /// Returns the number of `U` elements in this view.
57 #[inline]
58 pub fn len(&self) -> usize {
59 self.len
60 }
61
62 /// Returns `true` if the view contains zero elements.
63 #[inline]
64 pub fn is_empty(&self) -> bool {
65 self.len == 0
66 }
67
68 /// Returns the total byte size of this view.
69 #[inline]
70 pub fn byte_size(&self) -> usize {
71 self.len * std::mem::size_of::<U>()
72 }
73
74 /// Returns the raw [`CUdeviceptr`] for this view.
75 #[inline]
76 pub fn as_device_ptr(&self) -> CUdeviceptr {
77 self.ptr
78 }
79}
80
81impl<U: Copy> std::fmt::Debug for BufferView<'_, U> {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("BufferView")
84 .field("ptr", &self.ptr)
85 .field("len", &self.len)
86 .field("elem_size", &std::mem::size_of::<U>())
87 .finish()
88 }
89}
90
91// ---------------------------------------------------------------------------
92// BufferViewMut<'a, U>
93// ---------------------------------------------------------------------------
94
95/// A mutable, type-reinterpreted view into a [`DeviceBuffer`].
96///
97/// Like [`BufferView`] but allows mutable operations (e.g., passing
98/// to a kernel that writes through this reinterpreted pointer).
99pub struct BufferViewMut<'a, U: Copy> {
100 /// Device pointer to the start of the buffer.
101 ptr: CUdeviceptr,
102 /// Number of `U` elements in the reinterpreted view.
103 len: usize,
104 /// Ties the lifetime to the parent buffer (mutable borrow).
105 _phantom: PhantomData<&'a mut U>,
106}
107
108impl<U: Copy> BufferViewMut<'_, U> {
109 /// Returns the number of `U` elements in this view.
110 #[inline]
111 pub fn len(&self) -> usize {
112 self.len
113 }
114
115 /// Returns `true` if the view contains zero elements.
116 #[inline]
117 pub fn is_empty(&self) -> bool {
118 self.len == 0
119 }
120
121 /// Returns the total byte size of this view.
122 #[inline]
123 pub fn byte_size(&self) -> usize {
124 self.len * std::mem::size_of::<U>()
125 }
126
127 /// Returns the raw [`CUdeviceptr`] for this view.
128 #[inline]
129 pub fn as_device_ptr(&self) -> CUdeviceptr {
130 self.ptr
131 }
132}
133
134impl<U: Copy> std::fmt::Debug for BufferViewMut<'_, U> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("BufferViewMut")
137 .field("ptr", &self.ptr)
138 .field("len", &self.len)
139 .field("elem_size", &std::mem::size_of::<U>())
140 .finish()
141 }
142}
143
144// ---------------------------------------------------------------------------
145// DeviceBuffer extensions
146// ---------------------------------------------------------------------------
147
148impl<T: Copy> DeviceBuffer<T> {
149 /// Reinterprets this buffer as a different element type `U` (immutable).
150 ///
151 /// The total byte size of the buffer must be evenly divisible by
152 /// `size_of::<U>()`. The resulting view has `byte_size / size_of::<U>()`
153 /// elements.
154 ///
155 /// # Errors
156 ///
157 /// Returns [`CudaError::InvalidValue`] if:
158 /// - `size_of::<U>()` is zero (ZST).
159 /// - The buffer's byte size is not divisible by `size_of::<U>()`.
160 pub fn view_as<U: Copy>(&self) -> CudaResult<BufferView<'_, U>> {
161 let u_size = std::mem::size_of::<U>();
162 if u_size == 0 {
163 return Err(CudaError::InvalidValue);
164 }
165 let byte_size = self.byte_size();
166 if byte_size % u_size != 0 {
167 return Err(CudaError::InvalidValue);
168 }
169 Ok(BufferView {
170 ptr: self.as_device_ptr(),
171 len: byte_size / u_size,
172 _phantom: PhantomData,
173 })
174 }
175
176 /// Reinterprets this buffer as a different element type `U` (mutable).
177 ///
178 /// The total byte size of the buffer must be evenly divisible by
179 /// `size_of::<U>()`. The resulting view has `byte_size / size_of::<U>()`
180 /// elements.
181 ///
182 /// # Errors
183 ///
184 /// Returns [`CudaError::InvalidValue`] if:
185 /// - `size_of::<U>()` is zero (ZST).
186 /// - The buffer's byte size is not divisible by `size_of::<U>()`.
187 pub fn view_as_mut<U: Copy>(&mut self) -> CudaResult<BufferViewMut<'_, U>> {
188 let u_size = std::mem::size_of::<U>();
189 if u_size == 0 {
190 return Err(CudaError::InvalidValue);
191 }
192 let byte_size = self.byte_size();
193 if byte_size % u_size != 0 {
194 return Err(CudaError::InvalidValue);
195 }
196 Ok(BufferViewMut {
197 ptr: self.as_device_ptr(),
198 len: byte_size / u_size,
199 _phantom: PhantomData,
200 })
201 }
202}
203
204// ---------------------------------------------------------------------------
205// Tests
206// ---------------------------------------------------------------------------
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn buffer_view_debug() {
214 let view: BufferView<'_, u32> = BufferView {
215 ptr: 0x1000,
216 len: 64,
217 _phantom: PhantomData,
218 };
219 let dbg = format!("{view:?}");
220 assert!(dbg.contains("BufferView"));
221 assert!(dbg.contains("64"));
222 }
223
224 #[test]
225 fn buffer_view_mut_debug() {
226 let view: BufferViewMut<'_, f32> = BufferViewMut {
227 ptr: 0x2000,
228 len: 128,
229 _phantom: PhantomData,
230 };
231 let dbg = format!("{view:?}");
232 assert!(dbg.contains("BufferViewMut"));
233 assert!(dbg.contains("128"));
234 }
235
236 #[test]
237 fn buffer_view_len_and_byte_size() {
238 let view: BufferView<'_, u64> = BufferView {
239 ptr: 0x3000,
240 len: 32,
241 _phantom: PhantomData,
242 };
243 assert_eq!(view.len(), 32);
244 assert_eq!(view.byte_size(), 32 * 8);
245 assert!(!view.is_empty());
246 assert_eq!(view.as_device_ptr(), 0x3000);
247 }
248
249 #[test]
250 fn buffer_view_mut_len_and_byte_size() {
251 let view: BufferViewMut<'_, u16> = BufferViewMut {
252 ptr: 0x4000,
253 len: 100,
254 _phantom: PhantomData,
255 };
256 assert_eq!(view.len(), 100);
257 assert_eq!(view.byte_size(), 200);
258 assert!(!view.is_empty());
259 assert_eq!(view.as_device_ptr(), 0x4000);
260 }
261
262 #[test]
263 fn buffer_view_empty() {
264 let view: BufferView<'_, f64> = BufferView {
265 ptr: 0,
266 len: 0,
267 _phantom: PhantomData,
268 };
269 assert!(view.is_empty());
270 assert_eq!(view.byte_size(), 0);
271 }
272
273 #[test]
274 fn view_as_signature_compiles() {
275 let _: fn(&DeviceBuffer<f32>) -> CudaResult<BufferView<'_, u32>> = DeviceBuffer::view_as;
276 }
277
278 #[test]
279 fn view_as_mut_signature_compiles() {
280 let _: fn(&mut DeviceBuffer<f32>) -> CudaResult<BufferViewMut<'_, u32>> =
281 DeviceBuffer::view_as_mut;
282 }
283}