cuda_rust_wasm/runtime/
dynamic_parallelism.rs1use crate::{Result, runtime_error};
8use crate::runtime::grid::{Grid, Block, Dim3};
9use crate::runtime::kernel::ThreadContext;
10use std::sync::{Arc, Mutex};
11
12pub trait ChildKernel: Send + Sync {
14 fn execute(&self, ctx: ThreadContext);
16
17 fn name(&self) -> &str;
19}
20
21#[derive(Debug, Clone)]
23pub struct ChildLaunch {
24 pub kernel_name: String,
26 pub grid: Dim3,
28 pub block: Dim3,
30 pub shared_mem_bytes: usize,
32 pub completed: bool,
34}
35
36pub struct DynamicParallelismContext {
38 max_depth: u32,
40 current_depth: u32,
42 launch_history: Arc<Mutex<Vec<ChildLaunch>>>,
44 max_pending: usize,
46}
47
48impl DynamicParallelismContext {
49 pub fn new() -> Self {
51 Self {
52 max_depth: 24, current_depth: 0,
54 launch_history: Arc::new(Mutex::new(Vec::new())),
55 max_pending: 2048,
56 }
57 }
58
59 pub fn with_max_depth(mut self, depth: u32) -> Self {
61 self.max_depth = depth;
62 self
63 }
64
65 pub fn with_max_pending(mut self, max: usize) -> Self {
67 self.max_pending = max;
68 self
69 }
70
71 pub fn launch_child<K: ChildKernel>(
73 &mut self,
74 kernel: &K,
75 grid: Grid,
76 block: Block,
77 shared_mem_bytes: usize,
78 ) -> Result<()> {
79 if self.current_depth >= self.max_depth {
81 return Err(runtime_error!(
82 "Maximum kernel nesting depth {} exceeded",
83 self.max_depth
84 ));
85 }
86
87 {
89 let history = self.launch_history.lock().unwrap();
90 let pending = history.iter().filter(|l| !l.completed).count();
91 if pending >= self.max_pending {
92 return Err(runtime_error!(
93 "Maximum pending child kernels {} exceeded",
94 self.max_pending
95 ));
96 }
97 }
98
99 block.validate()?;
101
102 let launch_record = ChildLaunch {
104 kernel_name: kernel.name().to_string(),
105 grid: grid.dim,
106 block: block.dim,
107 shared_mem_bytes,
108 completed: false,
109 };
110
111 {
112 let mut history = self.launch_history.lock().unwrap();
113 history.push(launch_record);
114 }
115
116 self.current_depth += 1;
118
119 let total_blocks = grid.num_blocks();
120 let threads_per_block = block.num_threads();
121
122 for block_id in 0..total_blocks {
123 let block_idx = Dim3 {
124 x: block_id % grid.dim.x,
125 y: (block_id / grid.dim.x) % grid.dim.y,
126 z: block_id / (grid.dim.x * grid.dim.y),
127 };
128
129 for thread_id in 0..threads_per_block {
130 let thread_idx = Dim3 {
131 x: thread_id % block.dim.x,
132 y: (thread_id / block.dim.x) % block.dim.y,
133 z: thread_id / (block.dim.x * block.dim.y),
134 };
135
136 let ctx = ThreadContext {
137 thread_idx,
138 block_idx,
139 block_dim: block.dim,
140 grid_dim: grid.dim,
141 };
142
143 kernel.execute(ctx);
144 }
145 }
146
147 self.current_depth -= 1;
148
149 {
151 let mut history = self.launch_history.lock().unwrap();
152 if let Some(last) = history.last_mut() {
153 last.completed = true;
154 }
155 }
156
157 Ok(())
158 }
159
160 pub fn device_synchronize(&self) -> Result<()> {
162 Ok(())
164 }
165
166 pub fn completed_launches(&self) -> usize {
168 self.launch_history
169 .lock()
170 .unwrap()
171 .iter()
172 .filter(|l| l.completed)
173 .count()
174 }
175
176 pub fn launch_history(&self) -> Vec<ChildLaunch> {
178 self.launch_history.lock().unwrap().clone()
179 }
180
181 pub fn current_depth(&self) -> u32 {
183 self.current_depth
184 }
185
186 pub fn max_depth(&self) -> u32 {
188 self.max_depth
189 }
190
191 pub fn reset(&mut self) {
193 self.current_depth = 0;
194 self.launch_history.lock().unwrap().clear();
195 }
196}
197
198impl Default for DynamicParallelismContext {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 struct AddOneKernel {
209 data: Arc<Mutex<Vec<f32>>>,
210 }
211
212 impl ChildKernel for AddOneKernel {
213 fn execute(&self, ctx: ThreadContext) {
214 let tid = ctx.global_thread_id();
215 let mut data = self.data.lock().unwrap();
216 if tid < data.len() {
217 data[tid] += 1.0;
218 }
219 }
220
221 fn name(&self) -> &str {
222 "add_one"
223 }
224 }
225
226 #[test]
227 fn test_dynamic_parallelism_basic() {
228 let mut dp = DynamicParallelismContext::new();
229 let data = Arc::new(Mutex::new(vec![0.0f32; 16]));
230 let kernel = AddOneKernel { data: data.clone() };
231
232 dp.launch_child(&kernel, Grid::new(1u32), Block::new(16u32), 0)
233 .unwrap();
234
235 let result = data.lock().unwrap();
236 assert!(result.iter().all(|&v| v == 1.0));
237 assert_eq!(dp.completed_launches(), 1);
238 }
239
240 #[test]
241 fn test_dynamic_parallelism_multiple_launches() {
242 let mut dp = DynamicParallelismContext::new();
243 let data = Arc::new(Mutex::new(vec![0.0f32; 8]));
244 let kernel = AddOneKernel { data: data.clone() };
245
246 for _ in 0..3 {
247 dp.launch_child(&kernel, Grid::new(1u32), Block::new(8u32), 0)
248 .unwrap();
249 }
250
251 let result = data.lock().unwrap();
252 assert!(result.iter().all(|&v| v == 3.0));
253 assert_eq!(dp.completed_launches(), 3);
254 }
255
256 #[test]
257 fn test_dynamic_parallelism_max_depth() {
258 let mut dp = DynamicParallelismContext::new().with_max_depth(0);
259 let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
260 let kernel = AddOneKernel { data };
261
262 let result = dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0);
263 assert!(result.is_err());
264 }
265
266 #[test]
267 fn test_dynamic_parallelism_device_sync() {
268 let dp = DynamicParallelismContext::new();
269 assert!(dp.device_synchronize().is_ok());
270 }
271
272 #[test]
273 fn test_dynamic_parallelism_reset() {
274 let mut dp = DynamicParallelismContext::new();
275 let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
276 let kernel = AddOneKernel { data };
277
278 dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0)
279 .unwrap();
280 assert_eq!(dp.completed_launches(), 1);
281
282 dp.reset();
283 assert_eq!(dp.completed_launches(), 0);
284 assert_eq!(dp.current_depth(), 0);
285 }
286
287 struct AddOne2DKernel {
288 data: Arc<Mutex<Vec<f32>>>,
289 width: usize,
290 }
291
292 impl ChildKernel for AddOne2DKernel {
293 fn execute(&self, ctx: ThreadContext) {
294 let (x, y) = ctx.global_thread_id_2d();
295 let idx = y * self.width + x;
296 let mut data = self.data.lock().unwrap();
297 if idx < data.len() {
298 data[idx] += 1.0;
299 }
300 }
301
302 fn name(&self) -> &str {
303 "add_one_2d"
304 }
305 }
306
307 #[test]
308 fn test_dynamic_parallelism_2d_grid() {
309 let mut dp = DynamicParallelismContext::new();
310 let width = 2 * 4; let height = 2 * 4; let data = Arc::new(Mutex::new(vec![0.0f32; width * height]));
314 let kernel = AddOne2DKernel { data: data.clone(), width };
315
316 dp.launch_child(
317 &kernel,
318 Grid::new((2u32, 2u32)),
319 Block::new((4u32, 4u32)),
320 0,
321 )
322 .unwrap();
323
324 let result = data.lock().unwrap();
325 assert!(result.iter().all(|&v| v == 1.0));
326 }
327
328 #[test]
329 fn test_launch_history() {
330 let mut dp = DynamicParallelismContext::new();
331 let data = Arc::new(Mutex::new(vec![0.0f32; 4]));
332 let kernel = AddOneKernel { data };
333
334 dp.launch_child(&kernel, Grid::new(1u32), Block::new(4u32), 0)
335 .unwrap();
336
337 let history = dp.launch_history();
338 assert_eq!(history.len(), 1);
339 assert_eq!(history[0].kernel_name, "add_one");
340 assert!(history[0].completed);
341 }
342}