Skip to main content

baracuda_driver/
array.rs

1//! CUDA arrays + texture / surface objects.
2//!
3//! A [`Array`] is an opaque on-device layout optimized for texture fetches
4//! and surface loads/stores. Host↔Array copies go through the 2-D memcpy
5//! path (see [`crate::memcpy2d`]), with `CUmemorytype::ARRAY` on whichever
6//! side the array sits on.
7//!
8//! Texture and surface objects are created *from* an array (or a pitched
9//! device pointer, for linear textures). This module exposes the modern
10//! "object" API (CUDA 5+); the legacy reference-based API is not wrapped.
11
12use core::ffi::c_void;
13use core::mem::size_of;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::types::{
17    CUarray_format, CUDA_ARRAY_DESCRIPTOR, CUDA_RESOURCE_DESC, CUDA_TEXTURE_DESC,
18};
19use baracuda_cuda_sys::{driver, CUarray, CUsurfObject, CUtexObject};
20use baracuda_types::DeviceRepr;
21
22use crate::context::Context;
23use crate::error::{check, Result};
24
25/// A 2-D CUDA array. Element format is chosen at creation; channels are
26/// typically 1, 2, or 4.
27pub struct Array {
28    inner: Arc<ArrayInner>,
29}
30
31struct ArrayInner {
32    handle: CUarray,
33    width: usize,
34    height: usize,
35    format: u32,
36    num_channels: u32,
37    #[allow(dead_code)]
38    context: Context,
39}
40
41unsafe impl Send for ArrayInner {}
42unsafe impl Sync for ArrayInner {}
43
44impl core::fmt::Debug for ArrayInner {
45    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46        f.debug_struct("Array")
47            .field("width", &self.width)
48            .field("height", &self.height)
49            .field("format", &self.format)
50            .field("channels", &self.num_channels)
51            .finish_non_exhaustive()
52    }
53}
54
55impl core::fmt::Debug for Array {
56    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57        self.inner.fmt(f)
58    }
59}
60
61impl Clone for Array {
62    fn clone(&self) -> Self {
63        Self {
64            inner: self.inner.clone(),
65        }
66    }
67}
68
69/// Element format shorthand for common single-channel arrays. Build a
70/// multi-channel descriptor by passing a [`CUarray_format`] constant and a
71/// channel count directly to [`Array::new`].
72#[derive(Copy, Clone, Debug, Eq, PartialEq)]
73pub enum ArrayFormat {
74    U8,
75    U16,
76    U32,
77    I8,
78    I16,
79    I32,
80    F16,
81    F32,
82}
83
84impl ArrayFormat {
85    #[inline]
86    fn raw(self) -> u32 {
87        match self {
88            ArrayFormat::U8 => CUarray_format::UNSIGNED_INT8,
89            ArrayFormat::U16 => CUarray_format::UNSIGNED_INT16,
90            ArrayFormat::U32 => CUarray_format::UNSIGNED_INT32,
91            ArrayFormat::I8 => CUarray_format::SIGNED_INT8,
92            ArrayFormat::I16 => CUarray_format::SIGNED_INT16,
93            ArrayFormat::I32 => CUarray_format::SIGNED_INT32,
94            ArrayFormat::F16 => CUarray_format::HALF,
95            ArrayFormat::F32 => CUarray_format::FLOAT,
96        }
97    }
98
99    /// Width of a single texel (one channel) in bytes.
100    #[inline]
101    pub fn bytes_per_channel(self) -> usize {
102        match self {
103            ArrayFormat::U8 | ArrayFormat::I8 => 1,
104            ArrayFormat::U16 | ArrayFormat::I16 | ArrayFormat::F16 => 2,
105            ArrayFormat::U32 | ArrayFormat::I32 | ArrayFormat::F32 => 4,
106        }
107    }
108}
109
110impl Array {
111    /// Allocate a `width × height` 2-D array with `num_channels` elements of
112    /// `format` per texel. Set `height = 0` for a 1-D array.
113    pub fn new(
114        context: &Context,
115        width: usize,
116        height: usize,
117        format: ArrayFormat,
118        num_channels: u32,
119    ) -> Result<Self> {
120        assert!(
121            matches!(num_channels, 1 | 2 | 4),
122            "CUDA arrays require 1, 2, or 4 channels (got {num_channels})",
123        );
124        context.set_current()?;
125        let d = driver()?;
126        let cu = d.cu_array_create()?;
127        let desc = CUDA_ARRAY_DESCRIPTOR {
128            width,
129            height,
130            format: format.raw(),
131            num_channels,
132        };
133        let mut handle: CUarray = core::ptr::null_mut();
134        check(unsafe { cu(&mut handle, &desc) })?;
135        Ok(Self {
136            inner: Arc::new(ArrayInner {
137                handle,
138                width,
139                height,
140                format: format.raw(),
141                num_channels,
142                context: context.clone(),
143            }),
144        })
145    }
146
147    /// Raw `CUarray`. Use with care.
148    #[inline]
149    pub fn as_raw(&self) -> CUarray {
150        self.inner.handle
151    }
152    #[inline]
153    pub fn width(&self) -> usize {
154        self.inner.width
155    }
156    #[inline]
157    pub fn height(&self) -> usize {
158        self.inner.height
159    }
160    #[inline]
161    pub fn num_channels(&self) -> u32 {
162        self.inner.num_channels
163    }
164    /// Element width in bytes (channel size × channel count).
165    pub fn bytes_per_element(&self) -> usize {
166        let ch_size = match self.inner.format {
167            CUarray_format::UNSIGNED_INT8 | CUarray_format::SIGNED_INT8 => 1,
168            CUarray_format::UNSIGNED_INT16
169            | CUarray_format::SIGNED_INT16
170            | CUarray_format::HALF => 2,
171            _ => 4,
172        };
173        ch_size * (self.inner.num_channels as usize)
174    }
175
176    /// Synchronous host→array 2-D copy. `host` must contain exactly
177    /// `width × height` elements of type `T`; `T` size must match
178    /// `self.bytes_per_element()`.
179    pub fn copy_from_host<T: DeviceRepr>(&self, host: &[T]) -> Result<()> {
180        assert_eq!(
181            size_of::<T>(),
182            self.bytes_per_element(),
183            "host element type must match array texel size",
184        );
185        let h = self.inner.height.max(1);
186        assert_eq!(host.len(), self.inner.width * h);
187        let d = driver()?;
188        let cu = d.cu_memcpy_2d()?;
189        let p = baracuda_cuda_sys::types::CUDA_MEMCPY2D {
190            src_memory_type: baracuda_cuda_sys::types::CUmemorytype::HOST,
191            src_host: host.as_ptr() as *const c_void,
192            src_pitch: self.inner.width * self.bytes_per_element(),
193            dst_memory_type: baracuda_cuda_sys::types::CUmemorytype::ARRAY,
194            dst_array: self.inner.handle,
195            width_in_bytes: self.inner.width * self.bytes_per_element(),
196            height: h,
197            ..Default::default()
198        };
199        check(unsafe { cu(&p) })
200    }
201
202    /// Query this array's descriptor back from CUDA. Useful for arrays
203    /// you received from an external source and don't have shape info for.
204    pub fn descriptor(&self) -> Result<CUDA_ARRAY_DESCRIPTOR> {
205        let d = driver()?;
206        let cu = d.cu_array_get_descriptor()?;
207        let mut desc = CUDA_ARRAY_DESCRIPTOR::default();
208        check(unsafe { cu(&mut desc, self.inner.handle) })?;
209        Ok(desc)
210    }
211
212    /// Query the array's memory-allocation size + alignment requirements
213    /// on `device`. Useful when backing an array with a VMM allocation.
214    pub fn memory_requirements(
215        &self,
216        device: &crate::Device,
217    ) -> Result<baracuda_cuda_sys::types::CUDA_ARRAY_MEMORY_REQUIREMENTS> {
218        let d = driver()?;
219        let cu = d.cu_array_get_memory_requirements()?;
220        let mut req = baracuda_cuda_sys::types::CUDA_ARRAY_MEMORY_REQUIREMENTS::default();
221        check(unsafe { cu(&mut req, self.inner.handle, device.as_raw()) })?;
222        Ok(req)
223    }
224
225    /// Query the array's sparse-tile properties. Meaningful on sparse /
226    /// tiled arrays created with the `SPARSE` flag.
227    pub fn sparse_properties(
228        &self,
229    ) -> Result<baracuda_cuda_sys::types::CUDA_ARRAY_SPARSE_PROPERTIES> {
230        let d = driver()?;
231        let cu = d.cu_array_get_sparse_properties()?;
232        let mut sp = baracuda_cuda_sys::types::CUDA_ARRAY_SPARSE_PROPERTIES::default();
233        check(unsafe { cu(&mut sp, self.inner.handle) })?;
234        Ok(sp)
235    }
236
237    /// Return the raw `CUarray` handle of plane `plane_idx` of a
238    /// multi-planar array (YUV / NV12). The plane is owned by `self` —
239    /// the raw handle must NOT be passed to `cuArrayDestroy`. Use
240    /// together with [`Array::descriptor_of_raw`] if you need shape info.
241    pub fn plane_raw(&self, plane_idx: u32) -> Result<CUarray> {
242        let d = driver()?;
243        let cu = d.cu_array_get_plane()?;
244        let mut out: CUarray = core::ptr::null_mut();
245        check(unsafe { cu(&mut out, self.inner.handle, plane_idx) })?;
246        Ok(out)
247    }
248
249    /// Helper: query the `CUDA_ARRAY_DESCRIPTOR` of a raw array handle
250    /// (e.g. a plane returned by [`Array::plane_raw`]).
251    ///
252    /// # Safety
253    ///
254    /// `handle` must be a live `CUarray`.
255    pub unsafe fn descriptor_of_raw(handle: CUarray) -> Result<CUDA_ARRAY_DESCRIPTOR> { unsafe {
256        let d = driver()?;
257        let cu = d.cu_array_get_descriptor()?;
258        let mut desc = CUDA_ARRAY_DESCRIPTOR::default();
259        check(cu(&mut desc, handle))?;
260        Ok(desc)
261    }}
262
263    /// Synchronous array→host 2-D copy.
264    pub fn copy_to_host<T: DeviceRepr>(&self, host: &mut [T]) -> Result<()> {
265        assert_eq!(
266            size_of::<T>(),
267            self.bytes_per_element(),
268            "host element type must match array texel size",
269        );
270        let h = self.inner.height.max(1);
271        assert_eq!(host.len(), self.inner.width * h);
272        let d = driver()?;
273        let cu = d.cu_memcpy_2d()?;
274        let p = baracuda_cuda_sys::types::CUDA_MEMCPY2D {
275            src_memory_type: baracuda_cuda_sys::types::CUmemorytype::ARRAY,
276            src_array: self.inner.handle,
277            dst_memory_type: baracuda_cuda_sys::types::CUmemorytype::HOST,
278            dst_host: host.as_mut_ptr() as *mut c_void,
279            dst_pitch: self.inner.width * self.bytes_per_element(),
280            width_in_bytes: self.inner.width * self.bytes_per_element(),
281            height: h,
282            ..Default::default()
283        };
284        check(unsafe { cu(&p) })
285    }
286}
287
288impl Drop for ArrayInner {
289    fn drop(&mut self) {
290        if self.handle.is_null() {
291            return;
292        }
293        if let Ok(d) = driver() {
294            if let Ok(cu) = d.cu_array_destroy() {
295                let _ = unsafe { cu(self.handle) };
296            }
297        }
298    }
299}
300
301/// A texture object — a read-only, filtered view onto a CUDA array.
302pub struct TextureObject {
303    handle: CUtexObject,
304    _array: Array,
305}
306
307impl core::fmt::Debug for TextureObject {
308    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
309        f.debug_struct("TextureObject")
310            .field("handle", &self.handle)
311            .finish_non_exhaustive()
312    }
313}
314
315unsafe impl Send for TextureObject {}
316unsafe impl Sync for TextureObject {}
317
318/// Configuration for [`TextureObject::new`].
319#[derive(Copy, Clone, Debug)]
320pub struct TextureDesc {
321    pub address_mode: [TextureAddressMode; 3],
322    pub filter_mode: TextureFilterMode,
323    pub read_normalized: bool,
324    pub normalized_coords: bool,
325}
326
327impl Default for TextureDesc {
328    fn default() -> Self {
329        Self {
330            address_mode: [TextureAddressMode::Clamp; 3],
331            filter_mode: TextureFilterMode::Point,
332            read_normalized: false,
333            normalized_coords: false,
334        }
335    }
336}
337
338#[derive(Copy, Clone, Debug, Eq, PartialEq)]
339pub enum TextureAddressMode {
340    Wrap,
341    Clamp,
342    Mirror,
343    Border,
344}
345
346#[derive(Copy, Clone, Debug, Eq, PartialEq)]
347pub enum TextureFilterMode {
348    Point,
349    Linear,
350}
351
352impl TextureAddressMode {
353    fn raw(self) -> u32 {
354        use baracuda_cuda_sys::types::CUaddress_mode;
355        match self {
356            TextureAddressMode::Wrap => CUaddress_mode::WRAP,
357            TextureAddressMode::Clamp => CUaddress_mode::CLAMP,
358            TextureAddressMode::Mirror => CUaddress_mode::MIRROR,
359            TextureAddressMode::Border => CUaddress_mode::BORDER,
360        }
361    }
362}
363
364impl TextureFilterMode {
365    fn raw(self) -> u32 {
366        use baracuda_cuda_sys::types::CUfilter_mode;
367        match self {
368            TextureFilterMode::Point => CUfilter_mode::POINT,
369            TextureFilterMode::Linear => CUfilter_mode::LINEAR,
370        }
371    }
372}
373
374impl TextureObject {
375    /// Create a texture object that reads from `array`. Uses point filtering
376    /// and clamp addressing by default; override with [`TextureObject::with_desc`].
377    pub fn new(array: &Array) -> Result<Self> {
378        Self::with_desc(array, TextureDesc::default())
379    }
380
381    pub fn with_desc(array: &Array, desc: TextureDesc) -> Result<Self> {
382        let d = driver()?;
383        let cu = d.cu_tex_object_create()?;
384        let res_desc = CUDA_RESOURCE_DESC::from_array(array.as_raw());
385        let mut flags: core::ffi::c_uint = 0;
386        const CU_TRSF_READ_AS_INTEGER: core::ffi::c_uint = 0x01;
387        const CU_TRSF_NORMALIZED_COORDINATES: core::ffi::c_uint = 0x02;
388        if !desc.read_normalized {
389            flags |= CU_TRSF_READ_AS_INTEGER;
390        }
391        if desc.normalized_coords {
392            flags |= CU_TRSF_NORMALIZED_COORDINATES;
393        }
394        let tex_desc = CUDA_TEXTURE_DESC {
395            address_mode: [
396                desc.address_mode[0].raw(),
397                desc.address_mode[1].raw(),
398                desc.address_mode[2].raw(),
399            ],
400            filter_mode: desc.filter_mode.raw(),
401            flags,
402            ..Default::default()
403        };
404        let mut handle: CUtexObject = 0;
405        check(unsafe { cu(&mut handle, &res_desc, &tex_desc, core::ptr::null()) })?;
406        Ok(Self {
407            handle,
408            _array: array.clone(),
409        })
410    }
411
412    #[inline]
413    pub fn as_raw(&self) -> CUtexObject {
414        self.handle
415    }
416}
417
418impl Drop for TextureObject {
419    fn drop(&mut self) {
420        if self.handle == 0 {
421            return;
422        }
423        if let Ok(d) = driver() {
424            if let Ok(cu) = d.cu_tex_object_destroy() {
425                let _ = unsafe { cu(self.handle) };
426            }
427        }
428    }
429}
430
431/// A surface object — a read/write view onto a CUDA array.
432pub struct SurfaceObject {
433    handle: CUsurfObject,
434    _array: Array,
435}
436
437impl core::fmt::Debug for SurfaceObject {
438    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
439        f.debug_struct("SurfaceObject")
440            .field("handle", &self.handle)
441            .finish_non_exhaustive()
442    }
443}
444
445unsafe impl Send for SurfaceObject {}
446unsafe impl Sync for SurfaceObject {}
447
448impl SurfaceObject {
449    pub fn new(array: &Array) -> Result<Self> {
450        let d = driver()?;
451        let cu = d.cu_surf_object_create()?;
452        let res_desc = CUDA_RESOURCE_DESC::from_array(array.as_raw());
453        let mut handle: CUsurfObject = 0;
454        check(unsafe { cu(&mut handle, &res_desc) })?;
455        Ok(Self {
456            handle,
457            _array: array.clone(),
458        })
459    }
460
461    #[inline]
462    pub fn as_raw(&self) -> CUsurfObject {
463        self.handle
464    }
465}
466
467impl Drop for SurfaceObject {
468    fn drop(&mut self) {
469        if self.handle == 0 {
470            return;
471        }
472        if let Ok(d) = driver() {
473            if let Ok(cu) = d.cu_surf_object_destroy() {
474                let _ = unsafe { cu(self.handle) };
475            }
476        }
477    }
478}