use std::sync::atomic::{AtomicU32, Ordering};
pub const WARP_SIZE: u32 = 32;
pub struct WarpState {
shuffle_buf: [AtomicU32; WARP_SIZE as usize],
active_mask: AtomicU32,
predicate_buf: [AtomicU32; WARP_SIZE as usize],
}
impl WarpState {
pub fn new() -> Self {
const INIT: AtomicU32 = AtomicU32::new(0);
Self {
shuffle_buf: [INIT; WARP_SIZE as usize],
active_mask: AtomicU32::new(0xFFFF_FFFF),
predicate_buf: [INIT; WARP_SIZE as usize],
}
}
pub fn set_lane_active(&self, lane_id: u32) {
debug_assert!(lane_id < WARP_SIZE);
self.active_mask.fetch_or(1 << lane_id, Ordering::SeqCst);
}
pub fn set_lane_inactive(&self, lane_id: u32) {
debug_assert!(lane_id < WARP_SIZE);
self.active_mask
.fetch_and(!(1 << lane_id), Ordering::SeqCst);
}
pub fn active_mask(&self) -> u32 {
self.active_mask.load(Ordering::SeqCst)
}
pub fn is_lane_active(&self, lane_id: u32) -> bool {
(self.active_mask() >> lane_id) & 1 == 1
}
pub fn shuffle(&self, lane_id: u32, value: u32, src_lane: u32) -> u32 {
debug_assert!(lane_id < WARP_SIZE);
self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
let effective_src = src_lane % WARP_SIZE;
self.shuffle_buf[effective_src as usize].load(Ordering::SeqCst)
}
pub fn shuffle_xor(&self, lane_id: u32, value: u32, lane_mask: u32) -> u32 {
let src_lane = lane_id ^ lane_mask;
self.shuffle(lane_id, value, src_lane)
}
pub fn shuffle_up(&self, lane_id: u32, value: u32, delta: u32) -> u32 {
self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
if lane_id >= delta {
let src_lane = lane_id - delta;
self.shuffle_buf[src_lane as usize].load(Ordering::SeqCst)
} else {
value
}
}
pub fn shuffle_down(&self, lane_id: u32, value: u32, delta: u32) -> u32 {
self.shuffle_buf[lane_id as usize].store(value, Ordering::SeqCst);
let src_lane = lane_id + delta;
if src_lane < WARP_SIZE {
self.shuffle_buf[src_lane as usize].load(Ordering::SeqCst)
} else {
value
}
}
pub fn shuffle_f32(&self, lane_id: u32, value: f32, src_lane: u32) -> f32 {
let bits = value.to_bits();
let result_bits = self.shuffle(lane_id, bits, src_lane);
f32::from_bits(result_bits)
}
pub fn shuffle_xor_f32(&self, lane_id: u32, value: f32, lane_mask: u32) -> f32 {
let bits = value.to_bits();
let result_bits = self.shuffle_xor(lane_id, bits, lane_mask);
f32::from_bits(result_bits)
}
pub fn shuffle_up_f32(&self, lane_id: u32, value: f32, delta: u32) -> f32 {
let bits = value.to_bits();
let result_bits = self.shuffle_up(lane_id, bits, delta);
f32::from_bits(result_bits)
}
pub fn shuffle_down_f32(&self, lane_id: u32, value: f32, delta: u32) -> f32 {
let bits = value.to_bits();
let result_bits = self.shuffle_down(lane_id, bits, delta);
f32::from_bits(result_bits)
}
pub fn vote_all(&self, lane_id: u32, predicate: bool) -> bool {
debug_assert!(lane_id < WARP_SIZE);
self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
let mask = self.active_mask();
for i in 0..WARP_SIZE {
if (mask >> i) & 1 == 1 {
if self.predicate_buf[i as usize].load(Ordering::SeqCst) == 0 {
return false;
}
}
}
true
}
pub fn vote_any(&self, lane_id: u32, predicate: bool) -> bool {
debug_assert!(lane_id < WARP_SIZE);
self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
let mask = self.active_mask();
for i in 0..WARP_SIZE {
if (mask >> i) & 1 == 1 {
if self.predicate_buf[i as usize].load(Ordering::SeqCst) != 0 {
return true;
}
}
}
false
}
pub fn ballot(&self, lane_id: u32, predicate: bool) -> u32 {
debug_assert!(lane_id < WARP_SIZE);
self.predicate_buf[lane_id as usize].store(predicate as u32, Ordering::SeqCst);
let mask = self.active_mask();
let mut result: u32 = 0;
for i in 0..WARP_SIZE {
if (mask >> i) & 1 == 1 {
if self.predicate_buf[i as usize].load(Ordering::SeqCst) != 0 {
result |= 1 << i;
}
}
}
result
}
pub fn reduce_sum_f32(&self, lane_id: u32, value: f32) -> f32 {
let mut v = value;
let mut delta = WARP_SIZE / 2;
while delta >= 1 {
let other = self.shuffle_down_f32(lane_id, v, delta);
v += other;
delta /= 2;
}
v
}
pub fn reduce_max_f32(&self, lane_id: u32, value: f32) -> f32 {
let mut v = value;
let mut delta = WARP_SIZE / 2;
while delta >= 1 {
let other = self.shuffle_down_f32(lane_id, v, delta);
if other > v {
v = other;
}
delta /= 2;
}
v
}
pub fn reduce_min_f32(&self, lane_id: u32, value: f32) -> f32 {
let mut v = value;
let mut delta = WARP_SIZE / 2;
while delta >= 1 {
let other = self.shuffle_down_f32(lane_id, v, delta);
if other < v {
v = other;
}
delta /= 2;
}
v
}
pub fn popc_ballot(&self, lane_id: u32, predicate: bool) -> u32 {
self.ballot(lane_id, predicate).count_ones()
}
}
impl Default for WarpState {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_warp_state() {
let ws = WarpState::new();
assert_eq!(ws.active_mask(), 0xFFFF_FFFF);
}
#[test]
fn test_set_lane_active_inactive() {
let ws = WarpState::new();
ws.set_lane_inactive(5);
assert!(!ws.is_lane_active(5));
assert!(ws.is_lane_active(0));
ws.set_lane_active(5);
assert!(ws.is_lane_active(5));
}
#[test]
fn test_shuffle_basic() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.shuffle_buf[lane as usize].store(100 + lane, Ordering::SeqCst);
}
let result = ws.shuffle(5, 105, 10);
assert_eq!(result, 110);
}
#[test]
fn test_shuffle_xor() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.shuffle_buf[lane as usize].store(lane * 10, Ordering::SeqCst);
}
let result = ws.shuffle_xor(3, 30, 1);
assert_eq!(result, 20);
}
#[test]
fn test_shuffle_up() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.shuffle_buf[lane as usize].store(lane, Ordering::SeqCst);
}
let result = ws.shuffle_up(5, 5, 2);
assert_eq!(result, 3);
let result = ws.shuffle_up(0, 0, 1);
assert_eq!(result, 0);
}
#[test]
fn test_shuffle_down() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.shuffle_buf[lane as usize].store(lane, Ordering::SeqCst);
}
let result = ws.shuffle_down(5, 5, 3);
assert_eq!(result, 8);
let result = ws.shuffle_down(31, 31, 1);
assert_eq!(result, 31);
}
#[test]
fn test_shuffle_f32() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
let val = lane as f32 * 1.5;
ws.shuffle_buf[lane as usize].store(val.to_bits(), Ordering::SeqCst);
}
let result = ws.shuffle_f32(0, 0.0, 10);
let expected = 10.0 * 1.5;
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_vote_all_true() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
}
assert!(ws.vote_all(0, true));
}
#[test]
fn test_vote_all_one_false() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
}
ws.predicate_buf[15].store(0, Ordering::SeqCst);
assert!(!ws.vote_all(0, true));
}
#[test]
fn test_vote_any() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
}
assert!(ws.vote_any(7, true));
}
#[test]
fn test_ballot() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
}
ws.predicate_buf[0].store(1, Ordering::SeqCst);
ws.predicate_buf[1].store(1, Ordering::SeqCst);
ws.predicate_buf[2].store(1, Ordering::SeqCst);
let result = ws.ballot(3, false);
assert_eq!(result & 0b111, 0b111); assert_eq!(result & (1 << 3), 0); }
#[test]
fn test_popc_ballot() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.predicate_buf[lane as usize].store(0, Ordering::SeqCst);
}
for lane in 0..5 {
ws.predicate_buf[lane as usize].store(1, Ordering::SeqCst);
}
let count = ws.popc_ballot(10, false);
assert_eq!(count, 5);
}
#[test]
fn test_reduce_sum_simple() {
let ws = WarpState::new();
for lane in 0..WARP_SIZE {
ws.shuffle_buf[lane as usize].store(1.0f32.to_bits(), Ordering::SeqCst);
}
let result = ws.reduce_sum_f32(0, 1.0);
assert!(result >= 1.0);
}
}