use std::marker::PhantomData;
pub trait ActiveSet: Copy + 'static {
const MASK: u32;
const NAME: &'static str;
}
pub trait ComplementOf<Other: ActiveSet>: ActiveSet {}
#[derive(Copy, Clone, Debug, Default)]
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
const NAME: &'static str = "All";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct None;
impl ActiveSet for None {
const MASK: u32 = 0x00000000;
const NAME: &'static str = "None";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct Even;
impl ActiveSet for Even {
const MASK: u32 = 0x55555555;
const NAME: &'static str = "Even";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct Odd;
impl ActiveSet for Odd {
const MASK: u32 = 0xAAAAAAAA;
const NAME: &'static str = "Odd";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct LowHalf;
impl ActiveSet for LowHalf {
const MASK: u32 = 0x0000FFFF;
const NAME: &'static str = "LowHalf";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct HighHalf;
impl ActiveSet for HighHalf {
const MASK: u32 = 0xFFFF0000;
const NAME: &'static str = "HighHalf";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct Lane0;
impl ActiveSet for Lane0 {
const MASK: u32 = 0x00000001;
const NAME: &'static str = "Lane0";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct NotLane0;
impl ActiveSet for NotLane0 {
const MASK: u32 = 0xFFFFFFFE;
const NAME: &'static str = "NotLane0";
}
impl ComplementOf<Odd> for Even {}
impl ComplementOf<Even> for Odd {}
impl ComplementOf<HighHalf> for LowHalf {}
impl ComplementOf<LowHalf> for HighHalf {}
impl ComplementOf<NotLane0> for Lane0 {}
impl ComplementOf<Lane0> for NotLane0 {}
impl ComplementOf<None> for All {}
impl ComplementOf<All> for None {}
#[derive(Copy, Clone)]
pub struct Warp<S: ActiveSet> {
_phantom: PhantomData<S>,
}
impl<S: ActiveSet> Warp<S> {
pub fn new() -> Self {
Warp {
_phantom: PhantomData,
}
}
pub fn active_set_name(&self) -> &'static str {
S::NAME
}
pub fn active_mask(&self) -> u32 {
S::MASK
}
}
impl<S: ActiveSet> Default for Warp<S> {
fn default() -> Self {
Self::new()
}
}
impl Warp<All> {
pub fn diverge_even_odd(self) -> (Warp<Even>, Warp<Odd>) {
(Warp::new(), Warp::new())
}
pub fn diverge_halves(self) -> (Warp<LowHalf>, Warp<HighHalf>) {
(Warp::new(), Warp::new())
}
pub fn extract_lane0(self) -> (Warp<Lane0>, Warp<NotLane0>) {
(Warp::new(), Warp::new())
}
}
impl Warp<Even> {
pub fn diverge_halves(self) -> (Warp<EvenLow>, Warp<EvenHigh>) {
(Warp::new(), Warp::new())
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct EvenLow;
impl ActiveSet for EvenLow {
const MASK: u32 = 0x00005555; const NAME: &'static str = "EvenLow";
}
#[derive(Copy, Clone, Debug, Default)]
pub struct EvenHigh;
impl ActiveSet for EvenHigh {
const MASK: u32 = 0x55550000; const NAME: &'static str = "EvenHigh";
}
impl ComplementOf<EvenHigh> for EvenLow {}
impl ComplementOf<EvenLow> for EvenHigh {}
pub fn merge<S1, S2>(_left: Warp<S1>, _right: Warp<S2>) -> Warp<All>
where
S1: ComplementOf<S2>,
S2: ActiveSet,
{
Warp::new()
}
pub fn merge_to_even(_left: Warp<EvenLow>, _right: Warp<EvenHigh>) -> Warp<Even> {
Warp::new()
}
#[derive(Copy, Clone, Debug)]
pub struct PerLane<T>(pub T);
impl Warp<All> {
pub fn shuffle_xor<T: Copy>(&self, data: PerLane<T>, _mask: u32) -> PerLane<T> {
data
}
pub fn shuffle_down<T: Copy>(&self, data: PerLane<T>, _delta: u32) -> PerLane<T> {
data
}
pub fn reduce_sum<T: Copy + std::ops::Add<Output = T>>(&self, data: PerLane<T>) -> T {
data.0
}
pub fn broadcast<T: Copy>(&self, value: T) -> PerLane<T> {
PerLane(value)
}
}
impl<S: ActiveSet> Warp<S> {
pub fn sync(&self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diverge_merge_even_odd() {
let all: Warp<All> = Warp::new();
let (evens, odds) = all.diverge_even_odd();
assert_eq!(evens.active_set_name(), "Even");
assert_eq!(odds.active_set_name(), "Odd");
let merged: Warp<All> = merge(evens, odds);
assert_eq!(merged.active_set_name(), "All");
}
#[test]
fn test_diverge_merge_halves() {
let all: Warp<All> = Warp::new();
let (low, high) = all.diverge_halves();
let merged = merge(low, high);
assert_eq!(merged.active_mask(), 0xFFFFFFFF);
}
#[test]
fn test_nested_diverge() {
let all: Warp<All> = Warp::new();
let (evens, odds) = all.diverge_even_odd();
let (even_low, even_high) = evens.diverge_halves();
assert_eq!(even_low.active_mask(), 0x00005555);
assert_eq!(even_high.active_mask(), 0x55550000);
let evens_restored: Warp<Even> = merge_to_even(even_low, even_high);
assert_eq!(evens_restored.active_set_name(), "Even");
let all_restored = merge(evens_restored, odds);
assert_eq!(all_restored.active_set_name(), "All");
}
#[test]
fn test_shuffle_only_on_all() {
let all: Warp<All> = Warp::new();
let data = PerLane(42i32);
let _shuffled = all.shuffle_xor(data, 1);
}
#[test]
fn test_reduce_only_on_all() {
let all: Warp<All> = Warp::new();
let data = PerLane(1i32);
let _sum = all.reduce_sum(data);
}
#[test]
fn test_sync_on_any_active_set() {
let all: Warp<All> = Warp::new();
let (evens, odds) = all.diverge_even_odd();
all.sync(); evens.sync(); odds.sync(); }
}