1use std::cell::RefCell;
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
29use std::sync::{Arc, Barrier, RwLock};
30
31#[derive(Debug, Clone)]
37pub struct MockKernelConfig {
38 pub grid_dim: (u32, u32, u32),
40 pub block_dim: (u32, u32, u32),
42 pub shared_memory_size: usize,
44 pub simulate_warps: bool,
46 pub warp_size: u32,
48}
49
50impl Default for MockKernelConfig {
51 fn default() -> Self {
52 Self {
53 grid_dim: (1, 1, 1),
54 block_dim: (256, 1, 1),
55 shared_memory_size: 49152, simulate_warps: false,
57 warp_size: 32,
58 }
59 }
60}
61
62impl MockKernelConfig {
63 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn with_grid_size(mut self, x: u32, y: u32, z: u32) -> Self {
70 self.grid_dim = (x, y, z);
71 self
72 }
73
74 pub fn with_block_size(mut self, x: u32, y: u32, z: u32) -> Self {
76 self.block_dim = (x, y, z);
77 self
78 }
79
80 pub fn with_shared_memory(mut self, bytes: usize) -> Self {
82 self.shared_memory_size = bytes;
83 self
84 }
85
86 pub fn with_warp_simulation(mut self, warp_size: u32) -> Self {
88 self.simulate_warps = true;
89 self.warp_size = warp_size;
90 self
91 }
92
93 pub fn total_threads(&self) -> u64 {
95 let blocks = self.grid_dim.0 as u64 * self.grid_dim.1 as u64 * self.grid_dim.2 as u64;
96 let threads_per_block =
97 self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
98 blocks * threads_per_block
99 }
100
101 pub fn threads_per_block(&self) -> u32 {
103 self.block_dim.0 * self.block_dim.1 * self.block_dim.2
104 }
105
106 pub fn total_blocks(&self) -> u32 {
108 self.grid_dim.0 * self.grid_dim.1 * self.grid_dim.2
109 }
110}
111
112#[derive(Debug, Clone)]
118pub struct MockThread {
119 pub thread_idx: (u32, u32, u32),
121 pub block_idx: (u32, u32, u32),
123 pub block_dim: (u32, u32, u32),
125 pub grid_dim: (u32, u32, u32),
127 pub warp_id: u32,
129 pub lane_id: u32,
131 pub warp_size: u32,
133}
134
135impl MockThread {
136 pub fn new(
138 thread_idx: (u32, u32, u32),
139 block_idx: (u32, u32, u32),
140 config: &MockKernelConfig,
141 ) -> Self {
142 let linear_tid = thread_idx.0
143 + thread_idx.1 * config.block_dim.0
144 + thread_idx.2 * config.block_dim.0 * config.block_dim.1;
145
146 Self {
147 thread_idx,
148 block_idx,
149 block_dim: config.block_dim,
150 grid_dim: config.grid_dim,
151 warp_id: linear_tid / config.warp_size,
152 lane_id: linear_tid % config.warp_size,
153 warp_size: config.warp_size,
154 }
155 }
156
157 #[inline]
163 pub fn thread_idx_x(&self) -> u32 {
164 self.thread_idx.0
165 }
166
167 #[inline]
169 pub fn thread_idx_y(&self) -> u32 {
170 self.thread_idx.1
171 }
172
173 #[inline]
175 pub fn thread_idx_z(&self) -> u32 {
176 self.thread_idx.2
177 }
178
179 #[inline]
181 pub fn block_idx_x(&self) -> u32 {
182 self.block_idx.0
183 }
184
185 #[inline]
187 pub fn block_idx_y(&self) -> u32 {
188 self.block_idx.1
189 }
190
191 #[inline]
193 pub fn block_idx_z(&self) -> u32 {
194 self.block_idx.2
195 }
196
197 #[inline]
199 pub fn block_dim_x(&self) -> u32 {
200 self.block_dim.0
201 }
202
203 #[inline]
205 pub fn block_dim_y(&self) -> u32 {
206 self.block_dim.1
207 }
208
209 #[inline]
211 pub fn block_dim_z(&self) -> u32 {
212 self.block_dim.2
213 }
214
215 #[inline]
217 pub fn grid_dim_x(&self) -> u32 {
218 self.grid_dim.0
219 }
220
221 #[inline]
223 pub fn grid_dim_y(&self) -> u32 {
224 self.grid_dim.1
225 }
226
227 #[inline]
229 pub fn grid_dim_z(&self) -> u32 {
230 self.grid_dim.2
231 }
232
233 #[inline]
235 pub fn global_id(&self) -> u64 {
236 let block_linear = self.block_idx.0 as u64
237 + self.block_idx.1 as u64 * self.grid_dim.0 as u64
238 + self.block_idx.2 as u64 * self.grid_dim.0 as u64 * self.grid_dim.1 as u64;
239
240 let threads_per_block =
241 self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
242 let thread_linear = self.thread_idx.0 as u64
243 + self.thread_idx.1 as u64 * self.block_dim.0 as u64
244 + self.thread_idx.2 as u64 * self.block_dim.0 as u64 * self.block_dim.1 as u64;
245
246 block_linear * threads_per_block + thread_linear
247 }
248
249 #[inline]
251 pub fn global_x(&self) -> u32 {
252 self.block_idx.0 * self.block_dim.0 + self.thread_idx.0
253 }
254
255 #[inline]
257 pub fn global_y(&self) -> u32 {
258 self.block_idx.1 * self.block_dim.1 + self.thread_idx.1
259 }
260
261 #[inline]
263 pub fn global_z(&self) -> u32 {
264 self.block_idx.2 * self.block_dim.2 + self.thread_idx.2
265 }
266
267 #[inline]
269 pub fn is_block_leader(&self) -> bool {
270 self.thread_idx == (0, 0, 0)
271 }
272
273 #[inline]
275 pub fn is_warp_leader(&self) -> bool {
276 self.lane_id == 0
277 }
278}
279
280pub struct MockSharedMemory {
286 data: RefCell<Vec<u8>>,
287 size: usize,
288}
289
290impl MockSharedMemory {
291 pub fn new(size: usize) -> Self {
293 Self {
294 data: RefCell::new(vec![0u8; size]),
295 size,
296 }
297 }
298
299 pub fn size(&self) -> usize {
301 self.size
302 }
303
304 pub fn read<T: Copy>(&self, offset: usize) -> T {
306 let data = self.data.borrow();
307 assert!(offset + std::mem::size_of::<T>() <= self.size);
308 unsafe { std::ptr::read(data.as_ptr().add(offset) as *const T) }
309 }
310
311 pub fn write<T: Copy>(&self, offset: usize, value: T) {
313 let mut data = self.data.borrow_mut();
314 assert!(offset + std::mem::size_of::<T>() <= self.size);
315 unsafe { std::ptr::write(data.as_mut_ptr().add(offset) as *mut T, value) };
316 }
317
318 pub fn as_slice<T: Copy>(&self, offset: usize, count: usize) -> Vec<T> {
320 let data = self.data.borrow();
321 let byte_size = count * std::mem::size_of::<T>();
322 assert!(offset + byte_size <= self.size);
323
324 let mut result = Vec::with_capacity(count);
325 unsafe {
326 let ptr = data.as_ptr().add(offset) as *const T;
327 for i in 0..count {
328 result.push(*ptr.add(i));
329 }
330 }
331 result
332 }
333
334 pub fn write_slice<T: Copy>(&self, offset: usize, values: &[T]) {
336 let mut data = self.data.borrow_mut();
337 let byte_size = std::mem::size_of_val(values);
338 assert!(offset + byte_size <= self.size);
339
340 unsafe {
341 let ptr = data.as_mut_ptr().add(offset) as *mut T;
342 for (i, v) in values.iter().enumerate() {
343 *ptr.add(i) = *v;
344 }
345 }
346 }
347}
348
349pub struct MockAtomics {
355 u32_values: RwLock<HashMap<usize, AtomicU32>>,
356 u64_values: RwLock<HashMap<usize, AtomicU64>>,
357}
358
359impl Default for MockAtomics {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365impl MockAtomics {
366 pub fn new() -> Self {
368 Self {
369 u32_values: RwLock::new(HashMap::new()),
370 u64_values: RwLock::new(HashMap::new()),
371 }
372 }
373
374 pub fn atomic_add_u32(&self, addr: usize, val: u32) -> u32 {
376 let mut map = self
377 .u32_values
378 .write()
379 .expect("MockAtomics u32 RwLock poisoned");
380 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
381 atomic.fetch_add(val, Ordering::SeqCst)
382 }
383
384 pub fn atomic_add_u64(&self, addr: usize, val: u64) -> u64 {
386 let mut map = self
387 .u64_values
388 .write()
389 .expect("MockAtomics u64 RwLock poisoned");
390 let atomic = map.entry(addr).or_insert_with(|| AtomicU64::new(0));
391 atomic.fetch_add(val, Ordering::SeqCst)
392 }
393
394 pub fn atomic_cas_u32(&self, addr: usize, expected: u32, new: u32) -> u32 {
396 let mut map = self
397 .u32_values
398 .write()
399 .expect("MockAtomics u32 RwLock poisoned");
400 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
401 match atomic.compare_exchange(expected, new, Ordering::SeqCst, Ordering::SeqCst) {
402 Ok(v) | Err(v) => v,
403 }
404 }
405
406 pub fn atomic_max_u32(&self, addr: usize, val: u32) -> u32 {
408 let mut map = self
409 .u32_values
410 .write()
411 .expect("MockAtomics u32 RwLock poisoned");
412 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
413 atomic.fetch_max(val, Ordering::SeqCst)
414 }
415
416 pub fn atomic_min_u32(&self, addr: usize, val: u32) -> u32 {
418 let mut map = self
419 .u32_values
420 .write()
421 .expect("MockAtomics u32 RwLock poisoned");
422 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
423 atomic.fetch_min(val, Ordering::SeqCst)
424 }
425
426 pub fn load_u32(&self, addr: usize) -> u32 {
428 let map = self
429 .u32_values
430 .read()
431 .expect("MockAtomics u32 RwLock poisoned");
432 map.get(&addr)
433 .map(|a| a.load(Ordering::SeqCst))
434 .unwrap_or(0)
435 }
436
437 pub fn store_u32(&self, addr: usize, val: u32) {
439 let mut map = self
440 .u32_values
441 .write()
442 .expect("MockAtomics u32 RwLock poisoned");
443 let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
444 atomic.store(val, Ordering::SeqCst);
445 }
446}
447
448pub struct MockGpu {
454 config: MockKernelConfig,
455 atomics: Arc<MockAtomics>,
456}
457
458impl MockGpu {
459 pub fn new(config: MockKernelConfig) -> Self {
461 Self {
462 config,
463 atomics: Arc::new(MockAtomics::new()),
464 }
465 }
466
467 pub fn config(&self) -> &MockKernelConfig {
469 &self.config
470 }
471
472 pub fn atomics(&self) -> &MockAtomics {
474 &self.atomics
475 }
476
477 pub fn dispatch<F>(&self, kernel: F)
482 where
483 F: Fn(&MockThread),
484 {
485 for bz in 0..self.config.grid_dim.2 {
486 for by in 0..self.config.grid_dim.1 {
487 for bx in 0..self.config.grid_dim.0 {
488 for tz in 0..self.config.block_dim.2 {
489 for ty in 0..self.config.block_dim.1 {
490 for tx in 0..self.config.block_dim.0 {
491 let thread =
492 MockThread::new((tx, ty, tz), (bx, by, bz), &self.config);
493 kernel(&thread);
494 }
495 }
496 }
497 }
498 }
499 }
500 }
501
502 pub fn dispatch_with_sync<F>(&self, kernel: F)
506 where
507 F: Fn(&MockThread, &Barrier) + Send + Sync,
508 {
509 let threads_per_block = self.config.threads_per_block() as usize;
510
511 for bz in 0..self.config.grid_dim.2 {
512 for by in 0..self.config.grid_dim.1 {
513 for bx in 0..self.config.grid_dim.0 {
514 let barrier = Arc::new(Barrier::new(threads_per_block));
516 std::thread::scope(|s| {
517 for tz in 0..self.config.block_dim.2 {
518 for ty in 0..self.config.block_dim.1 {
519 for tx in 0..self.config.block_dim.0 {
520 let barrier = Arc::clone(&barrier);
521 let config = &self.config;
522 let kernel_ref = &kernel;
523 s.spawn(move || {
524 let thread =
525 MockThread::new((tx, ty, tz), (bx, by, bz), config);
526 kernel_ref(&thread, &barrier);
527 });
528 }
529 }
530 }
531 });
532 }
533 }
534 }
535 }
536}
537
538pub struct MockWarp {
544 lane_values: Vec<u32>,
546 warp_size: u32,
548}
549
550impl MockWarp {
551 pub fn new(warp_size: u32) -> Self {
553 Self {
554 lane_values: vec![0; warp_size as usize],
555 warp_size,
556 }
557 }
558
559 pub fn set_lane(&mut self, lane: u32, value: u32) {
561 if (lane as usize) < self.lane_values.len() {
562 self.lane_values[lane as usize] = value;
563 }
564 }
565
566 pub fn shuffle(&self, src_lane: u32) -> u32 {
568 self.lane_values
569 .get(src_lane as usize)
570 .copied()
571 .unwrap_or(0)
572 }
573
574 pub fn shuffle_xor(&self, lane_id: u32, mask: u32) -> u32 {
576 let src = lane_id ^ mask;
577 self.shuffle(src)
578 }
579
580 pub fn shuffle_up(&self, lane_id: u32, delta: u32) -> u32 {
582 if lane_id >= delta {
583 self.shuffle(lane_id - delta)
584 } else {
585 self.lane_values[lane_id as usize]
586 }
587 }
588
589 pub fn shuffle_down(&self, lane_id: u32, delta: u32) -> u32 {
591 if lane_id + delta < self.warp_size {
592 self.shuffle(lane_id + delta)
593 } else {
594 self.lane_values[lane_id as usize]
595 }
596 }
597
598 pub fn ballot(&self, predicate: impl Fn(u32) -> bool) -> u64 {
600 let mut result = 0u64;
601 for lane in 0..self.warp_size {
602 if predicate(lane) {
603 result |= 1 << lane;
604 }
605 }
606 result
607 }
608
609 pub fn any(&self, predicate: impl Fn(u32) -> bool) -> bool {
611 (0..self.warp_size).any(predicate)
612 }
613
614 pub fn all(&self, predicate: impl Fn(u32) -> bool) -> bool {
616 (0..self.warp_size).all(predicate)
617 }
618
619 pub fn reduce_sum(&self) -> u32 {
621 self.lane_values.iter().sum()
622 }
623
624 pub fn prefix_sum_exclusive(&self) -> Vec<u32> {
626 let mut result = Vec::with_capacity(self.warp_size as usize);
627 let mut sum = 0;
628 for &v in &self.lane_values {
629 result.push(sum);
630 sum += v;
631 }
632 result
633 }
634}
635
636#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn test_mock_config() {
646 let config = MockKernelConfig::new()
647 .with_grid_size(4, 4, 1)
648 .with_block_size(32, 8, 1);
649
650 assert_eq!(config.total_blocks(), 16);
651 assert_eq!(config.threads_per_block(), 256);
652 assert_eq!(config.total_threads(), 4096);
653 }
654
655 #[test]
656 fn test_mock_thread_intrinsics() {
657 let config = MockKernelConfig::new()
658 .with_grid_size(2, 2, 1)
659 .with_block_size(16, 16, 1);
660
661 let thread = MockThread::new((5, 3, 0), (1, 0, 0), &config);
662
663 assert_eq!(thread.thread_idx_x(), 5);
664 assert_eq!(thread.thread_idx_y(), 3);
665 assert_eq!(thread.block_idx_x(), 1);
666 assert_eq!(thread.block_dim_x(), 16);
667 assert_eq!(thread.global_x(), 21); assert_eq!(thread.global_y(), 3); }
670
671 #[test]
672 fn test_mock_shared_memory() {
673 let shmem = MockSharedMemory::new(1024);
674
675 shmem.write::<f32>(0, 3.125);
676 shmem.write::<f32>(4, 2.75);
677
678 assert!((shmem.read::<f32>(0) - 3.125).abs() < 0.001);
679 assert!((shmem.read::<f32>(4) - 2.75).abs() < 0.001);
680
681 shmem.write_slice::<u32>(100, &[1, 2, 3, 4]);
682 let slice = shmem.as_slice::<u32>(100, 4);
683 assert_eq!(slice, vec![1, 2, 3, 4]);
684 }
685
686 #[test]
687 fn test_mock_atomics() {
688 let atomics = MockAtomics::new();
689
690 let old = atomics.atomic_add_u32(0, 5);
691 assert_eq!(old, 0);
692
693 let old = atomics.atomic_add_u32(0, 3);
694 assert_eq!(old, 5);
695
696 assert_eq!(atomics.load_u32(0), 8);
697 }
698
699 #[test]
700 fn test_mock_gpu_dispatch() {
701 let config = MockKernelConfig::new()
702 .with_grid_size(2, 1, 1)
703 .with_block_size(4, 1, 1);
704
705 let gpu = MockGpu::new(config);
706 let counter = Arc::new(AtomicU32::new(0));
707
708 let c = Arc::clone(&counter);
709 gpu.dispatch(move |_thread| {
710 c.fetch_add(1, Ordering::SeqCst);
711 });
712
713 assert_eq!(counter.load(Ordering::SeqCst), 8); }
715
716 #[test]
717 fn test_mock_warp_shuffle() {
718 let mut warp = MockWarp::new(32);
719
720 for i in 0..32 {
722 warp.set_lane(i, i * 2);
723 }
724
725 assert_eq!(warp.shuffle(5), 10);
727 assert_eq!(warp.shuffle(15), 30);
728
729 assert_eq!(warp.shuffle_xor(0, 1), 2); assert_eq!(warp.shuffle_xor(2, 1), 6); }
733
734 #[test]
735 fn test_mock_warp_ballot() {
736 let warp = MockWarp::new(32);
737
738 let ballot = warp.ballot(|lane| lane % 2 == 0);
740 assert_eq!(ballot, 0x55555555); }
742
743 #[test]
744 fn test_mock_warp_reduce() {
745 let mut warp = MockWarp::new(4);
746
747 warp.set_lane(0, 1);
748 warp.set_lane(1, 2);
749 warp.set_lane(2, 3);
750 warp.set_lane(3, 4);
751
752 assert_eq!(warp.reduce_sum(), 10);
753
754 let prefix = warp.prefix_sum_exclusive();
755 assert_eq!(prefix, vec![0, 1, 3, 6]);
756 }
757
758 #[test]
759 fn test_thread_global_id() {
760 let config = MockKernelConfig::new()
761 .with_grid_size(2, 2, 1)
762 .with_block_size(4, 4, 1);
763
764 let t1 = MockThread::new((0, 0, 0), (0, 0, 0), &config);
766 assert_eq!(t1.global_id(), 0);
767
768 let t2 = MockThread::new((0, 0, 0), (1, 0, 0), &config);
770 assert_eq!(t2.global_id(), 16);
771
772 let t3 = MockThread::new((3, 3, 0), (0, 0, 0), &config);
774 assert_eq!(t3.global_id(), 15);
775 }
776}