use std::marker::PhantomData;
pub trait ActiveSet: Copy + 'static {
const MASK: u32;
}
pub trait ComplementOf<T>: ActiveSet {}
#[derive(Copy, Clone)]
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
}
#[derive(Copy, Clone)]
pub struct Even;
impl ActiveSet for Even {
const MASK: u32 = 0x55555555;
}
#[derive(Copy, Clone)]
pub struct Odd;
impl ActiveSet for Odd {
const MASK: u32 = 0xAAAAAAAA;
}
impl ComplementOf<Odd> for Even {}
impl ComplementOf<Even> for Odd {}
#[derive(Copy, Clone, Debug)]
pub struct ValueIn<T, S: ActiveSet> {
values: [Option<T>; 32],
_scope: PhantomData<S>,
}
impl<T: Copy + Default, S: ActiveSet> ValueIn<T, S> {
pub fn new(compute: impl Fn(usize) -> T) -> Self {
let mut values = [None; 32];
for lane in 0..32 {
if S::MASK & (1 << lane) != 0 {
values[lane] = Some(compute(lane));
}
}
ValueIn {
values,
_scope: PhantomData,
}
}
pub fn uniform(value: T) -> Self {
let mut values = [None; 32];
for lane in 0..32 {
if S::MASK & (1 << lane) != 0 {
values[lane] = Some(value);
}
}
ValueIn {
values,
_scope: PhantomData,
}
}
pub fn get(&self, lane: usize) -> Option<T> {
self.values[lane]
}
}
#[derive(Copy, Clone, Debug)]
pub struct UniformIn<T, S: ActiveSet> {
value: T,
_scope: PhantomData<S>,
}
impl<T: Copy, S: ActiveSet> UniformIn<T, S> {
pub fn new(value: T) -> Self {
UniformIn {
value,
_scope: PhantomData,
}
}
pub fn get(&self) -> T {
self.value
}
}
#[derive(Copy, Clone, Debug)]
pub struct PerLane<T>(pub [T; 32]);
pub fn phi_merge<T: Copy + Default, S1: ActiveSet + ComplementOf<S2>, S2: ActiveSet>(
left: ValueIn<T, S1>,
right: ValueIn<T, S2>,
) -> PerLane<T> {
let mut result = [T::default(); 32];
for lane in 0..32 {
if S1::MASK & (1 << lane) != 0 {
result[lane] = left.values[lane].unwrap_or_default();
} else if S2::MASK & (1 << lane) != 0 {
result[lane] = right.values[lane].unwrap_or_default();
}
}
PerLane(result)
}
pub fn phi_merge_uniform<T: Copy + Default, S1: ActiveSet + ComplementOf<S2>, S2: ActiveSet>(
left: UniformIn<T, S1>,
right: UniformIn<T, S2>,
) -> PerLane<T> {
let mut result = [T::default(); 32];
for lane in 0..32 {
if S1::MASK & (1 << lane) != 0 {
result[lane] = left.value;
} else if S2::MASK & (1 << lane) != 0 {
result[lane] = right.value;
}
}
PerLane(result)
}
#[derive(Copy, Clone, Debug)]
pub struct MaybeValue<T>(pub [Option<T>; 32]);
impl<T: Copy> MaybeValue<T> {
pub fn new() -> Self {
MaybeValue([None; 32])
}
pub fn set_where<S: ActiveSet>(&mut self, value: T) {
for lane in 0..32 {
if S::MASK & (1 << lane) != 0 {
self.0[lane] = Some(value);
}
}
}
pub fn get_or_default(&self, lane: usize, default: T) -> T {
self.0[lane].unwrap_or(default)
}
pub fn to_per_lane(&self, default: T) -> PerLane<T> {
let mut result = [default; 32];
for lane in 0..32 {
if let Some(v) = self.0[lane] {
result[lane] = v;
}
}
PerLane(result)
}
}
#[derive(Copy, Clone)]
pub struct Warp<S: ActiveSet> {
_marker: PhantomData<S>,
}
impl<S: ActiveSet> Warp<S> {
pub fn new() -> Self {
Warp {
_marker: PhantomData,
}
}
}
pub fn diverge_compute<T, F1, F2, S1, S2>(
_warp: Warp<All>,
_pred: impl Fn(usize) -> bool,
then_branch: F1,
else_branch: F2,
) -> PerLane<T>
where
T: Copy + Default,
S1: ActiveSet + ComplementOf<S2>,
S2: ActiveSet + ComplementOf<S1>,
F1: FnOnce(Warp<S1>) -> ValueIn<T, S1>,
F2: FnOnce(Warp<S2>) -> ValueIn<T, S2>,
{
let then_val = then_branch(Warp::new());
let else_val = else_branch(Warp::new());
phi_merge(then_val, else_val)
}
#[derive(Copy, Clone, Debug)]
pub struct TrackedVar<T: Copy> {
values: [T; 32],
}
impl<T: Copy> TrackedVar<T> {
pub fn uniform(value: T) -> Self {
TrackedVar {
values: [value; 32],
}
}
pub fn per_lane(values: [T; 32]) -> Self {
TrackedVar { values }
}
pub fn update_where<S: ActiveSet>(&mut self, new_value: T) {
for lane in 0..32 {
if S::MASK & (1 << lane) != 0 {
self.values[lane] = new_value;
}
}
}
pub fn update_where_with<S: ActiveSet>(&mut self, compute: impl Fn(usize) -> T) {
for lane in 0..32 {
if S::MASK & (1 << lane) != 0 {
self.values[lane] = compute(lane);
}
}
}
pub fn get(&self, lane: usize) -> T {
self.values[lane]
}
pub fn to_per_lane(self) -> PerLane<T> {
PerLane(self.values)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_value_in_scope() {
let even_val: ValueIn<i32, Even> = ValueIn::uniform(100);
assert_eq!(even_val.get(0), Some(100)); assert_eq!(even_val.get(1), None); assert_eq!(even_val.get(2), Some(100)); }
#[test]
fn test_phi_merge_values() {
let even_val: ValueIn<i32, Even> = ValueIn::uniform(100);
let odd_val: ValueIn<i32, Odd> = ValueIn::uniform(200);
let merged = phi_merge(even_val, odd_val);
assert_eq!(merged.0[0], 100); assert_eq!(merged.0[1], 200); assert_eq!(merged.0[2], 100); assert_eq!(merged.0[3], 200); }
#[test]
fn test_phi_merge_uniform() {
let even_uniform: UniformIn<i32, Even> = UniformIn::new(42);
let odd_uniform: UniformIn<i32, Odd> = UniformIn::new(99);
let merged = phi_merge_uniform(even_uniform, odd_uniform);
assert_eq!(merged.0[0], 42);
assert_eq!(merged.0[1], 99);
assert_eq!(merged.0[2], 42);
assert_eq!(merged.0[3], 99);
}
#[test]
fn test_maybe_value_asymmetric() {
let mut v: MaybeValue<i32> = MaybeValue::new();
v.set_where::<Even>(100);
assert_eq!(v.0[0], Some(100));
assert_eq!(v.0[1], None);
let per_lane = v.to_per_lane(0);
assert_eq!(per_lane.0[0], 100);
assert_eq!(per_lane.0[1], 0); }
#[test]
fn test_tracked_var_reassignment() {
let mut x = TrackedVar::uniform(0i32);
x.update_where::<Even>(100);
assert_eq!(x.get(0), 100); assert_eq!(x.get(1), 0); assert_eq!(x.get(2), 100);
assert_eq!(x.get(3), 0);
}
#[test]
fn test_tracked_var_varying_update() {
let mut x = TrackedVar::uniform(0i32);
x.update_where_with::<Even>(|lane| lane as i32 * 10);
assert_eq!(x.get(0), 0); assert_eq!(x.get(1), 0); assert_eq!(x.get(2), 20); assert_eq!(x.get(4), 40); }
#[test]
fn test_nested_divergence_values() {
let mut x = TrackedVar::uniform(0i32);
x.update_where::<Even>(1);
for lane in 0..16 {
if lane % 2 == 0 {
x.values[lane] = 2;
}
}
assert_eq!(x.get(0), 2); assert_eq!(x.get(1), 0); assert_eq!(x.get(2), 2); assert_eq!(x.get(16), 1); assert_eq!(x.get(17), 0); }
}
pub const _DOC: () = ();