use std::marker::PhantomData;
pub mod divergent_values {
use super::*;
pub trait ActiveSet {
const MASK: u32;
fn name() -> &'static str;
}
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
fn name() -> &'static str {
"All"
}
}
pub struct Even;
impl ActiveSet for Even {
const MASK: u32 = 0x55555555;
fn name() -> &'static str {
"Even"
}
}
pub struct Odd;
impl ActiveSet for Odd {
const MASK: u32 = 0xAAAAAAAA;
fn name() -> &'static str {
"Odd"
}
}
#[derive(Clone, Copy)]
pub struct Divergent<T, S: ActiveSet> {
value: T,
_marker: PhantomData<S>,
}
impl<T, S: ActiveSet> Divergent<T, S> {
pub fn new(value: T) -> Self {
Divergent {
value,
_marker: PhantomData,
}
}
pub fn get(&self) -> &T {
&self.value
}
pub fn valid_mask() -> u32 {
S::MASK
}
}
pub trait CanShuffle<Perm> {
type Output;
}
pub struct XorPerm<const MASK: u32>;
impl<T> CanShuffle<XorPerm<1>> for Divergent<T, All> {
type Output = Divergent<T, All>;
}
pub fn shuffle_within<T: Copy, S: ActiveSet, const MASK: u32>(
data: Divergent<T, S>,
) -> Divergent<T, S>
where
{
data
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_divergent_values() {
let _all_valid: Divergent<i32, All> = Divergent::new(42);
let _even_valid: Divergent<i32, Even> = Divergent::new(42);
assert_eq!(Divergent::<i32, All>::valid_mask(), 0xFFFFFFFF);
assert_eq!(Divergent::<i32, Even>::valid_mask(), 0x55555555);
}
}
}
pub mod explicit_mask {
#[derive(Clone, Copy)]
pub struct PerLane<T>(pub T);
pub fn shuffle_xor_masked<T: Copy>(
data: PerLane<T>,
xor_mask: u32,
valid_mask: u32,
default: T,
lane_id: u32,
) -> T {
let source_lane = lane_id ^ xor_mask;
let source_valid = (valid_mask >> source_lane) & 1 != 0;
if source_valid {
data.0 } else {
default
}
}
pub fn shuffle_xor_checked<T: Copy>(
data: PerLane<T>,
xor_mask: u32,
valid_mask: u32,
lane_id: u32,
) -> Option<T> {
let source_lane = lane_id ^ xor_mask;
let source_valid = (valid_mask >> source_lane) & 1 != 0;
if source_valid {
Some(data.0)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_masked_shuffle() {
let data = PerLane(42);
let even_mask = 0x55555555u32;
let result = shuffle_xor_checked(data, 1, even_mask, 0);
assert!(result.is_none());
let result = shuffle_xor_checked(data, 2, even_mask, 0);
assert!(result.is_some());
}
}
}
pub mod sentinel {
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum MaybeValid<T> {
Valid(T),
Invalid, }
impl<T> MaybeValid<T> {
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> MaybeValid<U> {
match self {
MaybeValid::Valid(x) => MaybeValid::Valid(f(x)),
MaybeValid::Invalid => MaybeValid::Invalid,
}
}
pub fn unwrap_or(self, default: T) -> T {
match self {
MaybeValid::Valid(x) => x,
MaybeValid::Invalid => default,
}
}
}
pub fn diverge_with_sentinel<T>(value: T, lane_id: u32, active_mask: u32) -> MaybeValid<T> {
if (active_mask >> lane_id) & 1 != 0 {
MaybeValid::Valid(value)
} else {
MaybeValid::Invalid
}
}
pub fn shuffle_xor_sentinel<T: Copy>(
data: MaybeValid<T>,
_xor_mask: u32,
_lane_id: u32,
) -> MaybeValid<T> {
data }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sentinel_propagation() {
let valid = MaybeValid::Valid(42);
let invalid: MaybeValid<i32> = MaybeValid::Invalid;
assert_eq!(valid.map(|x| x * 2), MaybeValid::Valid(84));
assert_eq!(invalid.map(|x| x * 2), MaybeValid::Invalid);
assert_eq!(valid.unwrap_or(0), 42);
assert_eq!(invalid.unwrap_or(0), 0);
}
}
}
pub mod warp_restricted {
use super::*;
pub trait ActiveSet {
const MASK: u32;
}
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
}
pub struct Even;
impl ActiveSet for Even {
const MASK: u32 = 0x55555555;
}
pub struct Odd;
impl ActiveSet for Odd {
const MASK: u32 = 0xAAAAAAAA;
}
pub struct Warp<S: ActiveSet> {
_marker: PhantomData<S>,
}
impl<S: ActiveSet> Warp<S> {
pub fn new() -> Self {
Warp {
_marker: PhantomData,
}
}
}
impl Warp<All> {
pub fn shuffle_xor<T: Copy>(&self, data: T, _mask: u32) -> T {
data }
pub fn shuffle_idx<T: Copy>(&self, data: T, _source: u32) -> T {
data }
}
impl Warp<Even> {
pub fn shuffle_xor_within<T: Copy>(&self, data: T, mask: u32) -> Option<T> {
if mask % 2 == 0 {
Some(data) } else {
None }
}
pub fn broadcast_from_0<T: Copy>(&self, data: T) -> T {
data
}
}
impl Warp<Odd> {
pub fn shuffle_xor_within<T: Copy>(&self, data: T, mask: u32) -> Option<T> {
if mask % 2 == 0 {
Some(data) } else {
None }
}
pub fn broadcast_from_1<T: Copy>(&self, data: T) -> T {
data
}
}
pub fn merge<S1: ActiveSet, S2: ActiveSet>(_left: Warp<S1>, _right: Warp<S2>) -> Warp<All> {
Warp::new()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_warp_all_has_full_shuffle() {
let warp: Warp<All> = Warp::new();
let data = 42;
let _ = warp.shuffle_xor(data, 1);
let _ = warp.shuffle_xor(data, 5);
let _ = warp.shuffle_xor(data, 31);
}
#[test]
fn test_warp_even_restricted() {
let warp: Warp<Even> = Warp::new();
let data = 42;
assert!(warp.shuffle_xor_within(data, 2).is_some());
assert!(warp.shuffle_xor_within(data, 4).is_some());
assert!(warp.shuffle_xor_within(data, 1).is_none());
assert!(warp.shuffle_xor_within(data, 3).is_none());
}
#[test]
fn test_must_merge_for_full_shuffle() {
let warp_even: Warp<Even> = Warp::new();
let warp_odd: Warp<Odd> = Warp::new();
let warp_all = merge(warp_even, warp_odd);
let _ = warp_all.shuffle_xor(42, 1);
}
}
}
pub mod hybrid {
use super::*;
pub trait ActiveSet: Copy {
const MASK: u32;
}
#[derive(Copy, Clone)]
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
}
#[derive(Copy, Clone)]
pub struct Even;
impl ActiveSet for Even {
const MASK: u32 = 0x55555555;
}
#[derive(Clone, Copy)]
pub struct TrackedValue<T, S: ActiveSet> {
value: T,
runtime_mask: u32,
_marker: PhantomData<S>,
}
impl<T, S: ActiveSet> TrackedValue<T, S> {
pub fn new(value: T) -> Self {
TrackedValue {
value,
runtime_mask: S::MASK,
_marker: PhantomData,
}
}
pub fn verify(&self) -> bool {
self.runtime_mask == S::MASK
}
}
pub fn checked_shuffle_xor<T: Copy, S: ActiveSet>(
data: TrackedValue<T, S>,
xor_mask: u32,
lane_id: u32,
) -> Result<T, &'static str> {
let source_lane = lane_id ^ xor_mask;
if (data.runtime_mask >> source_lane) & 1 == 0 {
return Err("Source lane is inactive");
}
Ok(data.value)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_checking() {
let data: TrackedValue<i32, Even> = TrackedValue::new(42);
assert!(checked_shuffle_xor(data, 1, 0).is_err());
assert!(checked_shuffle_xor(data, 2, 0).is_ok());
}
}
}
#[cfg(test)]
mod integration_tests {
use super::warp_restricted::*;
#[test]
fn test_recommended_pattern() {
let warp: Warp<All> = Warp::new();
let data = 42;
let _ = warp.shuffle_xor(data, 1);
let warp_even: Warp<Even> = Warp::new(); let warp_odd: Warp<Odd> = Warp::new();
let warp_merged = merge(warp_even, warp_odd);
let _ = warp_merged.shuffle_xor(data, 1); }
}