1use std::cell::RefCell;
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
29use std::sync::{Arc, Barrier, RwLock};
30
31#[derive(Debug, Clone)]
37pub struct MockKernelConfig {
38 pub grid_dim: (u32, u32, u32),
40 pub block_dim: (u32, u32, u32),
42 pub shared_memory_size: usize,
44 pub simulate_warps: bool,
46 pub warp_size: u32,
48}
49
50impl Default for MockKernelConfig {
51 fn default() -> Self {
52 Self {
53 grid_dim: (1, 1, 1),
54 block_dim: (256, 1, 1),
55 shared_memory_size: 49152, simulate_warps: false,
57 warp_size: 32,
58 }
59 }
60}
61
62impl MockKernelConfig {
63 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn with_grid_size(mut self, x: u32, y: u32, z: u32) -> Self {
70 self.grid_dim = (x, y, z);
71 self
72 }
73
74 pub fn with_block_size(mut self, x: u32, y: u32, z: u32) -> Self {
76 self.block_dim = (x, y, z);
77 self
78 }
79
80 pub fn with_shared_memory(mut self, bytes: usize) -> Self {
82 self.shared_memory_size = bytes;
83 self
84 }
85
86 pub fn with_warp_simulation(mut self, warp_size: u32) -> Self {
88 self.simulate_warps = true;
89 self.warp_size = warp_size;
90 self
91 }
92
93 pub fn total_threads(&self) -> u64 {
95 let blocks = self.grid_dim.0 as u64 * self.grid_dim.1 as u64 * self.grid_dim.2 as u64;
96 let threads_per_block =
97 self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
98 blocks * threads_per_block
99 }
100
101 pub fn threads_per_block(&self) -> u32 {
103 self.block_dim.0 * self.block_dim.1 * self.block_dim.2
104 }
105
106 pub fn total_blocks(&self) -> u32 {
108 self.grid_dim.0 * self.grid_dim.1 * self.grid_dim.2
109 }
110}
111
112#[derive(Debug, Clone)]
118pub struct MockThread {
119 pub thread_idx: (u32, u32, u32),
121 pub block_idx: (u32, u32, u32),
123 pub block_dim: (u32, u32, u32),
125 pub grid_dim: (u32, u32, u32),
127 pub warp_id: u32,
129 pub lane_id: u32,
131 pub warp_size: u32,
133}
134
135impl MockThread {
136 pub fn new(
138 thread_idx: (u32, u32, u32),
139 block_idx: (u32, u32, u32),
140 config: &MockKernelConfig,
141 ) -> Self {
142 let linear_tid = thread_idx.0
143 + thread_idx.1 * config.block_dim.0
144 + thread_idx.2 * config.block_dim.0 * config.block_dim.1;
145
146 Self {
147 thread_idx,
148 block_idx,
149 block_dim: config.block_dim,
150 grid_dim: config.grid_dim,
151 warp_id: linear_tid / config.warp_size,
152 lane_id: linear_tid % config.warp_size,
153 warp_size: config.warp_size,
154 }
155 }
156
157 #[inline]
163 pub fn thread_idx_x(&self) -> u32 {
164 self.thread_idx.0
165 }
166
167 #[inline]
169 pub fn thread_idx_y(&self) -> u32 {
170 self.thread_idx.1
171 }
172
173 #[inline]
175 pub fn thread_idx_z(&self) -> u32 {
176 self.thread_idx.2
177 }
178
179 #[inline]
181 pub fn block_idx_x(&self) -> u32 {
182 self.block_idx.0
183 }
184
185 #[inline]
187 pub fn block_idx_y(&self) -> u32 {
188 self.block_idx.1
189 }
190
191 #[inline]
193 pub fn block_idx_z(&self) -> u32 {
194 self.block_idx.2
195 }
196
197 #[inline]
199 pub fn block_dim_x(&self) -> u32 {
200 self.block_dim.0
201 }
202
203 #[inline]
205 pub fn block_dim_y(&self) -> u32 {
206 self.block_dim.1
207 }
208
209 #[inline]
211 pub fn block_dim_z(&self) -> u32 {
212 self.block_dim.2
213 }
214
215 #[inline]
217 pub fn grid_dim_x(&self) -> u32 {
218 self.grid_dim.0
219 }
220
221 #[inline]
223 pub fn grid_dim_y(&self) -> u32 {
224 self.grid_dim.1
225 }
226
227 #[inline]
229 pub fn grid_dim_z(&self) -> u32 {
230 self.grid_dim.2
231 }
232
233 #[inline]
235 pub fn global_id(&self) -> u64 {
236 let block_linear = self.block_idx.0 as u64
237 + self.block_idx.1 as u64 * self.grid_dim.0 as u64
238 + self.block_idx.2 as u64 * self.grid_dim.0 as u64 * self.grid_dim.1 as u64;
239
240 let threads_per_block =
241 self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
242 let thread_linear = self.thread_idx.0 as u64
243 + self.thread_idx.1 as u64 * self.block_dim.0 as u64
244 + self.thread_idx.2 as u64 * self.block_dim.0 as u64 * self.block_dim.1 as u64;
245
246 block_linear * threads_per_block + thread_linear
247 }
248
249 #[inline]
251 pub fn global_x(&self) -> u32 {
252 self.block_idx.0 * self.block_dim.0 + self.thread_idx.0
253 }
254
255 #[inline]
257 pub fn global_y(&self) -> u32 {
258 self.block_idx.1 * self.block_dim.1 + self.thread_idx.1
259 }
260
261 #[inline]
263 pub fn global_z(&self) -> u32 {
264 self.block_idx.2 * self.block_dim.2 + self.thread_idx.2
265 }
266
267 #[inline]
269 pub fn is_block_leader(&self) -> bool {
270 self.thread_idx == (0, 0, 0)
271 }
272
273 #[inline]
275 pub fn is_warp_leader(&self) -> bool {
276 self.lane_id == 0
277 }
278}
279
280pub struct MockSharedMemory {
286 data: RefCell<Vec<u8>>,
287 size: usize,
288}
289
290impl MockSharedMemory {
291 pub fn new(size: usize) -> Self {
293 Self {
294 data: RefCell::new(vec![0u8; size]),
295 size,
296 }
297 }
298
299 pub fn size(&self) -> usize {
301 self.size
302 }
303
304 pub fn read<T: Copy>(&self, offset: usize) -> T {
306 let data = self.data.borrow();
307 assert!(offset + std::mem::size_of::<T>() <= self.size);
308 unsafe { std::ptr::read(data.as_ptr().add(offset) as *const T) }
309 }
310
311 pub fn write<T: Copy>(&self, offset: usize, value: T) {
313 let mut data = self.data.borrow_mut();
314 assert!(offset + std::mem::size_of::<T>() <= self.size);
315 unsafe { std::ptr::write(data.as_mut_ptr().add(offset) as *mut T, value) };
316 }
317
318 pub fn as_slice<T: Copy>(&self, offset: usize, count: usize) -> Vec<T> {
320 let data = self.data.borrow();
321 let byte_size = count * std::mem::size_of::<T>();
322 assert!(offset + byte_size <= self.size);
323
324 let mut result = Vec::with_capacity(count);
325 unsafe {
326 let ptr = data.as_ptr().add(offset) as *const T;
327 for i in 0..count {
328 result.push(*ptr.add(i));
329 }
330 }
331 result
332 }
333
334 pub fn write_slice<T: Copy>(&self, offset: usize, values: &[T]) {
336 let mut data = self.data.borrow_mut();
337 let byte_size = std::mem::size_of_val(values);
338 assert!(offset + byte_size <= self.size);
339
340 unsafe {
341 let ptr = data.as_mut_ptr().add(offset) as *mut T;
342 for (i, v) in values.iter().enumerate() {
343 *ptr.add(i) = *v;
344 }
345 }
346 }
347}
348
349pub struct MockAtomics {
355 u32_values: RwLock<HashMap<usize, AtomicU32>>,
356 u64_values: RwLock<HashMap<usize, AtomicU64>>,
357}
358
359impl Default for MockAtomics {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365impl MockAtomics {
366 pub fn new() -> Self {
368 Self {
369 u32_values: RwLock::new(HashMap::new()),
370 u64_values: RwLock::new(HashMap::new()),
371 }
372 }
373
374 pub fn atomic_add_u32(&self, addr: usize, val: u32) -> u32 {
376 let mut map = self.u32_values.write().unwrap();
377 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
378 atomic.fetch_add(val, Ordering::SeqCst)
379 }
380
381 pub fn atomic_add_u64(&self, addr: usize, val: u64) -> u64 {
383 let mut map = self.u64_values.write().unwrap();
384 let atomic = map.entry(addr).or_insert_with(|| AtomicU64::new(0));
385 atomic.fetch_add(val, Ordering::SeqCst)
386 }
387
388 pub fn atomic_cas_u32(&self, addr: usize, expected: u32, new: u32) -> u32 {
390 let mut map = self.u32_values.write().unwrap();
391 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
392 match atomic.compare_exchange(expected, new, Ordering::SeqCst, Ordering::SeqCst) {
393 Ok(v) | Err(v) => v,
394 }
395 }
396
397 pub fn atomic_max_u32(&self, addr: usize, val: u32) -> u32 {
399 let mut map = self.u32_values.write().unwrap();
400 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
401 atomic.fetch_max(val, Ordering::SeqCst)
402 }
403
404 pub fn atomic_min_u32(&self, addr: usize, val: u32) -> u32 {
406 let mut map = self.u32_values.write().unwrap();
407 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
408 atomic.fetch_min(val, Ordering::SeqCst)
409 }
410
411 pub fn load_u32(&self, addr: usize) -> u32 {
413 let map = self.u32_values.read().unwrap();
414 map.get(&addr)
415 .map(|a| a.load(Ordering::SeqCst))
416 .unwrap_or(0)
417 }
418
419 pub fn store_u32(&self, addr: usize, val: u32) {
421 let mut map = self.u32_values.write().unwrap();
422 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
423 atomic.store(val, Ordering::SeqCst);
424 }
425}
426
427pub struct MockGpu {
433 config: MockKernelConfig,
434 atomics: Arc<MockAtomics>,
435}
436
437impl MockGpu {
438 pub fn new(config: MockKernelConfig) -> Self {
440 Self {
441 config,
442 atomics: Arc::new(MockAtomics::new()),
443 }
444 }
445
446 pub fn config(&self) -> &MockKernelConfig {
448 &self.config
449 }
450
451 pub fn atomics(&self) -> &MockAtomics {
453 &self.atomics
454 }
455
456 pub fn dispatch<F>(&self, kernel: F)
461 where
462 F: Fn(&MockThread),
463 {
464 for bz in 0..self.config.grid_dim.2 {
465 for by in 0..self.config.grid_dim.1 {
466 for bx in 0..self.config.grid_dim.0 {
467 for tz in 0..self.config.block_dim.2 {
468 for ty in 0..self.config.block_dim.1 {
469 for tx in 0..self.config.block_dim.0 {
470 let thread =
471 MockThread::new((tx, ty, tz), (bx, by, bz), &self.config);
472 kernel(&thread);
473 }
474 }
475 }
476 }
477 }
478 }
479 }
480
481 pub fn dispatch_with_sync<F>(&self, kernel: F)
485 where
486 F: Fn(&MockThread, &Barrier) + Send + Sync,
487 {
488 let threads_per_block = self.config.threads_per_block() as usize;
489
490 for bz in 0..self.config.grid_dim.2 {
491 for by in 0..self.config.grid_dim.1 {
492 for bx in 0..self.config.grid_dim.0 {
493 let barrier = Arc::new(Barrier::new(threads_per_block));
495 std::thread::scope(|s| {
496 for tz in 0..self.config.block_dim.2 {
497 for ty in 0..self.config.block_dim.1 {
498 for tx in 0..self.config.block_dim.0 {
499 let barrier = Arc::clone(&barrier);
500 let config = &self.config;
501 let kernel_ref = &kernel;
502 s.spawn(move || {
503 let thread =
504 MockThread::new((tx, ty, tz), (bx, by, bz), config);
505 kernel_ref(&thread, &barrier);
506 });
507 }
508 }
509 }
510 });
511 }
512 }
513 }
514 }
515}
516
517pub struct MockWarp {
523 lane_values: Vec<u32>,
525 warp_size: u32,
527}
528
529impl MockWarp {
530 pub fn new(warp_size: u32) -> Self {
532 Self {
533 lane_values: vec![0; warp_size as usize],
534 warp_size,
535 }
536 }
537
538 pub fn set_lane(&mut self, lane: u32, value: u32) {
540 if (lane as usize) < self.lane_values.len() {
541 self.lane_values[lane as usize] = value;
542 }
543 }
544
545 pub fn shuffle(&self, src_lane: u32) -> u32 {
547 self.lane_values
548 .get(src_lane as usize)
549 .copied()
550 .unwrap_or(0)
551 }
552
553 pub fn shuffle_xor(&self, lane_id: u32, mask: u32) -> u32 {
555 let src = lane_id ^ mask;
556 self.shuffle(src)
557 }
558
559 pub fn shuffle_up(&self, lane_id: u32, delta: u32) -> u32 {
561 if lane_id >= delta {
562 self.shuffle(lane_id - delta)
563 } else {
564 self.lane_values[lane_id as usize]
565 }
566 }
567
568 pub fn shuffle_down(&self, lane_id: u32, delta: u32) -> u32 {
570 if lane_id + delta < self.warp_size {
571 self.shuffle(lane_id + delta)
572 } else {
573 self.lane_values[lane_id as usize]
574 }
575 }
576
577 pub fn ballot(&self, predicate: impl Fn(u32) -> bool) -> u64 {
579 let mut result = 0u64;
580 for lane in 0..self.warp_size {
581 if predicate(lane) {
582 result |= 1 << lane;
583 }
584 }
585 result
586 }
587
588 pub fn any(&self, predicate: impl Fn(u32) -> bool) -> bool {
590 (0..self.warp_size).any(predicate)
591 }
592
593 pub fn all(&self, predicate: impl Fn(u32) -> bool) -> bool {
595 (0..self.warp_size).all(predicate)
596 }
597
598 pub fn reduce_sum(&self) -> u32 {
600 self.lane_values.iter().sum()
601 }
602
603 pub fn prefix_sum_exclusive(&self) -> Vec<u32> {
605 let mut result = Vec::with_capacity(self.warp_size as usize);
606 let mut sum = 0;
607 for &v in &self.lane_values {
608 result.push(sum);
609 sum += v;
610 }
611 result
612 }
613}
614
615#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[test]
624 fn test_mock_config() {
625 let config = MockKernelConfig::new()
626 .with_grid_size(4, 4, 1)
627 .with_block_size(32, 8, 1);
628
629 assert_eq!(config.total_blocks(), 16);
630 assert_eq!(config.threads_per_block(), 256);
631 assert_eq!(config.total_threads(), 4096);
632 }
633
634 #[test]
635 fn test_mock_thread_intrinsics() {
636 let config = MockKernelConfig::new()
637 .with_grid_size(2, 2, 1)
638 .with_block_size(16, 16, 1);
639
640 let thread = MockThread::new((5, 3, 0), (1, 0, 0), &config);
641
642 assert_eq!(thread.thread_idx_x(), 5);
643 assert_eq!(thread.thread_idx_y(), 3);
644 assert_eq!(thread.block_idx_x(), 1);
645 assert_eq!(thread.block_dim_x(), 16);
646 assert_eq!(thread.global_x(), 21); assert_eq!(thread.global_y(), 3); }
649
650 #[test]
651 fn test_mock_shared_memory() {
652 let shmem = MockSharedMemory::new(1024);
653
654 shmem.write::<f32>(0, 3.125);
655 shmem.write::<f32>(4, 2.75);
656
657 assert!((shmem.read::<f32>(0) - 3.125).abs() < 0.001);
658 assert!((shmem.read::<f32>(4) - 2.75).abs() < 0.001);
659
660 shmem.write_slice::<u32>(100, &[1, 2, 3, 4]);
661 let slice = shmem.as_slice::<u32>(100, 4);
662 assert_eq!(slice, vec![1, 2, 3, 4]);
663 }
664
665 #[test]
666 fn test_mock_atomics() {
667 let atomics = MockAtomics::new();
668
669 let old = atomics.atomic_add_u32(0, 5);
670 assert_eq!(old, 0);
671
672 let old = atomics.atomic_add_u32(0, 3);
673 assert_eq!(old, 5);
674
675 assert_eq!(atomics.load_u32(0), 8);
676 }
677
678 #[test]
679 fn test_mock_gpu_dispatch() {
680 let config = MockKernelConfig::new()
681 .with_grid_size(2, 1, 1)
682 .with_block_size(4, 1, 1);
683
684 let gpu = MockGpu::new(config);
685 let counter = Arc::new(AtomicU32::new(0));
686
687 let c = Arc::clone(&counter);
688 gpu.dispatch(move |_thread| {
689 c.fetch_add(1, Ordering::SeqCst);
690 });
691
692 assert_eq!(counter.load(Ordering::SeqCst), 8); }
694
695 #[test]
696 fn test_mock_warp_shuffle() {
697 let mut warp = MockWarp::new(32);
698
699 for i in 0..32 {
701 warp.set_lane(i, i * 2);
702 }
703
704 assert_eq!(warp.shuffle(5), 10);
706 assert_eq!(warp.shuffle(15), 30);
707
708 assert_eq!(warp.shuffle_xor(0, 1), 2); assert_eq!(warp.shuffle_xor(2, 1), 6); }
712
713 #[test]
714 fn test_mock_warp_ballot() {
715 let warp = MockWarp::new(32);
716
717 let ballot = warp.ballot(|lane| lane % 2 == 0);
719 assert_eq!(ballot, 0x55555555); }
721
722 #[test]
723 fn test_mock_warp_reduce() {
724 let mut warp = MockWarp::new(4);
725
726 warp.set_lane(0, 1);
727 warp.set_lane(1, 2);
728 warp.set_lane(2, 3);
729 warp.set_lane(3, 4);
730
731 assert_eq!(warp.reduce_sum(), 10);
732
733 let prefix = warp.prefix_sum_exclusive();
734 assert_eq!(prefix, vec![0, 1, 3, 6]);
735 }
736
737 #[test]
738 fn test_thread_global_id() {
739 let config = MockKernelConfig::new()
740 .with_grid_size(2, 2, 1)
741 .with_block_size(4, 4, 1);
742
743 let t1 = MockThread::new((0, 0, 0), (0, 0, 0), &config);
745 assert_eq!(t1.global_id(), 0);
746
747 let t2 = MockThread::new((0, 0, 0), (1, 0, 0), &config);
749 assert_eq!(t2.global_id(), 16);
750
751 let t3 = MockThread::new((3, 3, 0), (0, 0, 0), &config);
753 assert_eq!(t3.global_id(), 15);
754 }
755}