use crate::active_set::All;
use crate::data::PerLane;
use crate::gpu::GpuShuffle;
use crate::warp::Warp;
use crate::GpuValue;
impl Warp<All> {
pub fn reduce<T, F>(&self, data: PerLane<T>, op: F) -> T
where
T: GpuValue + GpuShuffle,
F: Fn(T, T) -> T,
{
let mut val = data.get();
#[cfg(feature = "warp64")]
{
val = op(val, val.gpu_shfl_xor(32));
}
val = op(val, val.gpu_shfl_xor(16));
val = op(val, val.gpu_shfl_xor(8));
val = op(val, val.gpu_shfl_xor(4));
val = op(val, val.gpu_shfl_xor(2));
val = op(val, val.gpu_shfl_xor(1));
val
}
#[deprecated(
note = "Not correct on any target — Hillis-Steele without lane_id guard. Use SimWarp for tested scan."
)]
pub fn inclusive_sum<T>(&self, data: PerLane<T>) -> PerLane<T>
where
T: GpuValue + GpuShuffle + core::ops::Add<Output = T>,
{
let mut val = data.get();
let s1 = val.gpu_shfl_up(1);
val = val + s1;
let s2 = val.gpu_shfl_up(2);
val = val + s2;
let s4 = val.gpu_shfl_up(4);
val = val + s4;
let s8 = val.gpu_shfl_up(8);
val = val + s8;
let s16 = val.gpu_shfl_up(16);
val = val + s16;
PerLane::new(val)
}
#[deprecated(note = "produces incorrect results — use inclusive_sum and manual shift instead")]
pub fn exclusive_sum<T>(&self, data: PerLane<T>, identity: T) -> PerLane<T>
where
T: GpuValue + GpuShuffle + core::ops::Add<Output = T>,
{
#[allow(deprecated)]
let inclusive = self.inclusive_sum(data);
let shifted = inclusive.get().gpu_shfl_up(1);
let _ = identity;
PerLane::new(shifted)
}
pub fn reduce_add<T>(&self, data: PerLane<T>) -> T
where
T: GpuValue + GpuShuffle + core::ops::Add<Output = T>,
{
self.reduce_sum(data).get()
}
pub fn reduce_max<T>(&self, data: PerLane<T>) -> T
where
T: GpuValue + GpuShuffle + Ord,
{
self.reduce(data, |a, b| if a >= b { a } else { b })
}
pub fn reduce_min<T>(&self, data: PerLane<T>) -> T
where
T: GpuValue + GpuShuffle + Ord,
{
self.reduce(data, |a, b| if a <= b { a } else { b })
}
pub fn broadcast_lane<T: GpuValue + GpuShuffle>(
&self,
data: PerLane<T>,
src_lane: u32,
) -> PerLane<T> {
assert!(
src_lane < crate::WARP_SIZE,
"broadcast_lane: src_lane {src_lane} >= {}",
crate::WARP_SIZE
);
PerLane::new(data.get().gpu_shfl_idx(src_lane))
}
pub fn shuffle_up<T: GpuValue + GpuShuffle>(&self, data: PerLane<T>, delta: u32) -> PerLane<T> {
PerLane::new(data.get().gpu_shfl_up(delta))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::PerLane;
#[test]
fn test_reduce_custom_op() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(5i32);
let result = warp.reduce(data, |a, b| if a > b { a } else { b });
assert_eq!(result, 5);
}
#[test]
fn test_reduce_add() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(1i32);
let result = warp.reduce_add(data);
assert_eq!(result, crate::WARP_SIZE as i32);
}
#[test]
fn test_reduce_max() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(42i32);
let result = warp.reduce_max(data);
assert_eq!(result, 42);
}
#[test]
fn test_reduce_min() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(7i32);
let result = warp.reduce_min(data);
assert_eq!(result, 7);
}
#[test]
#[allow(deprecated)]
fn test_inclusive_sum() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(1i32);
let result = warp.inclusive_sum(data);
assert_eq!(result.get(), 32);
}
#[test]
fn test_broadcast_lane() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(99i32);
let result = warp.broadcast_lane(data, 0);
assert_eq!(result.get(), 99);
}
#[test]
fn test_shuffle_up() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(10i32);
let result = warp.shuffle_up(data, 1);
assert_eq!(result.get(), 10);
}
#[test]
#[allow(deprecated)]
fn test_cub_requires_all() {
let warp: Warp<All> = Warp::kernel_entry();
let data = PerLane::new(1i32);
let _ = warp.reduce_add(data);
let _ = warp.reduce_max(data);
let _ = warp.reduce_min(data);
let _ = warp.inclusive_sum(data);
let _ = warp.broadcast_lane(data, 0);
let _ = warp.shuffle_up(data, 1);
}
}