Skip to main content

khal_std/sync/
barriers.rs

1/// Workgroup memory barrier with group synchronization.
2///
3/// On GPU: calls `spirv_std::arch::workgroup_memory_barrier_with_group_sync`.
4/// On CPU: waits on the thread-local workgroup barrier (set by CPU dispatch).
5#[inline(always)]
6pub fn workgroup_memory_barrier_with_group_sync() {
7    #[cfg(target_arch = "spirv")]
8    {
9        spirv_std::arch::workgroup_memory_barrier_with_group_sync();
10    }
11    #[cfg(target_arch = "nvptx64")]
12    {
13        // Call the LLVM intrinsic directly instead of cuda_std::thread::sync_threads()
14        // so that LLVM sees the `convergent` attribute during optimization passes.
15        // Without this, LLVM tail-duplicates the barrier into both sides of divergent
16        // branches (if/else), causing threads to hit different bar.sync instructions
17        // and deadlocking the block.
18        // This fixes kernels with barriers that were otherwise hanging when using
19        // cuda_std::thread::sync_thread() instead.
20        unsafe extern "C" {
21            #[link_name = "llvm.nvvm.barrier0"]
22            fn nvvm_barrier0();
23        }
24        unsafe {
25            nvvm_barrier0();
26        }
27        //     cuda_std::thread::sync_threads();
28    }
29
30    #[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
31    #[cfg(feature = "cpu")]
32    {
33        crate::arch::cpu::barrier_wait();
34    }
35}
36
37/// Control barrier with explicit execution scope, memory scope, and semantics.
38///
39/// On GPU (SPIR-V): calls `spirv_std::arch::control_barrier`.
40/// On GPU (CUDA): calls `__syncthreads()`.
41/// On CPU: waits on the thread-local workgroup barrier.
42#[inline(always)]
43pub fn control_barrier<const EXECUTION: u32, const MEMORY: u32, const SEMANTICS: u32>() {
44    #[cfg(target_arch = "spirv")]
45    {
46        spirv_std::arch::control_barrier::<EXECUTION, MEMORY, SEMANTICS>();
47    }
48    #[cfg(not(target_arch = "spirv"))]
49    {
50        // handle CUDA and CPU backends
51        workgroup_memory_barrier_with_group_sync();
52    }
53}