use crate::{Result, runtime_error};
use crate::runtime::grid::{Grid, Block, Dim3};
use crate::runtime::kernel::ThreadContext;
use std::sync::{Arc, Mutex};
pub trait ChildKernel: Send + Sync {
fn execute(&self, ctx: ThreadContext);
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct ChildLaunch {
pub kernel_name: String,
pub grid: Dim3,
pub block: Dim3,
pub shared_mem_bytes: usize,
pub completed: bool,
}
pub struct DynamicParallelismContext {
max_depth: u32,
current_depth: u32,
launch_history: Arc<Mutex<Vec<ChildLaunch>>>,
max_pending: usize,
}
impl DynamicParallelismContext {
pub fn new() -> Self {
Self {
max_depth: 24, current_depth: 0,
launch_history: Arc::new(Mutex::new(Vec::new())),
max_pending: 2048,
}
}
pub fn with_max_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
pub fn with_max_pending(mut self, max: usize) -> Self {
self.max_pending = max;
self
}
pub fn launch_child<K: ChildKernel>(
&mut self,
kernel: &K,
grid: Grid,
block: Block,
shared_mem_bytes: usize,
) -> Result<()> {
if self.current_depth >= self.max_depth {
return Err(runtime_error!(
"Maximum kernel nesting depth {} exceeded",
self.max_depth
));
}
{
let history = self.launch_history.lock().unwrap();
let pending = history.iter().filter(|l| !l.completed).count();
if pending >= self.max_pending {
return Err(runtime_error!(
"Maximum pending child kernels {} exceeded",
self.max_pending
));
}
}
block.validate()?;
let launch_record = ChildLaunch {
kernel_name: kernel.name().to_string(),
grid: grid.dim,
block: block.dim,
shared_mem_bytes,
completed: false,
};
{
let mut history = self.launch_history.lock().unwrap();
history.push(launch_record);
}
self.current_depth += 1;
let total_blocks = grid.num_blocks();
let threads_per_block = block.num_threads();
for block_id in 0..total_blocks {
let block_idx = Dim3 {
x: block_id % grid.dim.x,
y: (block_id / grid.dim.x) % grid.dim.y,
z: block_id / (grid.dim.x * grid.dim.y),
};
for thread_id in 0..threads_per_block {
let thread_idx = Dim3 {
x: thread_id % block.dim.x,
y: (thread_id / block.dim.x) % block.dim.y,
z: thread_id / (block.dim.x * block.dim.y),
};
let ctx = ThreadContext {
thread_idx,
block_idx,
block_dim: block.dim,
grid_dim: grid.dim,
};
kernel.execute(ctx);
}
}
self.current_depth -= 1;
{
let mut history = self.launch_history.lock().unwrap();
if let Some(last) = history.last_mut() {
last.completed = true;
}
}
Ok(())
}
pub fn device_synchronize(&self) -> Result<()> {
Ok(())
}
pub fn completed_launches(&self) -> usize {
self.launch_history
.lock()
.unwrap()
.iter()
.filter(|l| l.completed)
.count()
}
pub fn launch_history(&self) -> Vec<ChildLaunch> {
self.launch_history.lock().unwrap().clone()
}
pub fn current_depth(&self) -> u32 {
self.current_depth
}
pub fn max_depth(&self) -> u32 {
self.max_depth
}
pub fn reset(&mut self) {
self.current_depth = 0;
self.launch_history.lock().unwrap().clear();
}
}
impl Default for DynamicParallelismContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct AddOneKernel {
data: Arc<Mutex<Vec<f32>>>,
}
impl ChildKernel for AddOneKernel {
fn execute(&self, ctx: ThreadContext) {
let tid = ctx.global_thread_id();
let mut data = self.data.lock().unwrap();
if tid < data.len() {
data[tid] += 1.0;
}
}
fn name(&self) -> &str {
"add_one"
}
}
#[test]
fn test_dynamic_parallelism_basic() {
let mut dp = DynamicParallelismContext::new();
let data = Arc::new(Mutex::new(vec![0.0f32; 16]));
let kernel = AddOneKernel { data: data.clone() };
dp.launch_child(&kernel, Grid::new(1u32), Block::new(16u32), 0)
.unwrap();
let result = data.lock().unwrap();
assert!(result.iter().all(|&v| v == 1.0));
assert_eq!(dp.completed_launches(), 1);
}
#[test]
fn test_dynamic_parallelism_multiple_launches() {
let mut dp = DynamicParallelismContext::new();
let data = Arc::new(Mutex::new(vec![0.0f32; 8]));
let kernel = AddOneKernel { data: data.clone() };
for _ in 0..3 {
dp.launch_child(&kernel, Grid::new(1u32), Block::new(8u32), 0)
.unwrap();
}
let result = data.lock().unwrap();
assert!(result.iter().all(|&v| v == 3.0));
assert_eq!(dp.completed_launches(), 3);
}
#[test]
fn test_dynamic_parallelism_max_depth() {
let mut dp = DynamicParallelismContext::new().with_max_depth(0);
let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
let kernel = AddOneKernel { data };
let result = dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0);
assert!(result.is_err());
}
#[test]
fn test_dynamic_parallelism_device_sync() {
let dp = DynamicParallelismContext::new();
assert!(dp.device_synchronize().is_ok());
}
#[test]
fn test_dynamic_parallelism_reset() {
let mut dp = DynamicParallelismContext::new();
let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
let kernel = AddOneKernel { data };
dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0)
.unwrap();
assert_eq!(dp.completed_launches(), 1);
dp.reset();
assert_eq!(dp.completed_launches(), 0);
assert_eq!(dp.current_depth(), 0);
}
struct AddOne2DKernel {
data: Arc<Mutex<Vec<f32>>>,
width: usize,
}
impl ChildKernel for AddOne2DKernel {
fn execute(&self, ctx: ThreadContext) {
let (x, y) = ctx.global_thread_id_2d();
let idx = y * self.width + x;
let mut data = self.data.lock().unwrap();
if idx < data.len() {
data[idx] += 1.0;
}
}
fn name(&self) -> &str {
"add_one_2d"
}
}
#[test]
fn test_dynamic_parallelism_2d_grid() {
let mut dp = DynamicParallelismContext::new();
let width = 2 * 4; let height = 2 * 4; let data = Arc::new(Mutex::new(vec![0.0f32; width * height]));
let kernel = AddOne2DKernel { data: data.clone(), width };
dp.launch_child(
&kernel,
Grid::new((2u32, 2u32)),
Block::new((4u32, 4u32)),
0,
)
.unwrap();
let result = data.lock().unwrap();
assert!(result.iter().all(|&v| v == 1.0));
}
#[test]
fn test_launch_history() {
let mut dp = DynamicParallelismContext::new();
let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
let kernel = AddOneKernel { data };
dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0)
.unwrap();
let history = dp.launch_history();
assert_eq!(history.len(), 1);
assert_eq!(history[0].kernel_name, "add_one");
assert!(history[0].completed);
}
}