use std::marker::PhantomData;
pub trait ActiveSet: Copy + 'static {
const MASK: u32;
}
pub trait ComplementOf<T>: ActiveSet {}
#[derive(Copy, Clone)]
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
}
#[derive(Copy, Clone)]
pub struct Even;
impl ActiveSet for Even {
const MASK: u32 = 0x55555555;
}
#[derive(Copy, Clone)]
pub struct Odd;
impl ActiveSet for Odd {
const MASK: u32 = 0xAAAAAAAA;
}
impl ComplementOf<Odd> for Even {}
impl ComplementOf<Even> for Odd {}
#[derive(Copy, Clone)]
pub struct Warp<S: ActiveSet> {
_marker: PhantomData<S>,
}
impl<S: ActiveSet> Warp<S> {
pub fn new() -> Self {
Warp {
_marker: PhantomData,
}
}
}
pub mod explicit {
use super::*;
pub fn diverge<S1, S2>(_warp: Warp<All>) -> (Warp<S1>, Warp<S2>)
where
S1: ActiveSet + ComplementOf<S2>,
S2: ActiveSet + ComplementOf<S1>,
{
(Warp::new(), Warp::new())
}
pub fn merge<S1, S2>(_w1: Warp<S1>, _w2: Warp<S2>) -> Warp<All>
where
S1: ActiveSet + ComplementOf<S2>,
S2: ActiveSet + ComplementOf<S1>,
{
Warp::new()
}
pub fn example_explicit() {
let warp: Warp<All> = Warp::new();
let (even, odd): (Warp<Even>, Warp<Odd>) = diverge(warp);
let _ = process_even(even);
let _ = process_odd(odd);
let _merged: Warp<All> = merge(even, odd);
}
fn process_even(_w: Warp<Even>) -> i32 {
1
}
fn process_odd(_w: Warp<Odd>) -> i32 {
2
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_explicit_diverge_merge() {
example_explicit();
}
}
}
pub mod sugar {
use super::*;
pub fn with_diverged<S1, S2, A, F1, F2>(_warp: Warp<All>, then_fn: F1, else_fn: F2) -> (A, A)
where
S1: ActiveSet + ComplementOf<S2>,
S2: ActiveSet + ComplementOf<S1>,
F1: FnOnce(Warp<S1>) -> A,
F2: FnOnce(Warp<S2>) -> A,
{
let then_result = then_fn(Warp::new());
let else_result = else_fn(Warp::new());
(then_result, else_result)
}
pub fn warp_if<A, F1, F2>(
warp: Warp<All>,
_pred: impl Fn(usize) -> bool,
then_fn: F1,
else_fn: F2,
) -> (A, A)
where
F1: FnOnce(Warp<Even>) -> A,
F2: FnOnce(Warp<Odd>) -> A,
{
with_diverged::<Even, Odd, A, F1, F2>(warp, then_fn, else_fn)
}
pub fn example_sugar() {
let warp: Warp<All> = Warp::new();
let (even_result, odd_result) = warp_if(
warp,
|lane| lane % 2 == 0,
|_even_warp| 100,
|_odd_warp| 200,
);
assert_eq!(even_result, 100);
assert_eq!(odd_result, 200);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_with_diverged() {
let warp: Warp<All> = Warp::new();
let (a, b) = with_diverged::<Even, Odd, i32, _, _>(warp, |_| 1, |_| 2);
assert_eq!(a, 1);
assert_eq!(b, 2);
}
#[test]
fn test_warp_if() {
example_sugar();
}
}
}
pub mod scoped {
use super::*;
pub struct ScopedDiverge<S1: ActiveSet, S2: ActiveSet> {
pub left: Warp<S1>,
pub right: Warp<S2>,
}
impl<S1: ActiveSet + ComplementOf<S2>, S2: ActiveSet + ComplementOf<S1>> ScopedDiverge<S1, S2> {
pub fn new(_warp: Warp<All>) -> Self {
ScopedDiverge {
left: Warp::new(),
right: Warp::new(),
}
}
pub fn merge(self) -> Warp<All> {
Warp::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scoped_diverge() {
let warp: Warp<All> = Warp::new();
let diverged = ScopedDiverge::<Even, Odd>::new(warp);
let _even: Warp<Even> = diverged.left;
let _odd: Warp<Odd> = diverged.right;
let _merged: Warp<All> = diverged.merge();
}
}
}
pub mod ispc_style {
pub struct MaskedExecution {
active_mask: u32,
}
impl MaskedExecution {
pub fn new() -> Self {
MaskedExecution {
active_mask: 0xFFFFFFFF,
}
}
pub fn active_mask(&self) -> u32 {
self.active_mask
}
pub fn masked_if<F>(&mut self, pred: impl Fn(usize) -> bool, body: F)
where
F: FnOnce(&mut Self),
{
let old_mask = self.active_mask;
let mut new_mask = 0u32;
for lane in 0..32 {
if old_mask & (1 << lane) != 0 && pred(lane) {
new_mask |= 1 << lane;
}
}
self.active_mask = new_mask;
body(self);
self.active_mask = old_mask;
}
pub fn unsafe_shuffle(&self) {
if self.active_mask != 0xFFFFFFFF {
panic!("Shuffle requires all lanes!");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ispc_style() {
let mut exec = MaskedExecution::new();
assert_eq!(exec.active_mask(), 0xFFFFFFFF);
exec.masked_if(
|lane| lane % 2 == 0,
|inner| {
assert_eq!(inner.active_mask() & 0x55555555, inner.active_mask());
},
);
assert_eq!(exec.active_mask(), 0xFFFFFFFF);
}
#[test]
#[should_panic(expected = "Shuffle requires all lanes")]
fn test_ispc_shuffle_bug() {
let mut exec = MaskedExecution::new();
exec.masked_if(
|lane| lane % 2 == 0,
|inner| {
inner.unsafe_shuffle();
},
);
}
}
}
pub const _CONCLUSION: () = ();
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_explicit_is_required_for_shuffle() {
let warp: Warp<All> = Warp::new();
let _ = shuffle_xor(&warp, 1);
let (even, odd): (Warp<Even>, Warp<Odd>) = explicit::diverge(warp);
let merged = explicit::merge(even, odd);
let _ = shuffle_xor(&merged, 1);
}
fn shuffle_xor(_warp: &Warp<All>, _delta: u32) -> [i32; 32] {
[0; 32]
}
}