Skip to main content

oxicuda_memory/
buffer_view.rs

1//! Type-safe buffer reinterpretation for device memory.
2//!
3//! This module provides [`BufferView`] and [`BufferViewMut`], which allow
4//! reinterpreting a [`DeviceBuffer<T>`] as a different element type `U`
5//! without copying data. This is useful for viewing a buffer of `f32`
6//! values as `u32` (e.g., for bitwise operations in a kernel), or for
7//! interpreting raw byte buffers as structured types.
8//!
9//! # Size constraints
10//!
11//! The total byte size of the original buffer must be evenly divisible
12//! by `std::mem::size_of::<U>()`. If not, the view creation returns
13//! [`CudaError::InvalidValue`].
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! # use oxicuda_memory::DeviceBuffer;
19//! # use oxicuda_memory::buffer_view::BufferView;
20//! let buf = DeviceBuffer::<f32>::alloc(256)?;
21//! // Reinterpret as u32 (same size, different type)
22//! let view: BufferView<'_, u32> = buf.view_as::<u32>()?;
23//! assert_eq!(view.len(), 256);
24//! # Ok::<(), oxicuda_driver::error::CudaError>(())
25//! ```
26
27use std::marker::PhantomData;
28
29use oxicuda_driver::error::{CudaError, CudaResult};
30use oxicuda_driver::ffi::CUdeviceptr;
31
32use crate::device_buffer::DeviceBuffer;
33
34// ---------------------------------------------------------------------------
35// BufferView<'a, U>
36// ---------------------------------------------------------------------------
37
38/// An immutable, type-reinterpreted view into a [`DeviceBuffer`].
39///
40/// This struct borrows the underlying device allocation and exposes it
41/// as a different element type `U`. No data is copied; only the
42/// pointer arithmetic changes.
43///
44/// The view is lifetime-bound to the original buffer, preventing use
45/// after the buffer is freed.
46pub struct BufferView<'a, U: Copy> {
47    /// Device pointer to the start of the buffer.
48    ptr: CUdeviceptr,
49    /// Number of `U` elements in the reinterpreted view.
50    len: usize,
51    /// Ties the lifetime to the parent buffer.
52    _phantom: PhantomData<&'a U>,
53}
54
55impl<U: Copy> BufferView<'_, U> {
56    /// Returns the number of `U` elements in this view.
57    #[inline]
58    pub fn len(&self) -> usize {
59        self.len
60    }
61
62    /// Returns `true` if the view contains zero elements.
63    #[inline]
64    pub fn is_empty(&self) -> bool {
65        self.len == 0
66    }
67
68    /// Returns the total byte size of this view.
69    #[inline]
70    pub fn byte_size(&self) -> usize {
71        self.len * std::mem::size_of::<U>()
72    }
73
74    /// Returns the raw [`CUdeviceptr`] for this view.
75    #[inline]
76    pub fn as_device_ptr(&self) -> CUdeviceptr {
77        self.ptr
78    }
79}
80
81impl<U: Copy> std::fmt::Debug for BufferView<'_, U> {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("BufferView")
84            .field("ptr", &self.ptr)
85            .field("len", &self.len)
86            .field("elem_size", &std::mem::size_of::<U>())
87            .finish()
88    }
89}
90
91// ---------------------------------------------------------------------------
92// BufferViewMut<'a, U>
93// ---------------------------------------------------------------------------
94
95/// A mutable, type-reinterpreted view into a [`DeviceBuffer`].
96///
97/// Like [`BufferView`] but allows mutable operations (e.g., passing
98/// to a kernel that writes through this reinterpreted pointer).
99pub struct BufferViewMut<'a, U: Copy> {
100    /// Device pointer to the start of the buffer.
101    ptr: CUdeviceptr,
102    /// Number of `U` elements in the reinterpreted view.
103    len: usize,
104    /// Ties the lifetime to the parent buffer (mutable borrow).
105    _phantom: PhantomData<&'a mut U>,
106}
107
108impl<U: Copy> BufferViewMut<'_, U> {
109    /// Returns the number of `U` elements in this view.
110    #[inline]
111    pub fn len(&self) -> usize {
112        self.len
113    }
114
115    /// Returns `true` if the view contains zero elements.
116    #[inline]
117    pub fn is_empty(&self) -> bool {
118        self.len == 0
119    }
120
121    /// Returns the total byte size of this view.
122    #[inline]
123    pub fn byte_size(&self) -> usize {
124        self.len * std::mem::size_of::<U>()
125    }
126
127    /// Returns the raw [`CUdeviceptr`] for this view.
128    #[inline]
129    pub fn as_device_ptr(&self) -> CUdeviceptr {
130        self.ptr
131    }
132}
133
134impl<U: Copy> std::fmt::Debug for BufferViewMut<'_, U> {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("BufferViewMut")
137            .field("ptr", &self.ptr)
138            .field("len", &self.len)
139            .field("elem_size", &std::mem::size_of::<U>())
140            .finish()
141    }
142}
143
144// ---------------------------------------------------------------------------
145// DeviceBuffer extensions
146// ---------------------------------------------------------------------------
147
148impl<T: Copy> DeviceBuffer<T> {
149    /// Reinterprets this buffer as a different element type `U` (immutable).
150    ///
151    /// The total byte size of the buffer must be evenly divisible by
152    /// `size_of::<U>()`. The resulting view has `byte_size / size_of::<U>()`
153    /// elements.
154    ///
155    /// # Errors
156    ///
157    /// Returns [`CudaError::InvalidValue`] if:
158    /// - `size_of::<U>()` is zero (ZST).
159    /// - The buffer's byte size is not divisible by `size_of::<U>()`.
160    pub fn view_as<U: Copy>(&self) -> CudaResult<BufferView<'_, U>> {
161        let u_size = std::mem::size_of::<U>();
162        if u_size == 0 {
163            return Err(CudaError::InvalidValue);
164        }
165        let byte_size = self.byte_size();
166        if byte_size % u_size != 0 {
167            return Err(CudaError::InvalidValue);
168        }
169        Ok(BufferView {
170            ptr: self.as_device_ptr(),
171            len: byte_size / u_size,
172            _phantom: PhantomData,
173        })
174    }
175
176    /// Reinterprets this buffer as a different element type `U` (mutable).
177    ///
178    /// The total byte size of the buffer must be evenly divisible by
179    /// `size_of::<U>()`. The resulting view has `byte_size / size_of::<U>()`
180    /// elements.
181    ///
182    /// # Errors
183    ///
184    /// Returns [`CudaError::InvalidValue`] if:
185    /// - `size_of::<U>()` is zero (ZST).
186    /// - The buffer's byte size is not divisible by `size_of::<U>()`.
187    pub fn view_as_mut<U: Copy>(&mut self) -> CudaResult<BufferViewMut<'_, U>> {
188        let u_size = std::mem::size_of::<U>();
189        if u_size == 0 {
190            return Err(CudaError::InvalidValue);
191        }
192        let byte_size = self.byte_size();
193        if byte_size % u_size != 0 {
194            return Err(CudaError::InvalidValue);
195        }
196        Ok(BufferViewMut {
197            ptr: self.as_device_ptr(),
198            len: byte_size / u_size,
199            _phantom: PhantomData,
200        })
201    }
202}
203
204// ---------------------------------------------------------------------------
205// Tests
206// ---------------------------------------------------------------------------
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn buffer_view_debug() {
214        let view: BufferView<'_, u32> = BufferView {
215            ptr: 0x1000,
216            len: 64,
217            _phantom: PhantomData,
218        };
219        let dbg = format!("{view:?}");
220        assert!(dbg.contains("BufferView"));
221        assert!(dbg.contains("64"));
222    }
223
224    #[test]
225    fn buffer_view_mut_debug() {
226        let view: BufferViewMut<'_, f32> = BufferViewMut {
227            ptr: 0x2000,
228            len: 128,
229            _phantom: PhantomData,
230        };
231        let dbg = format!("{view:?}");
232        assert!(dbg.contains("BufferViewMut"));
233        assert!(dbg.contains("128"));
234    }
235
236    #[test]
237    fn buffer_view_len_and_byte_size() {
238        let view: BufferView<'_, u64> = BufferView {
239            ptr: 0x3000,
240            len: 32,
241            _phantom: PhantomData,
242        };
243        assert_eq!(view.len(), 32);
244        assert_eq!(view.byte_size(), 32 * 8);
245        assert!(!view.is_empty());
246        assert_eq!(view.as_device_ptr(), 0x3000);
247    }
248
249    #[test]
250    fn buffer_view_mut_len_and_byte_size() {
251        let view: BufferViewMut<'_, u16> = BufferViewMut {
252            ptr: 0x4000,
253            len: 100,
254            _phantom: PhantomData,
255        };
256        assert_eq!(view.len(), 100);
257        assert_eq!(view.byte_size(), 200);
258        assert!(!view.is_empty());
259        assert_eq!(view.as_device_ptr(), 0x4000);
260    }
261
262    #[test]
263    fn buffer_view_empty() {
264        let view: BufferView<'_, f64> = BufferView {
265            ptr: 0,
266            len: 0,
267            _phantom: PhantomData,
268        };
269        assert!(view.is_empty());
270        assert_eq!(view.byte_size(), 0);
271    }
272
273    #[test]
274    fn view_as_signature_compiles() {
275        let _: fn(&DeviceBuffer<f32>) -> CudaResult<BufferView<'_, u32>> = DeviceBuffer::view_as;
276    }
277
278    #[test]
279    fn view_as_mut_signature_compiles() {
280        let _: fn(&mut DeviceBuffer<f32>) -> CudaResult<BufferViewMut<'_, u32>> =
281            DeviceBuffer::view_as_mut;
282    }
283}