use std::fmt;
use oxicuda_ptx::PtxType;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::error::PtxGenError;
use crate::error::LaunchError;
use crate::grid::Dim3;
const CUDA_MAX_NESTING_DEPTH: u32 = 24;
const DEFAULT_MAX_PENDING_LAUNCHES: u32 = 2048;
const BASE_LAUNCH_OVERHEAD_BYTES: u64 = 2048;
const PER_DEPTH_OVERHEAD_BYTES: u64 = 4096;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DynamicParallelismConfig {
pub max_nesting_depth: u32,
pub max_pending_launches: u32,
pub sync_depth: u32,
pub child_grid: Dim3,
pub child_block: Dim3,
pub child_shared_mem: u32,
pub sm_version: SmVersion,
}
impl DynamicParallelismConfig {
#[must_use]
pub fn new(sm_version: SmVersion) -> Self {
Self {
max_nesting_depth: 4,
max_pending_launches: DEFAULT_MAX_PENDING_LAUNCHES,
sync_depth: 2,
child_grid: Dim3::x(128),
child_block: Dim3::x(256),
child_shared_mem: 0,
sm_version,
}
}
}
impl fmt::Display for DynamicParallelismConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DynParallelism(depth={}, pending={}, sync@{}, grid={}, block={}, smem={}, {})",
self.max_nesting_depth,
self.max_pending_launches,
self.sync_depth,
self.child_grid,
self.child_block,
self.child_shared_mem,
self.sm_version,
)
}
}
#[derive(Debug, Clone)]
pub struct DynamicLaunchPlan {
pub config: DynamicParallelismConfig,
pub parent_kernel_name: String,
pub child_kernel_name: String,
pub estimated_child_launches: u64,
pub memory_overhead_bytes: u64,
}
impl fmt::Display for DynamicLaunchPlan {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DynamicLaunchPlan {{ parent: '{}', child: '{}', \
est_launches: {}, overhead: {} bytes, config: {} }}",
self.parent_kernel_name,
self.child_kernel_name,
self.estimated_child_launches,
self.memory_overhead_bytes,
self.config,
)
}
}
#[derive(Debug, Clone)]
pub struct ChildKernelSpec {
pub name: String,
pub param_types: Vec<PtxType>,
pub grid_dim: GridSpec,
pub block_dim: Dim3,
pub shared_mem_bytes: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GridSpec {
Fixed(Dim3),
DataDependent {
param_index: u32,
},
ThreadDependent,
}
impl fmt::Display for GridSpec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fixed(dim) => write!(f, "Fixed({dim})"),
Self::DataDependent { param_index } => {
write!(f, "DataDependent(param[{param_index}])")
}
Self::ThreadDependent => write!(f, "ThreadDependent"),
}
}
}
pub fn validate_dynamic_config(config: &DynamicParallelismConfig) -> Result<(), LaunchError> {
if config.max_nesting_depth == 0 || config.max_nesting_depth > CUDA_MAX_NESTING_DEPTH {
return Err(LaunchError::InvalidDimension {
dim: "max_nesting_depth",
value: config.max_nesting_depth,
});
}
if config.max_pending_launches == 0 {
return Err(LaunchError::InvalidDimension {
dim: "max_pending_launches",
value: 0,
});
}
if config.sync_depth > config.max_nesting_depth {
return Err(LaunchError::InvalidDimension {
dim: "sync_depth",
value: config.sync_depth,
});
}
if config.child_grid.x == 0 {
return Err(LaunchError::InvalidDimension {
dim: "child_grid.x",
value: 0,
});
}
if config.child_grid.y == 0 {
return Err(LaunchError::InvalidDimension {
dim: "child_grid.y",
value: 0,
});
}
if config.child_grid.z == 0 {
return Err(LaunchError::InvalidDimension {
dim: "child_grid.z",
value: 0,
});
}
if config.child_block.x == 0 {
return Err(LaunchError::InvalidDimension {
dim: "child_block.x",
value: 0,
});
}
if config.child_block.y == 0 {
return Err(LaunchError::InvalidDimension {
dim: "child_block.y",
value: 0,
});
}
if config.child_block.z == 0 {
return Err(LaunchError::InvalidDimension {
dim: "child_block.z",
value: 0,
});
}
let max_threads = config.sm_version.max_threads_per_block();
let block_total = config.child_block.total();
if block_total > max_threads {
return Err(LaunchError::BlockSizeExceedsLimit {
requested: block_total,
max: max_threads,
});
}
let max_smem = config.sm_version.max_shared_mem_per_block();
if config.child_shared_mem > max_smem {
return Err(LaunchError::SharedMemoryExceedsLimit {
requested: config.child_shared_mem,
max: max_smem,
});
}
Ok(())
}
pub fn plan_dynamic_launch(
config: &DynamicParallelismConfig,
) -> Result<DynamicLaunchPlan, LaunchError> {
validate_dynamic_config(config)?;
let parent_grid_total = config.child_grid.total() as u64;
let estimated_child_launches =
parent_grid_total.saturating_mul(config.child_block.total() as u64);
let memory_overhead_bytes =
estimate_launch_overhead(config.max_nesting_depth, config.max_pending_launches);
Ok(DynamicLaunchPlan {
config: config.clone(),
parent_kernel_name: String::from("parent_kernel"),
child_kernel_name: String::from("child_kernel"),
estimated_child_launches,
memory_overhead_bytes,
})
}
pub fn estimate_launch_overhead(depth: u32, pending: u32) -> u64 {
let per_launch = BASE_LAUNCH_OVERHEAD_BYTES.saturating_mul(pending as u64);
let per_depth = PER_DEPTH_OVERHEAD_BYTES.saturating_mul(depth as u64);
per_launch.saturating_add(per_depth)
}
pub fn max_nesting_for_sm(sm: SmVersion) -> u32 {
match sm {
SmVersion::Sm75 => CUDA_MAX_NESTING_DEPTH,
SmVersion::Sm80 | SmVersion::Sm86 => CUDA_MAX_NESTING_DEPTH,
SmVersion::Sm89 => CUDA_MAX_NESTING_DEPTH,
SmVersion::Sm90 | SmVersion::Sm90a => CUDA_MAX_NESTING_DEPTH,
SmVersion::Sm100 => CUDA_MAX_NESTING_DEPTH,
SmVersion::Sm120 => CUDA_MAX_NESTING_DEPTH,
}
}
pub fn generate_child_launch_ptx(
parent_name: &str,
child: &ChildKernelSpec,
sm: SmVersion,
) -> Result<String, PtxGenError> {
if child.name.is_empty() {
return Err(PtxGenError::GenerationFailed(
"child kernel name must not be empty".to_string(),
));
}
if child.block_dim.x == 0 || child.block_dim.y == 0 || child.block_dim.z == 0 {
return Err(PtxGenError::GenerationFailed(
"child block dimensions must be non-zero".to_string(),
));
}
let (isa_major, isa_minor) = sm.ptx_isa_version();
let target = sm.as_ptx_str();
let mut ptx = String::with_capacity(2048);
ptx.push_str(&format!(
"// Dynamic parallelism: {parent_name} -> {child_name}\n",
child_name = child.name,
));
ptx.push_str(&format!(
".version {isa_major}.{isa_minor}\n\
.target {target}\n\
.address_size 64\n\n"
));
ptx.push_str(&format!(
"// Child kernel declaration\n\
.extern .entry {child_name}(\n",
child_name = child.name,
));
for (i, ptype) in child.param_types.iter().enumerate() {
let comma = if i + 1 < child.param_types.len() {
","
} else {
""
};
ptx.push_str(&format!(
" .param {ty} _param_{i}{comma}\n",
ty = ptype.as_ptx_str(),
));
}
ptx.push_str(")\n\n");
let func_name = format!(
"__{parent_name}_launch_{child_name}",
child_name = child.name
);
ptx.push_str("// Device-side launch helper\n");
ptx.push_str(&format!(".func (.param .s32 _retval) {func_name}(\n"));
for (i, ptype) in child.param_types.iter().enumerate() {
let comma = if i + 1 < child.param_types.len() {
","
} else {
""
};
ptx.push_str(&format!(
" .param {ty} arg_{i}{comma}\n",
ty = ptype.as_ptx_str(),
));
}
ptx.push_str(")\n{\n");
ptx.push_str(" // Register declarations\n");
ptx.push_str(" .reg .s32 %retval;\n");
ptx.push_str(" .reg .u32 %grid_x, %grid_y, %grid_z;\n");
ptx.push_str(" .reg .u32 %block_x, %block_y, %block_z;\n");
ptx.push_str(" .reg .u32 %shared_mem;\n");
ptx.push_str(" .reg .u64 %stream;\n");
if let GridSpec::DataDependent { .. } = &child.grid_dim {
ptx.push_str(" .reg .u32 %n_elements, %block_size;\n");
}
if matches!(&child.grid_dim, GridSpec::ThreadDependent) {
ptx.push_str(" .reg .u32 %tid_x, %ntid_x, %ctaid_x;\n");
}
ptx.push('\n');
match &child.grid_dim {
GridSpec::Fixed(dim) => {
ptx.push_str(&format!(
" // Fixed grid dimensions\n\
mov.u32 %grid_x, {gx};\n\
mov.u32 %grid_y, {gy};\n\
mov.u32 %grid_z, {gz};\n",
gx = dim.x,
gy = dim.y,
gz = dim.z,
));
}
GridSpec::DataDependent { param_index } => {
ptx.push_str(&format!(
" // Data-dependent grid: ceil(param[{param_index}] / block.x)\n\
ld.param.u32 %n_elements, [arg_{param_index}];\n\
mov.u32 %block_size, {bx};\n\
add.u32 %grid_x, %n_elements, %block_size;\n\
sub.u32 %grid_x, %grid_x, 1;\n\
div.u32 %grid_x, %grid_x, %block_size;\n\
mov.u32 %grid_y, 1;\n\
mov.u32 %grid_z, 1;\n",
bx = child.block_dim.x,
));
}
GridSpec::ThreadDependent => {
ptx.push_str(
" // Thread-dependent: one child launch per parent thread\n\
mov.u32 %tid_x, %tid.x;\n\
mov.u32 %ntid_x, %ntid.x;\n\
mov.u32 %ctaid_x, %ctaid.x;\n\
// Each thread launches a 1-block child grid\n\
mov.u32 %grid_x, 1;\n\
mov.u32 %grid_y, 1;\n\
mov.u32 %grid_z, 1;\n",
);
}
}
ptx.push_str(&format!(
"\n // Block dimensions\n\
mov.u32 %block_x, {bx};\n\
mov.u32 %block_y, {by};\n\
mov.u32 %block_z, {bz};\n",
bx = child.block_dim.x,
by = child.block_dim.y,
bz = child.block_dim.z,
));
ptx.push_str(&format!(
"\n // Shared memory and stream (NULL = default stream)\n\
mov.u32 %shared_mem, {smem};\n\
mov.u64 %stream, 0;\n",
smem = child.shared_mem_bytes,
));
ptx.push_str(&format!(
"\n // Launch child kernel: {child_name}\n\
// cudaLaunchDevice(\n\
// &{child_name},\n\
// param_buffer,\n\
// dim3(grid_x, grid_y, grid_z),\n\
// dim3(block_x, block_y, block_z),\n\
// shared_mem, stream\n\
// )\n\
// Note: actual device-side launch uses cudaLaunchDeviceV2\n\
// which takes a pre-formatted parameter buffer.\n\
mov.s32 %retval, 0; // cudaSuccess\n",
child_name = child.name,
));
ptx.push_str(
"\n st.param.s32 [_retval], %retval;\n\
ret;\n\
}\n",
);
Ok(ptx)
}
pub fn generate_device_sync_ptx(sm: SmVersion) -> Result<String, PtxGenError> {
let (isa_major, isa_minor) = sm.ptx_isa_version();
let target = sm.as_ptx_str();
let ptx = format!(
"// Device-side synchronization\n\
.version {isa_major}.{isa_minor}\n\
.target {target}\n\
.address_size 64\n\
\n\
// cudaDeviceSynchronize() from device code\n\
// Synchronizes all pending child kernel launches.\n\
.func (.param .s32 _retval) __device_synchronize()\n\
{{\n\
.reg .s32 %retval;\n\
\n\
// Device-side cudaDeviceSynchronize is a runtime call\n\
// that blocks until all child kernels complete.\n\
// In PTX, this maps to a system call:\n\
// call.uni cudaDeviceSynchronize;\n\
// For code generation, we emit the call pattern.\n\
mov.s32 %retval, 0; // cudaSuccess (placeholder)\n\
\n\
st.param.s32 [_retval], %retval;\n\
ret;\n\
}}\n"
);
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> DynamicParallelismConfig {
DynamicParallelismConfig::new(SmVersion::Sm80)
}
#[test]
fn validate_default_config_ok() {
let config = default_config();
assert!(validate_dynamic_config(&config).is_ok());
}
#[test]
fn validate_zero_nesting_depth_fails() {
let mut config = default_config();
config.max_nesting_depth = 0;
let err = validate_dynamic_config(&config);
assert!(err.is_err());
let err = err.err();
assert!(matches!(
err,
Some(LaunchError::InvalidDimension {
dim: "max_nesting_depth",
..
})
));
}
#[test]
fn validate_excessive_nesting_depth_fails() {
let mut config = default_config();
config.max_nesting_depth = 25;
let err = validate_dynamic_config(&config);
assert!(err.is_err());
}
#[test]
fn validate_max_nesting_depth_boundary() {
let mut config = default_config();
config.max_nesting_depth = CUDA_MAX_NESTING_DEPTH;
config.sync_depth = CUDA_MAX_NESTING_DEPTH;
assert!(validate_dynamic_config(&config).is_ok());
}
#[test]
fn validate_zero_pending_launches_fails() {
let mut config = default_config();
config.max_pending_launches = 0;
assert!(validate_dynamic_config(&config).is_err());
}
#[test]
fn validate_sync_depth_exceeds_nesting_fails() {
let mut config = default_config();
config.max_nesting_depth = 4;
config.sync_depth = 5;
assert!(validate_dynamic_config(&config).is_err());
}
#[test]
fn validate_zero_child_block_fails() {
let mut config = default_config();
config.child_block = Dim3::new(0, 256, 1);
assert!(validate_dynamic_config(&config).is_err());
}
#[test]
fn validate_zero_child_grid_fails() {
let mut config = default_config();
config.child_grid = Dim3::new(128, 0, 1);
assert!(validate_dynamic_config(&config).is_err());
}
#[test]
fn validate_block_size_exceeds_limit() {
let mut config = default_config();
config.child_block = Dim3::new(32, 32, 2);
let err = validate_dynamic_config(&config);
assert!(matches!(
err,
Err(LaunchError::BlockSizeExceedsLimit { .. })
));
}
#[test]
fn validate_shared_mem_exceeds_limit() {
let mut config = default_config();
config.child_shared_mem = 500_000; let err = validate_dynamic_config(&config);
assert!(matches!(
err,
Err(LaunchError::SharedMemoryExceedsLimit { .. })
));
}
#[test]
fn plan_dynamic_launch_ok() {
let config = default_config();
let plan = plan_dynamic_launch(&config);
assert!(plan.is_ok());
let plan = plan.ok();
assert!(plan.is_some());
if let Some(plan) = plan {
assert!(plan.estimated_child_launches > 0);
assert!(plan.memory_overhead_bytes > 0);
assert_eq!(plan.parent_kernel_name, "parent_kernel");
assert_eq!(plan.child_kernel_name, "child_kernel");
}
}
#[test]
fn plan_dynamic_launch_invalid_config_fails() {
let mut config = default_config();
config.max_nesting_depth = 0;
let plan = plan_dynamic_launch(&config);
assert!(plan.is_err());
}
#[test]
fn plan_display() {
let config = default_config();
let plan = plan_dynamic_launch(&config);
if let Ok(plan) = plan {
let display = format!("{plan}");
assert!(display.contains("parent_kernel"));
assert!(display.contains("child_kernel"));
assert!(display.contains("bytes"));
}
}
#[test]
fn estimate_overhead_basic() {
let overhead = estimate_launch_overhead(1, 1);
assert_eq!(
overhead,
BASE_LAUNCH_OVERHEAD_BYTES + PER_DEPTH_OVERHEAD_BYTES
);
}
#[test]
fn estimate_overhead_default() {
let overhead = estimate_launch_overhead(4, 2048);
let expected = BASE_LAUNCH_OVERHEAD_BYTES * 2048 + PER_DEPTH_OVERHEAD_BYTES * 4;
assert_eq!(overhead, expected);
}
#[test]
fn estimate_overhead_zero() {
let overhead = estimate_launch_overhead(0, 0);
assert_eq!(overhead, 0);
}
#[test]
fn max_nesting_all_sm_versions() {
assert_eq!(max_nesting_for_sm(SmVersion::Sm75), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm80), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm86), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm89), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm90), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm90a), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm100), 24);
assert_eq!(max_nesting_for_sm(SmVersion::Sm120), 24);
}
#[test]
fn generate_child_launch_ptx_basic() {
let child = ChildKernelSpec {
name: "child_add".to_string(),
param_types: vec![PtxType::U64, PtxType::U64, PtxType::U32],
grid_dim: GridSpec::Fixed(Dim3::x(64)),
block_dim: Dim3::x(256),
shared_mem_bytes: 0,
};
let result = generate_child_launch_ptx("parent_add", &child, SmVersion::Sm80);
assert!(result.is_ok());
let ptx = result.ok();
assert!(ptx.is_some());
if let Some(ptx) = ptx {
assert!(ptx.contains("child_add"));
assert!(ptx.contains("parent_add"));
assert!(ptx.contains(".version 7.0"));
assert!(ptx.contains("sm_80"));
assert!(ptx.contains("mov.u32 %grid_x, 64"));
assert!(ptx.contains(".u64"));
assert!(ptx.contains(".u32"));
}
}
#[test]
fn generate_child_launch_ptx_data_dependent() {
let child = ChildKernelSpec {
name: "child_scale".to_string(),
param_types: vec![PtxType::U64, PtxType::U32],
grid_dim: GridSpec::DataDependent { param_index: 1 },
block_dim: Dim3::x(128),
shared_mem_bytes: 1024,
};
let result = generate_child_launch_ptx("parent_scale", &child, SmVersion::Sm90);
assert!(result.is_ok());
if let Ok(ptx) = result {
assert!(ptx.contains("Data-dependent"));
assert!(ptx.contains("arg_1"));
assert!(ptx.contains("div.u32"));
}
}
#[test]
fn generate_child_launch_ptx_thread_dependent() {
let child = ChildKernelSpec {
name: "child_per_thread".to_string(),
param_types: vec![PtxType::U64],
grid_dim: GridSpec::ThreadDependent,
block_dim: Dim3::x(32),
shared_mem_bytes: 0,
};
let result = generate_child_launch_ptx("parent", &child, SmVersion::Sm80);
assert!(result.is_ok());
if let Ok(ptx) = result {
assert!(ptx.contains("Thread-dependent"));
assert!(ptx.contains("%tid.x"));
}
}
#[test]
fn generate_child_launch_ptx_empty_name_fails() {
let child = ChildKernelSpec {
name: String::new(),
param_types: vec![],
grid_dim: GridSpec::Fixed(Dim3::x(1)),
block_dim: Dim3::x(1),
shared_mem_bytes: 0,
};
let result = generate_child_launch_ptx("parent", &child, SmVersion::Sm80);
assert!(result.is_err());
}
#[test]
fn generate_child_launch_ptx_zero_block_fails() {
let child = ChildKernelSpec {
name: "child".to_string(),
param_types: vec![],
grid_dim: GridSpec::Fixed(Dim3::x(1)),
block_dim: Dim3::new(0, 1, 1),
shared_mem_bytes: 0,
};
let result = generate_child_launch_ptx("parent", &child, SmVersion::Sm80);
assert!(result.is_err());
}
#[test]
fn generate_device_sync_ptx_basic() {
let result = generate_device_sync_ptx(SmVersion::Sm80);
assert!(result.is_ok());
if let Ok(ptx) = result {
assert!(ptx.contains("__device_synchronize"));
assert!(ptx.contains(".version 7.0"));
assert!(ptx.contains("sm_80"));
assert!(ptx.contains("cudaDeviceSynchronize"));
}
}
#[test]
fn generate_device_sync_ptx_hopper() {
let result = generate_device_sync_ptx(SmVersion::Sm90);
assert!(result.is_ok());
if let Ok(ptx) = result {
assert!(ptx.contains(".version 8.0"));
assert!(ptx.contains("sm_90"));
}
}
#[test]
fn config_display() {
let config = default_config();
let display = format!("{config}");
assert!(display.contains("depth=4"));
assert!(display.contains("pending=2048"));
assert!(display.contains("sync@2"));
assert!(display.contains("sm_80"));
}
#[test]
fn grid_spec_display() {
assert_eq!(format!("{}", GridSpec::Fixed(Dim3::x(64))), "Fixed(64)");
assert_eq!(
format!("{}", GridSpec::DataDependent { param_index: 2 }),
"DataDependent(param[2])"
);
assert_eq!(format!("{}", GridSpec::ThreadDependent), "ThreadDependent");
}
}