Skip to main content

oxicuda_driver/
cooperative_launch.rs

1//! Cooperative kernel launch support (CUDA 9.0+).
2//!
3//! Cooperative launches allow thread blocks within a kernel (and optionally
4//! across multiple GPUs) to synchronise with each other via
5//! `cooperative_groups`. This module wraps:
6//!
7//! * `cuLaunchCooperativeKernel` — single-device cooperative launch.
8//! * `cuLaunchCooperativeKernelMultiDevice` — multi-device cooperative launch.
9//! * `cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags` — cooperative-
10//!   aware occupancy query.
11//!
12//! # Platform behaviour
13//!
14//! On macOS (where NVIDIA dropped CUDA support), query functions return
15//! synthetic data suitable for unit tests, while actual launch functions
16//! return `Err(CudaError::NotSupported)`.
17//!
18//! # Example
19//!
20//! ```rust,no_run
21//! use oxicuda_driver::cooperative_launch::*;
22//! use oxicuda_driver::device::Device;
23//!
24//! oxicuda_driver::init()?;
25//! let dev = Device::get(0)?;
26//!
27//! let supported = CooperativeLaunchSupport::is_cooperative_supported(&dev)?;
28//! println!("Cooperative launch supported: {supported}");
29//! # Ok::<(), oxicuda_driver::CudaError>(())
30//! ```
31
32use std::ffi::c_void;
33use std::fmt;
34
35use crate::error::{CudaError, CudaResult};
36use crate::ffi::{CUfunction, CUstream};
37
38#[cfg(any(not(target_os = "macos"), test))]
39use crate::ffi::CUdevice_attribute;
40
41use crate::device::Device;
42use crate::module::Function;
43
44// ---------------------------------------------------------------------------
45// Constants
46// ---------------------------------------------------------------------------
47
48/// Flag for `cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags`:
49/// no special flags.
50#[cfg(any(not(target_os = "macos"), test))]
51const CU_OCCUPANCY_DEFAULT: u32 = 0x0;
52
53/// Flag for cooperative-launch-aware occupancy calculation.
54#[cfg(any(not(target_os = "macos"), test))]
55const CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE: u32 = 0x1;
56
57// ---------------------------------------------------------------------------
58// CooperativeLaunchConfig
59// ---------------------------------------------------------------------------
60
61/// Configuration for a single-device cooperative kernel launch.
62///
63/// This mirrors the parameters accepted by `cuLaunchCooperativeKernel`.
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65pub struct CooperativeLaunchConfig {
66    /// Grid dimensions `(x, y, z)` in blocks.
67    pub grid_dim: (u32, u32, u32),
68    /// Block dimensions `(x, y, z)` in threads.
69    pub block_dim: (u32, u32, u32),
70    /// Dynamic shared memory in bytes.
71    pub shared_mem_bytes: u32,
72    /// Stream handle. `None` means the default (null) stream.
73    pub stream: Option<CUstream>,
74}
75
76impl CooperativeLaunchConfig {
77    /// Create a new configuration with the given grid and block dimensions.
78    ///
79    /// Shared memory defaults to 0 and the default stream is used.
80    pub fn new(grid_dim: (u32, u32, u32), block_dim: (u32, u32, u32)) -> Self {
81        Self {
82            grid_dim,
83            block_dim,
84            shared_mem_bytes: 0,
85            stream: None,
86        }
87    }
88
89    /// Set dynamic shared memory in bytes.
90    #[must_use]
91    pub fn with_shared_mem(mut self, bytes: u32) -> Self {
92        self.shared_mem_bytes = bytes;
93        self
94    }
95
96    /// Set the stream for the launch.
97    #[must_use]
98    pub fn with_stream(mut self, stream: CUstream) -> Self {
99        self.stream = Some(stream);
100        self
101    }
102
103    /// Validate the configuration.
104    ///
105    /// All dimensions must be non-zero.
106    ///
107    /// # Errors
108    ///
109    /// Returns [`CudaError::InvalidValue`] if any dimension component is zero.
110    pub fn validate(&self) -> CudaResult<()> {
111        if self.grid_dim.0 == 0 || self.grid_dim.1 == 0 || self.grid_dim.2 == 0 {
112            return Err(CudaError::InvalidValue);
113        }
114        if self.block_dim.0 == 0 || self.block_dim.1 == 0 || self.block_dim.2 == 0 {
115            return Err(CudaError::InvalidValue);
116        }
117        Ok(())
118    }
119
120    /// Total number of threads per block.
121    pub fn threads_per_block(&self) -> u32 {
122        self.block_dim.0 * self.block_dim.1 * self.block_dim.2
123    }
124
125    /// Total number of blocks in the grid.
126    pub fn total_blocks(&self) -> u64 {
127        u64::from(self.grid_dim.0) * u64::from(self.grid_dim.1) * u64::from(self.grid_dim.2)
128    }
129
130    /// Resolve the stream handle (null pointer for the default stream).
131    #[cfg(any(not(target_os = "macos"), test))]
132    fn resolved_stream(&self) -> CUstream {
133        self.stream.unwrap_or_default()
134    }
135}
136
137impl Default for CooperativeLaunchConfig {
138    fn default() -> Self {
139        Self {
140            grid_dim: (1, 1, 1),
141            block_dim: (1, 1, 1),
142            shared_mem_bytes: 0,
143            stream: None,
144        }
145    }
146}
147
148// ---------------------------------------------------------------------------
149// DeviceLaunchConfig
150// ---------------------------------------------------------------------------
151
152/// Per-device configuration for a multi-device cooperative launch.
153///
154/// Each entry describes one device's contribution to the cooperative kernel.
155/// All entries in a multi-device launch must use identical grid and block
156/// dimensions.
157#[derive(Clone)]
158pub struct DeviceLaunchConfig {
159    /// Device ordinal (0-based).
160    pub device_ordinal: i32,
161    /// Raw CUDA function handle for this device's kernel.
162    pub function: CUfunction,
163    /// Grid dimensions `(x, y, z)` in blocks.
164    pub grid_dim: (u32, u32, u32),
165    /// Block dimensions `(x, y, z)` in threads.
166    pub block_dim: (u32, u32, u32),
167    /// Dynamic shared memory in bytes.
168    pub shared_mem_bytes: u32,
169    /// Stream handle for this device.
170    pub stream: CUstream,
171    /// Kernel arguments (pointers to argument values).
172    pub args: Vec<*mut c_void>,
173}
174
175// SAFETY: The raw pointers in `args` are only used during the launch call
176// and the caller is responsible for ensuring their validity at that point.
177unsafe impl Send for DeviceLaunchConfig {}
178unsafe impl Sync for DeviceLaunchConfig {}
179
180impl fmt::Debug for DeviceLaunchConfig {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        f.debug_struct("DeviceLaunchConfig")
183            .field("device_ordinal", &self.device_ordinal)
184            .field("function", &self.function)
185            .field("grid_dim", &self.grid_dim)
186            .field("block_dim", &self.block_dim)
187            .field("shared_mem_bytes", &self.shared_mem_bytes)
188            .field("stream", &self.stream)
189            .field("args_count", &self.args.len())
190            .finish()
191    }
192}
193
194impl DeviceLaunchConfig {
195    /// Create a new per-device launch configuration.
196    pub fn new(
197        device_ordinal: i32,
198        function: CUfunction,
199        grid_dim: (u32, u32, u32),
200        block_dim: (u32, u32, u32),
201        stream: CUstream,
202    ) -> Self {
203        Self {
204            device_ordinal,
205            function,
206            grid_dim,
207            block_dim,
208            shared_mem_bytes: 0,
209            stream,
210            args: Vec::new(),
211        }
212    }
213
214    /// Set dynamic shared memory in bytes.
215    #[must_use]
216    pub fn with_shared_mem(mut self, bytes: u32) -> Self {
217        self.shared_mem_bytes = bytes;
218        self
219    }
220
221    /// Set kernel arguments.
222    #[must_use]
223    pub fn with_args(mut self, args: Vec<*mut c_void>) -> Self {
224        self.args = args;
225        self
226    }
227}
228
229// ---------------------------------------------------------------------------
230// MultiDeviceCooperativeLaunchConfig
231// ---------------------------------------------------------------------------
232
233/// Configuration for a multi-device cooperative kernel launch.
234///
235/// Wraps a collection of [`DeviceLaunchConfig`] entries, one per participating
236/// device. All entries must use the same grid and block dimensions.
237#[derive(Debug, Clone)]
238pub struct MultiDeviceCooperativeLaunchConfig {
239    /// Per-device configurations.
240    pub per_device: Vec<DeviceLaunchConfig>,
241}
242
243impl MultiDeviceCooperativeLaunchConfig {
244    /// Create a new multi-device configuration from per-device entries.
245    ///
246    /// # Errors
247    ///
248    /// Returns [`CudaError::InvalidValue`] if:
249    /// - The list is empty.
250    /// - Grid or block dimensions differ across devices.
251    /// - Any device ordinal is negative.
252    /// - Any dimension component is zero.
253    pub fn new(per_device: Vec<DeviceLaunchConfig>) -> CudaResult<Self> {
254        Self::validate_configs(&per_device)?;
255        Ok(Self { per_device })
256    }
257
258    /// Validate that all device configurations are consistent.
259    fn validate_configs(configs: &[DeviceLaunchConfig]) -> CudaResult<()> {
260        if configs.is_empty() {
261            return Err(CudaError::InvalidValue);
262        }
263
264        let first = &configs[0];
265
266        // All dimensions must be non-zero.
267        if first.grid_dim.0 == 0
268            || first.grid_dim.1 == 0
269            || first.grid_dim.2 == 0
270            || first.block_dim.0 == 0
271            || first.block_dim.1 == 0
272            || first.block_dim.2 == 0
273        {
274            return Err(CudaError::InvalidValue);
275        }
276
277        if first.device_ordinal < 0 {
278            return Err(CudaError::InvalidValue);
279        }
280
281        for cfg in &configs[1..] {
282            if cfg.grid_dim != first.grid_dim {
283                return Err(CudaError::InvalidValue);
284            }
285            if cfg.block_dim != first.block_dim {
286                return Err(CudaError::InvalidValue);
287            }
288            if cfg.device_ordinal < 0 {
289                return Err(CudaError::InvalidValue);
290            }
291        }
292
293        Ok(())
294    }
295
296    /// Number of devices participating in the launch.
297    pub fn device_count(&self) -> usize {
298        self.per_device.len()
299    }
300}
301
302// ---------------------------------------------------------------------------
303// CooperativeLaunchSupport — query helpers
304// ---------------------------------------------------------------------------
305
306/// Query helpers for cooperative launch capabilities.
307pub struct CooperativeLaunchSupport;
308
309impl CooperativeLaunchSupport {
310    /// Query whether the device supports cooperative kernel launches.
311    ///
312    /// Checks `CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH` (attribute 95).
313    ///
314    /// # Errors
315    ///
316    /// Returns a [`CudaError`] if the driver call fails.
317    pub fn is_cooperative_supported(device: &Device) -> CudaResult<bool> {
318        #[cfg(not(target_os = "macos"))]
319        {
320            let driver = crate::loader::try_driver()?;
321            let mut value: i32 = 0;
322            crate::error::check(unsafe {
323                (driver.cu_device_get_attribute)(
324                    &mut value,
325                    CUdevice_attribute::CooperativeLaunch,
326                    device.raw(),
327                )
328            })?;
329            Ok(value != 0)
330        }
331        #[cfg(target_os = "macos")]
332        {
333            let _ = device;
334            // Synthetic: report cooperative launch as supported for testing.
335            Ok(true)
336        }
337    }
338
339    /// Query whether the device supports multi-device cooperative launches.
340    ///
341    /// Checks `CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH` (attribute 96).
342    ///
343    /// # Errors
344    ///
345    /// Returns a [`CudaError`] if the driver call fails.
346    pub fn is_multi_device_supported(device: &Device) -> CudaResult<bool> {
347        #[cfg(not(target_os = "macos"))]
348        {
349            let driver = crate::loader::try_driver()?;
350            let mut value: i32 = 0;
351            crate::error::check(unsafe {
352                (driver.cu_device_get_attribute)(
353                    &mut value,
354                    CUdevice_attribute::CooperativeMultiDeviceLaunch,
355                    device.raw(),
356                )
357            })?;
358            Ok(value != 0)
359        }
360        #[cfg(target_os = "macos")]
361        {
362            let _ = device;
363            // Synthetic: report multi-device cooperative launch as supported.
364            Ok(true)
365        }
366    }
367
368    /// Returns the maximum number of active blocks per SM for a cooperative
369    /// kernel launch.
370    ///
371    /// This is a cooperative-launch-aware variant of the standard occupancy
372    /// query. For cooperative launches, the hardware may limit the number of
373    /// blocks more tightly than for regular launches.
374    ///
375    /// # Parameters
376    ///
377    /// * `func` — the kernel function handle.
378    /// * `block_size` — number of threads per block.
379    /// * `shared_mem` — dynamic shared memory per block in bytes.
380    ///
381    /// # Errors
382    ///
383    /// Returns a [`CudaError`] if the driver call fails.
384    pub fn max_cooperative_grid_blocks(
385        func: &Function,
386        block_size: u32,
387        shared_mem: u32,
388    ) -> CudaResult<u32> {
389        #[cfg(not(target_os = "macos"))]
390        {
391            let driver = crate::loader::try_driver()?;
392            let mut num_blocks: i32 = 0;
393            crate::error::check(unsafe {
394                (driver.cu_occupancy_max_active_blocks_per_multiprocessor_with_flags)(
395                    &mut num_blocks,
396                    func.raw(),
397                    block_size as i32,
398                    shared_mem as usize,
399                    CU_OCCUPANCY_DEFAULT,
400                )
401            })?;
402            Ok(num_blocks as u32)
403        }
404        #[cfg(target_os = "macos")]
405        {
406            let _ = (func, block_size, shared_mem);
407            // Synthetic: return a reasonable occupancy value for testing.
408            Ok(16)
409        }
410    }
411
412    /// Returns the maximum number of active blocks per SM for a cooperative
413    /// launch, with the caching override disabled.
414    ///
415    /// When `disable_caching_override` is `true`, the driver will not use
416    /// the L1/texture cache to increase occupancy.
417    ///
418    /// # Errors
419    ///
420    /// Returns a [`CudaError`] if the driver call fails.
421    pub fn max_cooperative_grid_blocks_with_flags(
422        func: &Function,
423        block_size: u32,
424        shared_mem: u32,
425        disable_caching_override: bool,
426    ) -> CudaResult<u32> {
427        #[cfg(not(target_os = "macos"))]
428        {
429            let driver = crate::loader::try_driver()?;
430            let flags = if disable_caching_override {
431                CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE
432            } else {
433                CU_OCCUPANCY_DEFAULT
434            };
435            let mut num_blocks: i32 = 0;
436            crate::error::check(unsafe {
437                (driver.cu_occupancy_max_active_blocks_per_multiprocessor_with_flags)(
438                    &mut num_blocks,
439                    func.raw(),
440                    block_size as i32,
441                    shared_mem as usize,
442                    flags,
443                )
444            })?;
445            Ok(num_blocks as u32)
446        }
447        #[cfg(target_os = "macos")]
448        {
449            let _ = (func, block_size, shared_mem, disable_caching_override);
450            Ok(16)
451        }
452    }
453}
454
455// ---------------------------------------------------------------------------
456// cooperative_launch — single-device
457// ---------------------------------------------------------------------------
458
459/// Launch a cooperative kernel on a single device.
460///
461/// Cooperative launches enable thread blocks to synchronise with each other
462/// via `cooperative_groups::this_grid().sync()`. The grid must be small
463/// enough that all blocks can be simultaneously resident on the GPU; use
464/// [`CooperativeLaunchSupport::max_cooperative_grid_blocks`] to query the
465/// maximum.
466///
467/// # Safety (caller-side)
468///
469/// The `args` slice must contain pointers to valid kernel argument values
470/// with correct types and alignment. This is the same contract as
471/// `cuLaunchKernel`.
472///
473/// # Errors
474///
475/// * [`CudaError::InvalidValue`] if the config fails validation.
476/// * [`CudaError::CooperativeLaunchTooLarge`] if the grid is too large.
477/// * [`CudaError::NotSupported`] on macOS.
478pub fn cooperative_launch(
479    func: &Function,
480    config: &CooperativeLaunchConfig,
481    args: &[*mut c_void],
482) -> CudaResult<()> {
483    config.validate()?;
484
485    #[cfg(not(target_os = "macos"))]
486    {
487        let driver = crate::loader::try_driver()?;
488        let mut kernel_params: Vec<*mut c_void> = args.to_vec();
489        let params_ptr = if kernel_params.is_empty() {
490            std::ptr::null_mut()
491        } else {
492            kernel_params.as_mut_ptr()
493        };
494
495        crate::error::check(unsafe {
496            (driver.cu_launch_cooperative_kernel)(
497                func.raw(),
498                config.grid_dim.0,
499                config.grid_dim.1,
500                config.grid_dim.2,
501                config.block_dim.0,
502                config.block_dim.1,
503                config.block_dim.2,
504                config.shared_mem_bytes,
505                config.resolved_stream(),
506                params_ptr,
507            )
508        })
509    }
510    #[cfg(target_os = "macos")]
511    {
512        let _ = (func, args);
513        Err(CudaError::NotSupported)
514    }
515}
516
517// ---------------------------------------------------------------------------
518// cooperative_launch_multi_device
519// ---------------------------------------------------------------------------
520
521/// CUDA internal structure for `cuLaunchCooperativeKernelMultiDevice`.
522///
523/// This matches `CUDA_LAUNCH_PARAMS` from the CUDA Driver API.
524#[cfg(not(target_os = "macos"))]
525#[repr(C)]
526#[allow(non_camel_case_types)]
527struct CUDA_LAUNCH_PARAMS {
528    function: CUfunction,
529    grid_dim_x: u32,
530    grid_dim_y: u32,
531    grid_dim_z: u32,
532    block_dim_x: u32,
533    block_dim_y: u32,
534    block_dim_z: u32,
535    shared_mem_bytes: u32,
536    h_stream: CUstream,
537    kernel_params: *mut *mut c_void,
538}
539
540/// Launch a cooperative kernel across multiple devices simultaneously.
541///
542/// All devices execute the same grid/block configuration and can synchronise
543/// via `cooperative_groups::this_multi_grid().sync()`.
544///
545/// # Safety (caller-side)
546///
547/// Each device's `args` must contain valid pointers to kernel argument values.
548/// The appropriate CUDA context must be current for each device when building
549/// the configuration.
550///
551/// # Errors
552///
553/// * [`CudaError::InvalidValue`] if configurations are inconsistent.
554/// * [`CudaError::CooperativeLaunchTooLarge`] if the grid is too large.
555/// * [`CudaError::NotSupported`] on macOS.
556pub fn cooperative_launch_multi_device(configs: &[DeviceLaunchConfig]) -> CudaResult<()> {
557    if configs.is_empty() {
558        return Err(CudaError::InvalidValue);
559    }
560
561    // Validate consistency.
562    MultiDeviceCooperativeLaunchConfig::validate_configs(configs)?;
563
564    #[cfg(not(target_os = "macos"))]
565    {
566        let driver = crate::loader::try_driver()?;
567
568        // Build the CUDA_LAUNCH_PARAMS array. We need mutable copies of the
569        // args vectors so we can take pointers into them.
570        let mut args_storage: Vec<Vec<*mut c_void>> =
571            configs.iter().map(|c| c.args.clone()).collect();
572
573        let mut launch_params: Vec<CUDA_LAUNCH_PARAMS> = configs
574            .iter()
575            .enumerate()
576            .map(|(i, cfg)| CUDA_LAUNCH_PARAMS {
577                function: cfg.function,
578                grid_dim_x: cfg.grid_dim.0,
579                grid_dim_y: cfg.grid_dim.1,
580                grid_dim_z: cfg.grid_dim.2,
581                block_dim_x: cfg.block_dim.0,
582                block_dim_y: cfg.block_dim.1,
583                block_dim_z: cfg.block_dim.2,
584                shared_mem_bytes: cfg.shared_mem_bytes,
585                h_stream: cfg.stream,
586                kernel_params: if args_storage[i].is_empty() {
587                    std::ptr::null_mut()
588                } else {
589                    args_storage[i].as_mut_ptr()
590                },
591            })
592            .collect();
593
594        let num_devices = launch_params.len() as u32;
595        crate::error::check(unsafe {
596            (driver.cu_launch_cooperative_kernel_multi_device)(
597                launch_params.as_mut_ptr().cast::<c_void>(),
598                num_devices,
599                0, // flags
600            )
601        })
602    }
603    #[cfg(target_os = "macos")]
604    {
605        let _ = configs;
606        Err(CudaError::NotSupported)
607    }
608}
609
610// ---------------------------------------------------------------------------
611// Tests
612// ---------------------------------------------------------------------------
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    // -- CooperativeLaunchConfig tests --
619
620    #[test]
621    fn test_config_new() {
622        let config = CooperativeLaunchConfig::new((4, 2, 1), (256, 1, 1));
623        assert_eq!(config.grid_dim, (4, 2, 1));
624        assert_eq!(config.block_dim, (256, 1, 1));
625        assert_eq!(config.shared_mem_bytes, 0);
626        assert!(config.stream.is_none());
627    }
628
629    #[test]
630    fn test_config_default() {
631        let config = CooperativeLaunchConfig::default();
632        assert_eq!(config.grid_dim, (1, 1, 1));
633        assert_eq!(config.block_dim, (1, 1, 1));
634        assert_eq!(config.shared_mem_bytes, 0);
635        assert!(config.stream.is_none());
636    }
637
638    #[test]
639    fn test_config_builder_methods() {
640        let stream = CUstream::default();
641        let config = CooperativeLaunchConfig::new((8, 1, 1), (128, 1, 1))
642            .with_shared_mem(4096)
643            .with_stream(stream);
644        assert_eq!(config.shared_mem_bytes, 4096);
645        assert!(config.stream.is_some());
646    }
647
648    #[test]
649    fn test_config_validate_valid() {
650        let config = CooperativeLaunchConfig::new((1, 1, 1), (32, 1, 1));
651        assert!(config.validate().is_ok());
652    }
653
654    #[test]
655    fn test_config_validate_zero_grid_x() {
656        let config = CooperativeLaunchConfig::new((0, 1, 1), (32, 1, 1));
657        assert_eq!(config.validate(), Err(CudaError::InvalidValue));
658    }
659
660    #[test]
661    fn test_config_validate_zero_grid_y() {
662        let config = CooperativeLaunchConfig::new((1, 0, 1), (32, 1, 1));
663        assert_eq!(config.validate(), Err(CudaError::InvalidValue));
664    }
665
666    #[test]
667    fn test_config_validate_zero_block_z() {
668        let config = CooperativeLaunchConfig::new((1, 1, 1), (32, 1, 0));
669        assert_eq!(config.validate(), Err(CudaError::InvalidValue));
670    }
671
672    #[test]
673    fn test_config_threads_per_block() {
674        let config = CooperativeLaunchConfig::new((1, 1, 1), (16, 8, 2));
675        assert_eq!(config.threads_per_block(), 256);
676    }
677
678    #[test]
679    fn test_config_total_blocks() {
680        let config = CooperativeLaunchConfig::new((4, 3, 2), (1, 1, 1));
681        assert_eq!(config.total_blocks(), 24);
682    }
683
684    #[test]
685    fn test_config_resolved_stream_default() {
686        let config = CooperativeLaunchConfig::default();
687        let stream = config.resolved_stream();
688        assert!(stream.is_null());
689    }
690
691    // -- DeviceLaunchConfig tests --
692
693    #[test]
694    fn test_device_launch_config_new() {
695        let func = CUfunction::default();
696        let stream = CUstream::default();
697        let cfg = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
698        assert_eq!(cfg.device_ordinal, 0);
699        assert_eq!(cfg.grid_dim, (4, 1, 1));
700        assert_eq!(cfg.block_dim, (256, 1, 1));
701        assert_eq!(cfg.shared_mem_bytes, 0);
702        assert!(cfg.args.is_empty());
703    }
704
705    #[test]
706    fn test_device_launch_config_builder() {
707        let func = CUfunction::default();
708        let stream = CUstream::default();
709        let mut val: u32 = 42;
710        let arg_ptr = &mut val as *mut u32 as *mut c_void;
711        let cfg = DeviceLaunchConfig::new(1, func, (2, 2, 1), (128, 1, 1), stream)
712            .with_shared_mem(2048)
713            .with_args(vec![arg_ptr]);
714        assert_eq!(cfg.shared_mem_bytes, 2048);
715        assert_eq!(cfg.args.len(), 1);
716    }
717
718    #[test]
719    fn test_device_launch_config_debug() {
720        let func = CUfunction::default();
721        let stream = CUstream::default();
722        let cfg = DeviceLaunchConfig::new(0, func, (1, 1, 1), (32, 1, 1), stream);
723        let debug_str = format!("{cfg:?}");
724        assert!(debug_str.contains("DeviceLaunchConfig"));
725        assert!(debug_str.contains("args_count"));
726    }
727
728    // -- MultiDeviceCooperativeLaunchConfig tests --
729
730    #[test]
731    fn test_multi_device_config_empty() {
732        let result = MultiDeviceCooperativeLaunchConfig::new(vec![]);
733        assert_eq!(result.err(), Some(CudaError::InvalidValue));
734    }
735
736    #[test]
737    fn test_multi_device_config_single() {
738        let func = CUfunction::default();
739        let stream = CUstream::default();
740        let cfg = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
741        let multi = MultiDeviceCooperativeLaunchConfig::new(vec![cfg]);
742        assert!(multi.is_ok());
743        let multi = multi.ok();
744        assert!(multi.is_some());
745        let multi = multi.map(|m| m.device_count());
746        assert_eq!(multi, Some(1));
747    }
748
749    #[test]
750    fn test_multi_device_config_mismatched_grid() {
751        let func = CUfunction::default();
752        let stream = CUstream::default();
753        let cfg0 = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
754        let cfg1 = DeviceLaunchConfig::new(1, func, (8, 1, 1), (256, 1, 1), stream);
755        let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg0, cfg1]);
756        assert_eq!(result.err(), Some(CudaError::InvalidValue));
757    }
758
759    #[test]
760    fn test_multi_device_config_mismatched_block() {
761        let func = CUfunction::default();
762        let stream = CUstream::default();
763        let cfg0 = DeviceLaunchConfig::new(0, func, (4, 1, 1), (256, 1, 1), stream);
764        let cfg1 = DeviceLaunchConfig::new(1, func, (4, 1, 1), (128, 1, 1), stream);
765        let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg0, cfg1]);
766        assert_eq!(result.err(), Some(CudaError::InvalidValue));
767    }
768
769    #[test]
770    fn test_multi_device_config_negative_ordinal() {
771        let func = CUfunction::default();
772        let stream = CUstream::default();
773        let cfg = DeviceLaunchConfig::new(-1, func, (4, 1, 1), (256, 1, 1), stream);
774        let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg]);
775        assert_eq!(result.err(), Some(CudaError::InvalidValue));
776    }
777
778    #[test]
779    fn test_multi_device_config_zero_dim() {
780        let func = CUfunction::default();
781        let stream = CUstream::default();
782        let cfg = DeviceLaunchConfig::new(0, func, (0, 1, 1), (256, 1, 1), stream);
783        let result = MultiDeviceCooperativeLaunchConfig::new(vec![cfg]);
784        assert_eq!(result.err(), Some(CudaError::InvalidValue));
785    }
786
787    #[test]
788    fn test_multi_device_config_consistent_pair() {
789        let func = CUfunction::default();
790        let stream = CUstream::default();
791        let cfg0 = DeviceLaunchConfig::new(0, func, (4, 2, 1), (128, 2, 1), stream);
792        let cfg1 = DeviceLaunchConfig::new(1, func, (4, 2, 1), (128, 2, 1), stream);
793        let multi = MultiDeviceCooperativeLaunchConfig::new(vec![cfg0, cfg1]);
794        assert!(multi.is_ok());
795        let multi = multi.ok();
796        assert!(multi.is_some());
797        let count = multi.map(|m| m.device_count());
798        assert_eq!(count, Some(2));
799    }
800
801    // -- cooperative_launch on macOS --
802
803    #[cfg(target_os = "macos")]
804    #[test]
805    fn test_cooperative_launch_returns_not_supported_on_macos() {
806        let _func_handle = CUfunction::default();
807        // We can't construct a Function directly, so test via the raw path.
808        // The cooperative_launch function requires &Function, which needs
809        // module.rs. Instead, test the config validation works.
810        let config = CooperativeLaunchConfig::new((1, 1, 1), (32, 1, 1));
811        assert!(config.validate().is_ok());
812    }
813
814    #[cfg(target_os = "macos")]
815    #[test]
816    fn test_multi_device_launch_returns_not_supported_on_macos() {
817        let configs: &[DeviceLaunchConfig] = &[];
818        let result = cooperative_launch_multi_device(configs);
819        assert_eq!(result, Err(CudaError::InvalidValue));
820    }
821
822    // -- CooperativeLaunchSupport tests (macOS synthetic) --
823
824    #[cfg(target_os = "macos")]
825    #[test]
826    fn test_cooperative_support_query_macos() {
827        // On macOS, Device::get will fail because the driver is not loaded.
828        // We test the synthetic path by verifying the constants exist.
829        assert_eq!(CUdevice_attribute::CooperativeLaunch as i32, 95);
830        assert_eq!(CUdevice_attribute::CooperativeMultiDeviceLaunch as i32, 96);
831    }
832
833    #[test]
834    fn test_occupancy_constants() {
835        assert_eq!(CU_OCCUPANCY_DEFAULT, 0x0);
836        assert_eq!(CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE, 0x1);
837    }
838
839    #[test]
840    fn test_config_large_total_blocks_no_overflow() {
841        // Test that large grids don't overflow (u64 arithmetic).
842        let config = CooperativeLaunchConfig::new((65535, 65535, 64), (1, 1, 1));
843        let total = config.total_blocks();
844        assert_eq!(total, 65535u64 * 65535 * 64);
845    }
846}