#![allow(clippy::needless_range_loop, clippy::new_without_default)]
use std::marker::PhantomData;
pub trait ActiveSet: Copy + 'static {
const MASK: u32;
const NAME: &'static str;
}
pub trait ComplementOf<Other: ActiveSet>: ActiveSet {}
#[derive(Copy, Clone)]
pub struct Warp<S: ActiveSet> {
_phantom: PhantomData<S>,
}
impl<S: ActiveSet> Warp<S> {
pub fn new() -> Self {
Warp {
_phantom: PhantomData,
}
}
pub fn active_mask(&self) -> u32 {
S::MASK
}
}
#[derive(Copy, Clone, Debug)]
pub struct PerLane<T>(pub [T; 32]);
#[derive(Copy, Clone)]
pub struct All;
#[derive(Copy, Clone)]
pub struct Lane0;
#[derive(Copy, Clone)]
pub struct NotLane0;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
const NAME: &'static str = "All";
}
impl ActiveSet for Lane0 {
const MASK: u32 = 0x00000001;
const NAME: &'static str = "Lane0";
}
impl ActiveSet for NotLane0 {
const MASK: u32 = 0xFFFFFFFE;
const NAME: &'static str = "NotLane0";
}
impl ComplementOf<NotLane0> for Lane0 {}
impl ComplementOf<Lane0> for NotLane0 {}
impl Warp<All> {
pub fn diverge_lane0(self) -> (Warp<Lane0>, Warp<NotLane0>) {
(Warp::new(), Warp::new())
}
pub fn shuffle_broadcast(&self, data: &PerLane<i32>, src_lane: u32) -> PerLane<i32> {
let val = data.0[src_lane as usize];
PerLane([val; 32])
}
}
pub fn merge_data(_lane0_data: i32, _notlane0_data: i32, lane0_mask: u32) -> PerLane<i32> {
let mut result = [_notlane0_data; 32];
for i in 0..32 {
if lane0_mask & (1 << i) != 0 {
result[i] = _lane0_data;
}
}
PerLane(result)
}
pub fn merge<S1, S2>(_left: Warp<S1>, _right: Warp<S2>) -> Warp<All>
where
S1: ComplementOf<S2>,
S2: ActiveSet,
{
Warp::new()
}
fn _buggy_version_for_doctest() {}
fn correct_atomic_broadcast(warp: Warp<All>, counter: &mut i32) -> PerLane<i32> {
let (lane0, rest) = warp.diverge_lane0();
let atomic_result = *counter;
*counter += 16;
let combined_data = merge_data(
atomic_result, 0, Lane0::MASK,
);
let full: Warp<All> = merge(lane0, rest);
let broadcast = full.shuffle_broadcast(&combined_data, 0);
let mut result = broadcast.0;
for i in 0..32 {
result[i] += i as i32;
}
PerLane(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_correct_atomic_broadcast() {
let warp: Warp<All> = Warp::new();
let mut counter = 100;
let result = correct_atomic_broadcast(warp, &mut counter);
assert_eq!(counter, 116);
assert_eq!(result.0[0], 100); assert_eq!(result.0[1], 101); assert_eq!(result.0[31], 131); }
#[test]
fn test_type_prevents_diverged_shuffle() {
let warp: Warp<All> = Warp::new();
let (lane0, _rest) = warp.diverge_lane0();
assert_eq!(lane0.active_mask(), 0x00000001);
}
#[test]
fn test_why_merge_prevents_ub() {
let combined = merge_data(42, 0, Lane0::MASK);
assert_eq!(combined.0[0], 42);
for i in 1..32 {
assert_eq!(combined.0[i], 0);
}
}
#[test]
fn test_counter_advances_once() {
let warp: Warp<All> = Warp::new();
let mut counter = 0;
let _result = correct_atomic_broadcast(warp, &mut counter);
assert_eq!(counter, 16);
let _result = correct_atomic_broadcast(Warp::<All>::new(), &mut counter);
assert_eq!(counter, 32); }
}
fn main() {
println!("LLVM #155682: shfl_sync After Conditional Eliminates Branch");
println!("=============================================================\n");
println!("The Bug (CUDA/Clang):");
println!(" if (laneId == 0) {{ row = atomicAdd(counter, 16); }}");
println!(" row = __shfl_sync(0xffffffff, row, 0) + laneId;");
println!(" LLVM eliminates the if — atomicAdd runs on all 32 lanes.");
println!(" Counter advances 32x too fast. NVCC handles it correctly.\n");
println!("Why __shfl_sync Doesn't Help:");
println!(" The mask (0xffffffff) is correct — all lanes ARE active.");
println!(" The bug is uninitialized `row` on lanes 1-31 reaching the shuffle.");
println!(" LLVM sees UB → assumes branch always taken → eliminates it.\n");
println!("Why Warp Typestate Catches It:");
println!(" After if: Warp<Lane0>, not Warp<All>. Must merge before shuffle.");
println!(" Merge forces all lanes to provide data. No uninitialized values.");
println!(" No UB → no branch elimination.\n");
let warp: Warp<All> = Warp::new();
let mut counter = 0;
let result = correct_atomic_broadcast(warp, &mut counter);
println!(
"Counter after one call: {} (correct: 16, buggy: 512)",
counter
);
println!(
"Lane results: [0]={}, [1]={}, [31]={}",
result.0[0], result.0[1], result.0[31]
);
}