1use core::ffi::c_void;
13use core::mem::size_of;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::types::{
17 CUarray_format, CUDA_ARRAY_DESCRIPTOR, CUDA_RESOURCE_DESC, CUDA_TEXTURE_DESC,
18};
19use baracuda_cuda_sys::{driver, CUarray, CUsurfObject, CUtexObject};
20use baracuda_types::DeviceRepr;
21
22use crate::context::Context;
23use crate::error::{check, Result};
24
25pub struct Array {
28 inner: Arc<ArrayInner>,
29}
30
31struct ArrayInner {
32 handle: CUarray,
33 width: usize,
34 height: usize,
35 format: u32,
36 num_channels: u32,
37 #[allow(dead_code)]
38 context: Context,
39}
40
41unsafe impl Send for ArrayInner {}
42unsafe impl Sync for ArrayInner {}
43
44impl core::fmt::Debug for ArrayInner {
45 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46 f.debug_struct("Array")
47 .field("width", &self.width)
48 .field("height", &self.height)
49 .field("format", &self.format)
50 .field("channels", &self.num_channels)
51 .finish_non_exhaustive()
52 }
53}
54
55impl core::fmt::Debug for Array {
56 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57 self.inner.fmt(f)
58 }
59}
60
61impl Clone for Array {
62 fn clone(&self) -> Self {
63 Self {
64 inner: self.inner.clone(),
65 }
66 }
67}
68
69#[derive(Copy, Clone, Debug, Eq, PartialEq)]
73pub enum ArrayFormat {
74 U8,
75 U16,
76 U32,
77 I8,
78 I16,
79 I32,
80 F16,
81 F32,
82}
83
84impl ArrayFormat {
85 #[inline]
86 fn raw(self) -> u32 {
87 match self {
88 ArrayFormat::U8 => CUarray_format::UNSIGNED_INT8,
89 ArrayFormat::U16 => CUarray_format::UNSIGNED_INT16,
90 ArrayFormat::U32 => CUarray_format::UNSIGNED_INT32,
91 ArrayFormat::I8 => CUarray_format::SIGNED_INT8,
92 ArrayFormat::I16 => CUarray_format::SIGNED_INT16,
93 ArrayFormat::I32 => CUarray_format::SIGNED_INT32,
94 ArrayFormat::F16 => CUarray_format::HALF,
95 ArrayFormat::F32 => CUarray_format::FLOAT,
96 }
97 }
98
99 #[inline]
101 pub fn bytes_per_channel(self) -> usize {
102 match self {
103 ArrayFormat::U8 | ArrayFormat::I8 => 1,
104 ArrayFormat::U16 | ArrayFormat::I16 | ArrayFormat::F16 => 2,
105 ArrayFormat::U32 | ArrayFormat::I32 | ArrayFormat::F32 => 4,
106 }
107 }
108}
109
110impl Array {
111 pub fn new(
114 context: &Context,
115 width: usize,
116 height: usize,
117 format: ArrayFormat,
118 num_channels: u32,
119 ) -> Result<Self> {
120 assert!(
121 matches!(num_channels, 1 | 2 | 4),
122 "CUDA arrays require 1, 2, or 4 channels (got {num_channels})",
123 );
124 context.set_current()?;
125 let d = driver()?;
126 let cu = d.cu_array_create()?;
127 let desc = CUDA_ARRAY_DESCRIPTOR {
128 width,
129 height,
130 format: format.raw(),
131 num_channels,
132 };
133 let mut handle: CUarray = core::ptr::null_mut();
134 check(unsafe { cu(&mut handle, &desc) })?;
135 Ok(Self {
136 inner: Arc::new(ArrayInner {
137 handle,
138 width,
139 height,
140 format: format.raw(),
141 num_channels,
142 context: context.clone(),
143 }),
144 })
145 }
146
147 #[inline]
149 pub fn as_raw(&self) -> CUarray {
150 self.inner.handle
151 }
152 #[inline]
153 pub fn width(&self) -> usize {
154 self.inner.width
155 }
156 #[inline]
157 pub fn height(&self) -> usize {
158 self.inner.height
159 }
160 #[inline]
161 pub fn num_channels(&self) -> u32 {
162 self.inner.num_channels
163 }
164 pub fn bytes_per_element(&self) -> usize {
166 let ch_size = match self.inner.format {
167 CUarray_format::UNSIGNED_INT8 | CUarray_format::SIGNED_INT8 => 1,
168 CUarray_format::UNSIGNED_INT16
169 | CUarray_format::SIGNED_INT16
170 | CUarray_format::HALF => 2,
171 _ => 4,
172 };
173 ch_size * (self.inner.num_channels as usize)
174 }
175
176 pub fn copy_from_host<T: DeviceRepr>(&self, host: &[T]) -> Result<()> {
180 assert_eq!(
181 size_of::<T>(),
182 self.bytes_per_element(),
183 "host element type must match array texel size",
184 );
185 let h = self.inner.height.max(1);
186 assert_eq!(host.len(), self.inner.width * h);
187 let d = driver()?;
188 let cu = d.cu_memcpy_2d()?;
189 let p = baracuda_cuda_sys::types::CUDA_MEMCPY2D {
190 src_memory_type: baracuda_cuda_sys::types::CUmemorytype::HOST,
191 src_host: host.as_ptr() as *const c_void,
192 src_pitch: self.inner.width * self.bytes_per_element(),
193 dst_memory_type: baracuda_cuda_sys::types::CUmemorytype::ARRAY,
194 dst_array: self.inner.handle,
195 width_in_bytes: self.inner.width * self.bytes_per_element(),
196 height: h,
197 ..Default::default()
198 };
199 check(unsafe { cu(&p) })
200 }
201
202 pub fn descriptor(&self) -> Result<CUDA_ARRAY_DESCRIPTOR> {
205 let d = driver()?;
206 let cu = d.cu_array_get_descriptor()?;
207 let mut desc = CUDA_ARRAY_DESCRIPTOR::default();
208 check(unsafe { cu(&mut desc, self.inner.handle) })?;
209 Ok(desc)
210 }
211
212 pub fn memory_requirements(
215 &self,
216 device: &crate::Device,
217 ) -> Result<baracuda_cuda_sys::types::CUDA_ARRAY_MEMORY_REQUIREMENTS> {
218 let d = driver()?;
219 let cu = d.cu_array_get_memory_requirements()?;
220 let mut req = baracuda_cuda_sys::types::CUDA_ARRAY_MEMORY_REQUIREMENTS::default();
221 check(unsafe { cu(&mut req, self.inner.handle, device.as_raw()) })?;
222 Ok(req)
223 }
224
225 pub fn sparse_properties(
228 &self,
229 ) -> Result<baracuda_cuda_sys::types::CUDA_ARRAY_SPARSE_PROPERTIES> {
230 let d = driver()?;
231 let cu = d.cu_array_get_sparse_properties()?;
232 let mut sp = baracuda_cuda_sys::types::CUDA_ARRAY_SPARSE_PROPERTIES::default();
233 check(unsafe { cu(&mut sp, self.inner.handle) })?;
234 Ok(sp)
235 }
236
237 pub fn plane_raw(&self, plane_idx: u32) -> Result<CUarray> {
242 let d = driver()?;
243 let cu = d.cu_array_get_plane()?;
244 let mut out: CUarray = core::ptr::null_mut();
245 check(unsafe { cu(&mut out, self.inner.handle, plane_idx) })?;
246 Ok(out)
247 }
248
249 pub unsafe fn descriptor_of_raw(handle: CUarray) -> Result<CUDA_ARRAY_DESCRIPTOR> { unsafe {
256 let d = driver()?;
257 let cu = d.cu_array_get_descriptor()?;
258 let mut desc = CUDA_ARRAY_DESCRIPTOR::default();
259 check(cu(&mut desc, handle))?;
260 Ok(desc)
261 }}
262
263 pub fn copy_to_host<T: DeviceRepr>(&self, host: &mut [T]) -> Result<()> {
265 assert_eq!(
266 size_of::<T>(),
267 self.bytes_per_element(),
268 "host element type must match array texel size",
269 );
270 let h = self.inner.height.max(1);
271 assert_eq!(host.len(), self.inner.width * h);
272 let d = driver()?;
273 let cu = d.cu_memcpy_2d()?;
274 let p = baracuda_cuda_sys::types::CUDA_MEMCPY2D {
275 src_memory_type: baracuda_cuda_sys::types::CUmemorytype::ARRAY,
276 src_array: self.inner.handle,
277 dst_memory_type: baracuda_cuda_sys::types::CUmemorytype::HOST,
278 dst_host: host.as_mut_ptr() as *mut c_void,
279 dst_pitch: self.inner.width * self.bytes_per_element(),
280 width_in_bytes: self.inner.width * self.bytes_per_element(),
281 height: h,
282 ..Default::default()
283 };
284 check(unsafe { cu(&p) })
285 }
286}
287
288impl Drop for ArrayInner {
289 fn drop(&mut self) {
290 if self.handle.is_null() {
291 return;
292 }
293 if let Ok(d) = driver() {
294 if let Ok(cu) = d.cu_array_destroy() {
295 let _ = unsafe { cu(self.handle) };
296 }
297 }
298 }
299}
300
301pub struct TextureObject {
303 handle: CUtexObject,
304 _array: Array,
305}
306
307impl core::fmt::Debug for TextureObject {
308 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
309 f.debug_struct("TextureObject")
310 .field("handle", &self.handle)
311 .finish_non_exhaustive()
312 }
313}
314
315unsafe impl Send for TextureObject {}
316unsafe impl Sync for TextureObject {}
317
318#[derive(Copy, Clone, Debug)]
320pub struct TextureDesc {
321 pub address_mode: [TextureAddressMode; 3],
322 pub filter_mode: TextureFilterMode,
323 pub read_normalized: bool,
324 pub normalized_coords: bool,
325}
326
327impl Default for TextureDesc {
328 fn default() -> Self {
329 Self {
330 address_mode: [TextureAddressMode::Clamp; 3],
331 filter_mode: TextureFilterMode::Point,
332 read_normalized: false,
333 normalized_coords: false,
334 }
335 }
336}
337
338#[derive(Copy, Clone, Debug, Eq, PartialEq)]
339pub enum TextureAddressMode {
340 Wrap,
341 Clamp,
342 Mirror,
343 Border,
344}
345
346#[derive(Copy, Clone, Debug, Eq, PartialEq)]
347pub enum TextureFilterMode {
348 Point,
349 Linear,
350}
351
352impl TextureAddressMode {
353 fn raw(self) -> u32 {
354 use baracuda_cuda_sys::types::CUaddress_mode;
355 match self {
356 TextureAddressMode::Wrap => CUaddress_mode::WRAP,
357 TextureAddressMode::Clamp => CUaddress_mode::CLAMP,
358 TextureAddressMode::Mirror => CUaddress_mode::MIRROR,
359 TextureAddressMode::Border => CUaddress_mode::BORDER,
360 }
361 }
362}
363
364impl TextureFilterMode {
365 fn raw(self) -> u32 {
366 use baracuda_cuda_sys::types::CUfilter_mode;
367 match self {
368 TextureFilterMode::Point => CUfilter_mode::POINT,
369 TextureFilterMode::Linear => CUfilter_mode::LINEAR,
370 }
371 }
372}
373
374impl TextureObject {
375 pub fn new(array: &Array) -> Result<Self> {
378 Self::with_desc(array, TextureDesc::default())
379 }
380
381 pub fn with_desc(array: &Array, desc: TextureDesc) -> Result<Self> {
382 let d = driver()?;
383 let cu = d.cu_tex_object_create()?;
384 let res_desc = CUDA_RESOURCE_DESC::from_array(array.as_raw());
385 let mut flags: core::ffi::c_uint = 0;
386 const CU_TRSF_READ_AS_INTEGER: core::ffi::c_uint = 0x01;
387 const CU_TRSF_NORMALIZED_COORDINATES: core::ffi::c_uint = 0x02;
388 if !desc.read_normalized {
389 flags |= CU_TRSF_READ_AS_INTEGER;
390 }
391 if desc.normalized_coords {
392 flags |= CU_TRSF_NORMALIZED_COORDINATES;
393 }
394 let tex_desc = CUDA_TEXTURE_DESC {
395 address_mode: [
396 desc.address_mode[0].raw(),
397 desc.address_mode[1].raw(),
398 desc.address_mode[2].raw(),
399 ],
400 filter_mode: desc.filter_mode.raw(),
401 flags,
402 ..Default::default()
403 };
404 let mut handle: CUtexObject = 0;
405 check(unsafe { cu(&mut handle, &res_desc, &tex_desc, core::ptr::null()) })?;
406 Ok(Self {
407 handle,
408 _array: array.clone(),
409 })
410 }
411
412 #[inline]
413 pub fn as_raw(&self) -> CUtexObject {
414 self.handle
415 }
416}
417
418impl Drop for TextureObject {
419 fn drop(&mut self) {
420 if self.handle == 0 {
421 return;
422 }
423 if let Ok(d) = driver() {
424 if let Ok(cu) = d.cu_tex_object_destroy() {
425 let _ = unsafe { cu(self.handle) };
426 }
427 }
428 }
429}
430
431pub struct SurfaceObject {
433 handle: CUsurfObject,
434 _array: Array,
435}
436
437impl core::fmt::Debug for SurfaceObject {
438 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
439 f.debug_struct("SurfaceObject")
440 .field("handle", &self.handle)
441 .finish_non_exhaustive()
442 }
443}
444
445unsafe impl Send for SurfaceObject {}
446unsafe impl Sync for SurfaceObject {}
447
448impl SurfaceObject {
449 pub fn new(array: &Array) -> Result<Self> {
450 let d = driver()?;
451 let cu = d.cu_surf_object_create()?;
452 let res_desc = CUDA_RESOURCE_DESC::from_array(array.as_raw());
453 let mut handle: CUsurfObject = 0;
454 check(unsafe { cu(&mut handle, &res_desc) })?;
455 Ok(Self {
456 handle,
457 _array: array.clone(),
458 })
459 }
460
461 #[inline]
462 pub fn as_raw(&self) -> CUsurfObject {
463 self.handle
464 }
465}
466
467impl Drop for SurfaceObject {
468 fn drop(&mut self) {
469 if self.handle == 0 {
470 return;
471 }
472 if let Ok(d) = driver() {
473 if let Ok(cu) = d.cu_surf_object_destroy() {
474 let _ = unsafe { cu(self.handle) };
475 }
476 }
477 }
478}