1use std::sync::atomic::{AtomicU32, Ordering};
11
12pub const WARP_SIZE: u32 = 32;
14
15pub struct WarpState {
21 shuffle_buf: [AtomicU32; WARP_SIZE as usize],
24
25 active_mask: AtomicU32,
27
28 predicate_buf: [AtomicU32; WARP_SIZE as usize],
31}
32
33impl WarpState {
34 pub fn new() -> Self {
36 const INIT: AtomicU32 = AtomicU32::new(0);
37 Self {
38 shuffle_buf: [INIT; WARP_SIZE as usize],
39 active_mask: AtomicU32::new(0xFFFF_FFFF),
40 predicate_buf: [INIT; WARP_SIZE as usize],
41 }
42 }
43
44 pub fn set_lane_active(&self, lane_id: u32) {
50 debug_assert!(lane_id < WARP_SIZE);
51 self.active_mask.fetch_or(1 << lane_id, Ordering::SeqCst);
52 }
53
54 pub fn set_lane_inactive(&self, lane_id: u32) {
56 debug_assert!(lane_id < WARP_SIZE);
57 self.active_mask
58 .fetch_and(!(1 << lane_id), Ordering::SeqCst);
59 }
60
61 pub fn active_mask(&self) -> u32 {
63 self.active_mask.load(Ordering::SeqCst)
64 }
65
66 pub fn is_lane_active(&self, lane_id: u32) -> bool {
68 (self.active_mask() >> lane_id) & 1 == 1
69 }
70
71 pub fn shuffle(&self, lane_id: u32, value: u32, src_lane: u32) -> u32 {
89 debug_assert!(lane_id < WARP_SIZE);
90
91 self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
93
94 let effective_src = src_lane % WARP_SIZE;
98 self.shuffle_buf[effective_src as usize].load(Ordering::SeqCst)
99 }
100
101 pub fn shuffle_xor(&self, lane_id: u32, value: u32, lane_mask: u32) -> u32 {
103 let src_lane = lane_id ^ lane_mask;
104 self.shuffle(lane_id, value, src_lane)
105 }
106
107 pub fn shuffle_up(&self, lane_id: u32, value: u32, delta: u32) -> u32 {
110 self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
111
112 if lane_id >= delta {
113 let src_lane = lane_id - delta;
114 self.shuffle_buf[src_lane as usize].load(Ordering::SeqCst)
115 } else {
116 value
118 }
119 }
120
121 pub fn shuffle_down(&self, lane_id: u32, value: u32, delta: u32) -> u32 {
124 self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
125
126 let src_lane = lane_id + delta;
127 if src_lane < WARP_SIZE {
128 self.shuffle_buf[src_lane as usize].load(Ordering::SeqCst)
129 } else {
130 value
131 }
132 }
133
134 pub fn shuffle_f32(&self, lane_id: u32, value: f32, src_lane: u32) -> f32 {
140 let bits = value.to_bits();
141 let result_bits = self.shuffle(lane_id, bits, src_lane);
142 f32::from_bits(result_bits)
143 }
144
145 pub fn shuffle_xor_f32(&self, lane_id: u32, value: f32, lane_mask: u32) -> f32 {
147 let bits = value.to_bits();
148 let result_bits = self.shuffle_xor(lane_id, bits, lane_mask);
149 f32::from_bits(result_bits)
150 }
151
152 pub fn shuffle_up_f32(&self, lane_id: u32, value: f32, delta: u32) -> f32 {
154 let bits = value.to_bits();
155 let result_bits = self.shuffle_up(lane_id, bits, delta);
156 f32::from_bits(result_bits)
157 }
158
159 pub fn shuffle_down_f32(&self, lane_id: u32, value: f32, delta: u32) -> f32 {
161 let bits = value.to_bits();
162 let result_bits = self.shuffle_down(lane_id, bits, delta);
163 f32::from_bits(result_bits)
164 }
165
166 pub fn vote_all(&self, lane_id: u32, predicate: bool) -> bool {
172 debug_assert!(lane_id < WARP_SIZE);
173
174 self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
175
176 let mask = self.active_mask();
177 for i in 0..WARP_SIZE {
178 if (mask >> i) & 1 == 1 {
179 if self.predicate_buf[i as usize].load(Ordering::SeqCst) == 0 {
180 return false;
181 }
182 }
183 }
184 true
185 }
186
187 pub fn vote_any(&self, lane_id: u32, predicate: bool) -> bool {
189 debug_assert!(lane_id < WARP_SIZE);
190
191 self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
192
193 let mask = self.active_mask();
194 for i in 0..WARP_SIZE {
195 if (mask >> i) & 1 == 1 {
196 if self.predicate_buf[i as usize].load(Ordering::SeqCst) != 0 {
197 return true;
198 }
199 }
200 }
201 false
202 }
203
204 pub fn ballot(&self, lane_id: u32, predicate: bool) -> u32 {
207 debug_assert!(lane_id < WARP_SIZE);
208
209 self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
210
211 let mask = self.active_mask();
212 let mut result: u32 = 0;
213 for i in 0..WARP_SIZE {
214 if (mask >> i) & 1 == 1 {
215 if self.predicate_buf[i as usize].load(Ordering::SeqCst) != 0 {
216 result |= 1 << i;
217 }
218 }
219 }
220 result
221 }
222
223 pub fn reduce_sum_f32(&self, lane_id: u32, value: f32) -> f32 {
232 let mut v = value;
233 let mut delta = WARP_SIZE / 2;
235 while delta >= 1 {
236 let other = self.shuffle_down_f32(lane_id, v, delta);
237 v += other;
238 delta /= 2;
239 }
240 v
241 }
242
243 pub fn reduce_max_f32(&self, lane_id: u32, value: f32) -> f32 {
245 let mut v = value;
246 let mut delta = WARP_SIZE / 2;
247 while delta >= 1 {
248 let other = self.shuffle_down_f32(lane_id, v, delta);
249 if other > v {
250 v = other;
251 }
252 delta /= 2;
253 }
254 v
255 }
256
257 pub fn reduce_min_f32(&self, lane_id: u32, value: f32) -> f32 {
259 let mut v = value;
260 let mut delta = WARP_SIZE / 2;
261 while delta >= 1 {
262 let other = self.shuffle_down_f32(lane_id, v, delta);
263 if other < v {
264 v = other;
265 }
266 delta /= 2;
267 }
268 v
269 }
270
271 pub fn popc_ballot(&self, lane_id: u32, predicate: bool) -> u32 {
273 self.ballot(lane_id, predicate).count_ones()
274 }
275}
276
277impl Default for WarpState {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_new_warp_state() {
292 let ws = WarpState::new();
293 assert_eq!(ws.active_mask(), 0xFFFF_FFFF);
294 }
295
296 #[test]
297 fn test_set_lane_active_inactive() {
298 let ws = WarpState::new();
299 ws.set_lane_inactive(5);
300 assert!(!ws.is_lane_active(5));
301 assert!(ws.is_lane_active(0));
302
303 ws.set_lane_active(5);
304 assert!(ws.is_lane_active(5));
305 }
306
307 #[test]
308 fn test_shuffle_basic() {
309 let ws = WarpState::new();
310
311 for lane in 0..WARP_SIZE {
313 ws.shuffle_buf[lane as usize].store(100 + lane, Ordering::SeqCst);
314 }
315
316 let result = ws.shuffle(5, 105, 10);
318 assert_eq!(result, 110);
319 }
320
321 #[test]
322 fn test_shuffle_xor() {
323 let ws = WarpState::new();
324
325 for lane in 0..WARP_SIZE {
327 ws.shuffle_buf[lane as usize].store(lane * 10, Ordering::SeqCst);
328 }
329
330 let result = ws.shuffle_xor(3, 30, 1);
332 assert_eq!(result, 20);
333 }
334
335 #[test]
336 fn test_shuffle_up() {
337 let ws = WarpState::new();
338
339 for lane in 0..WARP_SIZE {
340 ws.shuffle_buf[lane as usize].store(lane, Ordering::SeqCst);
341 }
342
343 let result = ws.shuffle_up(5, 5, 2);
345 assert_eq!(result, 3);
346
347 let result = ws.shuffle_up(0, 0, 1);
349 assert_eq!(result, 0);
350 }
351
352 #[test]
353 fn test_shuffle_down() {
354 let ws = WarpState::new();
355
356 for lane in 0..WARP_SIZE {
357 ws.shuffle_buf[lane as usize].store(lane, Ordering::SeqCst);
358 }
359
360 let result = ws.shuffle_down(5, 5, 3);
362 assert_eq!(result, 8);
363
364 let result = ws.shuffle_down(31, 31, 1);
366 assert_eq!(result, 31);
367 }
368
369 #[test]
370 fn test_shuffle_f32() {
371 let ws = WarpState::new();
372
373 for lane in 0..WARP_SIZE {
375 let val = lane as f32 * 1.5;
376 ws.shuffle_buf[lane as usize].store(val.to_bits(), Ordering::SeqCst);
377 }
378
379 let result = ws.shuffle_f32(0, 0.0, 10);
380 let expected = 10.0 * 1.5;
381 assert!((result - expected).abs() < 1e-6);
382 }
383
384 #[test]
385 fn test_vote_all_true() {
386 let ws = WarpState::new();
387 for lane in 0..WARP_SIZE {
389 ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
390 }
391 assert!(ws.vote_all(0, true));
392 }
393
394 #[test]
395 fn test_vote_all_one_false() {
396 let ws = WarpState::new();
397 for lane in 0..WARP_SIZE {
398 ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
399 }
400 ws.predicate_buf[15].store(0, Ordering::SeqCst);
402 assert!(!ws.vote_all(0, true));
403 }
404
405 #[test]
406 fn test_vote_any() {
407 let ws = WarpState::new();
408 for lane in 0..WARP_SIZE {
410 ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
411 }
412
413 assert!(ws.vote_any(7, true));
415 }
416
417 #[test]
418 fn test_ballot() {
419 let ws = WarpState::new();
420 for lane in 0..WARP_SIZE {
422 ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
423 }
424
425 ws.predicate_buf[0].store(1, Ordering::SeqCst);
427 ws.predicate_buf[1].store(1, Ordering::SeqCst);
428 ws.predicate_buf[2].store(1, Ordering::SeqCst);
429
430 let result = ws.ballot(3, false);
431 assert_eq!(result & 0b111, 0b111); assert_eq!(result & (1 << 3), 0); }
434
435 #[test]
436 fn test_popc_ballot() {
437 let ws = WarpState::new();
438 for lane in 0..WARP_SIZE {
439 ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
440 }
441
442 for lane in 0..5 {
444 ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
445 }
446
447 let count = ws.popc_ballot(10, false);
448 assert_eq!(count, 5);
449 }
450
451 #[test]
452 fn test_reduce_sum_simple() {
453 let ws = WarpState::new();
454 for lane in 0..WARP_SIZE {
456 ws.shuffle_buf[lane as usize].store(1.0f32.to_bits(), Ordering::SeqCst);
457 }
458 let result = ws.reduce_sum_f32(0, 1.0);
461 assert!(result >= 1.0);
463 }
464}