Skip to main content

baracuda_driver/
memcpy2d.rs

1//! Strided (2-D) memory copies via `cuMemcpy2D`, and pitched device
2//! allocations via `cuMemAllocPitch`.
3//!
4//! 2-D memcpy is CUDA's tool for copying rectangular sub-regions of
5//! images / matrices / tensors where one or both sides have a **pitch**
6//! (row stride) different from the logical row width.
7
8use core::ffi::c_void;
9use core::mem::size_of;
10
11use baracuda_cuda_sys::driver;
12use baracuda_cuda_sys::types::{CUmemorytype, CUDA_MEMCPY2D};
13use baracuda_cuda_sys::CUdeviceptr;
14use baracuda_types::DeviceRepr;
15
16use crate::context::Context;
17use crate::error::{check, Result};
18use crate::stream::Stream;
19
20/// A pitched device allocation — a 2-D `height × width_in_bytes` block
21/// where each row is stored at `pitch` bytes apart (`pitch >= width_in_bytes`).
22/// Pitch is chosen by the driver to satisfy hardware alignment requirements.
23pub struct PitchedBuffer<T: DeviceRepr> {
24    ptr: CUdeviceptr,
25    pitch_bytes: usize,
26    width_elems: usize,
27    height: usize,
28    context: Context,
29    _marker: core::marker::PhantomData<T>,
30}
31
32unsafe impl<T: DeviceRepr + Send> Send for PitchedBuffer<T> {}
33
34impl<T: DeviceRepr> core::fmt::Debug for PitchedBuffer<T> {
35    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36        f.debug_struct("PitchedBuffer")
37            .field("ptr", &format_args!("{:#x}", self.ptr.0))
38            .field("width_elems", &self.width_elems)
39            .field("height", &self.height)
40            .field("pitch_bytes", &self.pitch_bytes)
41            .field("type", &core::any::type_name::<T>())
42            .finish()
43    }
44}
45
46impl<T: DeviceRepr> PitchedBuffer<T> {
47    /// Allocate a `height × width_elems` grid of `T`s with driver-chosen
48    /// pitch. The element size hint steers the alignment.
49    pub fn new(context: &Context, width_elems: usize, height: usize) -> Result<Self> {
50        context.set_current()?;
51        let d = driver()?;
52        let cu = d.cu_mem_alloc_pitch()?;
53        let mut ptr = CUdeviceptr(0);
54        let mut pitch: usize = 0;
55        let width_bytes = width_elems
56            .checked_mul(size_of::<T>())
57            .expect("overflow in 2D allocation width");
58        check(unsafe {
59            cu(
60                &mut ptr,
61                &mut pitch,
62                width_bytes,
63                height,
64                size_of::<T>() as core::ffi::c_uint,
65            )
66        })?;
67        Ok(Self {
68            ptr,
69            pitch_bytes: pitch,
70            width_elems,
71            height,
72            context: context.clone(),
73            _marker: core::marker::PhantomData,
74        })
75    }
76
77    #[inline]
78    pub fn width_elems(&self) -> usize {
79        self.width_elems
80    }
81    #[inline]
82    pub fn height(&self) -> usize {
83        self.height
84    }
85    /// Row stride in bytes as chosen by the driver.
86    #[inline]
87    pub fn pitch_bytes(&self) -> usize {
88        self.pitch_bytes
89    }
90    #[inline]
91    pub fn as_raw(&self) -> CUdeviceptr {
92        self.ptr
93    }
94    #[inline]
95    pub fn context(&self) -> &Context {
96        &self.context
97    }
98}
99
100impl<T: DeviceRepr> Drop for PitchedBuffer<T> {
101    fn drop(&mut self) {
102        if self.ptr.0 == 0 {
103            return;
104        }
105        if let Ok(d) = driver() {
106            if let Ok(cu) = d.cu_mem_free() {
107                let _ = unsafe { cu(self.ptr) };
108            }
109        }
110    }
111}
112
113/// Synchronous `height`-row / `width_elems`-column 2-D copy from a host
114/// slice (`src_host_pitch` bytes between row starts) into a pitched
115/// device buffer.
116///
117/// `src` must hold at least `(height - 1) * src_host_pitch + width_elems * size_of::<T>()` bytes.
118pub fn copy_h_to_d_2d<T: DeviceRepr>(
119    src: &[T],
120    src_host_pitch_bytes: usize,
121    dst: &PitchedBuffer<T>,
122    width_elems: usize,
123    height: usize,
124) -> Result<()> {
125    assert!(width_elems <= dst.width_elems);
126    assert!(height <= dst.height);
127    let d = driver()?;
128    let cu = d.cu_memcpy_2d()?;
129    let p = CUDA_MEMCPY2D {
130        src_memory_type: CUmemorytype::HOST,
131        src_host: src.as_ptr() as *const c_void,
132        src_pitch: src_host_pitch_bytes,
133        dst_memory_type: CUmemorytype::DEVICE,
134        dst_device: dst.ptr,
135        dst_pitch: dst.pitch_bytes,
136        width_in_bytes: width_elems * size_of::<T>(),
137        height,
138        ..Default::default()
139    };
140    check(unsafe { cu(&p) })
141}
142
143/// Synchronous 2-D copy from a pitched device buffer back into a host slice.
144pub fn copy_d_to_h_2d<T: DeviceRepr>(
145    src: &PitchedBuffer<T>,
146    dst: &mut [T],
147    dst_host_pitch_bytes: usize,
148    width_elems: usize,
149    height: usize,
150) -> Result<()> {
151    assert!(width_elems <= src.width_elems);
152    assert!(height <= src.height);
153    let d = driver()?;
154    let cu = d.cu_memcpy_2d()?;
155    let p = CUDA_MEMCPY2D {
156        src_memory_type: CUmemorytype::DEVICE,
157        src_device: src.ptr,
158        src_pitch: src.pitch_bytes,
159        dst_memory_type: CUmemorytype::HOST,
160        dst_host: dst.as_mut_ptr() as *mut c_void,
161        dst_pitch: dst_host_pitch_bytes,
162        width_in_bytes: width_elems * size_of::<T>(),
163        height,
164        ..Default::default()
165    };
166    check(unsafe { cu(&p) })
167}
168
169/// Asynchronous variant of [`copy_h_to_d_2d`] — issues on the given stream.
170pub fn copy_h_to_d_2d_async<T: DeviceRepr>(
171    src: &[T],
172    src_host_pitch_bytes: usize,
173    dst: &PitchedBuffer<T>,
174    width_elems: usize,
175    height: usize,
176    stream: &Stream,
177) -> Result<()> {
178    let d = driver()?;
179    let cu = d.cu_memcpy_2d_async()?;
180    let p = CUDA_MEMCPY2D {
181        src_memory_type: CUmemorytype::HOST,
182        src_host: src.as_ptr() as *const c_void,
183        src_pitch: src_host_pitch_bytes,
184        dst_memory_type: CUmemorytype::DEVICE,
185        dst_device: dst.ptr,
186        dst_pitch: dst.pitch_bytes,
187        width_in_bytes: width_elems * size_of::<T>(),
188        height,
189        ..Default::default()
190    };
191    check(unsafe { cu(&p, stream.as_raw()) })
192}