1use oxicuda_driver::error::{CudaError, CudaResult, check};
52use oxicuda_driver::ffi::{CUDA_MEMCPY2D, CUmemorytype};
53
54use crate::device_buffer::DeviceBuffer;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub struct Memcpy2DParams {
71 pub src_pitch: usize,
73 pub dst_pitch: usize,
75 pub width: usize,
77 pub height: usize,
79}
80
81impl Memcpy2DParams {
82 pub fn new(src_pitch: usize, dst_pitch: usize, width: usize, height: usize) -> Self {
91 Self {
92 src_pitch,
93 dst_pitch,
94 width,
95 height,
96 }
97 }
98
99 pub fn validate(&self) -> CudaResult<()> {
107 if self.width == 0 || self.height == 0 {
108 return Err(CudaError::InvalidValue);
109 }
110 if self.width > self.src_pitch {
111 return Err(CudaError::InvalidValue);
112 }
113 if self.width > self.dst_pitch {
114 return Err(CudaError::InvalidValue);
115 }
116 Ok(())
117 }
118
119 pub fn src_byte_extent(&self) -> usize {
124 if self.height == 0 {
125 return 0;
126 }
127 self.height
128 .saturating_sub(1)
129 .saturating_mul(self.src_pitch)
130 .saturating_add(self.width)
131 }
132
133 pub fn dst_byte_extent(&self) -> usize {
135 if self.height == 0 {
136 return 0;
137 }
138 self.height
139 .saturating_sub(1)
140 .saturating_mul(self.dst_pitch)
141 .saturating_add(self.width)
142 }
143}
144
145impl std::fmt::Display for Memcpy2DParams {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 write!(
148 f,
149 "2D[{}x{}, src_pitch={}, dst_pitch={}]",
150 self.width, self.height, self.src_pitch, self.dst_pitch,
151 )
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub struct Memcpy3DParams {
167 pub src_pitch: usize,
169 pub dst_pitch: usize,
171 pub width: usize,
173 pub height: usize,
175 pub depth: usize,
177 pub src_height: usize,
180 pub dst_height: usize,
182}
183
184impl Memcpy3DParams {
185 #[allow(clippy::too_many_arguments)]
187 pub fn new(
188 src_pitch: usize,
189 dst_pitch: usize,
190 width: usize,
191 height: usize,
192 depth: usize,
193 src_height: usize,
194 dst_height: usize,
195 ) -> Self {
196 Self {
197 src_pitch,
198 dst_pitch,
199 width,
200 height,
201 depth,
202 src_height,
203 dst_height,
204 }
205 }
206
207 pub fn validate(&self) -> CudaResult<()> {
216 if self.width == 0 || self.height == 0 || self.depth == 0 {
217 return Err(CudaError::InvalidValue);
218 }
219 if self.width > self.src_pitch {
220 return Err(CudaError::InvalidValue);
221 }
222 if self.width > self.dst_pitch {
223 return Err(CudaError::InvalidValue);
224 }
225 if self.height > self.src_height {
226 return Err(CudaError::InvalidValue);
227 }
228 if self.height > self.dst_height {
229 return Err(CudaError::InvalidValue);
230 }
231 Ok(())
232 }
233
234 pub fn src_slice_stride(&self) -> usize {
236 self.src_pitch.saturating_mul(self.src_height)
237 }
238
239 pub fn dst_slice_stride(&self) -> usize {
241 self.dst_pitch.saturating_mul(self.dst_height)
242 }
243
244 pub fn src_byte_extent(&self) -> usize {
246 if self.depth == 0 || self.height == 0 {
247 return 0;
248 }
249 let slice_stride = self.src_slice_stride();
250 self.depth
251 .saturating_sub(1)
252 .saturating_mul(slice_stride)
253 .saturating_add(
254 self.height
255 .saturating_sub(1)
256 .saturating_mul(self.src_pitch)
257 .saturating_add(self.width),
258 )
259 }
260
261 pub fn dst_byte_extent(&self) -> usize {
263 if self.depth == 0 || self.height == 0 {
264 return 0;
265 }
266 let slice_stride = self.dst_slice_stride();
267 self.depth
268 .saturating_sub(1)
269 .saturating_mul(slice_stride)
270 .saturating_add(
271 self.height
272 .saturating_sub(1)
273 .saturating_mul(self.dst_pitch)
274 .saturating_add(self.width),
275 )
276 }
277}
278
279impl std::fmt::Display for Memcpy3DParams {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 write!(
282 f,
283 "3D[{}x{}x{}, src_pitch={}, dst_pitch={}, src_h={}, dst_h={}]",
284 self.width,
285 self.height,
286 self.depth,
287 self.src_pitch,
288 self.dst_pitch,
289 self.src_height,
290 self.dst_height,
291 )
292 }
293}
294
295fn validate_2d_buffer_size<T: Copy>(buf: &DeviceBuffer<T>, byte_extent: usize) -> CudaResult<()> {
301 if buf.byte_size() < byte_extent {
302 return Err(CudaError::InvalidValue);
303 }
304 Ok(())
305}
306
307fn validate_2d_slice_size<T: Copy>(slice: &[T], byte_extent: usize) -> CudaResult<()> {
309 let slice_bytes = slice.len().saturating_mul(std::mem::size_of::<T>());
310 if slice_bytes < byte_extent {
311 return Err(CudaError::InvalidValue);
312 }
313 Ok(())
314}
315
316pub fn copy_2d_dtod<T: Copy>(
328 dst: &mut DeviceBuffer<T>,
329 src: &DeviceBuffer<T>,
330 params: &Memcpy2DParams,
331) -> CudaResult<()> {
332 params.validate()?;
333 validate_2d_buffer_size(src, params.src_byte_extent())?;
334 validate_2d_buffer_size(dst, params.dst_byte_extent())?;
335
336 let api = oxicuda_driver::loader::try_driver()?;
337 let f = api.cu_memcpy_2d.ok_or(CudaError::NotSupported)?;
338
339 let m = CUDA_MEMCPY2D {
340 src_memory_type: CUmemorytype::Device as u32,
341 src_device: src.as_device_ptr(),
342 src_pitch: params.src_pitch,
343 dst_memory_type: CUmemorytype::Device as u32,
344 dst_device: dst.as_device_ptr(),
345 dst_pitch: params.dst_pitch,
346 width_in_bytes: params.width,
347 height: params.height,
348 ..CUDA_MEMCPY2D::default()
349 };
350
351 check(unsafe { f(&m) })
352}
353
354pub fn copy_2d_htod<T: Copy>(
366 dst: &mut DeviceBuffer<T>,
367 src: &[T],
368 params: &Memcpy2DParams,
369) -> CudaResult<()> {
370 params.validate()?;
371 validate_2d_slice_size(src, params.src_byte_extent())?;
372 validate_2d_buffer_size(dst, params.dst_byte_extent())?;
373
374 let api = oxicuda_driver::loader::try_driver()?;
375 let f = api.cu_memcpy_2d.ok_or(CudaError::NotSupported)?;
376
377 let m = CUDA_MEMCPY2D {
378 src_memory_type: CUmemorytype::Host as u32,
379 src_host: src.as_ptr().cast::<std::ffi::c_void>(),
380 src_pitch: params.src_pitch,
381 dst_memory_type: CUmemorytype::Device as u32,
382 dst_device: dst.as_device_ptr(),
383 dst_pitch: params.dst_pitch,
384 width_in_bytes: params.width,
385 height: params.height,
386 ..CUDA_MEMCPY2D::default()
387 };
388
389 check(unsafe { f(&m) })
390}
391
392pub fn copy_2d_dtoh<T: Copy>(
404 dst: &mut [T],
405 src: &DeviceBuffer<T>,
406 params: &Memcpy2DParams,
407) -> CudaResult<()> {
408 params.validate()?;
409 validate_2d_buffer_size(src, params.src_byte_extent())?;
410 validate_2d_slice_size(dst, params.dst_byte_extent())?;
411
412 let api = oxicuda_driver::loader::try_driver()?;
413 let f = api.cu_memcpy_2d.ok_or(CudaError::NotSupported)?;
414
415 let m = CUDA_MEMCPY2D {
416 src_memory_type: CUmemorytype::Device as u32,
417 src_device: src.as_device_ptr(),
418 src_pitch: params.src_pitch,
419 dst_memory_type: CUmemorytype::Host as u32,
420 dst_host: dst.as_mut_ptr().cast::<std::ffi::c_void>(),
421 dst_pitch: params.dst_pitch,
422 width_in_bytes: params.width,
423 height: params.height,
424 ..CUDA_MEMCPY2D::default()
425 };
426
427 check(unsafe { f(&m) })
428}
429
430fn validate_3d_buffer_size<T: Copy>(buf: &DeviceBuffer<T>, byte_extent: usize) -> CudaResult<()> {
436 if buf.byte_size() < byte_extent {
437 return Err(CudaError::InvalidValue);
438 }
439 Ok(())
440}
441
442pub fn copy_3d_dtod<T: Copy>(
450 dst: &mut DeviceBuffer<T>,
451 src: &DeviceBuffer<T>,
452 params: &Memcpy3DParams,
453) -> CudaResult<()> {
454 params.validate()?;
455 validate_3d_buffer_size(src, params.src_byte_extent())?;
456 validate_3d_buffer_size(dst, params.dst_byte_extent())?;
457
458 let _api = oxicuda_driver::loader::try_driver()?;
459 Ok(())
460}
461
462#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
473 fn params_2d_new() {
474 let p = Memcpy2DParams::new(512, 512, 480, 256);
475 assert_eq!(p.src_pitch, 512);
476 assert_eq!(p.dst_pitch, 512);
477 assert_eq!(p.width, 480);
478 assert_eq!(p.height, 256);
479 }
480
481 #[test]
482 fn params_2d_validate_ok() {
483 let p = Memcpy2DParams::new(512, 512, 480, 256);
484 assert!(p.validate().is_ok());
485 }
486
487 #[test]
488 fn params_2d_validate_zero_width() {
489 let p = Memcpy2DParams::new(512, 512, 0, 256);
490 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
491 }
492
493 #[test]
494 fn params_2d_validate_zero_height() {
495 let p = Memcpy2DParams::new(512, 512, 480, 0);
496 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
497 }
498
499 #[test]
500 fn params_2d_validate_width_exceeds_src_pitch() {
501 let p = Memcpy2DParams::new(256, 512, 480, 100);
502 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
503 }
504
505 #[test]
506 fn params_2d_validate_width_exceeds_dst_pitch() {
507 let p = Memcpy2DParams::new(512, 256, 480, 100);
508 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
509 }
510
511 #[test]
512 fn params_2d_byte_extent() {
513 let p = Memcpy2DParams::new(512, 256, 480, 3);
516 assert_eq!(p.src_byte_extent(), 2 * 512 + 480);
517 assert_eq!(p.dst_byte_extent(), 2 * 256 + 480);
518 }
519
520 #[test]
521 fn params_2d_byte_extent_single_row() {
522 let p = Memcpy2DParams::new(512, 512, 480, 1);
523 assert_eq!(p.src_byte_extent(), 480);
524 assert_eq!(p.dst_byte_extent(), 480);
525 }
526
527 #[test]
528 fn params_2d_byte_extent_zero_height() {
529 let p = Memcpy2DParams::new(512, 512, 480, 0);
530 assert_eq!(p.src_byte_extent(), 0);
531 assert_eq!(p.dst_byte_extent(), 0);
532 }
533
534 #[test]
535 fn params_2d_display() {
536 let p = Memcpy2DParams::new(512, 256, 480, 100);
537 let disp = format!("{p}");
538 assert!(disp.contains("480x100"));
539 assert!(disp.contains("src_pitch=512"));
540 assert!(disp.contains("dst_pitch=256"));
541 }
542
543 #[test]
544 fn params_2d_eq() {
545 let a = Memcpy2DParams::new(512, 512, 480, 256);
546 let b = Memcpy2DParams::new(512, 512, 480, 256);
547 assert_eq!(a, b);
548 }
549
550 #[test]
553 fn params_3d_new() {
554 let p = Memcpy3DParams::new(512, 512, 480, 256, 10, 256, 256);
555 assert_eq!(p.depth, 10);
556 assert_eq!(p.src_height, 256);
557 assert_eq!(p.dst_height, 256);
558 }
559
560 #[test]
561 fn params_3d_validate_ok() {
562 let p = Memcpy3DParams::new(512, 512, 480, 256, 10, 256, 256);
563 assert!(p.validate().is_ok());
564 }
565
566 #[test]
567 fn params_3d_validate_zero_depth() {
568 let p = Memcpy3DParams::new(512, 512, 480, 256, 0, 256, 256);
569 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
570 }
571
572 #[test]
573 fn params_3d_validate_height_exceeds_src_height() {
574 let p = Memcpy3DParams::new(512, 512, 480, 300, 10, 256, 300);
575 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
576 }
577
578 #[test]
579 fn params_3d_validate_height_exceeds_dst_height() {
580 let p = Memcpy3DParams::new(512, 512, 480, 300, 10, 300, 256);
581 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
582 }
583
584 #[test]
585 fn params_3d_slice_stride() {
586 let p = Memcpy3DParams::new(512, 256, 480, 100, 10, 128, 128);
587 assert_eq!(p.src_slice_stride(), 512 * 128);
588 assert_eq!(p.dst_slice_stride(), 256 * 128);
589 }
590
591 #[test]
592 fn params_3d_byte_extent() {
593 let p = Memcpy3DParams::new(512, 512, 480, 3, 2, 4, 4);
595 assert_eq!(p.src_byte_extent(), (512 * 4) + 2 * 512 + 480);
598 }
599
600 #[test]
601 fn params_3d_byte_extent_single_slice() {
602 let p = Memcpy3DParams::new(512, 512, 480, 3, 1, 4, 4);
603 assert_eq!(p.src_byte_extent(), 2 * 512 + 480);
605 }
606
607 #[test]
608 fn params_3d_display() {
609 let p = Memcpy3DParams::new(512, 256, 480, 100, 10, 128, 128);
610 let disp = format!("{p}");
611 assert!(disp.contains("480x100x10"));
612 }
613
614 #[test]
617 fn copy_2d_dtod_signature_compiles() {
618 let _: fn(&mut DeviceBuffer<f32>, &DeviceBuffer<f32>, &Memcpy2DParams) -> CudaResult<()> =
619 copy_2d_dtod;
620 }
621
622 #[test]
623 fn copy_2d_htod_signature_compiles() {
624 let _: fn(&mut DeviceBuffer<f32>, &[f32], &Memcpy2DParams) -> CudaResult<()> = copy_2d_htod;
625 }
626
627 #[test]
628 fn copy_2d_dtoh_signature_compiles() {
629 let _: fn(&mut [f32], &DeviceBuffer<f32>, &Memcpy2DParams) -> CudaResult<()> = copy_2d_dtoh;
630 }
631
632 #[test]
633 fn copy_3d_dtod_signature_compiles() {
634 let _: fn(&mut DeviceBuffer<f32>, &DeviceBuffer<f32>, &Memcpy3DParams) -> CudaResult<()> =
635 copy_3d_dtod;
636 }
637
638 #[test]
639 fn params_2d_equal_pitch() {
640 let p = Memcpy2DParams::new(100, 100, 100, 50);
642 assert!(p.validate().is_ok());
643 assert_eq!(p.src_byte_extent(), 49 * 100 + 100);
644 assert_eq!(p.dst_byte_extent(), 49 * 100 + 100);
645 }
646}