vyre_reference/
subgroup.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub struct SubgroupSimulator {
11 width: usize,
12}
13
14impl Default for SubgroupSimulator {
15 fn default() -> Self {
16 Self { width: 32 }
17 }
18}
19
20impl SubgroupSimulator {
21 #[must_use]
23 pub fn new(width: usize) -> Self {
24 Self {
25 width: width.max(1),
26 }
27 }
28
29 #[must_use]
31 pub const fn width(&self) -> usize {
32 self.width
33 }
34
35 #[must_use]
37 pub fn ballot<const N: usize>(&self, mask: &[bool; N]) -> u32 {
38 self.ballot_slice(mask)
39 }
40
41 #[must_use]
43 pub fn ballot_slice(&self, mask: &[bool]) -> u32 {
44 let active = mask.len().min(self.width).min(32);
45 let mut bits = 0u32;
46 for (lane, &flag) in mask.iter().take(active).enumerate() {
47 if flag {
48 bits |= 1u32 << lane;
49 }
50 }
51 bits
52 }
53
54 #[must_use]
56 pub fn shuffle(&self, values: &[u32], src_lanes: &[u32]) -> Vec<u32> {
57 let active = values.len().min(src_lanes.len()).min(self.width);
58 src_lanes
59 .iter()
60 .take(active)
61 .map(|&src| values.get(src as usize).copied().unwrap_or(0))
62 .collect()
63 }
64
65 #[must_use]
67 pub fn add(&self, values: &[u32]) -> u32 {
68 values
69 .iter()
70 .take(self.width)
71 .copied()
72 .fold(0u32, u32::wrapping_add)
73 }
74
75 #[must_use]
77 pub fn subgroup_bounds(&self, lane_count: usize, lane_index: usize) -> (usize, usize) {
78 let start = (lane_index / self.width) * self.width;
79 let end = lane_count.min(start + self.width);
80 (start, end)
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::SubgroupSimulator;
87 use proptest::prelude::*;
88 use rayon::prelude::*;
89
90 #[test]
91 fn ballot_sets_expected_bits() {
92 let simulator = SubgroupSimulator::default();
93 assert_eq!(simulator.ballot(&[true, false, true, true]), 0b1101);
94 }
95
96 #[test]
97 fn shuffle_zeroes_out_of_range_lanes() {
98 let simulator = SubgroupSimulator::new(4);
99 assert_eq!(
100 simulator.shuffle(&[10, 20, 30, 40], &[0, 2, 5, 1]),
101 vec![10, 30, 0, 20]
102 );
103 }
104
105 proptest! {
106 #[test]
107 fn subgroup_add_matches_parallel_wrapping_sum(values in prop::collection::vec(any::<u32>(), 0..128)) {
108 let simulator = SubgroupSimulator::new(values.len().max(1));
109 let expected = values.par_iter().copied().reduce(|| 0u32, u32::wrapping_add);
110 prop_assert_eq!(simulator.add(&values), expected);
111 }
112 }
113}