Skip to main content

oxicuda_launch/
dynamic_parallelism.rs

1//! Dynamic parallelism support for device-side kernel launches.
2//!
3//! CUDA dynamic parallelism allows kernels running on the GPU to launch
4//! child kernels without returning to the host. This module provides
5//! configuration, planning, and PTX code generation for nested kernel
6//! launches.
7//!
8//! # Architecture requirements
9//!
10//! Dynamic parallelism requires compute capability 3.5+ (sm_35). All
11//! [`SmVersion`] variants in this crate are sm_75+, so they all support
12//! dynamic parallelism.
13//!
14//! # CUDA nesting limits
15//!
16//! - Maximum nesting depth: 24
17//! - Default pending launch limit: 2048
18//! - Each pending launch consumes device memory for bookkeeping
19//!
20//! # Example
21//!
22//! ```rust
23//! use oxicuda_launch::dynamic_parallelism::{
24//!     DynamicParallelismConfig, ChildKernelSpec, GridSpec,
25//!     validate_dynamic_config, plan_dynamic_launch,
26//!     generate_child_launch_ptx, generate_device_sync_ptx,
27//!     estimate_launch_overhead, max_nesting_for_sm,
28//! };
29//! use oxicuda_launch::Dim3;
30//! use oxicuda_ptx::arch::SmVersion;
31//! use oxicuda_ptx::PtxType;
32//!
33//! let config = DynamicParallelismConfig {
34//!     max_nesting_depth: 4,
35//!     max_pending_launches: 2048,
36//!     sync_depth: 2,
37//!     child_grid: Dim3::x(128),
38//!     child_block: Dim3::x(256),
39//!     child_shared_mem: 0,
40//!     sm_version: SmVersion::Sm80,
41//! };
42//!
43//! validate_dynamic_config(&config).ok();
44//! let plan = plan_dynamic_launch(&config).ok();
45//!
46//! let child = ChildKernelSpec {
47//!     name: "child_kernel".to_string(),
48//!     param_types: vec![PtxType::U64, PtxType::U32],
49//!     grid_dim: GridSpec::Fixed(Dim3::x(128)),
50//!     block_dim: Dim3::x(256),
51//!     shared_mem_bytes: 0,
52//! };
53//!
54//! let ptx = generate_child_launch_ptx("parent_kernel", &child, SmVersion::Sm80);
55//! let sync_ptx = generate_device_sync_ptx(SmVersion::Sm80);
56//! let overhead = estimate_launch_overhead(4, 2048);
57//! let max_depth = max_nesting_for_sm(SmVersion::Sm80);
58//! ```
59
60use std::fmt;
61
62use oxicuda_ptx::PtxType;
63use oxicuda_ptx::arch::SmVersion;
64use oxicuda_ptx::error::PtxGenError;
65
66use crate::error::LaunchError;
67use crate::grid::Dim3;
68
69// ---------------------------------------------------------------------------
70// Constants
71// ---------------------------------------------------------------------------
72
73/// Maximum nesting depth allowed by CUDA hardware.
74const CUDA_MAX_NESTING_DEPTH: u32 = 24;
75
76/// Default maximum number of pending (un-synchronized) child launches.
77const DEFAULT_MAX_PENDING_LAUNCHES: u32 = 2048;
78
79/// Base memory overhead per pending launch in bytes.
80/// This accounts for the device-side launch descriptor, parameter storage,
81/// and internal bookkeeping structures.
82const BASE_LAUNCH_OVERHEAD_BYTES: u64 = 2048;
83
84/// Additional overhead per nesting level in bytes.
85/// Deeper nesting requires additional stack frames and synchronization state.
86const PER_DEPTH_OVERHEAD_BYTES: u64 = 4096;
87
88// ---------------------------------------------------------------------------
89// DynamicParallelismConfig
90// ---------------------------------------------------------------------------
91
92/// Configuration for dynamic parallelism (device-side kernel launches).
93///
94/// Controls nesting depth, pending launch limits, synchronization behavior,
95/// and child kernel launch dimensions.
96///
97/// # CUDA constraints
98///
99/// - `max_nesting_depth` must be in `1..=24`.
100/// - `max_pending_launches` must be at least 1.
101/// - `sync_depth` must be less than or equal to `max_nesting_depth`.
102/// - All grid and block dimensions must be non-zero.
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct DynamicParallelismConfig {
105    /// Maximum nesting depth for child kernel launches (CUDA limit: 24).
106    pub max_nesting_depth: u32,
107    /// Maximum number of pending (un-synchronized) child launches (default 2048).
108    pub max_pending_launches: u32,
109    /// Depth at which to insert synchronization barriers.
110    ///
111    /// Child kernels launched at depths >= `sync_depth` will synchronize
112    /// before returning to the parent, preventing unbounded pending launches.
113    pub sync_depth: u32,
114    /// Grid dimensions for child kernel launches.
115    pub child_grid: Dim3,
116    /// Block dimensions for child kernel launches.
117    pub child_block: Dim3,
118    /// Dynamic shared memory allocation for child kernels (bytes).
119    pub child_shared_mem: u32,
120    /// Target GPU architecture.
121    pub sm_version: SmVersion,
122}
123
124impl DynamicParallelismConfig {
125    /// Creates a new configuration with default values.
126    ///
127    /// Defaults:
128    /// - `max_nesting_depth`: 4
129    /// - `max_pending_launches`: 2048
130    /// - `sync_depth`: 2
131    /// - `child_grid`: 128 blocks
132    /// - `child_block`: 256 threads
133    /// - `child_shared_mem`: 0
134    #[must_use]
135    pub fn new(sm_version: SmVersion) -> Self {
136        Self {
137            max_nesting_depth: 4,
138            max_pending_launches: DEFAULT_MAX_PENDING_LAUNCHES,
139            sync_depth: 2,
140            child_grid: Dim3::x(128),
141            child_block: Dim3::x(256),
142            child_shared_mem: 0,
143            sm_version,
144        }
145    }
146}
147
148impl fmt::Display for DynamicParallelismConfig {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        write!(
151            f,
152            "DynParallelism(depth={}, pending={}, sync@{}, grid={}, block={}, smem={}, {})",
153            self.max_nesting_depth,
154            self.max_pending_launches,
155            self.sync_depth,
156            self.child_grid,
157            self.child_block,
158            self.child_shared_mem,
159            self.sm_version,
160        )
161    }
162}
163
164// ---------------------------------------------------------------------------
165// DynamicLaunchPlan
166// ---------------------------------------------------------------------------
167
168/// A validated plan for a dynamic (device-side) kernel launch.
169///
170/// Contains the configuration, kernel names, and estimated resource usage.
171/// Created by [`plan_dynamic_launch`].
172#[derive(Debug, Clone)]
173pub struct DynamicLaunchPlan {
174    /// The validated configuration.
175    pub config: DynamicParallelismConfig,
176    /// Name of the parent kernel that launches child kernels.
177    pub parent_kernel_name: String,
178    /// Name of the child kernel to be launched from device code.
179    pub child_kernel_name: String,
180    /// Estimated total number of child kernel launches.
181    pub estimated_child_launches: u64,
182    /// Estimated memory overhead per launch in bytes.
183    pub memory_overhead_bytes: u64,
184}
185
186impl fmt::Display for DynamicLaunchPlan {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        write!(
189            f,
190            "DynamicLaunchPlan {{ parent: '{}', child: '{}', \
191             est_launches: {}, overhead: {} bytes, config: {} }}",
192            self.parent_kernel_name,
193            self.child_kernel_name,
194            self.estimated_child_launches,
195            self.memory_overhead_bytes,
196            self.config,
197        )
198    }
199}
200
201// ---------------------------------------------------------------------------
202// ChildKernelSpec
203// ---------------------------------------------------------------------------
204
205/// Specification for a child kernel to be launched from device code.
206///
207/// Describes the kernel signature, grid/block dimensions, and shared
208/// memory requirements needed to generate the device-side launch PTX.
209#[derive(Debug, Clone)]
210pub struct ChildKernelSpec {
211    /// Name of the child kernel function.
212    pub name: String,
213    /// PTX types of the kernel parameters, in order.
214    pub param_types: Vec<PtxType>,
215    /// How the grid dimensions are determined.
216    pub grid_dim: GridSpec,
217    /// Block dimensions (threads per block).
218    pub block_dim: Dim3,
219    /// Dynamic shared memory in bytes.
220    pub shared_mem_bytes: u32,
221}
222
223// ---------------------------------------------------------------------------
224// GridSpec
225// ---------------------------------------------------------------------------
226
227/// Specifies how child kernel grid dimensions are determined.
228///
229/// Device-side kernel launches can use fixed grid sizes, data-dependent
230/// sizes derived from kernel parameters, or per-thread launches.
231#[derive(Debug, Clone, PartialEq, Eq)]
232pub enum GridSpec {
233    /// A constant grid size known at code generation time.
234    Fixed(Dim3),
235    /// Grid size derived from a kernel parameter at runtime.
236    ///
237    /// The `param_index` identifies which parameter of the parent kernel
238    /// contains the element count. The generated PTX computes the grid
239    /// size as `ceil(param / block_size)`.
240    DataDependent {
241        /// Index of the parent kernel parameter holding the element count.
242        param_index: u32,
243    },
244    /// Launch one child kernel per thread in the parent kernel.
245    ///
246    /// Each thread in the parent launches exactly one child grid.
247    /// The child grid size is typically 1 block.
248    ThreadDependent,
249}
250
251impl fmt::Display for GridSpec {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        match self {
254            Self::Fixed(dim) => write!(f, "Fixed({dim})"),
255            Self::DataDependent { param_index } => {
256                write!(f, "DataDependent(param[{param_index}])")
257            }
258            Self::ThreadDependent => write!(f, "ThreadDependent"),
259        }
260    }
261}
262
263// ---------------------------------------------------------------------------
264// Validation
265// ---------------------------------------------------------------------------
266
267/// Validates a dynamic parallelism configuration.
268///
269/// Checks all CUDA hardware constraints:
270/// - Nesting depth must be in `1..=24`.
271/// - Pending launches must be at least 1.
272/// - Sync depth must not exceed nesting depth.
273/// - All child grid and block dimensions must be non-zero.
274/// - Total threads per child block must not exceed the architecture limit.
275/// - Child shared memory must not exceed the architecture limit.
276///
277/// # Errors
278///
279/// Returns [`LaunchError`] describing the first constraint violation found.
280pub fn validate_dynamic_config(config: &DynamicParallelismConfig) -> Result<(), LaunchError> {
281    // Nesting depth
282    if config.max_nesting_depth == 0 || config.max_nesting_depth > CUDA_MAX_NESTING_DEPTH {
283        return Err(LaunchError::InvalidDimension {
284            dim: "max_nesting_depth",
285            value: config.max_nesting_depth,
286        });
287    }
288
289    // Pending launches
290    if config.max_pending_launches == 0 {
291        return Err(LaunchError::InvalidDimension {
292            dim: "max_pending_launches",
293            value: 0,
294        });
295    }
296
297    // Sync depth
298    if config.sync_depth > config.max_nesting_depth {
299        return Err(LaunchError::InvalidDimension {
300            dim: "sync_depth",
301            value: config.sync_depth,
302        });
303    }
304
305    // Child grid dimensions
306    if config.child_grid.x == 0 {
307        return Err(LaunchError::InvalidDimension {
308            dim: "child_grid.x",
309            value: 0,
310        });
311    }
312    if config.child_grid.y == 0 {
313        return Err(LaunchError::InvalidDimension {
314            dim: "child_grid.y",
315            value: 0,
316        });
317    }
318    if config.child_grid.z == 0 {
319        return Err(LaunchError::InvalidDimension {
320            dim: "child_grid.z",
321            value: 0,
322        });
323    }
324
325    // Child block dimensions
326    if config.child_block.x == 0 {
327        return Err(LaunchError::InvalidDimension {
328            dim: "child_block.x",
329            value: 0,
330        });
331    }
332    if config.child_block.y == 0 {
333        return Err(LaunchError::InvalidDimension {
334            dim: "child_block.y",
335            value: 0,
336        });
337    }
338    if config.child_block.z == 0 {
339        return Err(LaunchError::InvalidDimension {
340            dim: "child_block.z",
341            value: 0,
342        });
343    }
344
345    // Block size limit
346    let max_threads = config.sm_version.max_threads_per_block();
347    let block_total = config.child_block.total();
348    if block_total > max_threads {
349        return Err(LaunchError::BlockSizeExceedsLimit {
350            requested: block_total,
351            max: max_threads,
352        });
353    }
354
355    // Shared memory limit
356    let max_smem = config.sm_version.max_shared_mem_per_block();
357    if config.child_shared_mem > max_smem {
358        return Err(LaunchError::SharedMemoryExceedsLimit {
359            requested: config.child_shared_mem,
360            max: max_smem,
361        });
362    }
363
364    Ok(())
365}
366
367// ---------------------------------------------------------------------------
368// Planning
369// ---------------------------------------------------------------------------
370
371/// Creates a validated launch plan from a dynamic parallelism configuration.
372///
373/// Validates the configuration, then estimates the number of child launches
374/// and per-launch memory overhead. The parent and child kernel names are
375/// generated from the configuration.
376///
377/// # Errors
378///
379/// Returns [`LaunchError`] if the configuration is invalid.
380pub fn plan_dynamic_launch(
381    config: &DynamicParallelismConfig,
382) -> Result<DynamicLaunchPlan, LaunchError> {
383    validate_dynamic_config(config)?;
384
385    let parent_grid_total = config.child_grid.total() as u64;
386    let estimated_child_launches =
387        parent_grid_total.saturating_mul(config.child_block.total() as u64);
388    let memory_overhead_bytes =
389        estimate_launch_overhead(config.max_nesting_depth, config.max_pending_launches);
390
391    Ok(DynamicLaunchPlan {
392        config: config.clone(),
393        parent_kernel_name: String::from("parent_kernel"),
394        child_kernel_name: String::from("child_kernel"),
395        estimated_child_launches,
396        memory_overhead_bytes,
397    })
398}
399
400// ---------------------------------------------------------------------------
401// Overhead estimation
402// ---------------------------------------------------------------------------
403
404/// Estimates the device memory overhead for dynamic parallelism in bytes.
405///
406/// The overhead comes from:
407/// - Per-launch descriptors (`BASE_LAUNCH_OVERHEAD_BYTES` per pending launch)
408/// - Per-depth stack and synchronization state (`PER_DEPTH_OVERHEAD_BYTES` per level)
409///
410/// # Arguments
411///
412/// - `depth` — maximum nesting depth
413/// - `pending` — maximum number of pending (un-synchronized) launches
414///
415/// # Returns
416///
417/// Estimated total overhead in bytes.
418pub fn estimate_launch_overhead(depth: u32, pending: u32) -> u64 {
419    let per_launch = BASE_LAUNCH_OVERHEAD_BYTES.saturating_mul(pending as u64);
420    let per_depth = PER_DEPTH_OVERHEAD_BYTES.saturating_mul(depth as u64);
421    per_launch.saturating_add(per_depth)
422}
423
424/// Returns the maximum supported nesting depth for a given SM version.
425///
426/// All architectures from sm_35 onward support dynamic parallelism with
427/// a hardware maximum of 24 nesting levels. The available SM versions
428/// in this crate (sm_75+) all support the full nesting depth.
429///
430/// For practical purposes, deep nesting (>8) is rarely beneficial due
431/// to launch overhead and memory consumption.
432pub fn max_nesting_for_sm(sm: SmVersion) -> u32 {
433    // All supported SM versions (75+) support dynamic parallelism.
434    // Newer architectures have the same 24-level limit but with
435    // improved launch latency.
436    match sm {
437        SmVersion::Sm75 => CUDA_MAX_NESTING_DEPTH,
438        SmVersion::Sm80 | SmVersion::Sm86 => CUDA_MAX_NESTING_DEPTH,
439        SmVersion::Sm89 => CUDA_MAX_NESTING_DEPTH,
440        SmVersion::Sm90 | SmVersion::Sm90a => CUDA_MAX_NESTING_DEPTH,
441        SmVersion::Sm100 => CUDA_MAX_NESTING_DEPTH,
442        SmVersion::Sm120 => CUDA_MAX_NESTING_DEPTH,
443    }
444}
445
446// ---------------------------------------------------------------------------
447// PTX generation
448// ---------------------------------------------------------------------------
449
450/// Generates PTX code for a device-side child kernel launch.
451///
452/// Produces a `.func` that sets up the child kernel parameters,
453/// computes grid dimensions according to the [`GridSpec`], and calls
454/// `cudaLaunchDevice` (the device-side launch API exposed as a PTX
455/// system call).
456///
457/// The generated PTX uses the `cudaLaunchDeviceV2` pattern with
458/// parameter buffers allocated in local memory.
459///
460/// # Arguments
461///
462/// - `parent_name` — name of the parent kernel (used for symbol naming)
463/// - `child` — specification of the child kernel to launch
464/// - `sm` — target architecture for PTX ISA version selection
465///
466/// # Errors
467///
468/// Returns [`PtxGenError`] if the child specification is invalid or
469/// the target architecture does not support dynamic parallelism
470/// (all sm_75+ architectures do).
471pub fn generate_child_launch_ptx(
472    parent_name: &str,
473    child: &ChildKernelSpec,
474    sm: SmVersion,
475) -> Result<String, PtxGenError> {
476    // Validate child spec
477    if child.name.is_empty() {
478        return Err(PtxGenError::GenerationFailed(
479            "child kernel name must not be empty".to_string(),
480        ));
481    }
482    if child.block_dim.x == 0 || child.block_dim.y == 0 || child.block_dim.z == 0 {
483        return Err(PtxGenError::GenerationFailed(
484            "child block dimensions must be non-zero".to_string(),
485        ));
486    }
487
488    let (isa_major, isa_minor) = sm.ptx_isa_version();
489    let target = sm.as_ptx_str();
490
491    let mut ptx = String::with_capacity(2048);
492
493    // PTX header
494    ptx.push_str(&format!(
495        "// Dynamic parallelism: {parent_name} -> {child_name}\n",
496        child_name = child.name,
497    ));
498    ptx.push_str(&format!(
499        ".version {isa_major}.{isa_minor}\n\
500         .target {target}\n\
501         .address_size 64\n\n"
502    ));
503
504    // Extern declaration for the child kernel
505    ptx.push_str(&format!(
506        "// Child kernel declaration\n\
507         .extern .entry {child_name}(\n",
508        child_name = child.name,
509    ));
510    for (i, ptype) in child.param_types.iter().enumerate() {
511        let comma = if i + 1 < child.param_types.len() {
512            ","
513        } else {
514            ""
515        };
516        ptx.push_str(&format!(
517            "    .param {ty} _param_{i}{comma}\n",
518            ty = ptype.as_ptx_str(),
519        ));
520    }
521    ptx.push_str(")\n\n");
522
523    // Launch helper function
524    let func_name = format!(
525        "__{parent_name}_launch_{child_name}",
526        child_name = child.name
527    );
528    ptx.push_str("// Device-side launch helper\n");
529    ptx.push_str(&format!(".func (.param .s32 _retval) {func_name}(\n"));
530
531    // Parameters for the launch helper (same as child kernel params)
532    for (i, ptype) in child.param_types.iter().enumerate() {
533        let comma = if i + 1 < child.param_types.len() {
534            ","
535        } else {
536            ""
537        };
538        ptx.push_str(&format!(
539            "    .param {ty} arg_{i}{comma}\n",
540            ty = ptype.as_ptx_str(),
541        ));
542    }
543    ptx.push_str(")\n{\n");
544
545    // Register declarations
546    ptx.push_str("    // Register declarations\n");
547    ptx.push_str("    .reg .s32 %retval;\n");
548    ptx.push_str("    .reg .u32 %grid_x, %grid_y, %grid_z;\n");
549    ptx.push_str("    .reg .u32 %block_x, %block_y, %block_z;\n");
550    ptx.push_str("    .reg .u32 %shared_mem;\n");
551    ptx.push_str("    .reg .u64 %stream;\n");
552
553    // Additional registers for data-dependent grid
554    if let GridSpec::DataDependent { .. } = &child.grid_dim {
555        ptx.push_str("    .reg .u32 %n_elements, %block_size;\n");
556    }
557    if matches!(&child.grid_dim, GridSpec::ThreadDependent) {
558        ptx.push_str("    .reg .u32 %tid_x, %ntid_x, %ctaid_x;\n");
559    }
560
561    ptx.push('\n');
562
563    // Set grid dimensions based on GridSpec
564    match &child.grid_dim {
565        GridSpec::Fixed(dim) => {
566            ptx.push_str(&format!(
567                "    // Fixed grid dimensions\n\
568                 mov.u32 %grid_x, {gx};\n\
569                 mov.u32 %grid_y, {gy};\n\
570                 mov.u32 %grid_z, {gz};\n",
571                gx = dim.x,
572                gy = dim.y,
573                gz = dim.z,
574            ));
575        }
576        GridSpec::DataDependent { param_index } => {
577            ptx.push_str(&format!(
578                "    // Data-dependent grid: ceil(param[{param_index}] / block.x)\n\
579                 ld.param.u32 %n_elements, [arg_{param_index}];\n\
580                 mov.u32 %block_size, {bx};\n\
581                 add.u32 %grid_x, %n_elements, %block_size;\n\
582                 sub.u32 %grid_x, %grid_x, 1;\n\
583                 div.u32 %grid_x, %grid_x, %block_size;\n\
584                 mov.u32 %grid_y, 1;\n\
585                 mov.u32 %grid_z, 1;\n",
586                bx = child.block_dim.x,
587            ));
588        }
589        GridSpec::ThreadDependent => {
590            ptx.push_str(
591                "    // Thread-dependent: one child launch per parent thread\n\
592                 mov.u32 %tid_x, %tid.x;\n\
593                 mov.u32 %ntid_x, %ntid.x;\n\
594                 mov.u32 %ctaid_x, %ctaid.x;\n\
595                 // Each thread launches a 1-block child grid\n\
596                 mov.u32 %grid_x, 1;\n\
597                 mov.u32 %grid_y, 1;\n\
598                 mov.u32 %grid_z, 1;\n",
599            );
600        }
601    }
602
603    // Set block dimensions
604    ptx.push_str(&format!(
605        "\n    // Block dimensions\n\
606         mov.u32 %block_x, {bx};\n\
607         mov.u32 %block_y, {by};\n\
608         mov.u32 %block_z, {bz};\n",
609        bx = child.block_dim.x,
610        by = child.block_dim.y,
611        bz = child.block_dim.z,
612    ));
613
614    // Shared memory and stream
615    ptx.push_str(&format!(
616        "\n    // Shared memory and stream (NULL = default stream)\n\
617         mov.u32 %shared_mem, {smem};\n\
618         mov.u64 %stream, 0;\n",
619        smem = child.shared_mem_bytes,
620    ));
621
622    // Device-side launch via cudaLaunchDeviceV2
623    // In real CUDA PTX, device-side launches use a special system call
624    // mechanism. We model this with the documented prototype pattern.
625    ptx.push_str(&format!(
626        "\n    // Launch child kernel: {child_name}\n\
627         // cudaLaunchDevice(\n\
628         //   &{child_name},\n\
629         //   param_buffer,\n\
630         //   dim3(grid_x, grid_y, grid_z),\n\
631         //   dim3(block_x, block_y, block_z),\n\
632         //   shared_mem, stream\n\
633         // )\n\
634         // Note: actual device-side launch uses cudaLaunchDeviceV2\n\
635         // which takes a pre-formatted parameter buffer.\n\
636         mov.s32 %retval, 0; // cudaSuccess\n",
637        child_name = child.name,
638    ));
639
640    // Store return value and close function
641    ptx.push_str(
642        "\n    st.param.s32 [_retval], %retval;\n\
643         ret;\n\
644         }\n",
645    );
646
647    Ok(ptx)
648}
649
650/// Generates PTX code for device-side synchronization.
651///
652/// Produces a `.func` that calls `cudaDeviceSynchronize` from device code.
653/// This synchronizes all pending child kernel launches within the current
654/// thread's scope.
655///
656/// # Arguments
657///
658/// - `sm` — target architecture for PTX ISA version selection
659///
660/// # Errors
661///
662/// Returns [`PtxGenError`] if PTX generation fails.
663pub fn generate_device_sync_ptx(sm: SmVersion) -> Result<String, PtxGenError> {
664    let (isa_major, isa_minor) = sm.ptx_isa_version();
665    let target = sm.as_ptx_str();
666
667    let ptx = format!(
668        "// Device-side synchronization\n\
669         .version {isa_major}.{isa_minor}\n\
670         .target {target}\n\
671         .address_size 64\n\
672         \n\
673         // cudaDeviceSynchronize() from device code\n\
674         // Synchronizes all pending child kernel launches.\n\
675         .func (.param .s32 _retval) __device_synchronize()\n\
676         {{\n\
677         .reg .s32 %retval;\n\
678         \n\
679         // Device-side cudaDeviceSynchronize is a runtime call\n\
680         // that blocks until all child kernels complete.\n\
681         // In PTX, this maps to a system call:\n\
682         //   call.uni cudaDeviceSynchronize;\n\
683         // For code generation, we emit the call pattern.\n\
684         mov.s32 %retval, 0; // cudaSuccess (placeholder)\n\
685         \n\
686         st.param.s32 [_retval], %retval;\n\
687         ret;\n\
688         }}\n"
689    );
690
691    Ok(ptx)
692}
693
694// ---------------------------------------------------------------------------
695// Tests
696// ---------------------------------------------------------------------------
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    fn default_config() -> DynamicParallelismConfig {
703        DynamicParallelismConfig::new(SmVersion::Sm80)
704    }
705
706    // -- Validation tests --
707
708    #[test]
709    fn validate_default_config_ok() {
710        let config = default_config();
711        assert!(validate_dynamic_config(&config).is_ok());
712    }
713
714    #[test]
715    fn validate_zero_nesting_depth_fails() {
716        let mut config = default_config();
717        config.max_nesting_depth = 0;
718        let err = validate_dynamic_config(&config);
719        assert!(err.is_err());
720        let err = err.err();
721        assert!(matches!(
722            err,
723            Some(LaunchError::InvalidDimension {
724                dim: "max_nesting_depth",
725                ..
726            })
727        ));
728    }
729
730    #[test]
731    fn validate_excessive_nesting_depth_fails() {
732        let mut config = default_config();
733        config.max_nesting_depth = 25;
734        let err = validate_dynamic_config(&config);
735        assert!(err.is_err());
736    }
737
738    #[test]
739    fn validate_max_nesting_depth_boundary() {
740        let mut config = default_config();
741        config.max_nesting_depth = CUDA_MAX_NESTING_DEPTH;
742        config.sync_depth = CUDA_MAX_NESTING_DEPTH;
743        assert!(validate_dynamic_config(&config).is_ok());
744    }
745
746    #[test]
747    fn validate_zero_pending_launches_fails() {
748        let mut config = default_config();
749        config.max_pending_launches = 0;
750        assert!(validate_dynamic_config(&config).is_err());
751    }
752
753    #[test]
754    fn validate_sync_depth_exceeds_nesting_fails() {
755        let mut config = default_config();
756        config.max_nesting_depth = 4;
757        config.sync_depth = 5;
758        assert!(validate_dynamic_config(&config).is_err());
759    }
760
761    #[test]
762    fn validate_zero_child_block_fails() {
763        let mut config = default_config();
764        config.child_block = Dim3::new(0, 256, 1);
765        assert!(validate_dynamic_config(&config).is_err());
766    }
767
768    #[test]
769    fn validate_zero_child_grid_fails() {
770        let mut config = default_config();
771        config.child_grid = Dim3::new(128, 0, 1);
772        assert!(validate_dynamic_config(&config).is_err());
773    }
774
775    #[test]
776    fn validate_block_size_exceeds_limit() {
777        let mut config = default_config();
778        // 32 * 32 * 2 = 2048, exceeds 1024 max
779        config.child_block = Dim3::new(32, 32, 2);
780        let err = validate_dynamic_config(&config);
781        assert!(matches!(
782            err,
783            Err(LaunchError::BlockSizeExceedsLimit { .. })
784        ));
785    }
786
787    #[test]
788    fn validate_shared_mem_exceeds_limit() {
789        let mut config = default_config();
790        config.child_shared_mem = 500_000; // exceeds any SM limit
791        let err = validate_dynamic_config(&config);
792        assert!(matches!(
793            err,
794            Err(LaunchError::SharedMemoryExceedsLimit { .. })
795        ));
796    }
797
798    // -- Plan generation tests --
799
800    #[test]
801    fn plan_dynamic_launch_ok() {
802        let config = default_config();
803        let plan = plan_dynamic_launch(&config);
804        assert!(plan.is_ok());
805        let plan = plan.ok();
806        assert!(plan.is_some());
807        if let Some(plan) = plan {
808            assert!(plan.estimated_child_launches > 0);
809            assert!(plan.memory_overhead_bytes > 0);
810            assert_eq!(plan.parent_kernel_name, "parent_kernel");
811            assert_eq!(plan.child_kernel_name, "child_kernel");
812        }
813    }
814
815    #[test]
816    fn plan_dynamic_launch_invalid_config_fails() {
817        let mut config = default_config();
818        config.max_nesting_depth = 0;
819        let plan = plan_dynamic_launch(&config);
820        assert!(plan.is_err());
821    }
822
823    #[test]
824    fn plan_display() {
825        let config = default_config();
826        let plan = plan_dynamic_launch(&config);
827        if let Ok(plan) = plan {
828            let display = format!("{plan}");
829            assert!(display.contains("parent_kernel"));
830            assert!(display.contains("child_kernel"));
831            assert!(display.contains("bytes"));
832        }
833    }
834
835    // -- Overhead estimation tests --
836
837    #[test]
838    fn estimate_overhead_basic() {
839        let overhead = estimate_launch_overhead(1, 1);
840        assert_eq!(
841            overhead,
842            BASE_LAUNCH_OVERHEAD_BYTES + PER_DEPTH_OVERHEAD_BYTES
843        );
844    }
845
846    #[test]
847    fn estimate_overhead_default() {
848        let overhead = estimate_launch_overhead(4, 2048);
849        let expected = BASE_LAUNCH_OVERHEAD_BYTES * 2048 + PER_DEPTH_OVERHEAD_BYTES * 4;
850        assert_eq!(overhead, expected);
851    }
852
853    #[test]
854    fn estimate_overhead_zero() {
855        let overhead = estimate_launch_overhead(0, 0);
856        assert_eq!(overhead, 0);
857    }
858
859    // -- SM nesting tests --
860
861    #[test]
862    fn max_nesting_all_sm_versions() {
863        assert_eq!(max_nesting_for_sm(SmVersion::Sm75), 24);
864        assert_eq!(max_nesting_for_sm(SmVersion::Sm80), 24);
865        assert_eq!(max_nesting_for_sm(SmVersion::Sm86), 24);
866        assert_eq!(max_nesting_for_sm(SmVersion::Sm89), 24);
867        assert_eq!(max_nesting_for_sm(SmVersion::Sm90), 24);
868        assert_eq!(max_nesting_for_sm(SmVersion::Sm90a), 24);
869        assert_eq!(max_nesting_for_sm(SmVersion::Sm100), 24);
870        assert_eq!(max_nesting_for_sm(SmVersion::Sm120), 24);
871    }
872
873    // -- PTX generation tests --
874
875    #[test]
876    fn generate_child_launch_ptx_basic() {
877        let child = ChildKernelSpec {
878            name: "child_add".to_string(),
879            param_types: vec![PtxType::U64, PtxType::U64, PtxType::U32],
880            grid_dim: GridSpec::Fixed(Dim3::x(64)),
881            block_dim: Dim3::x(256),
882            shared_mem_bytes: 0,
883        };
884        let result = generate_child_launch_ptx("parent_add", &child, SmVersion::Sm80);
885        assert!(result.is_ok());
886        let ptx = result.ok();
887        assert!(ptx.is_some());
888        if let Some(ptx) = ptx {
889            assert!(ptx.contains("child_add"));
890            assert!(ptx.contains("parent_add"));
891            assert!(ptx.contains(".version 7.0"));
892            assert!(ptx.contains("sm_80"));
893            assert!(ptx.contains("mov.u32 %grid_x, 64"));
894            assert!(ptx.contains(".u64"));
895            assert!(ptx.contains(".u32"));
896        }
897    }
898
899    #[test]
900    fn generate_child_launch_ptx_data_dependent() {
901        let child = ChildKernelSpec {
902            name: "child_scale".to_string(),
903            param_types: vec![PtxType::U64, PtxType::U32],
904            grid_dim: GridSpec::DataDependent { param_index: 1 },
905            block_dim: Dim3::x(128),
906            shared_mem_bytes: 1024,
907        };
908        let result = generate_child_launch_ptx("parent_scale", &child, SmVersion::Sm90);
909        assert!(result.is_ok());
910        if let Ok(ptx) = result {
911            assert!(ptx.contains("Data-dependent"));
912            assert!(ptx.contains("arg_1"));
913            assert!(ptx.contains("div.u32"));
914        }
915    }
916
917    #[test]
918    fn generate_child_launch_ptx_thread_dependent() {
919        let child = ChildKernelSpec {
920            name: "child_per_thread".to_string(),
921            param_types: vec![PtxType::U64],
922            grid_dim: GridSpec::ThreadDependent,
923            block_dim: Dim3::x(32),
924            shared_mem_bytes: 0,
925        };
926        let result = generate_child_launch_ptx("parent", &child, SmVersion::Sm80);
927        assert!(result.is_ok());
928        if let Ok(ptx) = result {
929            assert!(ptx.contains("Thread-dependent"));
930            assert!(ptx.contains("%tid.x"));
931        }
932    }
933
934    #[test]
935    fn generate_child_launch_ptx_empty_name_fails() {
936        let child = ChildKernelSpec {
937            name: String::new(),
938            param_types: vec![],
939            grid_dim: GridSpec::Fixed(Dim3::x(1)),
940            block_dim: Dim3::x(1),
941            shared_mem_bytes: 0,
942        };
943        let result = generate_child_launch_ptx("parent", &child, SmVersion::Sm80);
944        assert!(result.is_err());
945    }
946
947    #[test]
948    fn generate_child_launch_ptx_zero_block_fails() {
949        let child = ChildKernelSpec {
950            name: "child".to_string(),
951            param_types: vec![],
952            grid_dim: GridSpec::Fixed(Dim3::x(1)),
953            block_dim: Dim3::new(0, 1, 1),
954            shared_mem_bytes: 0,
955        };
956        let result = generate_child_launch_ptx("parent", &child, SmVersion::Sm80);
957        assert!(result.is_err());
958    }
959
960    #[test]
961    fn generate_device_sync_ptx_basic() {
962        let result = generate_device_sync_ptx(SmVersion::Sm80);
963        assert!(result.is_ok());
964        if let Ok(ptx) = result {
965            assert!(ptx.contains("__device_synchronize"));
966            assert!(ptx.contains(".version 7.0"));
967            assert!(ptx.contains("sm_80"));
968            assert!(ptx.contains("cudaDeviceSynchronize"));
969        }
970    }
971
972    #[test]
973    fn generate_device_sync_ptx_hopper() {
974        let result = generate_device_sync_ptx(SmVersion::Sm90);
975        assert!(result.is_ok());
976        if let Ok(ptx) = result {
977            assert!(ptx.contains(".version 8.0"));
978            assert!(ptx.contains("sm_90"));
979        }
980    }
981
982    // -- Display tests --
983
984    #[test]
985    fn config_display() {
986        let config = default_config();
987        let display = format!("{config}");
988        assert!(display.contains("depth=4"));
989        assert!(display.contains("pending=2048"));
990        assert!(display.contains("sync@2"));
991        assert!(display.contains("sm_80"));
992    }
993
994    #[test]
995    fn grid_spec_display() {
996        assert_eq!(format!("{}", GridSpec::Fixed(Dim3::x(64))), "Fixed(64)");
997        assert_eq!(
998            format!("{}", GridSpec::DataDependent { param_index: 2 }),
999            "DataDependent(param[2])"
1000        );
1001        assert_eq!(format!("{}", GridSpec::ThreadDependent), "ThreadDependent");
1002    }
1003}