use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Barrier, RwLock};
#[derive(Debug, Clone)]
pub struct MockKernelConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_memory_size: usize,
pub simulate_warps: bool,
pub warp_size: u32,
}
impl Default for MockKernelConfig {
fn default() -> Self {
Self {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_memory_size: 49152, simulate_warps: false,
warp_size: 32,
}
}
}
impl MockKernelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_grid_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.grid_dim = (x, y, z);
self
}
pub fn with_block_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.block_dim = (x, y, z);
self
}
pub fn with_shared_memory(mut self, bytes: usize) -> Self {
self.shared_memory_size = bytes;
self
}
pub fn with_warp_simulation(mut self, warp_size: u32) -> Self {
self.simulate_warps = true;
self.warp_size = warp_size;
self
}
pub fn total_threads(&self) -> u64 {
let blocks = self.grid_dim.0 as u64 * self.grid_dim.1 as u64 * self.grid_dim.2 as u64;
let threads_per_block =
self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
blocks * threads_per_block
}
pub fn threads_per_block(&self) -> u32 {
self.block_dim.0 * self.block_dim.1 * self.block_dim.2
}
pub fn total_blocks(&self) -> u32 {
self.grid_dim.0 * self.grid_dim.1 * self.grid_dim.2
}
}
#[derive(Debug, Clone)]
pub struct MockThread {
pub thread_idx: (u32, u32, u32),
pub block_idx: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub grid_dim: (u32, u32, u32),
pub warp_id: u32,
pub lane_id: u32,
pub warp_size: u32,
}
impl MockThread {
pub fn new(
thread_idx: (u32, u32, u32),
block_idx: (u32, u32, u32),
config: &MockKernelConfig,
) -> Self {
let linear_tid = thread_idx.0
+ thread_idx.1 * config.block_dim.0
+ thread_idx.2 * config.block_dim.0 * config.block_dim.1;
Self {
thread_idx,
block_idx,
block_dim: config.block_dim,
grid_dim: config.grid_dim,
warp_id: linear_tid / config.warp_size,
lane_id: linear_tid % config.warp_size,
warp_size: config.warp_size,
}
}
#[inline]
pub fn thread_idx_x(&self) -> u32 {
self.thread_idx.0
}
#[inline]
pub fn thread_idx_y(&self) -> u32 {
self.thread_idx.1
}
#[inline]
pub fn thread_idx_z(&self) -> u32 {
self.thread_idx.2
}
#[inline]
pub fn block_idx_x(&self) -> u32 {
self.block_idx.0
}
#[inline]
pub fn block_idx_y(&self) -> u32 {
self.block_idx.1
}
#[inline]
pub fn block_idx_z(&self) -> u32 {
self.block_idx.2
}
#[inline]
pub fn block_dim_x(&self) -> u32 {
self.block_dim.0
}
#[inline]
pub fn block_dim_y(&self) -> u32 {
self.block_dim.1
}
#[inline]
pub fn block_dim_z(&self) -> u32 {
self.block_dim.2
}
#[inline]
pub fn grid_dim_x(&self) -> u32 {
self.grid_dim.0
}
#[inline]
pub fn grid_dim_y(&self) -> u32 {
self.grid_dim.1
}
#[inline]
pub fn grid_dim_z(&self) -> u32 {
self.grid_dim.2
}
#[inline]
pub fn global_id(&self) -> u64 {
let block_linear = self.block_idx.0 as u64
+ self.block_idx.1 as u64 * self.grid_dim.0 as u64
+ self.block_idx.2 as u64 * self.grid_dim.0 as u64 * self.grid_dim.1 as u64;
let threads_per_block =
self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
let thread_linear = self.thread_idx.0 as u64
+ self.thread_idx.1 as u64 * self.block_dim.0 as u64
+ self.thread_idx.2 as u64 * self.block_dim.0 as u64 * self.block_dim.1 as u64;
block_linear * threads_per_block + thread_linear
}
#[inline]
pub fn global_x(&self) -> u32 {
self.block_idx.0 * self.block_dim.0 + self.thread_idx.0
}
#[inline]
pub fn global_y(&self) -> u32 {
self.block_idx.1 * self.block_dim.1 + self.thread_idx.1
}
#[inline]
pub fn global_z(&self) -> u32 {
self.block_idx.2 * self.block_dim.2 + self.thread_idx.2
}
#[inline]
pub fn is_block_leader(&self) -> bool {
self.thread_idx == (0, 0, 0)
}
#[inline]
pub fn is_warp_leader(&self) -> bool {
self.lane_id == 0
}
}
pub struct MockSharedMemory {
data: RefCell<Vec<u8>>,
size: usize,
}
impl MockSharedMemory {
pub fn new(size: usize) -> Self {
Self {
data: RefCell::new(vec![0u8; size]),
size,
}
}
pub fn size(&self) -> usize {
self.size
}
pub fn read<T: Copy>(&self, offset: usize) -> T {
let data = self.data.borrow();
assert!(offset + std::mem::size_of::<T>() <= self.size);
unsafe { std::ptr::read(data.as_ptr().add(offset) as *const T) }
}
pub fn write<T: Copy>(&self, offset: usize, value: T) {
let mut data = self.data.borrow_mut();
assert!(offset + std::mem::size_of::<T>() <= self.size);
unsafe { std::ptr::write(data.as_mut_ptr().add(offset) as *mut T, value) };
}
pub fn as_slice<T: Copy>(&self, offset: usize, count: usize) -> Vec<T> {
let data = self.data.borrow();
let byte_size = count * std::mem::size_of::<T>();
assert!(offset + byte_size <= self.size);
let mut result = Vec::with_capacity(count);
unsafe {
let ptr = data.as_ptr().add(offset) as *const T;
for i in 0..count {
result.push(*ptr.add(i));
}
}
result
}
pub fn write_slice<T: Copy>(&self, offset: usize, values: &[T]) {
let mut data = self.data.borrow_mut();
let byte_size = std::mem::size_of_val(values);
assert!(offset + byte_size <= self.size);
unsafe {
let ptr = data.as_mut_ptr().add(offset) as *mut T;
for (i, v) in values.iter().enumerate() {
*ptr.add(i) = *v;
}
}
}
}
pub struct MockAtomics {
u32_values: RwLock<HashMap<usize, AtomicU32>>,
u64_values: RwLock<HashMap<usize, AtomicU64>>,
}
impl Default for MockAtomics {
fn default() -> Self {
Self::new()
}
}
impl MockAtomics {
pub fn new() -> Self {
Self {
u32_values: RwLock::new(HashMap::new()),
u64_values: RwLock::new(HashMap::new()),
}
}
pub fn atomic_add_u32(&self, addr: usize, val: u32) -> u32 {
let mut map = self.u32_values.write().unwrap();
let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
atomic.fetch_add(val, Ordering::SeqCst)
}
pub fn atomic_add_u64(&self, addr: usize, val: u64) -> u64 {
let mut map = self.u64_values.write().unwrap();
let atomic = map.entry(addr).or_insert_with(|| AtomicU64::new(0));
atomic.fetch_add(val, Ordering::SeqCst)
}
pub fn atomic_cas_u32(&self, addr: usize, expected: u32, new: u32) -> u32 {
let mut map = self.u32_values.write().unwrap();
let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
match atomic.compare_exchange(expected, new, Ordering::SeqCst, Ordering::SeqCst) {
Ok(v) | Err(v) => v,
}
}
pub fn atomic_max_u32(&self, addr: usize, val: u32) -> u32 {
let mut map = self.u32_values.write().unwrap();
let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
atomic.fetch_max(val, Ordering::SeqCst)
}
pub fn atomic_min_u32(&self, addr: usize, val: u32) -> u32 {
let mut map = self.u32_values.write().unwrap();
let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
atomic.fetch_min(val, Ordering::SeqCst)
}
pub fn load_u32(&self, addr: usize) -> u32 {
let map = self.u32_values.read().unwrap();
map.get(&addr)
.map(|a| a.load(Ordering::SeqCst))
.unwrap_or(0)
}
pub fn store_u32(&self, addr: usize, val: u32) {
let mut map = self.u32_values.write().unwrap();
let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
atomic.store(val, Ordering::SeqCst);
}
}
pub struct MockGpu {
config: MockKernelConfig,
atomics: Arc<MockAtomics>,
}
impl MockGpu {
pub fn new(config: MockKernelConfig) -> Self {
Self {
config,
atomics: Arc::new(MockAtomics::new()),
}
}
pub fn config(&self) -> &MockKernelConfig {
&self.config
}
pub fn atomics(&self) -> &MockAtomics {
&self.atomics
}
pub fn dispatch<F>(&self, kernel: F)
where
F: Fn(&MockThread),
{
for bz in 0..self.config.grid_dim.2 {
for by in 0..self.config.grid_dim.1 {
for bx in 0..self.config.grid_dim.0 {
for tz in 0..self.config.block_dim.2 {
for ty in 0..self.config.block_dim.1 {
for tx in 0..self.config.block_dim.0 {
let thread =
MockThread::new((tx, ty, tz), (bx, by, bz), &self.config);
kernel(&thread);
}
}
}
}
}
}
}
pub fn dispatch_with_sync<F>(&self, kernel: F)
where
F: Fn(&MockThread, &Barrier) + Send + Sync,
{
let threads_per_block = self.config.threads_per_block() as usize;
for bz in 0..self.config.grid_dim.2 {
for by in 0..self.config.grid_dim.1 {
for bx in 0..self.config.grid_dim.0 {
let barrier = Arc::new(Barrier::new(threads_per_block));
std::thread::scope(|s| {
for tz in 0..self.config.block_dim.2 {
for ty in 0..self.config.block_dim.1 {
for tx in 0..self.config.block_dim.0 {
let barrier = Arc::clone(&barrier);
let config = &self.config;
let kernel_ref = &kernel;
s.spawn(move || {
let thread =
MockThread::new((tx, ty, tz), (bx, by, bz), config);
kernel_ref(&thread, &barrier);
});
}
}
}
});
}
}
}
}
}
pub struct MockWarp {
lane_values: Vec<u32>,
warp_size: u32,
}
impl MockWarp {
pub fn new(warp_size: u32) -> Self {
Self {
lane_values: vec![0; warp_size as usize],
warp_size,
}
}
pub fn set_lane(&mut self, lane: u32, value: u32) {
if (lane as usize) < self.lane_values.len() {
self.lane_values[lane as usize] = value;
}
}
pub fn shuffle(&self, src_lane: u32) -> u32 {
self.lane_values
.get(src_lane as usize)
.copied()
.unwrap_or(0)
}
pub fn shuffle_xor(&self, lane_id: u32, mask: u32) -> u32 {
let src = lane_id ^ mask;
self.shuffle(src)
}
pub fn shuffle_up(&self, lane_id: u32, delta: u32) -> u32 {
if lane_id >= delta {
self.shuffle(lane_id - delta)
} else {
self.lane_values[lane_id as usize]
}
}
pub fn shuffle_down(&self, lane_id: u32, delta: u32) -> u32 {
if lane_id + delta < self.warp_size {
self.shuffle(lane_id + delta)
} else {
self.lane_values[lane_id as usize]
}
}
pub fn ballot(&self, predicate: impl Fn(u32) -> bool) -> u64 {
let mut result = 0u64;
for lane in 0..self.warp_size {
if predicate(lane) {
result |= 1 << lane;
}
}
result
}
pub fn any(&self, predicate: impl Fn(u32) -> bool) -> bool {
(0..self.warp_size).any(predicate)
}
pub fn all(&self, predicate: impl Fn(u32) -> bool) -> bool {
(0..self.warp_size).all(predicate)
}
pub fn reduce_sum(&self) -> u32 {
self.lane_values.iter().sum()
}
pub fn prefix_sum_exclusive(&self) -> Vec<u32> {
let mut result = Vec::with_capacity(self.warp_size as usize);
let mut sum = 0;
for &v in &self.lane_values {
result.push(sum);
sum += v;
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_config() {
let config = MockKernelConfig::new()
.with_grid_size(4, 4, 1)
.with_block_size(32, 8, 1);
assert_eq!(config.total_blocks(), 16);
assert_eq!(config.threads_per_block(), 256);
assert_eq!(config.total_threads(), 4096);
}
#[test]
fn test_mock_thread_intrinsics() {
let config = MockKernelConfig::new()
.with_grid_size(2, 2, 1)
.with_block_size(16, 16, 1);
let thread = MockThread::new((5, 3, 0), (1, 0, 0), &config);
assert_eq!(thread.thread_idx_x(), 5);
assert_eq!(thread.thread_idx_y(), 3);
assert_eq!(thread.block_idx_x(), 1);
assert_eq!(thread.block_dim_x(), 16);
assert_eq!(thread.global_x(), 21); assert_eq!(thread.global_y(), 3); }
#[test]
fn test_mock_shared_memory() {
let shmem = MockSharedMemory::new(1024);
shmem.write::<f32>(0, 3.125);
shmem.write::<f32>(4, 2.75);
assert!((shmem.read::<f32>(0) - 3.125).abs() < 0.001);
assert!((shmem.read::<f32>(4) - 2.75).abs() < 0.001);
shmem.write_slice::<u32>(100, &[1, 2, 3, 4]);
let slice = shmem.as_slice::<u32>(100, 4);
assert_eq!(slice, vec![1, 2, 3, 4]);
}
#[test]
fn test_mock_atomics() {
let atomics = MockAtomics::new();
let old = atomics.atomic_add_u32(0, 5);
assert_eq!(old, 0);
let old = atomics.atomic_add_u32(0, 3);
assert_eq!(old, 5);
assert_eq!(atomics.load_u32(0), 8);
}
#[test]
fn test_mock_gpu_dispatch() {
let config = MockKernelConfig::new()
.with_grid_size(2, 1, 1)
.with_block_size(4, 1, 1);
let gpu = MockGpu::new(config);
let counter = Arc::new(AtomicU32::new(0));
let c = Arc::clone(&counter);
gpu.dispatch(move |_thread| {
c.fetch_add(1, Ordering::SeqCst);
});
assert_eq!(counter.load(Ordering::SeqCst), 8); }
#[test]
fn test_mock_warp_shuffle() {
let mut warp = MockWarp::new(32);
for i in 0..32 {
warp.set_lane(i, i * 2);
}
assert_eq!(warp.shuffle(5), 10);
assert_eq!(warp.shuffle(15), 30);
assert_eq!(warp.shuffle_xor(0, 1), 2); assert_eq!(warp.shuffle_xor(2, 1), 6); }
#[test]
fn test_mock_warp_ballot() {
let warp = MockWarp::new(32);
let ballot = warp.ballot(|lane| lane % 2 == 0);
assert_eq!(ballot, 0x55555555); }
#[test]
fn test_mock_warp_reduce() {
let mut warp = MockWarp::new(4);
warp.set_lane(0, 1);
warp.set_lane(1, 2);
warp.set_lane(2, 3);
warp.set_lane(3, 4);
assert_eq!(warp.reduce_sum(), 10);
let prefix = warp.prefix_sum_exclusive();
assert_eq!(prefix, vec![0, 1, 3, 6]);
}
#[test]
fn test_thread_global_id() {
let config = MockKernelConfig::new()
.with_grid_size(2, 2, 1)
.with_block_size(4, 4, 1);
let t1 = MockThread::new((0, 0, 0), (0, 0, 0), &config);
assert_eq!(t1.global_id(), 0);
let t2 = MockThread::new((0, 0, 0), (1, 0, 0), &config);
assert_eq!(t2.global_id(), 16);
let t3 = MockThread::new((3, 3, 0), (0, 0, 0), &config);
assert_eq!(t3.global_id(), 15);
}
}