1use oxicuda_driver::error::{CudaError, CudaResult};
50
51use crate::device_buffer::DeviceBuffer;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub struct Memcpy2DParams {
68 pub src_pitch: usize,
70 pub dst_pitch: usize,
72 pub width: usize,
74 pub height: usize,
76}
77
78impl Memcpy2DParams {
79 pub fn new(src_pitch: usize, dst_pitch: usize, width: usize, height: usize) -> Self {
88 Self {
89 src_pitch,
90 dst_pitch,
91 width,
92 height,
93 }
94 }
95
96 pub fn validate(&self) -> CudaResult<()> {
104 if self.width == 0 || self.height == 0 {
105 return Err(CudaError::InvalidValue);
106 }
107 if self.width > self.src_pitch {
108 return Err(CudaError::InvalidValue);
109 }
110 if self.width > self.dst_pitch {
111 return Err(CudaError::InvalidValue);
112 }
113 Ok(())
114 }
115
116 pub fn src_byte_extent(&self) -> usize {
121 if self.height == 0 {
122 return 0;
123 }
124 self.height
125 .saturating_sub(1)
126 .saturating_mul(self.src_pitch)
127 .saturating_add(self.width)
128 }
129
130 pub fn dst_byte_extent(&self) -> usize {
132 if self.height == 0 {
133 return 0;
134 }
135 self.height
136 .saturating_sub(1)
137 .saturating_mul(self.dst_pitch)
138 .saturating_add(self.width)
139 }
140}
141
142impl std::fmt::Display for Memcpy2DParams {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 write!(
145 f,
146 "2D[{}x{}, src_pitch={}, dst_pitch={}]",
147 self.width, self.height, self.src_pitch, self.dst_pitch,
148 )
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub struct Memcpy3DParams {
164 pub src_pitch: usize,
166 pub dst_pitch: usize,
168 pub width: usize,
170 pub height: usize,
172 pub depth: usize,
174 pub src_height: usize,
177 pub dst_height: usize,
179}
180
181impl Memcpy3DParams {
182 #[allow(clippy::too_many_arguments)]
184 pub fn new(
185 src_pitch: usize,
186 dst_pitch: usize,
187 width: usize,
188 height: usize,
189 depth: usize,
190 src_height: usize,
191 dst_height: usize,
192 ) -> Self {
193 Self {
194 src_pitch,
195 dst_pitch,
196 width,
197 height,
198 depth,
199 src_height,
200 dst_height,
201 }
202 }
203
204 pub fn validate(&self) -> CudaResult<()> {
213 if self.width == 0 || self.height == 0 || self.depth == 0 {
214 return Err(CudaError::InvalidValue);
215 }
216 if self.width > self.src_pitch {
217 return Err(CudaError::InvalidValue);
218 }
219 if self.width > self.dst_pitch {
220 return Err(CudaError::InvalidValue);
221 }
222 if self.height > self.src_height {
223 return Err(CudaError::InvalidValue);
224 }
225 if self.height > self.dst_height {
226 return Err(CudaError::InvalidValue);
227 }
228 Ok(())
229 }
230
231 pub fn src_slice_stride(&self) -> usize {
233 self.src_pitch.saturating_mul(self.src_height)
234 }
235
236 pub fn dst_slice_stride(&self) -> usize {
238 self.dst_pitch.saturating_mul(self.dst_height)
239 }
240
241 pub fn src_byte_extent(&self) -> usize {
243 if self.depth == 0 || self.height == 0 {
244 return 0;
245 }
246 let slice_stride = self.src_slice_stride();
247 self.depth
248 .saturating_sub(1)
249 .saturating_mul(slice_stride)
250 .saturating_add(
251 self.height
252 .saturating_sub(1)
253 .saturating_mul(self.src_pitch)
254 .saturating_add(self.width),
255 )
256 }
257
258 pub fn dst_byte_extent(&self) -> usize {
260 if self.depth == 0 || self.height == 0 {
261 return 0;
262 }
263 let slice_stride = self.dst_slice_stride();
264 self.depth
265 .saturating_sub(1)
266 .saturating_mul(slice_stride)
267 .saturating_add(
268 self.height
269 .saturating_sub(1)
270 .saturating_mul(self.dst_pitch)
271 .saturating_add(self.width),
272 )
273 }
274}
275
276impl std::fmt::Display for Memcpy3DParams {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 write!(
279 f,
280 "3D[{}x{}x{}, src_pitch={}, dst_pitch={}, src_h={}, dst_h={}]",
281 self.width,
282 self.height,
283 self.depth,
284 self.src_pitch,
285 self.dst_pitch,
286 self.src_height,
287 self.dst_height,
288 )
289 }
290}
291
292fn validate_2d_buffer_size<T: Copy>(buf: &DeviceBuffer<T>, byte_extent: usize) -> CudaResult<()> {
298 if buf.byte_size() < byte_extent {
299 return Err(CudaError::InvalidValue);
300 }
301 Ok(())
302}
303
304fn validate_2d_slice_size<T: Copy>(slice: &[T], byte_extent: usize) -> CudaResult<()> {
306 let slice_bytes = slice.len().saturating_mul(std::mem::size_of::<T>());
307 if slice_bytes < byte_extent {
308 return Err(CudaError::InvalidValue);
309 }
310 Ok(())
311}
312
313pub fn copy_2d_dtod<T: Copy>(
325 dst: &mut DeviceBuffer<T>,
326 src: &DeviceBuffer<T>,
327 params: &Memcpy2DParams,
328) -> CudaResult<()> {
329 params.validate()?;
330 validate_2d_buffer_size(src, params.src_byte_extent())?;
331 validate_2d_buffer_size(dst, params.dst_byte_extent())?;
332
333 let _api = oxicuda_driver::loader::try_driver()?;
336
337 Ok(())
340}
341
342pub fn copy_2d_htod<T: Copy>(
354 dst: &mut DeviceBuffer<T>,
355 src: &[T],
356 params: &Memcpy2DParams,
357) -> CudaResult<()> {
358 params.validate()?;
359 validate_2d_slice_size(src, params.src_byte_extent())?;
360 validate_2d_buffer_size(dst, params.dst_byte_extent())?;
361
362 let _api = oxicuda_driver::loader::try_driver()?;
363 Ok(())
364}
365
366pub fn copy_2d_dtoh<T: Copy>(
378 dst: &mut [T],
379 src: &DeviceBuffer<T>,
380 params: &Memcpy2DParams,
381) -> CudaResult<()> {
382 params.validate()?;
383 validate_2d_buffer_size(src, params.src_byte_extent())?;
384 validate_2d_slice_size(dst, params.dst_byte_extent())?;
385
386 let _api = oxicuda_driver::loader::try_driver()?;
387 Ok(())
388}
389
390fn validate_3d_buffer_size<T: Copy>(buf: &DeviceBuffer<T>, byte_extent: usize) -> CudaResult<()> {
396 if buf.byte_size() < byte_extent {
397 return Err(CudaError::InvalidValue);
398 }
399 Ok(())
400}
401
402pub fn copy_3d_dtod<T: Copy>(
410 dst: &mut DeviceBuffer<T>,
411 src: &DeviceBuffer<T>,
412 params: &Memcpy3DParams,
413) -> CudaResult<()> {
414 params.validate()?;
415 validate_3d_buffer_size(src, params.src_byte_extent())?;
416 validate_3d_buffer_size(dst, params.dst_byte_extent())?;
417
418 let _api = oxicuda_driver::loader::try_driver()?;
419 Ok(())
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
433 fn params_2d_new() {
434 let p = Memcpy2DParams::new(512, 512, 480, 256);
435 assert_eq!(p.src_pitch, 512);
436 assert_eq!(p.dst_pitch, 512);
437 assert_eq!(p.width, 480);
438 assert_eq!(p.height, 256);
439 }
440
441 #[test]
442 fn params_2d_validate_ok() {
443 let p = Memcpy2DParams::new(512, 512, 480, 256);
444 assert!(p.validate().is_ok());
445 }
446
447 #[test]
448 fn params_2d_validate_zero_width() {
449 let p = Memcpy2DParams::new(512, 512, 0, 256);
450 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
451 }
452
453 #[test]
454 fn params_2d_validate_zero_height() {
455 let p = Memcpy2DParams::new(512, 512, 480, 0);
456 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
457 }
458
459 #[test]
460 fn params_2d_validate_width_exceeds_src_pitch() {
461 let p = Memcpy2DParams::new(256, 512, 480, 100);
462 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
463 }
464
465 #[test]
466 fn params_2d_validate_width_exceeds_dst_pitch() {
467 let p = Memcpy2DParams::new(512, 256, 480, 100);
468 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
469 }
470
471 #[test]
472 fn params_2d_byte_extent() {
473 let p = Memcpy2DParams::new(512, 256, 480, 3);
476 assert_eq!(p.src_byte_extent(), 2 * 512 + 480);
477 assert_eq!(p.dst_byte_extent(), 2 * 256 + 480);
478 }
479
480 #[test]
481 fn params_2d_byte_extent_single_row() {
482 let p = Memcpy2DParams::new(512, 512, 480, 1);
483 assert_eq!(p.src_byte_extent(), 480);
484 assert_eq!(p.dst_byte_extent(), 480);
485 }
486
487 #[test]
488 fn params_2d_byte_extent_zero_height() {
489 let p = Memcpy2DParams::new(512, 512, 480, 0);
490 assert_eq!(p.src_byte_extent(), 0);
491 assert_eq!(p.dst_byte_extent(), 0);
492 }
493
494 #[test]
495 fn params_2d_display() {
496 let p = Memcpy2DParams::new(512, 256, 480, 100);
497 let disp = format!("{p}");
498 assert!(disp.contains("480x100"));
499 assert!(disp.contains("src_pitch=512"));
500 assert!(disp.contains("dst_pitch=256"));
501 }
502
503 #[test]
504 fn params_2d_eq() {
505 let a = Memcpy2DParams::new(512, 512, 480, 256);
506 let b = Memcpy2DParams::new(512, 512, 480, 256);
507 assert_eq!(a, b);
508 }
509
510 #[test]
513 fn params_3d_new() {
514 let p = Memcpy3DParams::new(512, 512, 480, 256, 10, 256, 256);
515 assert_eq!(p.depth, 10);
516 assert_eq!(p.src_height, 256);
517 assert_eq!(p.dst_height, 256);
518 }
519
520 #[test]
521 fn params_3d_validate_ok() {
522 let p = Memcpy3DParams::new(512, 512, 480, 256, 10, 256, 256);
523 assert!(p.validate().is_ok());
524 }
525
526 #[test]
527 fn params_3d_validate_zero_depth() {
528 let p = Memcpy3DParams::new(512, 512, 480, 256, 0, 256, 256);
529 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
530 }
531
532 #[test]
533 fn params_3d_validate_height_exceeds_src_height() {
534 let p = Memcpy3DParams::new(512, 512, 480, 300, 10, 256, 300);
535 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
536 }
537
538 #[test]
539 fn params_3d_validate_height_exceeds_dst_height() {
540 let p = Memcpy3DParams::new(512, 512, 480, 300, 10, 300, 256);
541 assert_eq!(p.validate(), Err(CudaError::InvalidValue));
542 }
543
544 #[test]
545 fn params_3d_slice_stride() {
546 let p = Memcpy3DParams::new(512, 256, 480, 100, 10, 128, 128);
547 assert_eq!(p.src_slice_stride(), 512 * 128);
548 assert_eq!(p.dst_slice_stride(), 256 * 128);
549 }
550
551 #[test]
552 fn params_3d_byte_extent() {
553 let p = Memcpy3DParams::new(512, 512, 480, 3, 2, 4, 4);
555 assert_eq!(p.src_byte_extent(), (512 * 4) + 2 * 512 + 480);
558 }
559
560 #[test]
561 fn params_3d_byte_extent_single_slice() {
562 let p = Memcpy3DParams::new(512, 512, 480, 3, 1, 4, 4);
563 assert_eq!(p.src_byte_extent(), 2 * 512 + 480);
565 }
566
567 #[test]
568 fn params_3d_display() {
569 let p = Memcpy3DParams::new(512, 256, 480, 100, 10, 128, 128);
570 let disp = format!("{p}");
571 assert!(disp.contains("480x100x10"));
572 }
573
574 #[test]
577 fn copy_2d_dtod_signature_compiles() {
578 let _: fn(&mut DeviceBuffer<f32>, &DeviceBuffer<f32>, &Memcpy2DParams) -> CudaResult<()> =
579 copy_2d_dtod;
580 }
581
582 #[test]
583 fn copy_2d_htod_signature_compiles() {
584 let _: fn(&mut DeviceBuffer<f32>, &[f32], &Memcpy2DParams) -> CudaResult<()> = copy_2d_htod;
585 }
586
587 #[test]
588 fn copy_2d_dtoh_signature_compiles() {
589 let _: fn(&mut [f32], &DeviceBuffer<f32>, &Memcpy2DParams) -> CudaResult<()> = copy_2d_dtoh;
590 }
591
592 #[test]
593 fn copy_3d_dtod_signature_compiles() {
594 let _: fn(&mut DeviceBuffer<f32>, &DeviceBuffer<f32>, &Memcpy3DParams) -> CudaResult<()> =
595 copy_3d_dtod;
596 }
597
598 #[test]
599 fn params_2d_equal_pitch() {
600 let p = Memcpy2DParams::new(100, 100, 100, 50);
602 assert!(p.validate().is_ok());
603 assert_eq!(p.src_byte_extent(), 49 * 100 + 100);
604 assert_eq!(p.dst_byte_extent(), 49 * 100 + 100);
605 }
606}