baracuda_driver/
memcpy2d.rs1use core::ffi::c_void;
9use core::mem::size_of;
10
11use baracuda_cuda_sys::driver;
12use baracuda_cuda_sys::types::{CUmemorytype, CUDA_MEMCPY2D};
13use baracuda_cuda_sys::CUdeviceptr;
14use baracuda_types::DeviceRepr;
15
16use crate::context::Context;
17use crate::error::{check, Result};
18use crate::stream::Stream;
19
20pub struct PitchedBuffer<T: DeviceRepr> {
24 ptr: CUdeviceptr,
25 pitch_bytes: usize,
26 width_elems: usize,
27 height: usize,
28 context: Context,
29 _marker: core::marker::PhantomData<T>,
30}
31
32unsafe impl<T: DeviceRepr + Send> Send for PitchedBuffer<T> {}
33
34impl<T: DeviceRepr> core::fmt::Debug for PitchedBuffer<T> {
35 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36 f.debug_struct("PitchedBuffer")
37 .field("ptr", &format_args!("{:#x}", self.ptr.0))
38 .field("width_elems", &self.width_elems)
39 .field("height", &self.height)
40 .field("pitch_bytes", &self.pitch_bytes)
41 .field("type", &core::any::type_name::<T>())
42 .finish()
43 }
44}
45
46impl<T: DeviceRepr> PitchedBuffer<T> {
47 pub fn new(context: &Context, width_elems: usize, height: usize) -> Result<Self> {
50 context.set_current()?;
51 let d = driver()?;
52 let cu = d.cu_mem_alloc_pitch()?;
53 let mut ptr = CUdeviceptr(0);
54 let mut pitch: usize = 0;
55 let width_bytes = width_elems
56 .checked_mul(size_of::<T>())
57 .expect("overflow in 2D allocation width");
58 check(unsafe {
59 cu(
60 &mut ptr,
61 &mut pitch,
62 width_bytes,
63 height,
64 size_of::<T>() as core::ffi::c_uint,
65 )
66 })?;
67 Ok(Self {
68 ptr,
69 pitch_bytes: pitch,
70 width_elems,
71 height,
72 context: context.clone(),
73 _marker: core::marker::PhantomData,
74 })
75 }
76
77 #[inline]
78 pub fn width_elems(&self) -> usize {
79 self.width_elems
80 }
81 #[inline]
82 pub fn height(&self) -> usize {
83 self.height
84 }
85 #[inline]
87 pub fn pitch_bytes(&self) -> usize {
88 self.pitch_bytes
89 }
90 #[inline]
91 pub fn as_raw(&self) -> CUdeviceptr {
92 self.ptr
93 }
94 #[inline]
95 pub fn context(&self) -> &Context {
96 &self.context
97 }
98}
99
100impl<T: DeviceRepr> Drop for PitchedBuffer<T> {
101 fn drop(&mut self) {
102 if self.ptr.0 == 0 {
103 return;
104 }
105 if let Ok(d) = driver() {
106 if let Ok(cu) = d.cu_mem_free() {
107 let _ = unsafe { cu(self.ptr) };
108 }
109 }
110 }
111}
112
113pub fn copy_h_to_d_2d<T: DeviceRepr>(
119 src: &[T],
120 src_host_pitch_bytes: usize,
121 dst: &PitchedBuffer<T>,
122 width_elems: usize,
123 height: usize,
124) -> Result<()> {
125 assert!(width_elems <= dst.width_elems);
126 assert!(height <= dst.height);
127 let d = driver()?;
128 let cu = d.cu_memcpy_2d()?;
129 let p = CUDA_MEMCPY2D {
130 src_memory_type: CUmemorytype::HOST,
131 src_host: src.as_ptr() as *const c_void,
132 src_pitch: src_host_pitch_bytes,
133 dst_memory_type: CUmemorytype::DEVICE,
134 dst_device: dst.ptr,
135 dst_pitch: dst.pitch_bytes,
136 width_in_bytes: width_elems * size_of::<T>(),
137 height,
138 ..Default::default()
139 };
140 check(unsafe { cu(&p) })
141}
142
143pub fn copy_d_to_h_2d<T: DeviceRepr>(
145 src: &PitchedBuffer<T>,
146 dst: &mut [T],
147 dst_host_pitch_bytes: usize,
148 width_elems: usize,
149 height: usize,
150) -> Result<()> {
151 assert!(width_elems <= src.width_elems);
152 assert!(height <= src.height);
153 let d = driver()?;
154 let cu = d.cu_memcpy_2d()?;
155 let p = CUDA_MEMCPY2D {
156 src_memory_type: CUmemorytype::DEVICE,
157 src_device: src.ptr,
158 src_pitch: src.pitch_bytes,
159 dst_memory_type: CUmemorytype::HOST,
160 dst_host: dst.as_mut_ptr() as *mut c_void,
161 dst_pitch: dst_host_pitch_bytes,
162 width_in_bytes: width_elems * size_of::<T>(),
163 height,
164 ..Default::default()
165 };
166 check(unsafe { cu(&p) })
167}
168
169pub fn copy_h_to_d_2d_async<T: DeviceRepr>(
171 src: &[T],
172 src_host_pitch_bytes: usize,
173 dst: &PitchedBuffer<T>,
174 width_elems: usize,
175 height: usize,
176 stream: &Stream,
177) -> Result<()> {
178 let d = driver()?;
179 let cu = d.cu_memcpy_2d_async()?;
180 let p = CUDA_MEMCPY2D {
181 src_memory_type: CUmemorytype::HOST,
182 src_host: src.as_ptr() as *const c_void,
183 src_pitch: src_host_pitch_bytes,
184 dst_memory_type: CUmemorytype::DEVICE,
185 dst_device: dst.ptr,
186 dst_pitch: dst.pitch_bytes,
187 width_in_bytes: width_elems * size_of::<T>(),
188 height,
189 ..Default::default()
190 };
191 check(unsafe { cu(&p, stream.as_raw()) })
192}