use std::marker::PhantomData;
use crate::active_set::All;
use crate::data::PerLane;
use crate::warp::Warp;
pub mod forbid_warp_ops {
use super::*;
pub fn varying_loop<F>(warp: Warp<All>, trip_counts: PerLane<u32>, mut body: F) -> Warp<All>
where
F: FnMut(u32, u32), {
for i in 0..trip_counts.get() {
body(0, i);
}
warp }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_varying_loop_compiles() {
let warp: Warp<All> = Warp::new();
let trips = PerLane::new(10u32);
let warp_after = varying_loop(warp, trips, |_lane, iter| {
let _ = iter * 2;
});
assert_eq!(warp_after.active_set_name(), "All");
}
}
}
pub mod uniform_with_mask {
use super::*;
pub fn uniform_loop<F>(
warp: Warp<All>,
max_iters: u32,
done_at: PerLane<u32>, mut body: F,
) -> Warp<All>
where
F: FnMut(&Warp<All>, u32, bool), {
for i in 0..max_iters {
let is_active = i < done_at.get(); body(&warp, i, is_active);
}
warp
}
pub fn example_with_warp_ops(warp: Warp<All>) {
let done_at = PerLane::new(5u32);
let _ = uniform_loop(warp, 10, done_at, |w, iter, is_active| {
if is_active {
let _ = w.broadcast(iter);
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_loop() {
let warp: Warp<All> = Warp::new();
let done_at = PerLane::new(3u32);
let mut count = 0;
let _ = uniform_loop(warp, 5, done_at, |_w, _iter, is_active| {
if is_active {
count += 1;
}
});
assert_eq!(count, 3); }
}
}
pub mod phased_loop {
use super::*;
pub fn phased_loop<F1, F2>(
warp: Warp<All>,
trip_counts: PerLane<u32>,
min_trips: u32, mut uniform_body: F1,
mut cleanup_body: F2,
) -> Warp<All>
where
F1: FnMut(&Warp<All>, u32), F2: FnMut(u32, u32), {
for i in 0..min_trips {
uniform_body(&warp, i);
}
let remaining = trip_counts.get().saturating_sub(min_trips);
for i in 0..remaining {
cleanup_body(0, min_trips + i);
}
warp
}
pub fn compute_min_trips(_warp: &Warp<All>, trip_counts: PerLane<u32>) -> u32 {
trip_counts.get() }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_phased_loop() {
let warp: Warp<All> = Warp::new();
let trips = PerLane::new(10u32);
let min_trips = 5;
let mut uniform_count = 0;
let mut cleanup_count = 0;
let _ = phased_loop(
warp,
trips,
min_trips,
|_w, _i| {
uniform_count += 1;
},
|_lane, _i| {
cleanup_count += 1;
},
);
assert_eq!(uniform_count, 5); assert_eq!(cleanup_count, 5); }
}
}
pub mod work_redistribution {
use super::*;
pub fn redistribute_work<T: Copy + Default, F>(
warp: &Warp<All>,
items_per_lane: PerLane<u32>,
get_item: impl Fn(u32, u32) -> T, process: F,
) where
F: Fn(T),
{
let _total_items = warp.reduce_sum(items_per_lane);
for i in 0..items_per_lane.get() {
let item = get_item(0, i);
process(item);
}
}
pub fn example_variable_arrays(warp: &Warp<All>) {
let array_lengths = PerLane::new(5u32);
redistribute_work(
warp,
array_lengths,
|lane, idx| (lane, idx), |item| {
let (_lane, _idx) = item;
},
);
}
}
pub mod effect_system {
use super::*;
pub trait MayDiverge {}
pub trait Uniform {}
pub struct Computation<E, T> {
_effect: PhantomData<E>,
value: T,
}
#[derive(Copy, Clone)]
pub struct UniformEffect;
impl Uniform for UniformEffect {}
#[derive(Copy, Clone)]
pub struct DivergentEffect;
impl MayDiverge for DivergentEffect {}
pub fn shuffle<T: Copy>(
_warp: &Warp<All>,
data: Computation<UniformEffect, T>,
) -> Computation<UniformEffect, T> {
Computation {
_effect: PhantomData,
value: data.value,
}
}
pub fn varying_loop<T>(body: impl Fn() -> T) -> Computation<DivergentEffect, T> {
Computation {
_effect: PhantomData,
value: body(),
}
}
pub fn synchronize<T>(comp: Computation<DivergentEffect, T>) -> Computation<UniformEffect, T> {
Computation {
_effect: PhantomData,
value: comp.value,
}
}
}
pub mod recursive_session {
}
#[cfg(test)]
mod integration_tests {
use super::forbid_warp_ops::varying_loop;
use super::phased_loop::phased_loop;
use super::uniform_with_mask::uniform_loop;
use super::*;
#[test]
fn test_all_approaches_preserve_warp_all() {
let warp: Warp<All> = Warp::new();
let warp = varying_loop(warp, PerLane::new(5), |_, _| {});
assert_eq!(warp.active_set_name(), "All");
let warp = uniform_loop(warp, 5, PerLane::new(3), |_, _, _| {});
assert_eq!(warp.active_set_name(), "All");
let warp = phased_loop(warp, PerLane::new(10), 5, |_, _| {}, |_, _| {});
assert_eq!(warp.active_set_name(), "All");
}
}