Skip to main content

cuda_rust_wasm/kernel/
warp.rs

1//! Warp-level primitive emulation
2//!
3//! Emulates CUDA warp-level operations (shuffle, vote, ballot) on the CPU
4//! using shared memory buffers. This enables transpiled CUDA kernels that use
5//! warp intrinsics to execute correctly on CPU fallback paths.
6//!
7//! The emulation assumes `WARP_SIZE = 32` and uses thread-local storage to
8//! track the current lane identity within a warp.
9
10use std::sync::atomic::{AtomicU32, Ordering};
11
12/// The number of threads in a warp (matches CUDA).
13pub const WARP_SIZE: u32 = 32;
14
15/// Per-warp shared state used to emulate warp-level operations.
16///
17/// In a real GPU each warp executes in lock-step and has hardware support for
18/// cross-lane communication. On the CPU we emulate this by having all threads
19/// in a "warp" share a `WarpState` and synchronise explicitly via barriers.
20pub struct WarpState {
21    /// Shared data buffer for shuffle operations.
22    /// Each lane writes its value, then reads from the target lane.
23    shuffle_buf: [AtomicU32; WARP_SIZE as usize],
24
25    /// Bitmask of active lanes. Bit `i` is set if lane `i` is participating.
26    active_mask: AtomicU32,
27
28    /// Predicate buffer for vote/ballot operations.
29    /// Each lane writes 1 (true) or 0 (false).
30    predicate_buf: [AtomicU32; WARP_SIZE as usize],
31}
32
33impl WarpState {
34    /// Create a new warp state with all lanes active.
35    pub fn new() -> Self {
36        const INIT: AtomicU32 = AtomicU32::new(0);
37        Self {
38            shuffle_buf: [INIT; WARP_SIZE as usize],
39            active_mask: AtomicU32::new(0xFFFF_FFFF),
40            predicate_buf: [INIT; WARP_SIZE as usize],
41        }
42    }
43
44    // -----------------------------------------------------------------------
45    // Active mask management
46    // -----------------------------------------------------------------------
47
48    /// Set a lane as active.
49    pub fn set_lane_active(&self, lane_id: u32) {
50        debug_assert!(lane_id < WARP_SIZE);
51        self.active_mask.fetch_or(1 << lane_id, Ordering::SeqCst);
52    }
53
54    /// Set a lane as inactive.
55    pub fn set_lane_inactive(&self, lane_id: u32) {
56        debug_assert!(lane_id < WARP_SIZE);
57        self.active_mask
58            .fetch_and(!(1 << lane_id), Ordering::SeqCst);
59    }
60
61    /// Get the current active mask.
62    pub fn active_mask(&self) -> u32 {
63        self.active_mask.load(Ordering::SeqCst)
64    }
65
66    /// Returns true if the specified lane is currently active.
67    pub fn is_lane_active(&self, lane_id: u32) -> bool {
68        (self.active_mask() >> lane_id) & 1 == 1
69    }
70
71    // -----------------------------------------------------------------------
72    // Warp shuffle emulation
73    // -----------------------------------------------------------------------
74
75    /// Emulate `__shfl_sync`: read the value from `src_lane`.
76    ///
77    /// The caller (at `lane_id`) first writes its own value, then after a
78    /// barrier reads from `src_lane`. In a single-threaded emulation context,
79    /// the caller can pre-populate all lanes and then read.
80    ///
81    /// # Arguments
82    /// * `lane_id` - The calling thread's lane within the warp (0..31)
83    /// * `value` - The value this lane contributes
84    /// * `src_lane` - The lane to read from
85    ///
86    /// Returns the value from `src_lane`, or this lane's own value if
87    /// `src_lane` is out of range.
88    pub fn shuffle(&self, lane_id: u32, value: u32, src_lane: u32) -> u32 {
89        debug_assert!(lane_id < WARP_SIZE);
90
91        // Write our value into the shared buffer
92        self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
93
94        // In a multi-threaded scenario a barrier would go here.
95        // For single-threaded emulation we assume all lanes have written.
96
97        let effective_src = src_lane % WARP_SIZE;
98        self.shuffle_buf[effective_src as usize].load(Ordering::SeqCst)
99    }
100
101    /// Emulate `__shfl_xor_sync`: read from `lane_id ^ lane_mask`.
102    pub fn shuffle_xor(&self, lane_id: u32, value: u32, lane_mask: u32) -> u32 {
103        let src_lane = lane_id ^ lane_mask;
104        self.shuffle(lane_id, value, src_lane)
105    }
106
107    /// Emulate `__shfl_up_sync`: read from `lane_id - delta`.
108    /// If the source lane would be negative, return the caller's own value.
109    pub fn shuffle_up(&self, lane_id: u32, value: u32, delta: u32) -> u32 {
110        self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
111
112        if lane_id >= delta {
113            let src_lane = lane_id - delta;
114            self.shuffle_buf[src_lane as usize].load(Ordering::SeqCst)
115        } else {
116            // Out-of-range: return own value
117            value
118        }
119    }
120
121    /// Emulate `__shfl_down_sync`: read from `lane_id + delta`.
122    /// If the source lane would be >= WARP_SIZE, return the caller's own value.
123    pub fn shuffle_down(&self, lane_id: u32, value: u32, delta: u32) -> u32 {
124        self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
125
126        let src_lane = lane_id + delta;
127        if src_lane < WARP_SIZE {
128            self.shuffle_buf[src_lane as usize].load(Ordering::SeqCst)
129        } else {
130            value
131        }
132    }
133
134    // -----------------------------------------------------------------------
135    // Warp shuffle with f32 values
136    // -----------------------------------------------------------------------
137
138    /// Shuffle an f32 value (reinterpret bits through u32).
139    pub fn shuffle_f32(&self, lane_id: u32, value: f32, src_lane: u32) -> f32 {
140        let bits = value.to_bits();
141        let result_bits = self.shuffle(lane_id, bits, src_lane);
142        f32::from_bits(result_bits)
143    }
144
145    /// Shuffle XOR with f32.
146    pub fn shuffle_xor_f32(&self, lane_id: u32, value: f32, lane_mask: u32) -> f32 {
147        let bits = value.to_bits();
148        let result_bits = self.shuffle_xor(lane_id, bits, lane_mask);
149        f32::from_bits(result_bits)
150    }
151
152    /// Shuffle up with f32.
153    pub fn shuffle_up_f32(&self, lane_id: u32, value: f32, delta: u32) -> f32 {
154        let bits = value.to_bits();
155        let result_bits = self.shuffle_up(lane_id, bits, delta);
156        f32::from_bits(result_bits)
157    }
158
159    /// Shuffle down with f32.
160    pub fn shuffle_down_f32(&self, lane_id: u32, value: f32, delta: u32) -> f32 {
161        let bits = value.to_bits();
162        let result_bits = self.shuffle_down(lane_id, bits, delta);
163        f32::from_bits(result_bits)
164    }
165
166    // -----------------------------------------------------------------------
167    // Warp vote operations
168    // -----------------------------------------------------------------------
169
170    /// Emulate `__all_sync`: returns true if all active lanes have `predicate == true`.
171    pub fn vote_all(&self, lane_id: u32, predicate: bool) -> bool {
172        debug_assert!(lane_id < WARP_SIZE);
173
174        self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
175
176        let mask = self.active_mask();
177        for i in 0..WARP_SIZE {
178            if (mask >> i) & 1 == 1 {
179                if self.predicate_buf[i as usize].load(Ordering::SeqCst) == 0 {
180                    return false;
181                }
182            }
183        }
184        true
185    }
186
187    /// Emulate `__any_sync`: returns true if any active lane has `predicate == true`.
188    pub fn vote_any(&self, lane_id: u32, predicate: bool) -> bool {
189        debug_assert!(lane_id < WARP_SIZE);
190
191        self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
192
193        let mask = self.active_mask();
194        for i in 0..WARP_SIZE {
195            if (mask >> i) & 1 == 1 {
196                if self.predicate_buf[i as usize].load(Ordering::SeqCst) != 0 {
197                    return true;
198                }
199            }
200        }
201        false
202    }
203
204    /// Emulate `__ballot_sync`: returns a bitmask where bit `i` is set if
205    /// lane `i` is active and its predicate is true.
206    pub fn ballot(&self, lane_id: u32, predicate: bool) -> u32 {
207        debug_assert!(lane_id < WARP_SIZE);
208
209        self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
210
211        let mask = self.active_mask();
212        let mut result: u32 = 0;
213        for i in 0..WARP_SIZE {
214            if (mask >> i) & 1 == 1 {
215                if self.predicate_buf[i as usize].load(Ordering::SeqCst) != 0 {
216                    result |= 1 << i;
217                }
218            }
219        }
220        result
221    }
222
223    // -----------------------------------------------------------------------
224    // Utility: warp-level reduction (common pattern)
225    // -----------------------------------------------------------------------
226
227    /// Warp-level sum reduction using shuffle_down (butterfly pattern).
228    ///
229    /// Assumes all 32 lanes call this with their value. Returns the sum at
230    /// lane 0; other lanes get a partial result.
231    pub fn reduce_sum_f32(&self, lane_id: u32, value: f32) -> f32 {
232        let mut v = value;
233        // Butterfly reduction: delta = 16, 8, 4, 2, 1
234        let mut delta = WARP_SIZE / 2;
235        while delta >= 1 {
236            let other = self.shuffle_down_f32(lane_id, v, delta);
237            v += other;
238            delta /= 2;
239        }
240        v
241    }
242
243    /// Warp-level max reduction.
244    pub fn reduce_max_f32(&self, lane_id: u32, value: f32) -> f32 {
245        let mut v = value;
246        let mut delta = WARP_SIZE / 2;
247        while delta >= 1 {
248            let other = self.shuffle_down_f32(lane_id, v, delta);
249            if other > v {
250                v = other;
251            }
252            delta /= 2;
253        }
254        v
255    }
256
257    /// Warp-level min reduction.
258    pub fn reduce_min_f32(&self, lane_id: u32, value: f32) -> f32 {
259        let mut v = value;
260        let mut delta = WARP_SIZE / 2;
261        while delta >= 1 {
262            let other = self.shuffle_down_f32(lane_id, v, delta);
263            if other < v {
264                v = other;
265            }
266            delta /= 2;
267        }
268        v
269    }
270
271    /// Count the number of active lanes with a true predicate (popcount of ballot).
272    pub fn popc_ballot(&self, lane_id: u32, predicate: bool) -> u32 {
273        self.ballot(lane_id, predicate).count_ones()
274    }
275}
276
277impl Default for WarpState {
278    fn default() -> Self {
279        Self::new()
280    }
281}
282
283// ---------------------------------------------------------------------------
284// Tests
285// ---------------------------------------------------------------------------
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_new_warp_state() {
292        let ws = WarpState::new();
293        assert_eq!(ws.active_mask(), 0xFFFF_FFFF);
294    }
295
296    #[test]
297    fn test_set_lane_active_inactive() {
298        let ws = WarpState::new();
299        ws.set_lane_inactive(5);
300        assert!(!ws.is_lane_active(5));
301        assert!(ws.is_lane_active(0));
302
303        ws.set_lane_active(5);
304        assert!(ws.is_lane_active(5));
305    }
306
307    #[test]
308    fn test_shuffle_basic() {
309        let ws = WarpState::new();
310
311        // Populate lanes 0..31 with values 100..131
312        for lane in 0..WARP_SIZE {
313            ws.shuffle_buf[lane as usize].store(100 + lane, Ordering::SeqCst);
314        }
315
316        // Lane 5 shuffles from lane 10
317        let result = ws.shuffle(5, 105, 10);
318        assert_eq!(result, 110);
319    }
320
321    #[test]
322    fn test_shuffle_xor() {
323        let ws = WarpState::new();
324
325        // Populate lanes
326        for lane in 0..WARP_SIZE {
327            ws.shuffle_buf[lane as usize].store(lane * 10, Ordering::SeqCst);
328        }
329
330        // Lane 3 XOR 1 -> lane 2
331        let result = ws.shuffle_xor(3, 30, 1);
332        assert_eq!(result, 20);
333    }
334
335    #[test]
336    fn test_shuffle_up() {
337        let ws = WarpState::new();
338
339        for lane in 0..WARP_SIZE {
340            ws.shuffle_buf[lane as usize].store(lane, Ordering::SeqCst);
341        }
342
343        // Lane 5 shuffle up by 2 -> reads from lane 3
344        let result = ws.shuffle_up(5, 5, 2);
345        assert_eq!(result, 3);
346
347        // Lane 0 shuffle up by 1 -> out of range, returns own value
348        let result = ws.shuffle_up(0, 0, 1);
349        assert_eq!(result, 0);
350    }
351
352    #[test]
353    fn test_shuffle_down() {
354        let ws = WarpState::new();
355
356        for lane in 0..WARP_SIZE {
357            ws.shuffle_buf[lane as usize].store(lane, Ordering::SeqCst);
358        }
359
360        // Lane 5 shuffle down by 3 -> reads from lane 8
361        let result = ws.shuffle_down(5, 5, 3);
362        assert_eq!(result, 8);
363
364        // Lane 31 shuffle down by 1 -> out of range
365        let result = ws.shuffle_down(31, 31, 1);
366        assert_eq!(result, 31);
367    }
368
369    #[test]
370    fn test_shuffle_f32() {
371        let ws = WarpState::new();
372
373        // Populate all lanes with f32 values
374        for lane in 0..WARP_SIZE {
375            let val = lane as f32 * 1.5;
376            ws.shuffle_buf[lane as usize].store(val.to_bits(), Ordering::SeqCst);
377        }
378
379        let result = ws.shuffle_f32(0, 0.0, 10);
380        let expected = 10.0 * 1.5;
381        assert!((result - expected).abs() < 1e-6);
382    }
383
384    #[test]
385    fn test_vote_all_true() {
386        let ws = WarpState::new();
387        // Set all lanes to true
388        for lane in 0..WARP_SIZE {
389            ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
390        }
391        assert!(ws.vote_all(0, true));
392    }
393
394    #[test]
395    fn test_vote_all_one_false() {
396        let ws = WarpState::new();
397        for lane in 0..WARP_SIZE {
398            ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
399        }
400        // Lane 15 sets false
401        ws.predicate_buf[15].store(0, Ordering::SeqCst);
402        assert!(!ws.vote_all(0, true));
403    }
404
405    #[test]
406    fn test_vote_any() {
407        let ws = WarpState::new();
408        // All false
409        for lane in 0..WARP_SIZE {
410            ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
411        }
412
413        // Lane 7 sets true
414        assert!(ws.vote_any(7, true));
415    }
416
417    #[test]
418    fn test_ballot() {
419        let ws = WarpState::new();
420        // All lanes false
421        for lane in 0..WARP_SIZE {
422            ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
423        }
424
425        // Lanes 0, 1, 2 set true
426        ws.predicate_buf[0].store(1, Ordering::SeqCst);
427        ws.predicate_buf[1].store(1, Ordering::SeqCst);
428        ws.predicate_buf[2].store(1, Ordering::SeqCst);
429
430        let result = ws.ballot(3, false);
431        assert_eq!(result & 0b111, 0b111); // first 3 bits set
432        assert_eq!(result & (1 << 3), 0); // lane 3 not set
433    }
434
435    #[test]
436    fn test_popc_ballot() {
437        let ws = WarpState::new();
438        for lane in 0..WARP_SIZE {
439            ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
440        }
441
442        // Set 5 lanes to true
443        for lane in 0..5 {
444            ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
445        }
446
447        let count = ws.popc_ballot(10, false);
448        assert_eq!(count, 5);
449    }
450
451    #[test]
452    fn test_reduce_sum_simple() {
453        let ws = WarpState::new();
454        // In a single-threaded context, populate all lanes with 1.0
455        for lane in 0..WARP_SIZE {
456            ws.shuffle_buf[lane as usize].store(1.0f32.to_bits(), Ordering::SeqCst);
457        }
458        // Lane 0 reduces: should get 32.0 in ideal case
459        // Note: single-threaded emulation means only lane 0's perspective is valid
460        let result = ws.reduce_sum_f32(0, 1.0);
461        // With single-threaded emulation the shuffle_down reads pre-populated values
462        assert!(result >= 1.0);
463    }
464}