Skip to main content

khal_std/sync/
subgroup.rs

1/// Subgroup (wave/warp) floating-point add reduction.
2///
3/// Returns the sum of `val` across all invocations in the subgroup.
4///
5/// On SPIR-V: calls `spirv_std::arch::subgroup_f_add`.
6/// On CUDA: uses warp-level `__shfl_xor_sync` butterfly reduction.
7/// On CPU: returns `val` unchanged (subgroup size = 1).
8#[inline(always)]
9pub fn subgroup_f_add(val: f32) -> f32 {
10    #[cfg(target_arch = "spirv")]
11    {
12        spirv_std::arch::subgroup_f_add(val)
13    }
14    #[cfg(target_arch = "nvptx64")]
15    {
16        warp_reduce_add(val)
17    }
18    #[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
19    {
20        val
21    }
22}
23
24/// Subgroup (wave/warp) floating-point max reduction.
25///
26/// Returns the maximum of `val` across all invocations in the subgroup.
27///
28/// On SPIR-V: calls `spirv_std::arch::subgroup_f_max`.
29/// On CUDA: uses warp-level `__shfl_xor_sync` butterfly reduction.
30/// On CPU: returns `val` unchanged (subgroup size = 1).
31#[inline(always)]
32pub fn subgroup_f_max(val: f32) -> f32 {
33    #[cfg(target_arch = "spirv")]
34    {
35        spirv_std::arch::subgroup_f_max(val)
36    }
37    #[cfg(target_arch = "nvptx64")]
38    {
39        warp_reduce_max(val)
40    }
41    #[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
42    {
43        val
44    }
45}
46
47#[cfg(target_arch = "nvptx64")]
48#[inline(always)]
49fn shfl_xor_sync(val: f32, lane_mask: u32) -> f32 {
50    let result: f32;
51    unsafe {
52        core::arch::asm!(
53            "shfl.sync.bfly.b32 {result}, {val}, {lane_mask}, 0x1f, 0xffffffff;",
54            result = out(reg32) result,
55            val = in(reg32) val,
56            lane_mask = in(reg32) lane_mask,
57        );
58    }
59    result
60}
61
62#[cfg(target_arch = "nvptx64")]
63#[inline(always)]
64fn warp_reduce_add(mut val: f32) -> f32 {
65    val += shfl_xor_sync(val, 16);
66    val += shfl_xor_sync(val, 8);
67    val += shfl_xor_sync(val, 4);
68    val += shfl_xor_sync(val, 2);
69    val += shfl_xor_sync(val, 1);
70    val
71}
72
73#[cfg(target_arch = "nvptx64")]
74#[inline(always)]
75fn warp_reduce_max(mut val: f32) -> f32 {
76    let other = shfl_xor_sync(val, 16);
77    if other > val {
78        val = other;
79    }
80    let other = shfl_xor_sync(val, 8);
81    if other > val {
82        val = other;
83    }
84    let other = shfl_xor_sync(val, 4);
85    if other > val {
86        val = other;
87    }
88    let other = shfl_xor_sync(val, 2);
89    if other > val {
90        val = other;
91    }
92    let other = shfl_xor_sync(val, 1);
93    if other > val {
94        val = other;
95    }
96    val
97}