1use core::ffi::c_void;
13use core::mem::size_of;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::types::{
17 CUarrayMapInfo, CUmemorytype, CUDA_ARRAY3D_DESCRIPTOR, CUDA_MEMCPY3D,
18};
19use baracuda_cuda_sys::{driver, CUarray, CUmipmappedArray};
20use baracuda_types::DeviceRepr;
21
22use crate::array::ArrayFormat;
23use crate::context::Context;
24use crate::error::{check, Result};
25use crate::stream::Stream;
26
27pub struct Array3D {
29 inner: Arc<Array3DInner>,
30}
31
32struct Array3DInner {
33 handle: CUarray,
34 owned: bool,
35 width: usize,
36 height: usize,
37 depth: usize,
38 format: u32,
39 num_channels: u32,
40 #[allow(dead_code)]
41 context: Context,
42}
43
44unsafe impl Send for Array3DInner {}
45unsafe impl Sync for Array3DInner {}
46
47impl core::fmt::Debug for Array3DInner {
48 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
49 f.debug_struct("Array3D")
50 .field("w", &self.width)
51 .field("h", &self.height)
52 .field("d", &self.depth)
53 .field("channels", &self.num_channels)
54 .field("owned", &self.owned)
55 .finish_non_exhaustive()
56 }
57}
58
59impl core::fmt::Debug for Array3D {
60 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61 self.inner.fmt(f)
62 }
63}
64
65impl Clone for Array3D {
66 fn clone(&self) -> Self {
67 Self {
68 inner: self.inner.clone(),
69 }
70 }
71}
72
73impl Array3D {
74 pub fn new(
76 context: &Context,
77 width: usize,
78 height: usize,
79 depth: usize,
80 format: ArrayFormat,
81 num_channels: u32,
82 ) -> Result<Self> {
83 Self::with_flags(context, width, height, depth, format, num_channels, 0)
84 }
85
86 pub fn with_flags(
89 context: &Context,
90 width: usize,
91 height: usize,
92 depth: usize,
93 format: ArrayFormat,
94 num_channels: u32,
95 flags: u32,
96 ) -> Result<Self> {
97 assert!(
98 matches!(num_channels, 1 | 2 | 4),
99 "CUDA arrays require 1, 2, or 4 channels (got {num_channels})"
100 );
101 context.set_current()?;
102 let d = driver()?;
103 let cu = d.cu_array_3d_create()?;
104 let desc = CUDA_ARRAY3D_DESCRIPTOR {
105 width,
106 height,
107 depth,
108 format: format_raw(format),
109 num_channels,
110 flags,
111 };
112 let mut handle: CUarray = core::ptr::null_mut();
113 check(unsafe { cu(&mut handle, &desc) })?;
114 Ok(Self {
115 inner: Arc::new(Array3DInner {
116 handle,
117 owned: true,
118 width,
119 height,
120 depth,
121 format: format_raw(format),
122 num_channels,
123 context: context.clone(),
124 }),
125 })
126 }
127
128 pub unsafe fn from_borrowed(
135 context: &Context,
136 handle: CUarray,
137 width: usize,
138 height: usize,
139 depth: usize,
140 format: ArrayFormat,
141 num_channels: u32,
142 ) -> Self {
143 Self {
144 inner: Arc::new(Array3DInner {
145 handle,
146 owned: false,
147 width,
148 height,
149 depth,
150 format: format_raw(format),
151 num_channels,
152 context: context.clone(),
153 }),
154 }
155 }
156
157 #[inline]
158 pub fn as_raw(&self) -> CUarray {
159 self.inner.handle
160 }
161 #[inline]
162 pub fn width(&self) -> usize {
163 self.inner.width
164 }
165 #[inline]
166 pub fn height(&self) -> usize {
167 self.inner.height
168 }
169 #[inline]
170 pub fn depth(&self) -> usize {
171 self.inner.depth
172 }
173 pub fn bytes_per_element(&self) -> usize {
175 let ch_size = match self.inner.format {
176 baracuda_cuda_sys::types::CUarray_format::UNSIGNED_INT8
177 | baracuda_cuda_sys::types::CUarray_format::SIGNED_INT8 => 1,
178 baracuda_cuda_sys::types::CUarray_format::UNSIGNED_INT16
179 | baracuda_cuda_sys::types::CUarray_format::SIGNED_INT16
180 | baracuda_cuda_sys::types::CUarray_format::HALF => 2,
181 _ => 4,
182 };
183 ch_size * (self.inner.num_channels as usize)
184 }
185
186 fn slice_count(&self) -> usize {
187 self.inner.height.max(1) * self.inner.depth.max(1)
188 }
189
190 pub fn copy_from_host<T: DeviceRepr>(&self, host: &[T]) -> Result<()> {
193 assert_eq!(
194 size_of::<T>(),
195 self.bytes_per_element(),
196 "host element type must match array texel size"
197 );
198 assert_eq!(host.len(), self.inner.width * self.slice_count());
199 let d = driver()?;
200 let cu = d.cu_memcpy_3d()?;
201 let mut p = CUDA_MEMCPY3D::default();
202 p.src_memory_type = CUmemorytype::HOST;
203 p.src_host = host.as_ptr() as *const c_void;
204 p.src_pitch = self.inner.width * self.bytes_per_element();
205 p.src_height = self.inner.height.max(1);
206 p.dst_memory_type = CUmemorytype::ARRAY;
207 p.dst_array = self.inner.handle;
208 p.width_in_bytes = self.inner.width * self.bytes_per_element();
209 p.height = self.inner.height.max(1);
210 p.depth = self.inner.depth.max(1);
211 check(unsafe { cu(&p) })
212 }
213
214 pub fn copy_to_host<T: DeviceRepr>(&self, host: &mut [T]) -> Result<()> {
216 assert_eq!(
217 size_of::<T>(),
218 self.bytes_per_element(),
219 "host element type must match array texel size"
220 );
221 assert_eq!(host.len(), self.inner.width * self.slice_count());
222 let d = driver()?;
223 let cu = d.cu_memcpy_3d()?;
224 let mut p = CUDA_MEMCPY3D::default();
225 p.src_memory_type = CUmemorytype::ARRAY;
226 p.src_array = self.inner.handle;
227 p.dst_memory_type = CUmemorytype::HOST;
228 p.dst_host = host.as_mut_ptr() as *mut c_void;
229 p.dst_pitch = self.inner.width * self.bytes_per_element();
230 p.dst_height = self.inner.height.max(1);
231 p.width_in_bytes = self.inner.width * self.bytes_per_element();
232 p.height = self.inner.height.max(1);
233 p.depth = self.inner.depth.max(1);
234 check(unsafe { cu(&p) })
235 }
236
237 pub fn copy_from_host_async<T: DeviceRepr>(&self, host: &[T], stream: &Stream) -> Result<()> {
241 assert_eq!(size_of::<T>(), self.bytes_per_element());
242 assert_eq!(host.len(), self.inner.width * self.slice_count());
243 let d = driver()?;
244 let cu = d.cu_memcpy_3d_async()?;
245 let mut p = CUDA_MEMCPY3D::default();
246 p.src_memory_type = CUmemorytype::HOST;
247 p.src_host = host.as_ptr() as *const c_void;
248 p.src_pitch = self.inner.width * self.bytes_per_element();
249 p.src_height = self.inner.height.max(1);
250 p.dst_memory_type = CUmemorytype::ARRAY;
251 p.dst_array = self.inner.handle;
252 p.width_in_bytes = self.inner.width * self.bytes_per_element();
253 p.height = self.inner.height.max(1);
254 p.depth = self.inner.depth.max(1);
255 check(unsafe { cu(&p, stream.as_raw()) })
256 }
257}
258
259impl Drop for Array3DInner {
260 fn drop(&mut self) {
261 if !self.owned || self.handle.is_null() {
262 return;
263 }
264 if let Ok(d) = driver() {
265 if let Ok(cu) = d.cu_array_destroy() {
266 let _ = unsafe { cu(self.handle) };
267 }
268 }
269 }
270}
271
272fn format_raw(format: ArrayFormat) -> u32 {
273 use baracuda_cuda_sys::types::CUarray_format;
274 match format {
275 ArrayFormat::U8 => CUarray_format::UNSIGNED_INT8,
276 ArrayFormat::U16 => CUarray_format::UNSIGNED_INT16,
277 ArrayFormat::U32 => CUarray_format::UNSIGNED_INT32,
278 ArrayFormat::I8 => CUarray_format::SIGNED_INT8,
279 ArrayFormat::I16 => CUarray_format::SIGNED_INT16,
280 ArrayFormat::I32 => CUarray_format::SIGNED_INT32,
281 ArrayFormat::F16 => CUarray_format::HALF,
282 ArrayFormat::F32 => CUarray_format::FLOAT,
283 }
284}
285
286pub struct MipmappedArray {
288 handle: CUmipmappedArray,
289 base_width: usize,
290 base_height: usize,
291 base_depth: usize,
292 num_levels: u32,
293 format: ArrayFormat,
294 num_channels: u32,
295 context: Context,
296}
297
298unsafe impl Send for MipmappedArray {}
299unsafe impl Sync for MipmappedArray {}
300
301impl core::fmt::Debug for MipmappedArray {
302 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
303 f.debug_struct("MipmappedArray")
304 .field("levels", &self.num_levels)
305 .field("base_w", &self.base_width)
306 .field("base_h", &self.base_height)
307 .field("base_d", &self.base_depth)
308 .finish_non_exhaustive()
309 }
310}
311
312impl MipmappedArray {
313 #[allow(clippy::too_many_arguments)]
317 pub fn new(
318 context: &Context,
319 width: usize,
320 height: usize,
321 depth: usize,
322 format: ArrayFormat,
323 num_channels: u32,
324 num_levels: u32,
325 flags: u32,
326 ) -> Result<Self> {
327 assert!(
328 matches!(num_channels, 1 | 2 | 4),
329 "CUDA arrays require 1, 2, or 4 channels (got {num_channels})"
330 );
331 assert!(num_levels >= 1, "mipmap must have at least 1 level");
332 context.set_current()?;
333 let d = driver()?;
334 let cu = d.cu_mipmapped_array_create()?;
335 let desc = CUDA_ARRAY3D_DESCRIPTOR {
336 width,
337 height,
338 depth,
339 format: format_raw(format),
340 num_channels,
341 flags,
342 };
343 let mut handle: CUmipmappedArray = core::ptr::null_mut();
344 check(unsafe { cu(&mut handle, &desc, num_levels) })?;
345 Ok(Self {
346 handle,
347 base_width: width,
348 base_height: height,
349 base_depth: depth,
350 num_levels,
351 format,
352 num_channels,
353 context: context.clone(),
354 })
355 }
356
357 pub fn level(&self, level: u32) -> Result<Array3D> {
360 assert!(
361 level < self.num_levels,
362 "mipmap level {level} out of range (0..{})",
363 self.num_levels
364 );
365 let d = driver()?;
366 let cu = d.cu_mipmapped_array_get_level()?;
367 let mut arr: CUarray = core::ptr::null_mut();
368 check(unsafe { cu(&mut arr, self.handle, level) })?;
369 let shift = level as usize;
370 let w = (self.base_width >> shift).max(1);
371 let h = (self.base_height >> shift).max(1);
372 let depth = (self.base_depth >> shift).max(self.base_depth.min(1));
373 let view = unsafe {
376 Array3D::from_borrowed(
377 &self.context,
378 arr,
379 w,
380 h,
381 depth,
382 self.format,
383 self.num_channels,
384 )
385 };
386 Ok(view)
387 }
388
389 #[inline]
390 pub fn as_raw(&self) -> CUmipmappedArray {
391 self.handle
392 }
393 #[inline]
394 pub fn num_levels(&self) -> u32 {
395 self.num_levels
396 }
397}
398
399impl Drop for MipmappedArray {
400 fn drop(&mut self) {
401 if self.handle.is_null() {
402 return;
403 }
404 if let Ok(d) = driver() {
405 if let Ok(cu) = d.cu_mipmapped_array_destroy() {
406 let _ = unsafe { cu(self.handle) };
407 }
408 }
409 }
410}
411
412pub fn map_array_async(info: &mut [CUarrayMapInfo], stream: &Stream) -> Result<()> {
425 if info.is_empty() {
426 return Ok(());
427 }
428 let d = driver()?;
429 let cu = d.cu_mem_map_array_async()?;
430 check(unsafe {
431 cu(
432 info.as_mut_ptr(),
433 info.len() as core::ffi::c_uint,
434 stream.as_raw(),
435 )
436 })
437}