use crate::active_set::{sealed, All};
use crate::data::PerLane;
use crate::gpu::GpuShuffle;
use crate::warp::Warp;
use crate::GpuValue;
use core::marker::PhantomData;
#[must_use = "a Tile represents a partitioned warp — dropping it silently discards the partition"]
pub struct Tile<const SIZE: usize> {
_phantom: PhantomData<()>,
}
pub trait ValidTileSize: sealed::Sealed {
const TILE_MASK: u32;
}
#[allow(private_interfaces)]
impl sealed::Sealed for Tile<4> {
fn _sealed() -> sealed::SealToken {
sealed::SealToken
}
}
#[allow(private_interfaces)]
impl sealed::Sealed for Tile<8> {
fn _sealed() -> sealed::SealToken {
sealed::SealToken
}
}
#[allow(private_interfaces)]
impl sealed::Sealed for Tile<16> {
fn _sealed() -> sealed::SealToken {
sealed::SealToken
}
}
#[allow(private_interfaces)]
impl sealed::Sealed for Tile<32> {
fn _sealed() -> sealed::SealToken {
sealed::SealToken
}
}
impl ValidTileSize for Tile<4> {
const TILE_MASK: u32 = 0xF; }
impl ValidTileSize for Tile<8> {
const TILE_MASK: u32 = 0xFF; }
impl ValidTileSize for Tile<16> {
const TILE_MASK: u32 = 0xFFFF; }
impl ValidTileSize for Tile<32> {
const TILE_MASK: u32 = 0xFFFFFFFF; }
impl Warp<All> {
pub fn tile<const SIZE: usize>(&self) -> Tile<SIZE>
where
Tile<SIZE>: ValidTileSize,
{
Tile {
_phantom: PhantomData,
}
}
}
impl<const SIZE: usize> Tile<SIZE>
where
Tile<SIZE>: ValidTileSize,
{
pub fn shuffle_xor<T: GpuValue + GpuShuffle>(&self, data: PerLane<T>, mask: u32) -> PerLane<T> {
assert!(
mask < SIZE as u32,
"shuffle_xor: mask {mask} >= tile SIZE {SIZE}"
);
PerLane::new(data.get().gpu_shfl_xor_width(mask, SIZE as u32))
}
pub fn shuffle_down<T: GpuValue + GpuShuffle>(
&self,
data: PerLane<T>,
delta: u32,
) -> PerLane<T> {
assert!(
delta < SIZE as u32,
"shuffle_down: delta {delta} >= tile SIZE {SIZE}"
);
PerLane::new(data.get().gpu_shfl_down_width(delta, SIZE as u32))
}
pub fn reduce_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
&self,
data: PerLane<T>,
) -> T {
let mut val = data.get();
let mut stride = 1u32;
while stride < SIZE as u32 {
val = val + val.gpu_shfl_xor_width(stride, SIZE as u32);
stride *= 2;
}
val
}
#[deprecated(
note = "Not correct on any target — Hillis-Steele without lane_id guard. Use SimWarp for tested scan."
)]
pub fn inclusive_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
&self,
data: PerLane<T>,
) -> PerLane<T> {
let mut val = data.get();
let mut stride = 1u32;
while stride < SIZE as u32 {
let s = val.gpu_shfl_up_width(stride, SIZE as u32);
val = val + s;
stride *= 2;
}
PerLane::new(val)
}
pub const fn size(&self) -> usize {
SIZE
}
}
impl Tile<32> {
pub fn partition_16(&self) -> Tile<16> {
Tile {
_phantom: PhantomData,
}
}
pub fn partition_8(&self) -> Tile<8> {
Tile {
_phantom: PhantomData,
}
}
pub fn partition_4(&self) -> Tile<4> {
Tile {
_phantom: PhantomData,
}
}
}
impl Tile<16> {
pub fn partition_8(&self) -> Tile<8> {
Tile {
_phantom: PhantomData,
}
}
pub fn partition_4(&self) -> Tile<4> {
Tile {
_phantom: PhantomData,
}
}
}
impl Tile<8> {
pub fn partition_4(&self) -> Tile<4> {
Tile {
_phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::PerLane;
#[test]
fn test_tile_from_warp() {
let warp: Warp<All> = Warp::kernel_entry();
let tile32: Tile<32> = warp.tile();
let tile16: Tile<16> = warp.tile();
let tile8: Tile<8> = warp.tile();
let tile4: Tile<4> = warp.tile();
assert_eq!(tile32.size(), 32);
assert_eq!(tile16.size(), 16);
assert_eq!(tile8.size(), 8);
assert_eq!(tile4.size(), 4);
}
#[test]
fn test_tile_shuffle() {
let warp: Warp<All> = Warp::kernel_entry();
let tile: Tile<16> = warp.tile();
let data = PerLane::new(42i32);
let result = tile.shuffle_xor(data, 1);
assert_eq!(result.get(), 42); }
#[test]
fn test_tile_reduce() {
let warp: Warp<All> = Warp::kernel_entry();
let tile: Tile<8> = warp.tile();
let data = PerLane::new(1i32);
let sum = tile.reduce_sum(data);
assert_eq!(sum, 8);
}
#[test]
fn test_tile_reduce_32() {
let warp: Warp<All> = Warp::kernel_entry();
let tile: Tile<32> = warp.tile();
let data = PerLane::new(1i32);
let sum = tile.reduce_sum(data);
assert_eq!(sum, 32);
}
#[test]
fn test_tile_reduce_4() {
let warp: Warp<All> = Warp::kernel_entry();
let tile: Tile<4> = warp.tile();
let data = PerLane::new(1i32);
let sum = tile.reduce_sum(data);
assert_eq!(sum, 4);
}
#[test]
fn test_tile_sub_partition() {
let warp: Warp<All> = Warp::kernel_entry();
let t32: Tile<32> = warp.tile();
let t16 = t32.partition_16();
let t8 = t16.partition_8();
let t4 = t8.partition_4();
assert_eq!(t16.size(), 16);
assert_eq!(t8.size(), 8);
assert_eq!(t4.size(), 4);
}
#[test]
fn test_tile_shuffle_64bit() {
let warp: Warp<All> = Warp::kernel_entry();
let tile: Tile<16> = warp.tile();
let data = PerLane::new(123456789_i64);
let result = tile.shuffle_xor(data, 1);
assert_eq!(result.get(), 123456789_i64);
}
#[test]
#[allow(deprecated)]
fn test_tile_inclusive_sum() {
let warp: Warp<All> = Warp::kernel_entry();
let tile: Tile<8> = warp.tile();
let data = PerLane::new(1i32);
let result = tile.inclusive_sum(data);
assert_eq!(result.get(), 8);
}
}