use std::marker::PhantomData;
pub mod existential {
pub trait ActiveSet {
fn mask(&self) -> u32;
}
pub struct SomeWarp {
mask: u32,
}
impl SomeWarp {
pub fn from_predicate<F: Fn(u32) -> bool>(pred: F) -> Self {
let mut mask = 0u32;
for lane in 0..32 {
if pred(lane) {
mask |= 1 << lane;
}
}
SomeWarp { mask }
}
pub fn mask(&self) -> u32 {
self.mask
}
pub fn population(&self) -> u32 {
self.mask.count_ones()
}
pub fn complements(&self, other: &SomeWarp) -> bool {
(self.mask & other.mask) == 0 && (self.mask | other.mask) == 0xFFFFFFFF
}
}
pub fn diverge_arbitrary<F: Fn(u32) -> bool>(pred: F) -> (SomeWarp, SomeWarp) {
let true_branch = SomeWarp::from_predicate(&pred);
let false_branch = SomeWarp::from_predicate(|lane| !pred(lane));
(true_branch, false_branch)
}
pub fn merge_checked(left: SomeWarp, right: SomeWarp) -> Result<SomeWarp, &'static str> {
if left.complements(&right) {
Ok(SomeWarp {
mask: left.mask | right.mask,
})
} else {
Err("Warps are not complementary")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arbitrary_predicate() {
let threshold = 10u32;
let (below, above) = diverge_arbitrary(|lane| lane < threshold);
assert_eq!(below.population(), 10); assert_eq!(above.population(), 22); assert!(below.complements(&above));
}
#[test]
fn test_merge_checked() {
let (a, b) = diverge_arbitrary(|lane| lane % 3 == 0);
assert!(merge_checked(a, b).is_ok());
let (c, _) = diverge_arbitrary(|lane| lane < 5);
let (d, _) = diverge_arbitrary(|lane| lane < 10);
assert_ne!(
c.mask, d.mask,
"overlapping predicates produce different masks"
);
assert!(
merge_checked(c, d).is_err(),
"overlapping warps should fail merge"
);
}
}
}
pub mod refinement {
use super::*;
pub trait LanePredicate: Copy {
fn test(lane: u32) -> bool;
fn name() -> &'static str;
}
pub struct RefinedWarp<P: LanePredicate> {
_marker: PhantomData<P>,
}
impl<P: LanePredicate> RefinedWarp<P> {
pub fn new() -> Self {
RefinedWarp {
_marker: PhantomData,
}
}
pub fn mask() -> u32 {
let mut m = 0u32;
for lane in 0..32 {
if P::test(lane) {
m |= 1 << lane;
}
}
m
}
}
pub struct Not<P: LanePredicate>(PhantomData<P>);
impl<P: LanePredicate> Copy for Not<P> {}
impl<P: LanePredicate> Clone for Not<P> {
fn clone(&self) -> Self {
*self
}
}
impl<P: LanePredicate> LanePredicate for Not<P> {
fn test(lane: u32) -> bool {
!P::test(lane)
}
fn name() -> &'static str {
"Not<P>" }
}
pub fn diverge<P: LanePredicate>() -> (RefinedWarp<P>, RefinedWarp<Not<P>>) {
(RefinedWarp::new(), RefinedWarp::new())
}
pub fn merge<P: LanePredicate>(
_left: RefinedWarp<P>,
_right: RefinedWarp<Not<P>>,
) -> RefinedWarp<All> {
RefinedWarp::new()
}
#[derive(Copy, Clone)]
pub struct All;
impl LanePredicate for All {
fn test(_: u32) -> bool {
true
}
fn name() -> &'static str {
"All"
}
}
#[derive(Copy, Clone)]
pub struct LessThan<const N: u32>;
impl<const N: u32> LanePredicate for LessThan<N> {
fn test(lane: u32) -> bool {
lane < N
}
fn name() -> &'static str {
"LessThan<N>"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_refinement_static() {
let (below, above) = diverge::<LessThan<10>>();
assert_eq!(RefinedWarp::<LessThan<10>>::mask(), 0x000003FF); assert_eq!(RefinedWarp::<Not<LessThan<10>>>::mask(), 0xFFFFFC00);
let _all = merge(below, above);
}
}
}
pub mod indexed {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct PredicateId(u32);
pub struct IndexedWarp {
predicate_id: PredicateId,
mask: u32,
}
pub struct PredicateRegistry {
next_id: u32,
masks: std::collections::HashMap<PredicateId, u32>,
complements: std::collections::HashMap<PredicateId, PredicateId>,
}
impl PredicateRegistry {
pub fn new() -> Self {
PredicateRegistry {
next_id: 0,
masks: std::collections::HashMap::new(),
complements: std::collections::HashMap::new(),
}
}
pub fn register<F: Fn(u32) -> bool>(&mut self, pred: F) -> (PredicateId, PredicateId) {
let id_true = PredicateId(self.next_id);
let id_false = PredicateId(self.next_id + 1);
self.next_id += 2;
let mut mask_true = 0u32;
for lane in 0..32 {
if pred(lane) {
mask_true |= 1 << lane;
}
}
let mask_false = !mask_true;
self.masks.insert(id_true, mask_true);
self.masks.insert(id_false, mask_false);
self.complements.insert(id_true, id_false);
self.complements.insert(id_false, id_true);
(id_true, id_false)
}
pub fn are_complements(&self, a: PredicateId, b: PredicateId) -> bool {
self.complements.get(&a) == Some(&b)
}
pub fn mask(&self, id: PredicateId) -> u32 {
self.masks.get(&id).copied().unwrap_or(0)
}
}
impl IndexedWarp {
pub fn new(id: PredicateId, mask: u32) -> Self {
IndexedWarp {
predicate_id: id,
mask,
}
}
pub fn id(&self) -> PredicateId {
self.predicate_id
}
pub fn mask(&self) -> u32 {
self.mask
}
}
pub fn merge_indexed(
registry: &PredicateRegistry,
left: IndexedWarp,
right: IndexedWarp,
) -> Result<IndexedWarp, &'static str> {
if registry.are_complements(left.id(), right.id()) {
Ok(IndexedWarp::new(
PredicateId(u32::MAX), left.mask | right.mask,
))
} else {
Err("Predicates are not registered complements")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_indexed_predicates() {
let mut registry = PredicateRegistry::new();
let threshold = 15u32; let (below_id, above_id) = registry.register(|lane| lane < threshold);
let below = IndexedWarp::new(below_id, registry.mask(below_id));
let above = IndexedWarp::new(above_id, registry.mask(above_id));
assert!(registry.are_complements(below_id, above_id));
assert!(merge_indexed(®istry, below, above).is_ok());
}
}
}
pub mod hybrid_shape {
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum Shape {
All,
LowRange,
HighRange,
Arbitrary,
}
pub struct ShapedWarp {
shape: Shape,
mask: u32,
}
impl ShapedWarp {
pub fn all() -> Self {
ShapedWarp {
shape: Shape::All,
mask: 0xFFFFFFFF,
}
}
pub fn shape(&self) -> Shape {
self.shape
}
pub fn mask(&self) -> u32 {
self.mask
}
}
pub fn diverge_by_threshold(threshold: u32) -> (ShapedWarp, ShapedWarp) {
let low_mask = if threshold >= 32 {
u32::MAX
} else {
(1u32 << threshold) - 1
};
let high_mask = !low_mask;
(
ShapedWarp {
shape: Shape::LowRange,
mask: low_mask,
},
ShapedWarp {
shape: Shape::HighRange,
mask: high_mask,
},
)
}
pub fn merge_shaped(left: ShapedWarp, right: ShapedWarp) -> Result<ShapedWarp, &'static str> {
match (left.shape, right.shape) {
(Shape::LowRange, Shape::HighRange) | (Shape::HighRange, Shape::LowRange) => {
Ok(ShapedWarp {
shape: Shape::All,
mask: left.mask | right.mask,
})
}
(Shape::Arbitrary, Shape::Arbitrary) => {
if (left.mask & right.mask) == 0 && (left.mask | right.mask) == 0xFFFFFFFF {
Ok(ShapedWarp {
shape: Shape::All,
mask: 0xFFFFFFFF,
})
} else {
Err("Arbitrary shapes don't complement")
}
}
_ => Err("Incompatible shapes"),
}
}
impl ShapedWarp {
pub fn shuffle_xor(&self, data: i32, _mask: u32) -> Option<i32> {
match self.shape {
Shape::All => Some(data), _ => None,
}
}
pub fn broadcast_first(&self, data: i32) -> Option<i32> {
match self.shape {
Shape::All | Shape::LowRange | Shape::HighRange => Some(data),
Shape::Arbitrary => None, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape_merge() {
let threshold = 20u32; let (low, high) = diverge_by_threshold(threshold);
assert_eq!(low.shape(), Shape::LowRange);
assert_eq!(high.shape(), Shape::HighRange);
let merged = merge_shaped(low, high).unwrap();
assert_eq!(merged.shape(), Shape::All);
}
#[test]
fn test_shape_operations() {
let (low, _high) = diverge_by_threshold(16);
assert!(low.shuffle_xor(42, 1).is_none());
assert!(low.broadcast_first(42).is_some());
}
}
}
pub mod dependent_sketch {
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_layered_approach() {
let threshold = 20u32;
let (low, high) = hybrid_shape::diverge_by_threshold(threshold);
let merged = hybrid_shape::merge_shaped(low, high).unwrap();
assert_eq!(merged.shape(), hybrid_shape::Shape::All);
let mut registry = indexed::PredicateRegistry::new();
let (id_a, id_b) = registry.register(|lane| lane % 5 == 0);
assert!(registry.are_complements(id_a, id_b));
let (some_a, some_b) = existential::diverge_arbitrary(|lane| lane.count_ones() % 2 == 0);
assert!(some_a.complements(&some_b));
}
}