use crate::active_set::All;
use crate::warp::Warp;
#[must_use = "a DynDiverge must be merged to recover Warp<All> — dropping it loses both branches"]
pub struct DynDiverge {
true_mask: u64,
false_mask: u64,
parent_mask: u64,
}
impl DynDiverge {
pub fn true_mask(&self) -> u64 {
self.true_mask
}
pub fn false_mask(&self) -> u64 {
self.false_mask
}
pub fn true_count(&self) -> u32 {
self.true_mask.count_ones()
}
pub fn false_count(&self) -> u32 {
self.false_mask.count_ones()
}
pub fn merge(self) -> Warp<All> {
assert_eq!(
self.true_mask | self.false_mask,
self.parent_mask,
"DynDiverge invariant violated: true_mask | false_mask != parent_mask \
(0x{:016X} | 0x{:016X} = 0x{:016X}, expected 0x{:016X})",
self.true_mask,
self.false_mask,
self.true_mask | self.false_mask,
self.parent_mask,
);
Warp::new()
}
pub fn with_branches<F1, F2>(self, true_fn: F1, false_fn: F2) -> Warp<All>
where
F1: FnOnce(u64),
F2: FnOnce(u64),
{
true_fn(self.true_mask);
false_fn(self.false_mask);
self.merge()
}
}
impl Warp<All> {
pub fn diverge_dynamic(self, predicate_mask: u64) -> DynDiverge {
let all_mask = self.active_mask();
assert!(
predicate_mask & !all_mask == 0,
"diverge_dynamic: predicate_mask 0x{:016X} has bits outside warp mask 0x{:016X}",
predicate_mask,
all_mask,
);
DynDiverge {
true_mask: all_mask & predicate_mask,
false_mask: all_mask & !predicate_mask,
parent_mask: all_mask,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::active_set::{ActiveSet, Even, HighHalf, LowHalf, Odd};
use crate::data::PerLane;
#[test]
fn test_diverge_dynamic_masks() {
let warp: Warp<All> = Warp::kernel_entry();
let diverged = warp.diverge_dynamic(LowHalf::MASK);
assert_eq!(diverged.true_mask(), LowHalf::MASK);
assert_eq!(diverged.false_mask(), HighHalf::MASK);
assert_eq!(diverged.true_count(), crate::WARP_SIZE / 2);
assert_eq!(diverged.false_count(), crate::WARP_SIZE / 2);
assert_eq!(diverged.true_mask() | diverged.false_mask(), All::MASK);
assert_eq!(diverged.true_mask() & diverged.false_mask(), 0);
}
#[test]
fn test_diverge_dynamic_merge() {
let warp: Warp<All> = Warp::kernel_entry();
let diverged = warp.diverge_dynamic(Even::MASK);
let merged = diverged.merge();
assert_eq!(merged.active_mask(), All::MASK);
let data = PerLane::new(1i32);
let _result = merged.shuffle_xor(data, 1);
}
#[test]
fn test_diverge_dynamic_with_branches() {
let warp: Warp<All> = Warp::kernel_entry();
let mut true_seen = 0u64;
let mut false_seen = 0u64;
let even_mask = Even::MASK;
let odd_mask = Odd::MASK;
let merged = warp.diverge_dynamic(even_mask).with_branches(
|t| {
true_seen = t;
},
|f| {
false_seen = f;
},
);
assert_eq!(true_seen, even_mask);
assert_eq!(false_seen, odd_mask);
assert_eq!(merged.population(), crate::WARP_SIZE);
}
#[test]
fn test_diverge_dynamic_empty_branch() {
let warp: Warp<All> = Warp::kernel_entry();
let diverged = warp.diverge_dynamic(All::MASK);
assert_eq!(diverged.true_count(), crate::WARP_SIZE);
assert_eq!(diverged.false_count(), 0);
let merged = diverged.merge();
let data = PerLane::new(1i32);
let _ = merged.reduce_sum(data);
}
#[test]
fn test_diverge_dynamic_arbitrary_predicate() {
let warp: Warp<All> = Warp::kernel_entry();
let predicate_mask = HighHalf::MASK;
let diverged = warp.diverge_dynamic(predicate_mask);
assert_eq!(diverged.true_count(), crate::WARP_SIZE / 2);
assert_eq!(diverged.false_count(), crate::WARP_SIZE / 2);
let warp = diverged.merge();
let _ = warp.reduce_sum(PerLane::new(1i32));
}
fn generic_helper<S: crate::active_set::ActiveSet>(warp: &Warp<S>) -> u32 {
warp.population()
}
fn all_only_helper(warp: &Warp<All>, data: PerLane<i32>) -> i32 {
warp.reduce_sum(data).get()
}
#[test]
fn test_cross_function_inference() {
let warp: Warp<All> = Warp::kernel_entry();
assert_eq!(generic_helper(&warp), crate::WARP_SIZE);
let sum = all_only_helper(&warp, PerLane::new(1i32));
assert_eq!(sum, crate::WARP_SIZE as i32);
let (evens, odds) = warp.diverge_even_odd();
assert_eq!(generic_helper(&evens), crate::WARP_SIZE / 2);
assert_eq!(generic_helper(&odds), crate::WARP_SIZE / 2);
}
}