use oxicuda_driver::error::{CudaError, CudaResult};
use crate::device_buffer::DeviceBuffer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Memcpy2DParams {
pub src_pitch: usize,
pub dst_pitch: usize,
pub width: usize,
pub height: usize,
}
impl Memcpy2DParams {
pub fn new(src_pitch: usize, dst_pitch: usize, width: usize, height: usize) -> Self {
Self {
src_pitch,
dst_pitch,
width,
height,
}
}
pub fn validate(&self) -> CudaResult<()> {
if self.width == 0 || self.height == 0 {
return Err(CudaError::InvalidValue);
}
if self.width > self.src_pitch {
return Err(CudaError::InvalidValue);
}
if self.width > self.dst_pitch {
return Err(CudaError::InvalidValue);
}
Ok(())
}
pub fn src_byte_extent(&self) -> usize {
if self.height == 0 {
return 0;
}
self.height
.saturating_sub(1)
.saturating_mul(self.src_pitch)
.saturating_add(self.width)
}
pub fn dst_byte_extent(&self) -> usize {
if self.height == 0 {
return 0;
}
self.height
.saturating_sub(1)
.saturating_mul(self.dst_pitch)
.saturating_add(self.width)
}
}
impl std::fmt::Display for Memcpy2DParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"2D[{}x{}, src_pitch={}, dst_pitch={}]",
self.width, self.height, self.src_pitch, self.dst_pitch,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Memcpy3DParams {
pub src_pitch: usize,
pub dst_pitch: usize,
pub width: usize,
pub height: usize,
pub depth: usize,
pub src_height: usize,
pub dst_height: usize,
}
impl Memcpy3DParams {
#[allow(clippy::too_many_arguments)]
pub fn new(
src_pitch: usize,
dst_pitch: usize,
width: usize,
height: usize,
depth: usize,
src_height: usize,
dst_height: usize,
) -> Self {
Self {
src_pitch,
dst_pitch,
width,
height,
depth,
src_height,
dst_height,
}
}
pub fn validate(&self) -> CudaResult<()> {
if self.width == 0 || self.height == 0 || self.depth == 0 {
return Err(CudaError::InvalidValue);
}
if self.width > self.src_pitch {
return Err(CudaError::InvalidValue);
}
if self.width > self.dst_pitch {
return Err(CudaError::InvalidValue);
}
if self.height > self.src_height {
return Err(CudaError::InvalidValue);
}
if self.height > self.dst_height {
return Err(CudaError::InvalidValue);
}
Ok(())
}
pub fn src_slice_stride(&self) -> usize {
self.src_pitch.saturating_mul(self.src_height)
}
pub fn dst_slice_stride(&self) -> usize {
self.dst_pitch.saturating_mul(self.dst_height)
}
pub fn src_byte_extent(&self) -> usize {
if self.depth == 0 || self.height == 0 {
return 0;
}
let slice_stride = self.src_slice_stride();
self.depth
.saturating_sub(1)
.saturating_mul(slice_stride)
.saturating_add(
self.height
.saturating_sub(1)
.saturating_mul(self.src_pitch)
.saturating_add(self.width),
)
}
pub fn dst_byte_extent(&self) -> usize {
if self.depth == 0 || self.height == 0 {
return 0;
}
let slice_stride = self.dst_slice_stride();
self.depth
.saturating_sub(1)
.saturating_mul(slice_stride)
.saturating_add(
self.height
.saturating_sub(1)
.saturating_mul(self.dst_pitch)
.saturating_add(self.width),
)
}
}
impl std::fmt::Display for Memcpy3DParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"3D[{}x{}x{}, src_pitch={}, dst_pitch={}, src_h={}, dst_h={}]",
self.width,
self.height,
self.depth,
self.src_pitch,
self.dst_pitch,
self.src_height,
self.dst_height,
)
}
}
fn validate_2d_buffer_size<T: Copy>(buf: &DeviceBuffer<T>, byte_extent: usize) -> CudaResult<()> {
if buf.byte_size() < byte_extent {
return Err(CudaError::InvalidValue);
}
Ok(())
}
fn validate_2d_slice_size<T: Copy>(slice: &[T], byte_extent: usize) -> CudaResult<()> {
let slice_bytes = slice.len().saturating_mul(std::mem::size_of::<T>());
if slice_bytes < byte_extent {
return Err(CudaError::InvalidValue);
}
Ok(())
}
pub fn copy_2d_dtod<T: Copy>(
dst: &mut DeviceBuffer<T>,
src: &DeviceBuffer<T>,
params: &Memcpy2DParams,
) -> CudaResult<()> {
params.validate()?;
validate_2d_buffer_size(src, params.src_byte_extent())?;
validate_2d_buffer_size(dst, params.dst_byte_extent())?;
let _api = oxicuda_driver::loader::try_driver()?;
Ok(())
}
pub fn copy_2d_htod<T: Copy>(
dst: &mut DeviceBuffer<T>,
src: &[T],
params: &Memcpy2DParams,
) -> CudaResult<()> {
params.validate()?;
validate_2d_slice_size(src, params.src_byte_extent())?;
validate_2d_buffer_size(dst, params.dst_byte_extent())?;
let _api = oxicuda_driver::loader::try_driver()?;
Ok(())
}
pub fn copy_2d_dtoh<T: Copy>(
dst: &mut [T],
src: &DeviceBuffer<T>,
params: &Memcpy2DParams,
) -> CudaResult<()> {
params.validate()?;
validate_2d_buffer_size(src, params.src_byte_extent())?;
validate_2d_slice_size(dst, params.dst_byte_extent())?;
let _api = oxicuda_driver::loader::try_driver()?;
Ok(())
}
fn validate_3d_buffer_size<T: Copy>(buf: &DeviceBuffer<T>, byte_extent: usize) -> CudaResult<()> {
if buf.byte_size() < byte_extent {
return Err(CudaError::InvalidValue);
}
Ok(())
}
pub fn copy_3d_dtod<T: Copy>(
dst: &mut DeviceBuffer<T>,
src: &DeviceBuffer<T>,
params: &Memcpy3DParams,
) -> CudaResult<()> {
params.validate()?;
validate_3d_buffer_size(src, params.src_byte_extent())?;
validate_3d_buffer_size(dst, params.dst_byte_extent())?;
let _api = oxicuda_driver::loader::try_driver()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn params_2d_new() {
let p = Memcpy2DParams::new(512, 512, 480, 256);
assert_eq!(p.src_pitch, 512);
assert_eq!(p.dst_pitch, 512);
assert_eq!(p.width, 480);
assert_eq!(p.height, 256);
}
#[test]
fn params_2d_validate_ok() {
let p = Memcpy2DParams::new(512, 512, 480, 256);
assert!(p.validate().is_ok());
}
#[test]
fn params_2d_validate_zero_width() {
let p = Memcpy2DParams::new(512, 512, 0, 256);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_2d_validate_zero_height() {
let p = Memcpy2DParams::new(512, 512, 480, 0);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_2d_validate_width_exceeds_src_pitch() {
let p = Memcpy2DParams::new(256, 512, 480, 100);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_2d_validate_width_exceeds_dst_pitch() {
let p = Memcpy2DParams::new(512, 256, 480, 100);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_2d_byte_extent() {
let p = Memcpy2DParams::new(512, 256, 480, 3);
assert_eq!(p.src_byte_extent(), 2 * 512 + 480);
assert_eq!(p.dst_byte_extent(), 2 * 256 + 480);
}
#[test]
fn params_2d_byte_extent_single_row() {
let p = Memcpy2DParams::new(512, 512, 480, 1);
assert_eq!(p.src_byte_extent(), 480);
assert_eq!(p.dst_byte_extent(), 480);
}
#[test]
fn params_2d_byte_extent_zero_height() {
let p = Memcpy2DParams::new(512, 512, 480, 0);
assert_eq!(p.src_byte_extent(), 0);
assert_eq!(p.dst_byte_extent(), 0);
}
#[test]
fn params_2d_display() {
let p = Memcpy2DParams::new(512, 256, 480, 100);
let disp = format!("{p}");
assert!(disp.contains("480x100"));
assert!(disp.contains("src_pitch=512"));
assert!(disp.contains("dst_pitch=256"));
}
#[test]
fn params_2d_eq() {
let a = Memcpy2DParams::new(512, 512, 480, 256);
let b = Memcpy2DParams::new(512, 512, 480, 256);
assert_eq!(a, b);
}
#[test]
fn params_3d_new() {
let p = Memcpy3DParams::new(512, 512, 480, 256, 10, 256, 256);
assert_eq!(p.depth, 10);
assert_eq!(p.src_height, 256);
assert_eq!(p.dst_height, 256);
}
#[test]
fn params_3d_validate_ok() {
let p = Memcpy3DParams::new(512, 512, 480, 256, 10, 256, 256);
assert!(p.validate().is_ok());
}
#[test]
fn params_3d_validate_zero_depth() {
let p = Memcpy3DParams::new(512, 512, 480, 256, 0, 256, 256);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_3d_validate_height_exceeds_src_height() {
let p = Memcpy3DParams::new(512, 512, 480, 300, 10, 256, 300);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_3d_validate_height_exceeds_dst_height() {
let p = Memcpy3DParams::new(512, 512, 480, 300, 10, 300, 256);
assert_eq!(p.validate(), Err(CudaError::InvalidValue));
}
#[test]
fn params_3d_slice_stride() {
let p = Memcpy3DParams::new(512, 256, 480, 100, 10, 128, 128);
assert_eq!(p.src_slice_stride(), 512 * 128);
assert_eq!(p.dst_slice_stride(), 256 * 128);
}
#[test]
fn params_3d_byte_extent() {
let p = Memcpy3DParams::new(512, 512, 480, 3, 2, 4, 4);
assert_eq!(p.src_byte_extent(), (512 * 4) + 2 * 512 + 480);
}
#[test]
fn params_3d_byte_extent_single_slice() {
let p = Memcpy3DParams::new(512, 512, 480, 3, 1, 4, 4);
assert_eq!(p.src_byte_extent(), 2 * 512 + 480);
}
#[test]
fn params_3d_display() {
let p = Memcpy3DParams::new(512, 256, 480, 100, 10, 128, 128);
let disp = format!("{p}");
assert!(disp.contains("480x100x10"));
}
#[test]
fn copy_2d_dtod_signature_compiles() {
let _: fn(&mut DeviceBuffer<f32>, &DeviceBuffer<f32>, &Memcpy2DParams) -> CudaResult<()> =
copy_2d_dtod;
}
#[test]
fn copy_2d_htod_signature_compiles() {
let _: fn(&mut DeviceBuffer<f32>, &[f32], &Memcpy2DParams) -> CudaResult<()> = copy_2d_htod;
}
#[test]
fn copy_2d_dtoh_signature_compiles() {
let _: fn(&mut [f32], &DeviceBuffer<f32>, &Memcpy2DParams) -> CudaResult<()> = copy_2d_dtoh;
}
#[test]
fn copy_3d_dtod_signature_compiles() {
let _: fn(&mut DeviceBuffer<f32>, &DeviceBuffer<f32>, &Memcpy3DParams) -> CudaResult<()> =
copy_3d_dtod;
}
#[test]
fn params_2d_equal_pitch() {
let p = Memcpy2DParams::new(100, 100, 100, 50);
assert!(p.validate().is_ok());
assert_eq!(p.src_byte_extent(), 49 * 100 + 100);
assert_eq!(p.dst_byte_extent(), 49 * 100 + 100);
}
}