use std::marker::PhantomData;
use crate::WARP_SIZE;
pub trait AccessPattern: Copy + 'static {
fn name() -> &'static str;
fn transactions_per_warp() -> usize; }
#[derive(Copy, Clone)]
pub struct Uniform;
impl AccessPattern for Uniform {
fn name() -> &'static str {
"Uniform"
}
fn transactions_per_warp() -> usize {
1
}
}
#[derive(Copy, Clone)]
pub struct Consecutive;
impl AccessPattern for Consecutive {
fn name() -> &'static str {
"Consecutive"
}
fn transactions_per_warp() -> usize {
1
} }
#[derive(Copy, Clone)]
pub struct Strided<const STRIDE: usize>;
impl<const STRIDE: usize> AccessPattern for Strided<STRIDE> {
fn name() -> &'static str {
"Strided"
}
fn transactions_per_warp() -> usize {
std::cmp::max(1, std::cmp::min(WARP_SIZE as usize, STRIDE))
}
}
#[derive(Copy, Clone)]
pub struct Random;
impl AccessPattern for Random {
fn name() -> &'static str {
"Random"
}
fn transactions_per_warp() -> usize {
WARP_SIZE as usize
} }
#[derive(Clone)]
pub struct WarpPtr<T: Copy, P: AccessPattern> {
base: *const T,
_pattern: PhantomData<P>,
}
impl<T: Copy, P: AccessPattern> WarpPtr<T, P> {
pub unsafe fn new(base: *const T) -> Self {
WarpPtr {
base,
_pattern: PhantomData,
}
}
pub fn base(&self) -> *const T {
self.base
}
pub fn pattern_name() -> &'static str {
P::name()
}
pub fn expected_transactions() -> usize {
P::transactions_per_warp()
}
}
pub struct WarpPtrMut<T: Copy, P: AccessPattern> {
base: *mut T,
_pattern: PhantomData<P>,
}
impl<T: Copy, P: AccessPattern> WarpPtrMut<T, P> {
pub unsafe fn new(base: *mut T) -> Self {
WarpPtrMut {
base,
_pattern: PhantomData,
}
}
pub fn base(&self) -> *mut T {
self.base
}
}
pub trait WorstOf<Other: AccessPattern>: AccessPattern {
type Result: AccessPattern;
}
impl WorstOf<Uniform> for Uniform {
type Result = Uniform;
}
impl WorstOf<Consecutive> for Uniform {
type Result = Consecutive;
}
impl<const S: usize> WorstOf<Strided<S>> for Uniform {
type Result = Strided<S>;
}
impl WorstOf<Random> for Uniform {
type Result = Random;
}
impl WorstOf<Uniform> for Consecutive {
type Result = Consecutive;
}
impl WorstOf<Consecutive> for Consecutive {
type Result = Consecutive;
}
impl<const S: usize> WorstOf<Strided<S>> for Consecutive {
type Result = Strided<S>;
}
impl WorstOf<Random> for Consecutive {
type Result = Random;
}
impl WorstOf<Uniform> for Random {
type Result = Random;
}
impl WorstOf<Consecutive> for Random {
type Result = Random;
}
impl<const S: usize> WorstOf<Strided<S>> for Random {
type Result = Random;
}
impl WorstOf<Random> for Random {
type Result = Random;
}
pub mod load {
use super::*;
pub fn uniform<T: Copy>(ptr: &WarpPtr<T, Uniform>) -> T {
unsafe { *ptr.base() }
}
pub fn consecutive<T: Copy + Default>(
ptr: &WarpPtr<T, Consecutive>,
) -> [T; WARP_SIZE as usize] {
let mut result = [T::default(); WARP_SIZE as usize];
for lane in 0..WARP_SIZE as usize {
unsafe {
result[lane] = *ptr.base().add(lane);
}
}
result
}
pub unsafe fn generic<T: Copy + Default, P: AccessPattern>(
ptr: &WarpPtr<T, P>,
indices: &[usize; WARP_SIZE as usize],
) -> [T; WARP_SIZE as usize] {
let mut result = [T::default(); WARP_SIZE as usize];
for lane in 0..WARP_SIZE as usize {
unsafe {
result[lane] = *ptr.base().add(indices[lane]);
}
}
result
}
}
pub mod store {
use super::*;
pub fn uniform<T: Copy>(ptr: &mut WarpPtrMut<T, Uniform>, value: T) {
unsafe {
*ptr.base() = value;
}
}
pub fn consecutive<T: Copy>(
ptr: &mut WarpPtrMut<T, Consecutive>,
values: &[T; WARP_SIZE as usize],
) {
for lane in 0..WARP_SIZE as usize {
unsafe {
*ptr.base().add(lane) = values[lane];
}
}
}
}
pub mod infer {
pub enum IndexExpr {
Constant(usize), LaneId, LaneIdTimes(usize), LaneIdPlus(usize), Computed, }
pub fn pattern_from_index(expr: &IndexExpr) -> &'static str {
match expr {
IndexExpr::Constant(_) => "Uniform",
IndexExpr::LaneId => "Consecutive",
IndexExpr::LaneIdTimes(1) => "Consecutive",
IndexExpr::LaneIdTimes(_) => "Strided",
IndexExpr::LaneIdPlus(_) => "Consecutive",
IndexExpr::Computed => "Random",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_inference() {
assert_eq!(pattern_from_index(&IndexExpr::Constant(0)), "Uniform");
assert_eq!(pattern_from_index(&IndexExpr::LaneId), "Consecutive");
assert_eq!(pattern_from_index(&IndexExpr::LaneIdTimes(4)), "Strided");
assert_eq!(pattern_from_index(&IndexExpr::Computed), "Random");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_counts() {
assert_eq!(Uniform::transactions_per_warp(), 1);
assert_eq!(Consecutive::transactions_per_warp(), 1);
assert_eq!(Strided::<4>::transactions_per_warp(), 4);
assert_eq!(Random::transactions_per_warp(), WARP_SIZE as usize);
}
#[test]
fn test_uniform_load() {
let data = [42i32; 1];
let ptr = unsafe { WarpPtr::<i32, Uniform>::new(data.as_ptr()) };
let value = load::uniform(&ptr);
assert_eq!(value, 42);
}
#[test]
fn test_consecutive_load() {
let data: [i32; WARP_SIZE as usize] = core::array::from_fn(|i| i as i32);
let ptr = unsafe { WarpPtr::<i32, Consecutive>::new(data.as_ptr()) };
let values = load::consecutive(&ptr);
for lane in 0..WARP_SIZE as usize {
assert_eq!(values[lane], lane as i32);
}
}
#[test]
fn test_pattern_hierarchy() {
type R1 = <Uniform as WorstOf<Consecutive>>::Result;
assert_eq!(R1::name(), "Consecutive");
type R2 = <Consecutive as WorstOf<Random>>::Result;
assert_eq!(R2::name(), "Random");
}
}
pub const _SUMMARY: () = ();