Skip to main content

cuda_rust_wasm/runtime/
dynamic_parallelism.rs

1//! Dynamic parallelism support for child kernel launches
2//!
3//! Allows kernels to launch child kernels, emulating CUDA's dynamic
4//! parallelism feature. In the CPU emulation backend, child kernels
5//! are executed synchronously or queued for deferred execution.
6
7use crate::{Result, runtime_error};
8use crate::runtime::grid::{Grid, Block, Dim3};
9use crate::runtime::kernel::ThreadContext;
10use std::sync::{Arc, Mutex};
11
12/// A child kernel that can be launched from within a parent kernel
13pub trait ChildKernel: Send + Sync {
14    /// Execute the child kernel for a single thread
15    fn execute(&self, ctx: ThreadContext);
16
17    /// Get kernel name
18    fn name(&self) -> &str;
19}
20
21/// Child kernel launch record
22#[derive(Debug, Clone)]
23pub struct ChildLaunch {
24    /// Kernel name
25    pub kernel_name: String,
26    /// Grid dimensions
27    pub grid: Dim3,
28    /// Block dimensions
29    pub block: Dim3,
30    /// Shared memory size
31    pub shared_mem_bytes: usize,
32    /// Whether execution is complete
33    pub completed: bool,
34}
35
36/// Dynamic parallelism context for managing child kernel launches
37pub struct DynamicParallelismContext {
38    /// Maximum nesting depth for child launches
39    max_depth: u32,
40    /// Current nesting depth
41    current_depth: u32,
42    /// Record of child launches
43    launch_history: Arc<Mutex<Vec<ChildLaunch>>>,
44    /// Maximum number of concurrent child kernels
45    max_pending: usize,
46}
47
48impl DynamicParallelismContext {
49    /// Create a new dynamic parallelism context
50    pub fn new() -> Self {
51        Self {
52            max_depth: 24, // CUDA default max depth
53            current_depth: 0,
54            launch_history: Arc::new(Mutex::new(Vec::new())),
55            max_pending: 2048,
56        }
57    }
58
59    /// Create with custom nesting depth limit
60    pub fn with_max_depth(mut self, depth: u32) -> Self {
61        self.max_depth = depth;
62        self
63    }
64
65    /// Create with custom max pending limit
66    pub fn with_max_pending(mut self, max: usize) -> Self {
67        self.max_pending = max;
68        self
69    }
70
71    /// Launch a child kernel (synchronous execution in CPU backend)
72    pub fn launch_child<K: ChildKernel>(
73        &mut self,
74        kernel: &K,
75        grid: Grid,
76        block: Block,
77        shared_mem_bytes: usize,
78    ) -> Result<()> {
79        // Check nesting depth
80        if self.current_depth >= self.max_depth {
81            return Err(runtime_error!(
82                "Maximum kernel nesting depth {} exceeded",
83                self.max_depth
84            ));
85        }
86
87        // Check pending limit
88        {
89            let history = self.launch_history.lock().unwrap();
90            let pending = history.iter().filter(|l| !l.completed).count();
91            if pending >= self.max_pending {
92                return Err(runtime_error!(
93                    "Maximum pending child kernels {} exceeded",
94                    self.max_pending
95                ));
96            }
97        }
98
99        // Validate block config
100        block.validate()?;
101
102        // Record the launch
103        let launch_record = ChildLaunch {
104            kernel_name: kernel.name().to_string(),
105            grid: grid.dim,
106            block: block.dim,
107            shared_mem_bytes,
108            completed: false,
109        };
110
111        {
112            let mut history = self.launch_history.lock().unwrap();
113            history.push(launch_record);
114        }
115
116        // Execute child kernel (CPU emulation: synchronous)
117        self.current_depth += 1;
118
119        let total_blocks = grid.num_blocks();
120        let threads_per_block = block.num_threads();
121
122        for block_id in 0..total_blocks {
123            let block_idx = Dim3 {
124                x: block_id % grid.dim.x,
125                y: (block_id / grid.dim.x) % grid.dim.y,
126                z: block_id / (grid.dim.x * grid.dim.y),
127            };
128
129            for thread_id in 0..threads_per_block {
130                let thread_idx = Dim3 {
131                    x: thread_id % block.dim.x,
132                    y: (thread_id / block.dim.x) % block.dim.y,
133                    z: thread_id / (block.dim.x * block.dim.y),
134                };
135
136                let ctx = ThreadContext {
137                    thread_idx,
138                    block_idx,
139                    block_dim: block.dim,
140                    grid_dim: grid.dim,
141                };
142
143                kernel.execute(ctx);
144            }
145        }
146
147        self.current_depth -= 1;
148
149        // Mark as completed
150        {
151            let mut history = self.launch_history.lock().unwrap();
152            if let Some(last) = history.last_mut() {
153                last.completed = true;
154            }
155        }
156
157        Ok(())
158    }
159
160    /// Synchronize all pending child kernels (no-op in synchronous mode)
161    pub fn device_synchronize(&self) -> Result<()> {
162        // In CPU emulation, all launches are synchronous, so this is a no-op
163        Ok(())
164    }
165
166    /// Get the number of completed child launches
167    pub fn completed_launches(&self) -> usize {
168        self.launch_history
169            .lock()
170            .unwrap()
171            .iter()
172            .filter(|l| l.completed)
173            .count()
174    }
175
176    /// Get launch history
177    pub fn launch_history(&self) -> Vec<ChildLaunch> {
178        self.launch_history.lock().unwrap().clone()
179    }
180
181    /// Get current nesting depth
182    pub fn current_depth(&self) -> u32 {
183        self.current_depth
184    }
185
186    /// Get maximum nesting depth
187    pub fn max_depth(&self) -> u32 {
188        self.max_depth
189    }
190
191    /// Reset the context
192    pub fn reset(&mut self) {
193        self.current_depth = 0;
194        self.launch_history.lock().unwrap().clear();
195    }
196}
197
198impl Default for DynamicParallelismContext {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    struct AddOneKernel {
209        data: Arc<Mutex<Vec<f32>>>,
210    }
211
212    impl ChildKernel for AddOneKernel {
213        fn execute(&self, ctx: ThreadContext) {
214            let tid = ctx.global_thread_id();
215            let mut data = self.data.lock().unwrap();
216            if tid < data.len() {
217                data[tid] += 1.0;
218            }
219        }
220
221        fn name(&self) -> &str {
222            "add_one"
223        }
224    }
225
226    #[test]
227    fn test_dynamic_parallelism_basic() {
228        let mut dp = DynamicParallelismContext::new();
229        let data = Arc::new(Mutex::new(vec![0.0f32; 16]));
230        let kernel = AddOneKernel { data: data.clone() };
231
232        dp.launch_child(&kernel, Grid::new(1u32), Block::new(16u32), 0)
233            .unwrap();
234
235        let result = data.lock().unwrap();
236        assert!(result.iter().all(|&v| v == 1.0));
237        assert_eq!(dp.completed_launches(), 1);
238    }
239
240    #[test]
241    fn test_dynamic_parallelism_multiple_launches() {
242        let mut dp = DynamicParallelismContext::new();
243        let data = Arc::new(Mutex::new(vec![0.0f32; 8]));
244        let kernel = AddOneKernel { data: data.clone() };
245
246        for _ in 0..3 {
247            dp.launch_child(&kernel, Grid::new(1u32), Block::new(8u32), 0)
248                .unwrap();
249        }
250
251        let result = data.lock().unwrap();
252        assert!(result.iter().all(|&v| v == 3.0));
253        assert_eq!(dp.completed_launches(), 3);
254    }
255
256    #[test]
257    fn test_dynamic_parallelism_max_depth() {
258        let mut dp = DynamicParallelismContext::new().with_max_depth(0);
259        let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
260        let kernel = AddOneKernel { data };
261
262        let result = dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0);
263        assert!(result.is_err());
264    }
265
266    #[test]
267    fn test_dynamic_parallelism_device_sync() {
268        let dp = DynamicParallelismContext::new();
269        assert!(dp.device_synchronize().is_ok());
270    }
271
272    #[test]
273    fn test_dynamic_parallelism_reset() {
274        let mut dp = DynamicParallelismContext::new();
275        let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
276        let kernel = AddOneKernel { data };
277
278        dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0)
279            .unwrap();
280        assert_eq!(dp.completed_launches(), 1);
281
282        dp.reset();
283        assert_eq!(dp.completed_launches(), 0);
284        assert_eq!(dp.current_depth(), 0);
285    }
286
287    struct AddOne2DKernel {
288        data: Arc<Mutex<Vec<f32>>>,
289        width: usize,
290    }
291
292    impl ChildKernel for AddOne2DKernel {
293        fn execute(&self, ctx: ThreadContext) {
294            let (x, y) = ctx.global_thread_id_2d();
295            let idx = y * self.width + x;
296            let mut data = self.data.lock().unwrap();
297            if idx < data.len() {
298                data[idx] += 1.0;
299            }
300        }
301
302        fn name(&self) -> &str {
303            "add_one_2d"
304        }
305    }
306
307    #[test]
308    fn test_dynamic_parallelism_2d_grid() {
309        let mut dp = DynamicParallelismContext::new();
310        // 2x2 grid, 4x4 block = 8x8 = 64 threads
311        let width = 2 * 4; // grid.x * block.x = 8
312        let height = 2 * 4; // grid.y * block.y = 8
313        let data = Arc::new(Mutex::new(vec![0.0f32; width * height]));
314        let kernel = AddOne2DKernel { data: data.clone(), width };
315
316        dp.launch_child(
317            &kernel,
318            Grid::new((2u32, 2u32)),
319            Block::new((4u32, 4u32)),
320            0,
321        )
322        .unwrap();
323
324        let result = data.lock().unwrap();
325        assert!(result.iter().all(|&v| v == 1.0));
326    }
327
328    #[test]
329    fn test_launch_history() {
330        let mut dp = DynamicParallelismContext::new();
331        let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
332        let kernel = AddOneKernel { data };
333
334        dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0)
335            .unwrap();
336
337        let history = dp.launch_history();
338        assert_eq!(history.len(), 1);
339        assert_eq!(history[0].kernel_name, "add_one");
340        assert!(history[0].completed);
341    }
342}