use std::marker::PhantomData;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct LaneId(pub u32);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct WarpId(pub u32);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct BlockId(pub u32);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ThreadId {
pub block: BlockId,
pub warp: WarpId,
pub lane: LaneId,
}
impl ThreadId {
pub fn global_id(&self, warps_per_block: u32, lanes_per_warp: u32) -> u32 {
self.block.0 * warps_per_block * lanes_per_warp + self.warp.0 * lanes_per_warp + self.lane.0
}
}
pub mod intra_warp {
pub trait ShuffleOp {
fn shuffle_xor<T: Copy>(val: T, mask: u32) -> T;
fn shuffle_down<T: Copy>(val: T, delta: u32) -> T;
fn shuffle_up<T: Copy>(val: T, delta: u32) -> T;
}
pub trait VoteOp {
fn all(pred: bool) -> bool;
fn any(pred: bool) -> bool;
fn ballot(pred: bool) -> u32;
}
}
pub mod intra_block {
use super::*;
pub struct SharedMem<T, const SIZE: usize> {
_marker: PhantomData<T>,
}
impl<T: Copy, const SIZE: usize> SharedMem<T, SIZE> {
pub fn read(&self, _idx: usize) -> T
where
T: Default,
{
T::default()
}
pub fn write(&mut self, _idx: usize, _val: T) {
}
}
pub fn sync_threads() {
}
pub struct BlockBarrier<State>(PhantomData<State>);
pub struct BeforeSync;
pub struct AfterSync;
impl BlockBarrier<BeforeSync> {
pub fn sync(self) -> BlockBarrier<AfterSync> {
sync_threads();
BlockBarrier(PhantomData)
}
}
}
pub mod inter_block {
use super::*;
pub struct GlobalMem<T> {
_marker: PhantomData<T>,
}
impl<T: Copy + Default> GlobalMem<T> {
pub fn read(&self, _idx: usize) -> T {
T::default() }
pub fn write(&mut self, _idx: usize, _val: T) {
}
pub fn atomic_add(&self, _idx: usize, _val: T) -> T
where
T: std::ops::Add<Output = T>,
{
T::default() }
}
pub fn thread_fence_system() {
}
pub fn grid_sync() {
}
}
pub trait BlockRole {
const NAME: &'static str;
}
pub struct Leader;
impl BlockRole for Leader {
const NAME: &'static str = "Leader";
}
pub struct Worker;
impl BlockRole for Worker {
const NAME: &'static str = "Worker";
}
pub trait ProtocolState {}
pub struct Initial;
impl ProtocolState for Initial {}
pub struct WorkDistributed;
impl ProtocolState for WorkDistributed {}
pub struct WorkComplete;
impl ProtocolState for WorkComplete {}
pub struct ResultsCollected;
impl ProtocolState for ResultsCollected {}
pub struct BlockSession<Role: BlockRole, State: ProtocolState, const N: usize> {
block_id: BlockId,
_role: PhantomData<Role>,
_state: PhantomData<State>,
}
impl<Role: BlockRole, State: ProtocolState, const N: usize> BlockSession<Role, State, N> {
pub fn new(block_id: BlockId) -> Self {
BlockSession {
block_id,
_role: PhantomData,
_state: PhantomData,
}
}
pub fn block_id(&self) -> BlockId {
self.block_id
}
}
impl<State: ProtocolState, const N: usize> BlockSession<Leader, State, N> {
pub fn broadcast<T: Copy>(
self,
_data: T,
_global: &mut inter_block::GlobalMem<T>,
) -> BlockSession<Leader, WorkDistributed, N>
where
State: Into<Initial>, {
BlockSession::new(self.block_id)
}
pub fn wait_for_workers(
self,
_signal: &inter_block::GlobalMem<u32>,
) -> BlockSession<Leader, WorkComplete, N>
where
State: Into<WorkDistributed>,
{
BlockSession::new(self.block_id)
}
}
impl<State: ProtocolState, const N: usize> BlockSession<Worker, State, N> {
pub fn receive<T: Copy + Default>(
self,
_global: &inter_block::GlobalMem<T>,
) -> (T, BlockSession<Worker, WorkDistributed, N>)
where
State: Into<Initial>,
{
let data: T = T::default(); (data, BlockSession::new(self.block_id))
}
pub fn signal_done(
self,
_signal: &inter_block::GlobalMem<u32>,
) -> BlockSession<Worker, WorkComplete, N>
where
State: Into<WorkDistributed>,
{
BlockSession::new(self.block_id)
}
}
pub mod cooperative {
use super::*;
pub trait CooperativeGroup {
fn size(&self) -> u32;
fn thread_rank(&self) -> u32;
fn sync(&self);
}
pub struct ThreadBlockGroup {
block_id: BlockId,
num_threads: u32,
}
impl CooperativeGroup for ThreadBlockGroup {
fn size(&self) -> u32 {
self.num_threads
}
fn thread_rank(&self) -> u32 {
0
} fn sync(&self) {
intra_block::sync_threads();
}
}
pub struct GridGroup {
num_blocks: u32,
threads_per_block: u32,
}
impl CooperativeGroup for GridGroup {
fn size(&self) -> u32 {
self.num_blocks * self.threads_per_block
}
fn thread_rank(&self) -> u32 {
0
} fn sync(&self) {
inter_block::grid_sync();
}
}
pub struct CoalescedGroup {
mask: u32, }
impl CooperativeGroup for CoalescedGroup {
fn size(&self) -> u32 {
self.mask.count_ones()
}
fn thread_rank(&self) -> u32 {
0
} fn sync(&self) {
}
}
pub struct TiledPartition<const SIZE: u32> {
_marker: PhantomData<()>,
}
impl<const SIZE: u32> CooperativeGroup for TiledPartition<SIZE> {
fn size(&self) -> u32 {
SIZE
}
fn thread_rank(&self) -> u32 {
0
} fn sync(&self) {
}
}
}
pub trait SessionLevel {
type Id: Copy + Eq;
type CommPrimitive;
type SyncPrimitive;
const MAX_PARTICIPANTS: u32;
}
pub struct WarpLevel;
impl SessionLevel for WarpLevel {
type Id = LaneId;
type CommPrimitive = (); type SyncPrimitive = (); const MAX_PARTICIPANTS: u32 = 32;
}
pub struct BlockLevel;
impl SessionLevel for BlockLevel {
type Id = WarpId;
type CommPrimitive = (); type SyncPrimitive = (); const MAX_PARTICIPANTS: u32 = 32; }
pub struct GridLevel;
impl SessionLevel for GridLevel {
type Id = BlockId;
type CommPrimitive = (); type SyncPrimitive = (); const MAX_PARTICIPANTS: u32 = 65535; }
pub trait HierarchicalProtocol {
fn warp_reduce(&self) -> u32;
fn block_reduce(&self) -> u32;
fn grid_reduce(&self) -> u32;
}
pub mod hierarchical_reduce {
use super::*;
pub struct WarpPhase;
pub struct BlockPhase;
pub struct GridPhase;
pub struct Complete;
pub struct ReductionSession<Phase> {
value: u32,
_phase: PhantomData<Phase>,
}
impl ReductionSession<WarpPhase> {
pub fn new(value: u32) -> Self {
ReductionSession {
value,
_phase: PhantomData,
}
}
pub fn warp_reduce(self) -> (u32, ReductionSession<BlockPhase>) {
let result = self.value; (
result,
ReductionSession {
value: result,
_phase: PhantomData,
},
)
}
}
impl ReductionSession<BlockPhase> {
pub fn block_reduce(self) -> (u32, ReductionSession<GridPhase>) {
let result = self.value;
(
result,
ReductionSession {
value: result,
_phase: PhantomData,
},
)
}
}
impl ReductionSession<GridPhase> {
pub fn grid_reduce(self) -> (u32, ReductionSession<Complete>) {
let result = self.value;
(
result,
ReductionSession {
value: result,
_phase: PhantomData,
},
)
}
}
impl ReductionSession<Complete> {
pub fn result(self) -> u32 {
self.value
}
}
}
#[cfg(test)]
mod tests {
use super::hierarchical_reduce::*;
use super::*;
#[test]
fn test_thread_id() {
let tid = ThreadId {
block: BlockId(2),
warp: WarpId(3),
lane: LaneId(7),
};
let global = tid.global_id(4, 32);
assert_eq!(global, 2 * 4 * 32 + 3 * 32 + 7);
}
#[test]
fn test_hierarchical_reduction_types() {
let session = ReductionSession::<WarpPhase>::new(42);
let (warp_result, session) = session.warp_reduce();
assert_eq!(warp_result, 42);
let (block_result, session) = session.block_reduce();
assert_eq!(block_result, 42);
let (grid_result, session) = session.grid_reduce();
assert_eq!(grid_result, 42);
let final_result = session.result();
assert_eq!(final_result, 42);
}
#[test]
fn test_block_session_state_transitions() {
let leader: BlockSession<Leader, Initial, 4> = BlockSession::new(BlockId(0));
assert_eq!(leader.block_id().0, 0);
let worker: BlockSession<Worker, Initial, 4> = BlockSession::new(BlockId(1));
assert_eq!(worker.block_id().0, 1);
}
}