Skip to main content

cuda_rust_wasm/runtime/
cooperative_groups.rs

1//! Cooperative groups for cross-block synchronization
2//!
3//! Provides a software emulation of CUDA cooperative groups, enabling
4//! flexible thread grouping and synchronization patterns beyond the
5//! traditional block-level `__syncthreads()`.
6
7use crate::{Result, runtime_error};
8use std::sync::{Arc, Barrier, Mutex};
9
10/// Thread group abstraction for cooperative kernel execution
11#[derive(Debug, Clone)]
12pub struct CooperativeGroup {
13    /// Number of threads in this group
14    size: u32,
15    /// Rank of this thread within the group
16    rank: u32,
17    /// Synchronization barrier shared among group members
18    barrier: Arc<Barrier>,
19}
20
21impl CooperativeGroup {
22    /// Create a new cooperative group
23    pub fn new(size: u32, rank: u32) -> Result<Self> {
24        if rank >= size {
25            return Err(runtime_error!(
26                "Thread rank {} exceeds group size {}",
27                rank, size
28            ));
29        }
30        Ok(Self {
31            size,
32            rank,
33            barrier: Arc::new(Barrier::new(size as usize)),
34        })
35    }
36
37    /// Create a group with a shared barrier
38    pub fn with_barrier(size: u32, rank: u32, barrier: Arc<Barrier>) -> Result<Self> {
39        if rank >= size {
40            return Err(runtime_error!(
41                "Thread rank {} exceeds group size {}",
42                rank, size
43            ));
44        }
45        Ok(Self { size, rank, barrier })
46    }
47
48    /// Get the number of threads in this group
49    pub fn size(&self) -> u32 {
50        self.size
51    }
52
53    /// Get this thread's rank within the group
54    pub fn thread_rank(&self) -> u32 {
55        self.rank
56    }
57
58    /// Synchronize all threads in this group
59    pub fn sync(&self) {
60        self.barrier.wait();
61    }
62
63    /// Check if this thread is the leader (rank 0)
64    pub fn is_leader(&self) -> bool {
65        self.rank == 0
66    }
67}
68
69/// Thread block group (all threads in a block)
70pub struct ThreadBlockGroup {
71    inner: CooperativeGroup,
72    block_idx: [u32; 3],
73    block_dim: [u32; 3],
74}
75
76impl ThreadBlockGroup {
77    /// Create a thread block group
78    pub fn new(block_dim: [u32; 3], thread_idx: [u32; 3], barrier: Arc<Barrier>) -> Result<Self> {
79        let size = block_dim[0] * block_dim[1] * block_dim[2];
80        let rank = thread_idx[2] * block_dim[0] * block_dim[1]
81            + thread_idx[1] * block_dim[0]
82            + thread_idx[0];
83        let inner = CooperativeGroup::with_barrier(size, rank, barrier)?;
84        Ok(Self {
85            inner,
86            block_idx: [0, 0, 0],
87            block_dim,
88        })
89    }
90
91    /// Set the block index
92    pub fn with_block_idx(mut self, idx: [u32; 3]) -> Self {
93        self.block_idx = idx;
94        self
95    }
96
97    /// Synchronize the thread block
98    pub fn sync(&self) {
99        self.inner.sync();
100    }
101
102    /// Get block dimensions
103    pub fn dim_threads(&self) -> [u32; 3] {
104        self.block_dim
105    }
106
107    /// Get group size (total threads in block)
108    pub fn size(&self) -> u32 {
109        self.inner.size()
110    }
111
112    /// Get thread rank within block
113    pub fn thread_rank(&self) -> u32 {
114        self.inner.thread_rank()
115    }
116
117    /// Get block index
118    pub fn block_index(&self) -> [u32; 3] {
119        self.block_idx
120    }
121}
122
123/// Grid group (all threads across all blocks)
124pub struct GridGroup {
125    /// Total number of threads across all blocks
126    total_threads: u32,
127    /// Global rank of this thread
128    global_rank: u32,
129    /// Grid dimensions
130    grid_dim: [u32; 3],
131    /// Block dimensions
132    block_dim: [u32; 3],
133    /// Optional grid-level barrier for cooperative launch
134    barrier: Option<Arc<Barrier>>,
135}
136
137impl GridGroup {
138    /// Create a grid group
139    pub fn new(
140        grid_dim: [u32; 3],
141        block_dim: [u32; 3],
142        block_idx: [u32; 3],
143        thread_idx: [u32; 3],
144    ) -> Self {
145        let threads_per_block = block_dim[0] * block_dim[1] * block_dim[2];
146        let total_blocks = grid_dim[0] * grid_dim[1] * grid_dim[2];
147        let total_threads = total_blocks * threads_per_block;
148
149        let block_linear = block_idx[2] * grid_dim[0] * grid_dim[1]
150            + block_idx[1] * grid_dim[0]
151            + block_idx[0];
152        let thread_linear = thread_idx[2] * block_dim[0] * block_dim[1]
153            + thread_idx[1] * block_dim[0]
154            + thread_idx[0];
155        let global_rank = block_linear * threads_per_block + thread_linear;
156
157        Self {
158            total_threads,
159            global_rank,
160            grid_dim,
161            block_dim,
162            barrier: None,
163        }
164    }
165
166    /// Create with a grid-level barrier for cooperative launch synchronization
167    pub fn with_barrier(mut self, barrier: Arc<Barrier>) -> Self {
168        self.barrier = Some(barrier);
169        self
170    }
171
172    /// Get total number of threads in the grid
173    pub fn size(&self) -> u32 {
174        self.total_threads
175    }
176
177    /// Get this thread's global rank
178    pub fn thread_rank(&self) -> u32 {
179        self.global_rank
180    }
181
182    /// Get grid dimensions
183    pub fn dim_blocks(&self) -> [u32; 3] {
184        self.grid_dim
185    }
186
187    /// Get block dimensions
188    pub fn dim_threads(&self) -> [u32; 3] {
189        self.block_dim
190    }
191
192    /// Check if this thread is the leader
193    pub fn is_leader(&self) -> bool {
194        self.global_rank == 0
195    }
196
197    /// Synchronize all threads in the grid (cooperative launch only)
198    pub fn sync(&self) -> Result<()> {
199        match &self.barrier {
200            Some(b) => {
201                b.wait();
202                Ok(())
203            }
204            None => Err(runtime_error!(
205                "Grid sync requires cooperative launch with a shared barrier"
206            )),
207        }
208    }
209}
210
211/// Tiled partition: a subdivision of a thread group
212pub struct TiledPartition {
213    /// Tile size (must be power of 2, max 32 for warp-level)
214    tile_size: u32,
215    /// Rank within the tile
216    rank: u32,
217    /// Barrier for the tile
218    barrier: Arc<Barrier>,
219    /// Shared data buffer for shuffle operations
220    shared_data: Arc<Mutex<Vec<f32>>>,
221}
222
223impl TiledPartition {
224    /// Create a tiled partition
225    pub fn new(tile_size: u32, rank: u32) -> Result<Self> {
226        if !tile_size.is_power_of_two() || tile_size > 32 {
227            return Err(runtime_error!(
228                "Tile size must be a power of 2 and <= 32, got {}",
229                tile_size
230            ));
231        }
232        if rank >= tile_size {
233            return Err(runtime_error!(
234                "Rank {} exceeds tile size {}",
235                rank, tile_size
236            ));
237        }
238        Ok(Self {
239            tile_size,
240            rank,
241            barrier: Arc::new(Barrier::new(tile_size as usize)),
242            shared_data: Arc::new(Mutex::new(vec![0.0; tile_size as usize])),
243        })
244    }
245
246    /// Create with shared state
247    pub fn with_shared(
248        tile_size: u32,
249        rank: u32,
250        barrier: Arc<Barrier>,
251        shared_data: Arc<Mutex<Vec<f32>>>,
252    ) -> Result<Self> {
253        if rank >= tile_size {
254            return Err(runtime_error!(
255                "Rank {} exceeds tile size {}",
256                rank, tile_size
257            ));
258        }
259        Ok(Self {
260            tile_size,
261            rank,
262            barrier,
263            shared_data,
264        })
265    }
266
267    /// Get tile size
268    pub fn size(&self) -> u32 {
269        self.tile_size
270    }
271
272    /// Get thread rank within tile
273    pub fn thread_rank(&self) -> u32 {
274        self.rank
275    }
276
277    /// Synchronize threads within the tile
278    pub fn sync(&self) {
279        self.barrier.wait();
280    }
281
282    /// Shuffle: get value from thread with given rank
283    pub fn shfl(&self, value: f32, src_rank: u32) -> f32 {
284        {
285            let mut data = self.shared_data.lock().unwrap();
286            data[self.rank as usize] = value;
287        }
288        self.sync();
289        let result = {
290            let data = self.shared_data.lock().unwrap();
291            let idx = (src_rank % self.tile_size) as usize;
292            data[idx]
293        };
294        self.sync();
295        result
296    }
297
298    /// Shuffle down: get value from thread rank + delta
299    pub fn shfl_down(&self, value: f32, delta: u32) -> f32 {
300        let src = self.rank + delta;
301        if src >= self.tile_size {
302            value // Return own value if source is out of range
303        } else {
304            self.shfl(value, src)
305        }
306    }
307
308    /// Shuffle up: get value from thread rank - delta
309    pub fn shfl_up(&self, value: f32, delta: u32) -> f32 {
310        if self.rank < delta {
311            value
312        } else {
313            self.shfl(value, self.rank - delta)
314        }
315    }
316
317    /// Shuffle XOR: get value from thread rank ^ mask
318    pub fn shfl_xor(&self, value: f32, mask: u32) -> f32 {
319        self.shfl(value, self.rank ^ mask)
320    }
321}
322
323/// Create a cooperative group for the current thread block
324pub fn this_thread_block(
325    block_dim: [u32; 3],
326    thread_idx: [u32; 3],
327    barrier: Arc<Barrier>,
328) -> Result<ThreadBlockGroup> {
329    ThreadBlockGroup::new(block_dim, thread_idx, barrier)
330}
331
332/// Create a grid group for cooperative kernel launch
333pub fn this_grid(
334    grid_dim: [u32; 3],
335    block_dim: [u32; 3],
336    block_idx: [u32; 3],
337    thread_idx: [u32; 3],
338) -> GridGroup {
339    GridGroup::new(grid_dim, block_dim, block_idx, thread_idx)
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_cooperative_group_creation() {
348        let group = CooperativeGroup::new(32, 0).unwrap();
349        assert_eq!(group.size(), 32);
350        assert_eq!(group.thread_rank(), 0);
351        assert!(group.is_leader());
352    }
353
354    #[test]
355    fn test_cooperative_group_invalid_rank() {
356        let result = CooperativeGroup::new(32, 32);
357        assert!(result.is_err());
358    }
359
360    #[test]
361    fn test_thread_block_group() {
362        let barrier = Arc::new(Barrier::new(1));
363        let group = ThreadBlockGroup::new([4, 4, 1], [2, 1, 0], barrier).unwrap();
364        assert_eq!(group.size(), 16);
365        assert_eq!(group.thread_rank(), 1 * 4 + 2); // y * dim_x + x = 6
366        assert_eq!(group.dim_threads(), [4, 4, 1]);
367    }
368
369    #[test]
370    fn test_grid_group() {
371        let gg = GridGroup::new([2, 2, 1], [4, 4, 1], [1, 0, 0], [2, 1, 0]);
372        assert_eq!(gg.size(), 4 * 16); // 4 blocks * 16 threads
373        assert_eq!(gg.dim_blocks(), [2, 2, 1]);
374        assert_eq!(gg.dim_threads(), [4, 4, 1]);
375        // Block 1 (linear), thread 6 (linear) = 1*16 + 6 = 22
376        assert_eq!(gg.thread_rank(), 22);
377    }
378
379    #[test]
380    fn test_grid_group_leader() {
381        let gg = GridGroup::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0]);
382        assert!(gg.is_leader());
383
384        let gg2 = GridGroup::new([2, 1, 1], [4, 1, 1], [1, 0, 0], [2, 0, 0]);
385        assert!(!gg2.is_leader());
386    }
387
388    #[test]
389    fn test_grid_sync_without_barrier() {
390        let gg = GridGroup::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0]);
391        assert!(gg.sync().is_err());
392    }
393
394    #[test]
395    fn test_grid_sync_with_barrier() {
396        let barrier = Arc::new(Barrier::new(1));
397        let gg = GridGroup::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0])
398            .with_barrier(barrier);
399        assert!(gg.sync().is_ok());
400    }
401
402    #[test]
403    fn test_tiled_partition_creation() {
404        let tile = TiledPartition::new(4, 0).unwrap();
405        assert_eq!(tile.size(), 4);
406        assert_eq!(tile.thread_rank(), 0);
407    }
408
409    #[test]
410    fn test_tiled_partition_invalid_size() {
411        // Not a power of two
412        assert!(TiledPartition::new(3, 0).is_err());
413        // Too large
414        assert!(TiledPartition::new(64, 0).is_err());
415    }
416
417    #[test]
418    fn test_cooperative_group_sync() {
419        // Single-thread sync should not deadlock
420        let group = CooperativeGroup::new(1, 0).unwrap();
421        group.sync();
422    }
423
424    #[test]
425    fn test_multi_thread_cooperative_sync() {
426        let barrier = Arc::new(Barrier::new(4));
427        let handles: Vec<_> = (0..4)
428            .map(|rank| {
429                let b = barrier.clone();
430                std::thread::spawn(move || {
431                    let group = CooperativeGroup::with_barrier(4, rank, b).unwrap();
432                    group.sync();
433                    group.thread_rank()
434                })
435            })
436            .collect();
437
438        let mut ranks: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
439        ranks.sort();
440        assert_eq!(ranks, vec![0, 1, 2, 3]);
441    }
442}