1use std::ffi::c_void;
33use std::fmt;
34
35use crate::error::{CudaError, CudaResult};
36use crate::ffi::{CUfunction, CUstream};
37
38#[cfg(any(not(target_os = "macos"), test))]
39use crate::ffi::CUdevice_attribute;
40
41use crate::device::Device;
42use crate::module::Function;
43
44#[cfg(any(not(target_os = "macos"), test))]
51const CU_OCCUPANCY_DEFAULT: u32 = 0x0;
52
53#[cfg(any(not(target_os = "macos"), test))]
55const CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE: u32 = 0x1;
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65pub struct CooperativeLaunchConfig {
66 pub grid_dim: (u32, u32, u32),
68 pub block_dim: (u32, u32, u32),
70 pub shared_mem_bytes: u32,
72 pub stream: Option<CUstream>,
74}
75
76impl CooperativeLaunchConfig {
77 pub fn new(grid_dim: (u32, u32, u32), block_dim: (u32, u32, u32)) -> Self {
81 Self {
82 grid_dim,
83 block_dim,
84 shared_mem_bytes: 0,
85 stream: None,
86 }
87 }
88
89 #[must_use]
91 pub fn with_shared_mem(mut self, bytes: u32) -> Self {
92 self.shared_mem_bytes = bytes;
93 self
94 }
95
96 #[must_use]
98 pub fn with_stream(mut self, stream: CUstream) -> Self {
99 self.stream = Some(stream);
100 self
101 }
102
103 pub fn validate(&self) -> CudaResult<()> {
111 if self.grid_dim.0 == 0 || self.grid_dim.1 == 0 || self.grid_dim.2 == 0 {
112 return Err(CudaError::InvalidValue);
113 }
114 if self.block_dim.0 == 0 || self.block_dim.1 == 0 || self.block_dim.2 == 0 {
115 return Err(CudaError::InvalidValue);
116 }
117 Ok(())
118 }
119
120 pub fn threads_per_block(&self) -> u32 {
122 self.block_dim.0 * self.block_dim.1 * self.block_dim.2
123 }
124
125 pub fn total_blocks(&self) -> u64 {
127 u64::from(self.grid_dim.0) * u64::from(self.grid_dim.1) * u64::from(self.grid_dim.2)
128 }
129
130 #[cfg(any(not(target_os = "macos"), test))]
132 fn resolved_stream(&self) -> CUstream {
133 self.stream.unwrap_or_default()
134 }
135}
136
137impl Default for CooperativeLaunchConfig {
138 fn default() -> Self {
139 Self {
140 grid_dim: (1, 1, 1),
141 block_dim: (1, 1, 1),
142 shared_mem_bytes: 0,
143 stream: None,
144 }
145 }
146}
147
148#[derive(Clone)]
158pub struct DeviceLaunchConfig {
159 pub device_ordinal: i32,
161 pub function: CUfunction,
163 pub grid_dim: (u32, u32, u32),
165 pub block_dim: (u32, u32, u32),
167 pub shared_mem_bytes: u32,
169 pub stream: CUstream,
171 pub args: Vec<*mut c_void>,
173}
174
175unsafe impl Send for DeviceLaunchConfig {}
178unsafe impl Sync for DeviceLaunchConfig {}
179
180impl fmt::Debug for DeviceLaunchConfig {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 f.debug_struct("DeviceLaunchConfig")
183 .field("device_ordinal", &self.device_ordinal)
184 .field("function", &self.function)
185 .field("grid_dim", &self.grid_dim)
186 .field("block_dim", &self.block_dim)
187 .field("shared_mem_bytes", &self.shared_mem_bytes)
188 .field("stream", &self.stream)
189 .field("args_count", &self.args.len())
190 .finish()
191 }
192}
193
194impl DeviceLaunchConfig {
195 pub fn new(
197 device_ordinal: i32,
198 function: CUfunction,
199 grid_dim: (u32, u32, u32),
200 block_dim: (u32, u32, u32),
201 stream: CUstream,
202 ) -> Self {
203 Self {
204 device_ordinal,
205 function,
206 grid_dim,
207 block_dim,
208 shared_mem_bytes: 0,
209 stream,
210 args: Vec::new(),
211 }
212 }
213
214 #[must_use]
216 pub fn with_shared_mem(mut self, bytes: u32) -> Self {
217 self.shared_mem_bytes = bytes;
218 self
219 }
220
221 #[must_use]
223 pub fn with_args(mut self, args: Vec<*mut c_void>) -> Self {
224 self.args = args;
225 self
226 }
227}
228
229#[derive(Debug, Clone)]
238pub struct MultiDeviceCooperativeLaunchConfig {
239 pub per_device: Vec<DeviceLaunchConfig>,
241}
242
243impl MultiDeviceCooperativeLaunchConfig {
244 pub fn new(per_device: Vec<DeviceLaunchConfig>) -> CudaResult<Self> {
254 Self::validate_configs(&per_device)?;
255 Ok(Self { per_device })
256 }
257
258 fn validate_configs(configs: &[DeviceLaunchConfig]) -> CudaResult<()> {
260 if configs.is_empty() {
261 return Err(CudaError::InvalidValue);
262 }
263
264 let first = &configs[0];
265
266 if first.grid_dim.0 == 0
268 || first.grid_dim.1 == 0
269 || first.grid_dim.2 == 0
270 || first.block_dim.0 == 0
271 || first.block_dim.1 == 0
272 || first.block_dim.2 == 0
273 {
274 return Err(CudaError::InvalidValue);
275 }
276
277 if first.device_ordinal < 0 {
278 return Err(CudaError::InvalidValue);
279 }
280
281 for cfg in &configs[1..] {
282 if cfg.grid_dim != first.grid_dim {
283 return Err(CudaError::InvalidValue);
284 }
285 if cfg.block_dim != first.block_dim {
286 return Err(CudaError::InvalidValue);
287 }
288 if cfg.device_ordinal < 0 {
289 return Err(CudaError::InvalidValue);
290 }
291 }
292
293 Ok(())
294 }
295
296 pub fn device_count(&self) -> usize {
298 self.per_device.len()
299 }
300}
301
302pub struct CooperativeLaunchSupport;
308
309impl CooperativeLaunchSupport {
310 pub fn is_cooperative_supported(device: &Device) -> CudaResult<bool> {
318 #[cfg(not(target_os = "macos"))]
319 {
320 let driver = crate::loader::try_driver()?;
321 let mut value: i32 = 0;
322 crate::error::check(unsafe {
323 (driver.cu_device_get_attribute)(
324 &mut value,
325 CUdevice_attribute::CooperativeLaunch,
326 device.raw(),
327 )
328 })?;
329 Ok(value != 0)
330 }
331 #[cfg(target_os = "macos")]
332 {
333 let _ = device;
334 Ok(true)
336 }
337 }
338
339 pub fn is_multi_device_supported(device: &Device) -> CudaResult<bool> {
347 #[cfg(not(target_os = "macos"))]
348 {
349 let driver = crate::loader::try_driver()?;
350 let mut value: i32 = 0;
351 crate::error::check(unsafe {
352 (driver.cu_device_get_attribute)(
353 &mut value,
354 CUdevice_attribute::CooperativeMultiDeviceLaunch,
355 device.raw(),
356 )
357 })?;
358 Ok(value != 0)
359 }
360 #[cfg(target_os = "macos")]
361 {
362 let _ = device;
363 Ok(true)
365 }
366 }
367
368 pub fn max_cooperative_grid_blocks(
385 func: &Function,
386 block_size: u32,
387 shared_mem: u32,
388 ) -> CudaResult<u32> {
389 #[cfg(not(target_os = "macos"))]
390 {
391 let driver = crate::loader::try_driver()?;
392 let mut num_blocks: i32 = 0;
393 crate::error::check(unsafe {
394 (driver.cu_occupancy_max_active_blocks_per_multiprocessor_with_flags)(
395 &mut num_blocks,
396 func.raw(),
397 block_size as i32,
398 shared_mem as usize,
399 CU_OCCUPANCY_DEFAULT,
400 )
401 })?;
402 Ok(num_blocks as u32)
403 }
404 #[cfg(target_os = "macos")]
405 {
406 let _ = (func, block_size, shared_mem);
407 Ok(16)
409 }
410 }
411
412 pub fn max_cooperative_grid_blocks_with_flags(
422 func: &Function,
423 block_size: u32,
424 shared_mem: u32,
425 disable_caching_override: bool,
426 ) -> CudaResult<u32> {
427 #[cfg(not(target_os = "macos"))]
428 {
429 let driver = crate::loader::try_driver()?;
430 let flags = if disable_caching_override {
431 CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE
432 } else {
433 CU_OCCUPANCY_DEFAULT
434 };
435 let mut num_blocks: i32 = 0;
436 crate::error::check(unsafe {
437 (driver.cu_occupancy_max_active_blocks_per_multiprocessor_with_flags)(
438 &mut num_blocks,
439 func.raw(),
440 block_size as i32,
441 shared_mem as usize,
442 flags,
443 )
444 })?;
445 Ok(num_blocks as u32)
446 }
447 #[cfg(target_os = "macos")]
448 {
449 let _ = (func, block_size, shared_mem, disable_caching_override);
450 Ok(16)
451 }
452 }
453}
454
455pub fn cooperative_launch(
479 func: &Function,
480 config: &CooperativeLaunchConfig,
481 args: &[*mut c_void],
482) -> CudaResult<()> {
483 config.validate()?;
484
485 #[cfg(not(target_os = "macos"))]
486 {
487 let driver = crate::loader::try_driver()?;
488 let mut kernel_params: Vec<*mut c_void> = args.to_vec();
489 let params_ptr = if kernel_params.is_empty() {
490 std::ptr::null_mut()
491 } else {
492 kernel_params.as_mut_ptr()
493 };
494
495 crate::error::check(unsafe {
496 (driver.cu_launch_cooperative_kernel)(
497 func.raw(),
498 config.grid_dim.0,
499 config.grid_dim.1,
500 config.grid_dim.2,
501 config.block_dim.0,
502 config.block_dim.1,
503 config.block_dim.2,
504 config.shared_mem_bytes,
505 config.resolved_stream(),
506 params_ptr,
507 )
508 })
509 }
510 #[cfg(target_os = "macos")]
511 {
512 let _ = (func, args);
513 Err(CudaError::NotSupported)
514 }
515}
516
517#[cfg(not(target_os = "macos"))]
525#[repr(C)]
526#[allow(non_camel_case_types)]
527struct CUDA_LAUNCH_PARAMS {
528 function: CUfunction,
529 grid_dim_x: u32,
530 grid_dim_y: u32,
531 grid_dim_z: u32,
532 block_dim_x: u32,
533 block_dim_y: u32,
534 block_dim_z: u32,
535 shared_mem_bytes: u32,
536 h_stream: CUstream,
537 kernel_params: *mut *mut c_void,
538}
539
540pub fn cooperative_launch_multi_device(configs: &[DeviceLaunchConfig]) -> CudaResult<()> {
557 if configs.is_empty() {
558 return Err(CudaError::InvalidValue);
559 }
560
561 MultiDeviceCooperativeLaunchConfig::validate_configs(configs)?;
563
564 #[cfg(not(target_os = "macos"))]
565 {
566 let driver = crate::loader::try_driver()?;
567
568 let mut args_storage: Vec<Vec<*mut c_void>> =
571 configs.iter().map(|c| c.args.clone()).collect();
572
573 let mut launch_params: Vec<CUDA_LAUNCH_PARAMS> = configs
574 .iter()
575 .enumerate()
576 .map(|(i, cfg)| CUDA_LAUNCH_PARAMS {
577 function: cfg.function,
578 grid_dim_x: cfg.grid_dim.0,
579 grid_dim_y: cfg.grid_dim.1,
580 grid_dim_z: cfg.grid_dim.2,
581 block_dim_x: cfg.block_dim.0,
582 block_dim_y: cfg.block_dim.1,
583 block_dim_z: cfg.block_dim.2,
584 shared_mem_bytes: cfg.shared_mem_bytes,
585 h_stream: cfg.stream,
586 kernel_params: if args_storage[i].is_empty() {
587 std::ptr::null_mut()
588 } else {
589 args_storage[i].as_mut_ptr()
590 },
591 })
592 .collect();
593
594 let num_devices = launch_params.len() as u32;
595 crate::error::check(unsafe {
596 (driver.cu_launch_cooperative_kernel_multi_device)(
597 launch_params.as_mut_ptr().cast::<c_void>(),
598 num_devices,
599 0, )
601 })
602 }
603 #[cfg(target_os = "macos")]
604 {
605 let _ = configs;
606 Err(CudaError::NotSupported)
607 }
608}
609
610#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
621 fn test_config_new() {
622 let config = CooperativeLaunchConfig::new((4, 2, 1), (256, 1, 1));
623 assert_eq!(config.grid_dim, (4, 2, 1));
624 assert_eq!(config.block_dim, (256, 1, 1));
625 assert_eq!(config.shared_mem_bytes, 0);
626 assert!(config.stream.is_none());
627 }
628
629 #[test]
630 fn test_config_default() {
631 let config = CooperativeLaunchConfig::default();
632 assert_eq!(config.grid_dim, (1, 1, 1));
633 assert_eq!(config.block_dim, (1, 1, 1));
634 assert_eq!(config.shared_mem_bytes, 0);
635 assert!(config.stream.is_none());
636 }
637
638 #[test]
639 fn test_config_builder_methods() {
640 let stream = CUstream::default();
641 let config = CooperativeLaunchConfig::new((8, 1, 1), (128, 1, 1))
642 .with_shared_mem(4096)
643 .with_stream(stream);
644 assert_eq!(config.shared_mem_bytes, 4096);
645 assert!(config.stream.is_some());
646 }
647
648 #[test]
649 fn test_config_validate_valid() {
650 let config = CooperativeLaunchConfig::new((1, 1, 1), (32, 1, 1));
651 assert!(config.validate().is_ok());
652 }
653
654 #[test]
655 fn test_config_validate_zero_grid_x() {
656 let config = CooperativeLaunchConfig::new((0, 1, 1), (32, 1, 1));
657 assert_eq!(config.validate(), Err(CudaError::InvalidValue));
658 }
659
660 #[test]
661 fn test_config_validate_zero_grid_y() {
662 let config = CooperativeLaunchConfig::new((1, 0, 1), (32, 1, 1));
663 assert_eq!(config.validate(), Err(CudaError::InvalidValue));
664 }
665
666 #[test]
667 fn test_config_validate_zero_block_z() {
668 let config = CooperativeLaunchConfig::new((1, 1, 1), (32, 1, 0));
669 assert_eq!(config.validate(), Err(CudaError::InvalidValue));
670 }
671
672 #[test]
673 fn test_config_threads_per_block() {
674 let config = CooperativeLaunchConfig::new((1, 1, 1), (16, 8, 2));
675 assert_eq!(config.threads_per_block(), 256);
676 }
677
678 #[test]
679 fn test_config_total_blocks() {
680 let config = CooperativeLaunchConfig::new((4, 3, 2), (1, 1, 1));
681 assert_eq!(config.total_blocks(), 24);
682 }
683
684 #[test]
685 fn test_config_resolved_stream_default() {
686 let config = CooperativeLaunchConfig::default();
687 let stream = config.resolved_stream();
688 assert!(stream.is_null());
689 }
690
691 #[test]
694 fn test_device_launch_config_new() {
695 let func = CUfunction::default();
696 let stream = CUstream::default();
697 let cfg = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
698 assert_eq!(cfg.device_ordinal, 0);
699 assert_eq!(cfg.grid_dim, (4, 1, 1));
700 assert_eq!(cfg.block_dim, (256, 1, 1));
701 assert_eq!(cfg.shared_mem_bytes, 0);
702 assert!(cfg.args.is_empty());
703 }
704
705 #[test]
706 fn test_device_launch_config_builder() {
707 let func = CUfunction::default();
708 let stream = CUstream::default();
709 let mut val: u32 = 42;
710 let arg_ptr = &mut val as *mut u32 as *mut c_void;
711 let cfg = DeviceLaunchConfig::new(1, func, (2, 2, 1), (128, 1, 1), stream)
712 .with_shared_mem(2048)
713 .with_args(vec![arg_ptr]);
714 assert_eq!(cfg.shared_mem_bytes, 2048);
715 assert_eq!(cfg.args.len(), 1);
716 }
717
718 #[test]
719 fn test_device_launch_config_debug() {
720 let func = CUfunction::default();
721 let stream = CUstream::default();
722 let cfg = DeviceLaunchConfig::new(0, func, (1, 1, 1), (32, 1, 1), stream);
723 let debug_str = format!("{cfg:?}");
724 assert!(debug_str.contains("DeviceLaunchConfig"));
725 assert!(debug_str.contains("args_count"));
726 }
727
728 #[test]
731 fn test_multi_device_config_empty() {
732 let result = MultiDeviceCooperativeLaunchConfig::new(vec![]);
733 assert_eq!(result.err(), Some(CudaError::InvalidValue));
734 }
735
736 #[test]
737 fn test_multi_device_config_single() {
738 let func = CUfunction::default();
739 let stream = CUstream::default();
740 let cfg = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
741 let multi = MultiDeviceCooperativeLaunchConfig::new(vec![cfg]);
742 assert!(multi.is_ok());
743 let multi = multi.ok();
744 assert!(multi.is_some());
745 let multi = multi.map(|m| m.device_count());
746 assert_eq!(multi, Some(1));
747 }
748
749 #[test]
750 fn test_multi_device_config_mismatched_grid() {
751 let func = CUfunction::default();
752 let stream = CUstream::default();
753 let cfg0 = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
754 let cfg1 = DeviceLaunchConfig::new(1, func, (8, 1, 1), (256, 1, 1), stream);
755 let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg0, cfg1]);
756 assert_eq!(result.err(), Some(CudaError::InvalidValue));
757 }
758
759 #[test]
760 fn test_multi_device_config_mismatched_block() {
761 let func = CUfunction::default();
762 let stream = CUstream::default();
763 let cfg0 = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
764 let cfg1 = DeviceLaunchConfig::new(1, func, (4, 1, 1), (128, 1, 1), stream);
765 let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg0, cfg1]);
766 assert_eq!(result.err(), Some(CudaError::InvalidValue));
767 }
768
769 #[test]
770 fn test_multi_device_config_negative_ordinal() {
771 let func = CUfunction::default();
772 let stream = CUstream::default();
773 let cfg = DeviceLaunchConfig::new(-1, func, (4, 1, 1), (256, 1, 1), stream);
774 let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg]);
775 assert_eq!(result.err(), Some(CudaError::InvalidValue));
776 }
777
778 #[test]
779 fn test_multi_device_config_zero_dim() {
780 let func = CUfunction::default();
781 let stream = CUstream::default();
782 let cfg = DeviceLaunchConfig::new(0, func, (0, 1, 1), (256, 1, 1), stream);
783 let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg]);
784 assert_eq!(result.err(), Some(CudaError::InvalidValue));
785 }
786
787 #[test]
788 fn test_multi_device_config_consistent_pair() {
789 let func = CUfunction::default();
790 let stream = CUstream::default();
791 let cfg0 = DeviceLaunchConfig::new(0, func, (4, 2, 1), (128, 2, 1), stream);
792 let cfg1 = DeviceLaunchConfig::new(1, func, (4, 2, 1), (128, 2, 1), stream);
793 let multi = MultiDeviceCooperativeLaunchConfig::new(vec![cfg0, cfg1]);
794 assert!(multi.is_ok());
795 let multi = multi.ok();
796 assert!(multi.is_some());
797 let count = multi.map(|m| m.device_count());
798 assert_eq!(count, Some(2));
799 }
800
801 #[cfg(target_os = "macos")]
804 #[test]
805 fn test_cooperative_launch_returns_not_supported_on_macos() {
806 let _func_handle = CUfunction::default();
807 let config = CooperativeLaunchConfig::new((1, 1, 1), (32, 1, 1));
811 assert!(config.validate().is_ok());
812 }
813
814 #[cfg(target_os = "macos")]
815 #[test]
816 fn test_multi_device_launch_returns_not_supported_on_macos() {
817 let configs: &[DeviceLaunchConfig] = &[];
818 let result = cooperative_launch_multi_device(configs);
819 assert_eq!(result, Err(CudaError::InvalidValue));
820 }
821
822 #[cfg(target_os = "macos")]
825 #[test]
826 fn test_cooperative_support_query_macos() {
827 assert_eq!(CUdevice_attribute::CooperativeLaunch as i32, 95);
830 assert_eq!(CUdevice_attribute::CooperativeMultiDeviceLaunch as i32, 96);
831 }
832
833 #[test]
834 fn test_occupancy_constants() {
835 assert_eq!(CU_OCCUPANCY_DEFAULT, 0x0);
836 assert_eq!(CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE, 0x1);
837 }
838
839 #[test]
840 fn test_config_large_total_blocks_no_overflow() {
841 let config = CooperativeLaunchConfig::new((65535, 65535, 64), (1, 1, 1));
843 let total = config.total_blocks();
844 assert_eq!(total, 65535u64 * 65535 * 64);
845 }
846}