use std::marker::PhantomData;
pub trait ActiveSet: Copy + 'static {
const MASK: u32;
}
#[derive(Copy, Clone)]
pub struct All;
impl ActiveSet for All {
const MASK: u32 = 0xFFFFFFFF;
}
#[derive(Copy, Clone)]
pub struct Warp<S: ActiveSet> {
_marker: PhantomData<S>,
}
impl<S: ActiveSet> Warp<S> {
pub fn new() -> Self {
Warp {
_marker: PhantomData,
}
}
}
#[derive(Copy, Clone)]
pub struct PerLane<T>(pub [T; 32]);
pub mod ballot_exit {
use super::*;
pub fn ballot(_warp: &Warp<All>, pred: PerLane<bool>) -> u32 {
let mut mask = 0u32;
for lane in 0..32 {
if pred.0[lane] {
mask |= 1 << lane;
}
}
mask
}
pub fn all_done(warp: &Warp<All>, done: PerLane<bool>) -> bool {
ballot(warp, done) == 0xFFFFFFFF
}
pub fn any_done(warp: &Warp<All>, done: PerLane<bool>) -> bool {
ballot(warp, done) != 0
}
pub fn search_ballot<F>(warp: Warp<All>, mut step: F) -> (Warp<All>, PerLane<bool>)
where
F: FnMut(usize) -> PerLane<bool>, {
let mut found = PerLane([false; 32]);
let mut iter = 0;
while !all_done(&warp, found) {
let new_found = step(iter);
for lane in 0..32 {
found.0[lane] = found.0[lane] || new_found.0[lane];
}
iter += 1;
}
(warp, found) }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ballot_exit() {
let warp: Warp<All> = Warp::new();
let (warp_out, found) = search_ballot(warp, |iter| {
let mut result = [false; 32];
for lane in 0..32 {
result[lane] = iter >= lane; }
PerLane(result)
});
assert!(found.0.iter().all(|&f| f));
let _: Warp<All> = warp_out;
}
}
}
pub mod reducing_exit {
use super::*;
pub struct ReducingWarp {
active_mask: u32,
}
impl ReducingWarp {
pub fn new() -> Self {
ReducingWarp {
active_mask: 0xFFFFFFFF,
}
}
pub fn active_mask(&self) -> u32 {
self.active_mask
}
pub fn any_active(&self) -> bool {
self.active_mask != 0
}
pub fn all_exited(&self) -> bool {
self.active_mask == 0
}
pub fn exit_lanes(&mut self, exiting: u32) {
self.active_mask &= !exiting;
}
}
pub fn search_reducing<F>(mut step: F) -> PerLane<bool>
where
F: FnMut(usize, u32) -> u32, {
let mut warp = ReducingWarp::new();
let mut found = PerLane([false; 32]);
let mut iter = 0;
while warp.any_active() {
let newly_found = step(iter, warp.active_mask());
for lane in 0..32 {
if newly_found & (1 << lane) != 0 {
found.0[lane] = true;
}
}
warp.exit_lanes(newly_found);
iter += 1;
}
found
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reducing_exit() {
let found = search_reducing(|iter, active| {
let mut exiting = 0u32;
for lane in 0..32 {
if active & (1 << lane) != 0 && iter == lane {
exiting |= 1 << lane;
}
}
exiting
});
assert!(found.0.iter().all(|&f| f));
}
}
}
pub mod work_redistribution {
use super::*;
#[derive(Copy, Clone, Debug)]
pub struct WorkItem {
pub lane: usize,
pub data: i32,
}
pub fn redistribute_work(
_warp: &Warp<All>,
_done: PerLane<bool>,
work: PerLane<Option<WorkItem>>,
) -> PerLane<Option<WorkItem>> {
work
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_redistribute_placeholder() {
let warp: Warp<All> = Warp::new();
let done = PerLane([false; 32]);
let work = PerLane([None; 32]);
let _redistributed = redistribute_work(&warp, done, work);
}
}
}
pub mod existential_exit {
pub struct BoundedWarp {
active_mask: u32,
max_active: usize, }
impl BoundedWarp {
pub fn new() -> Self {
BoundedWarp {
active_mask: 0xFFFFFFFF,
max_active: 32,
}
}
pub fn max_active(&self) -> usize {
self.max_active
}
pub fn exit_lanes(&mut self, exiting: u32) {
let count = exiting.count_ones() as usize;
self.active_mask &= !exiting;
self.max_active = self.max_active.saturating_sub(count);
}
pub fn all_exited(&self) -> bool {
self.max_active == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bounded_warp() {
let mut warp = BoundedWarp::new();
assert_eq!(warp.max_active(), 32);
warp.exit_lanes(0x0000FFFF); assert_eq!(warp.max_active(), 16);
warp.exit_lanes(0xFFFF0000); assert!(warp.all_exited());
}
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_ballot_preserves_warp_all() {
let warp: Warp<All> = Warp::new();
let (warp_out, _) = ballot_exit::search_ballot(warp, |iter| PerLane([iter > 10; 32]));
let _: Warp<All> = warp_out;
}
#[test]
fn test_reducing_no_warp_type() {
let found =
reducing_exit::search_reducing(|iter, _| if iter < 32 { 1u32 << iter } else { 0 });
assert!(found.0.iter().all(|&f| f));
}
}