khal_std/sync/
subgroup.rs1#[inline(always)]
9pub fn subgroup_f_add(val: f32) -> f32 {
10 #[cfg(target_arch = "spirv")]
11 {
12 spirv_std::arch::subgroup_f_add(val)
13 }
14 #[cfg(target_arch = "nvptx64")]
15 {
16 warp_reduce_add(val)
17 }
18 #[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
19 {
20 val
21 }
22}
23
24#[inline(always)]
32pub fn subgroup_f_max(val: f32) -> f32 {
33 #[cfg(target_arch = "spirv")]
34 {
35 spirv_std::arch::subgroup_f_max(val)
36 }
37 #[cfg(target_arch = "nvptx64")]
38 {
39 warp_reduce_max(val)
40 }
41 #[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
42 {
43 val
44 }
45}
46
47#[cfg(target_arch = "nvptx64")]
48#[inline(always)]
49fn shfl_xor_sync(val: f32, lane_mask: u32) -> f32 {
50 let result: f32;
51 unsafe {
52 core::arch::asm!(
53 "shfl.sync.bfly.b32 {result}, {val}, {lane_mask}, 0x1f, 0xffffffff;",
54 result = out(reg32) result,
55 val = in(reg32) val,
56 lane_mask = in(reg32) lane_mask,
57 );
58 }
59 result
60}
61
62#[cfg(target_arch = "nvptx64")]
63#[inline(always)]
64fn warp_reduce_add(mut val: f32) -> f32 {
65 val += shfl_xor_sync(val, 16);
66 val += shfl_xor_sync(val, 8);
67 val += shfl_xor_sync(val, 4);
68 val += shfl_xor_sync(val, 2);
69 val += shfl_xor_sync(val, 1);
70 val
71}
72
73#[cfg(target_arch = "nvptx64")]
74#[inline(always)]
75fn warp_reduce_max(mut val: f32) -> f32 {
76 let other = shfl_xor_sync(val, 16);
77 if other > val {
78 val = other;
79 }
80 let other = shfl_xor_sync(val, 8);
81 if other > val {
82 val = other;
83 }
84 let other = shfl_xor_sync(val, 4);
85 if other > val {
86 val = other;
87 }
88 let other = shfl_xor_sync(val, 2);
89 if other > val {
90 val = other;
91 }
92 let other = shfl_xor_sync(val, 1);
93 if other > val {
94 val = other;
95 }
96 val
97}