1use crate::{Result, runtime_error};
8use std::sync::{Arc, Barrier, Mutex};
9
10#[derive(Debug, Clone)]
12pub struct CooperativeGroup {
13 size: u32,
15 rank: u32,
17 barrier: Arc<Barrier>,
19}
20
21impl CooperativeGroup {
22 pub fn new(size: u32, rank: u32) -> Result<Self> {
24 if rank >= size {
25 return Err(runtime_error!(
26 "Thread rank {} exceeds group size {}",
27 rank, size
28 ));
29 }
30 Ok(Self {
31 size,
32 rank,
33 barrier: Arc::new(Barrier::new(size as usize)),
34 })
35 }
36
37 pub fn with_barrier(size: u32, rank: u32, barrier: Arc<Barrier>) -> Result<Self> {
39 if rank >= size {
40 return Err(runtime_error!(
41 "Thread rank {} exceeds group size {}",
42 rank, size
43 ));
44 }
45 Ok(Self { size, rank, barrier })
46 }
47
48 pub fn size(&self) -> u32 {
50 self.size
51 }
52
53 pub fn thread_rank(&self) -> u32 {
55 self.rank
56 }
57
58 pub fn sync(&self) {
60 self.barrier.wait();
61 }
62
63 pub fn is_leader(&self) -> bool {
65 self.rank == 0
66 }
67}
68
69pub struct ThreadBlockGroup {
71 inner: CooperativeGroup,
72 block_idx: [u32; 3],
73 block_dim: [u32; 3],
74}
75
76impl ThreadBlockGroup {
77 pub fn new(block_dim: [u32; 3], thread_idx: [u32; 3], barrier: Arc<Barrier>) -> Result<Self> {
79 let size = block_dim[0] * block_dim[1] * block_dim[2];
80 let rank = thread_idx[2] * block_dim[0] * block_dim[1]
81 + thread_idx[1] * block_dim[0]
82 + thread_idx[0];
83 let inner = CooperativeGroup::with_barrier(size, rank, barrier)?;
84 Ok(Self {
85 inner,
86 block_idx: [0, 0, 0],
87 block_dim,
88 })
89 }
90
91 pub fn with_block_idx(mut self, idx: [u32; 3]) -> Self {
93 self.block_idx = idx;
94 self
95 }
96
97 pub fn sync(&self) {
99 self.inner.sync();
100 }
101
102 pub fn dim_threads(&self) -> [u32; 3] {
104 self.block_dim
105 }
106
107 pub fn size(&self) -> u32 {
109 self.inner.size()
110 }
111
112 pub fn thread_rank(&self) -> u32 {
114 self.inner.thread_rank()
115 }
116
117 pub fn block_index(&self) -> [u32; 3] {
119 self.block_idx
120 }
121}
122
123pub struct GridGroup {
125 total_threads: u32,
127 global_rank: u32,
129 grid_dim: [u32; 3],
131 block_dim: [u32; 3],
133 barrier: Option<Arc<Barrier>>,
135}
136
137impl GridGroup {
138 pub fn new(
140 grid_dim: [u32; 3],
141 block_dim: [u32; 3],
142 block_idx: [u32; 3],
143 thread_idx: [u32; 3],
144 ) -> Self {
145 let threads_per_block = block_dim[0] * block_dim[1] * block_dim[2];
146 let total_blocks = grid_dim[0] * grid_dim[1] * grid_dim[2];
147 let total_threads = total_blocks * threads_per_block;
148
149 let block_linear = block_idx[2] * grid_dim[0] * grid_dim[1]
150 + block_idx[1] * grid_dim[0]
151 + block_idx[0];
152 let thread_linear = thread_idx[2] * block_dim[0] * block_dim[1]
153 + thread_idx[1] * block_dim[0]
154 + thread_idx[0];
155 let global_rank = block_linear * threads_per_block + thread_linear;
156
157 Self {
158 total_threads,
159 global_rank,
160 grid_dim,
161 block_dim,
162 barrier: None,
163 }
164 }
165
166 pub fn with_barrier(mut self, barrier: Arc<Barrier>) -> Self {
168 self.barrier = Some(barrier);
169 self
170 }
171
172 pub fn size(&self) -> u32 {
174 self.total_threads
175 }
176
177 pub fn thread_rank(&self) -> u32 {
179 self.global_rank
180 }
181
182 pub fn dim_blocks(&self) -> [u32; 3] {
184 self.grid_dim
185 }
186
187 pub fn dim_threads(&self) -> [u32; 3] {
189 self.block_dim
190 }
191
192 pub fn is_leader(&self) -> bool {
194 self.global_rank == 0
195 }
196
197 pub fn sync(&self) -> Result<()> {
199 match &self.barrier {
200 Some(b) => {
201 b.wait();
202 Ok(())
203 }
204 None => Err(runtime_error!(
205 "Grid sync requires cooperative launch with a shared barrier"
206 )),
207 }
208 }
209}
210
211pub struct TiledPartition {
213 tile_size: u32,
215 rank: u32,
217 barrier: Arc<Barrier>,
219 shared_data: Arc<Mutex<Vec<f32>>>,
221}
222
223impl TiledPartition {
224 pub fn new(tile_size: u32, rank: u32) -> Result<Self> {
226 if !tile_size.is_power_of_two() || tile_size > 32 {
227 return Err(runtime_error!(
228 "Tile size must be a power of 2 and <= 32, got {}",
229 tile_size
230 ));
231 }
232 if rank >= tile_size {
233 return Err(runtime_error!(
234 "Rank {} exceeds tile size {}",
235 rank, tile_size
236 ));
237 }
238 Ok(Self {
239 tile_size,
240 rank,
241 barrier: Arc::new(Barrier::new(tile_size as usize)),
242 shared_data: Arc::new(Mutex::new(vec![0.0; tile_size as usize])),
243 })
244 }
245
246 pub fn with_shared(
248 tile_size: u32,
249 rank: u32,
250 barrier: Arc<Barrier>,
251 shared_data: Arc<Mutex<Vec<f32>>>,
252 ) -> Result<Self> {
253 if rank >= tile_size {
254 return Err(runtime_error!(
255 "Rank {} exceeds tile size {}",
256 rank, tile_size
257 ));
258 }
259 Ok(Self {
260 tile_size,
261 rank,
262 barrier,
263 shared_data,
264 })
265 }
266
267 pub fn size(&self) -> u32 {
269 self.tile_size
270 }
271
272 pub fn thread_rank(&self) -> u32 {
274 self.rank
275 }
276
277 pub fn sync(&self) {
279 self.barrier.wait();
280 }
281
282 pub fn shfl(&self, value: f32, src_rank: u32) -> f32 {
284 {
285 let mut data = self.shared_data.lock().unwrap();
286 data[self.rank as usize] = value;
287 }
288 self.sync();
289 let result = {
290 let data = self.shared_data.lock().unwrap();
291 let idx = (src_rank % self.tile_size) as usize;
292 data[idx]
293 };
294 self.sync();
295 result
296 }
297
298 pub fn shfl_down(&self, value: f32, delta: u32) -> f32 {
300 let src = self.rank + delta;
301 if src >= self.tile_size {
302 value } else {
304 self.shfl(value, src)
305 }
306 }
307
308 pub fn shfl_up(&self, value: f32, delta: u32) -> f32 {
310 if self.rank < delta {
311 value
312 } else {
313 self.shfl(value, self.rank - delta)
314 }
315 }
316
317 pub fn shfl_xor(&self, value: f32, mask: u32) -> f32 {
319 self.shfl(value, self.rank ^ mask)
320 }
321}
322
323pub fn this_thread_block(
325 block_dim: [u32; 3],
326 thread_idx: [u32; 3],
327 barrier: Arc<Barrier>,
328) -> Result<ThreadBlockGroup> {
329 ThreadBlockGroup::new(block_dim, thread_idx, barrier)
330}
331
332pub fn this_grid(
334 grid_dim: [u32; 3],
335 block_dim: [u32; 3],
336 block_idx: [u32; 3],
337 thread_idx: [u32; 3],
338) -> GridGroup {
339 GridGroup::new(grid_dim, block_dim, block_idx, thread_idx)
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_cooperative_group_creation() {
348 let group = CooperativeGroup::new(32, 0).unwrap();
349 assert_eq!(group.size(), 32);
350 assert_eq!(group.thread_rank(), 0);
351 assert!(group.is_leader());
352 }
353
354 #[test]
355 fn test_cooperative_group_invalid_rank() {
356 let result = CooperativeGroup::new(32, 32);
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_thread_block_group() {
362 let barrier = Arc::new(Barrier::new(1));
363 let group = ThreadBlockGroup::new([4, 4, 1], [2, 1, 0], barrier).unwrap();
364 assert_eq!(group.size(), 16);
365 assert_eq!(group.thread_rank(), 1 * 4 + 2); assert_eq!(group.dim_threads(), [4, 4, 1]);
367 }
368
369 #[test]
370 fn test_grid_group() {
371 let gg = GridGroup::new([2, 2, 1], [4, 4, 1], [1, 0, 0], [2, 1, 0]);
372 assert_eq!(gg.size(), 4 * 16); assert_eq!(gg.dim_blocks(), [2, 2, 1]);
374 assert_eq!(gg.dim_threads(), [4, 4, 1]);
375 assert_eq!(gg.thread_rank(), 22);
377 }
378
379 #[test]
380 fn test_grid_group_leader() {
381 let gg = GridGroup::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0]);
382 assert!(gg.is_leader());
383
384 let gg2 = GridGroup::new([2, 1, 1], [4, 1, 1], [1, 0, 0], [2, 0, 0]);
385 assert!(!gg2.is_leader());
386 }
387
388 #[test]
389 fn test_grid_sync_without_barrier() {
390 let gg = GridGroup::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0]);
391 assert!(gg.sync().is_err());
392 }
393
394 #[test]
395 fn test_grid_sync_with_barrier() {
396 let barrier = Arc::new(Barrier::new(1));
397 let gg = GridGroup::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0])
398 .with_barrier(barrier);
399 assert!(gg.sync().is_ok());
400 }
401
402 #[test]
403 fn test_tiled_partition_creation() {
404 let tile = TiledPartition::new(4, 0).unwrap();
405 assert_eq!(tile.size(), 4);
406 assert_eq!(tile.thread_rank(), 0);
407 }
408
409 #[test]
410 fn test_tiled_partition_invalid_size() {
411 assert!(TiledPartition::new(3, 0).is_err());
413 assert!(TiledPartition::new(64, 0).is_err());
415 }
416
417 #[test]
418 fn test_cooperative_group_sync() {
419 let group = CooperativeGroup::new(1, 0).unwrap();
421 group.sync();
422 }
423
424 #[test]
425 fn test_multi_thread_cooperative_sync() {
426 let barrier = Arc::new(Barrier::new(4));
427 let handles: Vec<_> = (0..4)
428 .map(|rank| {
429 let b = barrier.clone();
430 std::thread::spawn(move || {
431 let group = CooperativeGroup::with_barrier(4, rank, b).unwrap();
432 group.sync();
433 group.thread_rank()
434 })
435 })
436 .collect();
437
438 let mut ranks: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
439 ranks.sort();
440 assert_eq!(ranks, vec![0, 1, 2, 3]);
441 }
442}