ringkernel_cpu/
mock.rs

1//! GPU Mock Testing Utilities
2//!
3//! This module provides utilities for mocking GPU behavior in CPU tests.
4//! It simulates GPU intrinsics, thread organization, and memory patterns.
5//!
6//! # Example
7//!
8//! ```rust
9//! use ringkernel_cpu::mock::{MockGpu, MockThread, MockKernelConfig};
10//!
11//! // Configure a mock kernel launch
12//! let config = MockKernelConfig::new()
13//!     .with_grid_size(4, 4, 1)
14//!     .with_block_size(32, 8, 1);
15//!
16//! // Create mock GPU context
17//! let gpu = MockGpu::new(config);
18//!
19//! // Execute kernel with mock threads
20//! gpu.dispatch(|thread| {
21//!     let gid = thread.global_id();
22//!     // Kernel code here
23//! });
24//! ```
25
26use std::cell::RefCell;
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
29use std::sync::{Arc, Barrier, RwLock};
30
31// ============================================================================
32// MOCK KERNEL CONFIGURATION
33// ============================================================================
34
35/// Configuration for mock kernel execution.
36#[derive(Debug, Clone)]
37pub struct MockKernelConfig {
38    /// Grid dimensions (number of blocks).
39    pub grid_dim: (u32, u32, u32),
40    /// Block dimensions (threads per block).
41    pub block_dim: (u32, u32, u32),
42    /// Shared memory size in bytes.
43    pub shared_memory_size: usize,
44    /// Whether to simulate warp execution.
45    pub simulate_warps: bool,
46    /// Warp size (typically 32 for NVIDIA, 64 for AMD).
47    pub warp_size: u32,
48}
49
50impl Default for MockKernelConfig {
51    fn default() -> Self {
52        Self {
53            grid_dim: (1, 1, 1),
54            block_dim: (256, 1, 1),
55            shared_memory_size: 49152, // 48KB default
56            simulate_warps: false,
57            warp_size: 32,
58        }
59    }
60}
61
62impl MockKernelConfig {
63    /// Create a new mock kernel configuration.
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Set grid dimensions.
69    pub fn with_grid_size(mut self, x: u32, y: u32, z: u32) -> Self {
70        self.grid_dim = (x, y, z);
71        self
72    }
73
74    /// Set block dimensions.
75    pub fn with_block_size(mut self, x: u32, y: u32, z: u32) -> Self {
76        self.block_dim = (x, y, z);
77        self
78    }
79
80    /// Set shared memory size.
81    pub fn with_shared_memory(mut self, bytes: usize) -> Self {
82        self.shared_memory_size = bytes;
83        self
84    }
85
86    /// Enable warp simulation.
87    pub fn with_warp_simulation(mut self, warp_size: u32) -> Self {
88        self.simulate_warps = true;
89        self.warp_size = warp_size;
90        self
91    }
92
93    /// Calculate total number of threads.
94    pub fn total_threads(&self) -> u64 {
95        let blocks = self.grid_dim.0 as u64 * self.grid_dim.1 as u64 * self.grid_dim.2 as u64;
96        let threads_per_block =
97            self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
98        blocks * threads_per_block
99    }
100
101    /// Calculate threads per block.
102    pub fn threads_per_block(&self) -> u32 {
103        self.block_dim.0 * self.block_dim.1 * self.block_dim.2
104    }
105
106    /// Calculate total blocks.
107    pub fn total_blocks(&self) -> u32 {
108        self.grid_dim.0 * self.grid_dim.1 * self.grid_dim.2
109    }
110}
111
112// ============================================================================
113// MOCK THREAD CONTEXT
114// ============================================================================
115
116/// Mock thread context providing GPU intrinsics.
117#[derive(Debug, Clone)]
118pub struct MockThread {
119    /// Thread index within block (x, y, z).
120    pub thread_idx: (u32, u32, u32),
121    /// Block index within grid (x, y, z).
122    pub block_idx: (u32, u32, u32),
123    /// Block dimensions.
124    pub block_dim: (u32, u32, u32),
125    /// Grid dimensions.
126    pub grid_dim: (u32, u32, u32),
127    /// Warp ID (within block).
128    pub warp_id: u32,
129    /// Lane ID (within warp).
130    pub lane_id: u32,
131    /// Warp size.
132    pub warp_size: u32,
133}
134
135impl MockThread {
136    /// Create a new mock thread.
137    pub fn new(
138        thread_idx: (u32, u32, u32),
139        block_idx: (u32, u32, u32),
140        config: &MockKernelConfig,
141    ) -> Self {
142        let linear_tid = thread_idx.0
143            + thread_idx.1 * config.block_dim.0
144            + thread_idx.2 * config.block_dim.0 * config.block_dim.1;
145
146        Self {
147            thread_idx,
148            block_idx,
149            block_dim: config.block_dim,
150            grid_dim: config.grid_dim,
151            warp_id: linear_tid / config.warp_size,
152            lane_id: linear_tid % config.warp_size,
153            warp_size: config.warp_size,
154        }
155    }
156
157    // ========================================================================
158    // GPU Intrinsics
159    // ========================================================================
160
161    /// Get thread index X.
162    #[inline]
163    pub fn thread_idx_x(&self) -> u32 {
164        self.thread_idx.0
165    }
166
167    /// Get thread index Y.
168    #[inline]
169    pub fn thread_idx_y(&self) -> u32 {
170        self.thread_idx.1
171    }
172
173    /// Get thread index Z.
174    #[inline]
175    pub fn thread_idx_z(&self) -> u32 {
176        self.thread_idx.2
177    }
178
179    /// Get block index X.
180    #[inline]
181    pub fn block_idx_x(&self) -> u32 {
182        self.block_idx.0
183    }
184
185    /// Get block index Y.
186    #[inline]
187    pub fn block_idx_y(&self) -> u32 {
188        self.block_idx.1
189    }
190
191    /// Get block index Z.
192    #[inline]
193    pub fn block_idx_z(&self) -> u32 {
194        self.block_idx.2
195    }
196
197    /// Get block dimension X.
198    #[inline]
199    pub fn block_dim_x(&self) -> u32 {
200        self.block_dim.0
201    }
202
203    /// Get block dimension Y.
204    #[inline]
205    pub fn block_dim_y(&self) -> u32 {
206        self.block_dim.1
207    }
208
209    /// Get block dimension Z.
210    #[inline]
211    pub fn block_dim_z(&self) -> u32 {
212        self.block_dim.2
213    }
214
215    /// Get grid dimension X.
216    #[inline]
217    pub fn grid_dim_x(&self) -> u32 {
218        self.grid_dim.0
219    }
220
221    /// Get grid dimension Y.
222    #[inline]
223    pub fn grid_dim_y(&self) -> u32 {
224        self.grid_dim.1
225    }
226
227    /// Get grid dimension Z.
228    #[inline]
229    pub fn grid_dim_z(&self) -> u32 {
230        self.grid_dim.2
231    }
232
233    /// Get global thread ID (1D linearized).
234    #[inline]
235    pub fn global_id(&self) -> u64 {
236        let block_linear = self.block_idx.0 as u64
237            + self.block_idx.1 as u64 * self.grid_dim.0 as u64
238            + self.block_idx.2 as u64 * self.grid_dim.0 as u64 * self.grid_dim.1 as u64;
239
240        let threads_per_block =
241            self.block_dim.0 as u64 * self.block_dim.1 as u64 * self.block_dim.2 as u64;
242        let thread_linear = self.thread_idx.0 as u64
243            + self.thread_idx.1 as u64 * self.block_dim.0 as u64
244            + self.thread_idx.2 as u64 * self.block_dim.0 as u64 * self.block_dim.1 as u64;
245
246        block_linear * threads_per_block + thread_linear
247    }
248
249    /// Get global X coordinate.
250    #[inline]
251    pub fn global_x(&self) -> u32 {
252        self.block_idx.0 * self.block_dim.0 + self.thread_idx.0
253    }
254
255    /// Get global Y coordinate.
256    #[inline]
257    pub fn global_y(&self) -> u32 {
258        self.block_idx.1 * self.block_dim.1 + self.thread_idx.1
259    }
260
261    /// Get global Z coordinate.
262    #[inline]
263    pub fn global_z(&self) -> u32 {
264        self.block_idx.2 * self.block_dim.2 + self.thread_idx.2
265    }
266
267    /// Check if this is the first thread in the block.
268    #[inline]
269    pub fn is_block_leader(&self) -> bool {
270        self.thread_idx == (0, 0, 0)
271    }
272
273    /// Check if this is the first thread in the warp.
274    #[inline]
275    pub fn is_warp_leader(&self) -> bool {
276        self.lane_id == 0
277    }
278}
279
280// ============================================================================
281// MOCK SHARED MEMORY
282// ============================================================================
283
284/// Mock shared memory for a block.
285pub struct MockSharedMemory {
286    data: RefCell<Vec<u8>>,
287    size: usize,
288}
289
290impl MockSharedMemory {
291    /// Create new shared memory.
292    pub fn new(size: usize) -> Self {
293        Self {
294            data: RefCell::new(vec![0u8; size]),
295            size,
296        }
297    }
298
299    /// Get size in bytes.
300    pub fn size(&self) -> usize {
301        self.size
302    }
303
304    /// Read a value at offset.
305    pub fn read<T: Copy>(&self, offset: usize) -> T {
306        let data = self.data.borrow();
307        assert!(offset + std::mem::size_of::<T>() <= self.size);
308        unsafe { std::ptr::read(data.as_ptr().add(offset) as *const T) }
309    }
310
311    /// Write a value at offset.
312    pub fn write<T: Copy>(&self, offset: usize, value: T) {
313        let mut data = self.data.borrow_mut();
314        assert!(offset + std::mem::size_of::<T>() <= self.size);
315        unsafe { std::ptr::write(data.as_mut_ptr().add(offset) as *mut T, value) };
316    }
317
318    /// Get a slice view.
319    pub fn as_slice<T: Copy>(&self, offset: usize, count: usize) -> Vec<T> {
320        let data = self.data.borrow();
321        let byte_size = count * std::mem::size_of::<T>();
322        assert!(offset + byte_size <= self.size);
323
324        let mut result = Vec::with_capacity(count);
325        unsafe {
326            let ptr = data.as_ptr().add(offset) as *const T;
327            for i in 0..count {
328                result.push(*ptr.add(i));
329            }
330        }
331        result
332    }
333
334    /// Write a slice.
335    pub fn write_slice<T: Copy>(&self, offset: usize, values: &[T]) {
336        let mut data = self.data.borrow_mut();
337        let byte_size = std::mem::size_of_val(values);
338        assert!(offset + byte_size <= self.size);
339
340        unsafe {
341            let ptr = data.as_mut_ptr().add(offset) as *mut T;
342            for (i, v) in values.iter().enumerate() {
343                *ptr.add(i) = *v;
344            }
345        }
346    }
347}
348
349// ============================================================================
350// MOCK ATOMICS
351// ============================================================================
352
353/// Mock atomic operations.
354pub struct MockAtomics {
355    u32_values: RwLock<HashMap<usize, AtomicU32>>,
356    u64_values: RwLock<HashMap<usize, AtomicU64>>,
357}
358
359impl Default for MockAtomics {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365impl MockAtomics {
366    /// Create new atomics storage.
367    pub fn new() -> Self {
368        Self {
369            u32_values: RwLock::new(HashMap::new()),
370            u64_values: RwLock::new(HashMap::new()),
371        }
372    }
373
374    /// Atomic add (u32).
375    pub fn atomic_add_u32(&self, addr: usize, val: u32) -> u32 {
376        let mut map = self.u32_values.write().unwrap();
377        let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
378        atomic.fetch_add(val, Ordering::SeqCst)
379    }
380
381    /// Atomic add (u64).
382    pub fn atomic_add_u64(&self, addr: usize, val: u64) -> u64 {
383        let mut map = self.u64_values.write().unwrap();
384        let atomic = map.entry(addr).or_insert_with(|| AtomicU64::new(0));
385        atomic.fetch_add(val, Ordering::SeqCst)
386    }
387
388    /// Atomic CAS (u32).
389    pub fn atomic_cas_u32(&self, addr: usize, expected: u32, new: u32) -> u32 {
390        let mut map = self.u32_values.write().unwrap();
391        let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
392        match atomic.compare_exchange(expected, new, Ordering::SeqCst, Ordering::SeqCst) {
393            Ok(v) | Err(v) => v,
394        }
395    }
396
397    /// Atomic max (u32).
398    pub fn atomic_max_u32(&self, addr: usize, val: u32) -> u32 {
399        let mut map = self.u32_values.write().unwrap();
400        let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
401        atomic.fetch_max(val, Ordering::SeqCst)
402    }
403
404    /// Atomic min (u32).
405    pub fn atomic_min_u32(&self, addr: usize, val: u32) -> u32 {
406        let mut map = self.u32_values.write().unwrap();
407        let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
408        atomic.fetch_min(val, Ordering::SeqCst)
409    }
410
411    /// Load value (u32).
412    pub fn load_u32(&self, addr: usize) -> u32 {
413        let map = self.u32_values.read().unwrap();
414        map.get(&addr)
415            .map(|a| a.load(Ordering::SeqCst))
416            .unwrap_or(0)
417    }
418
419    /// Store value (u32).
420    pub fn store_u32(&self, addr: usize, val: u32) {
421        let mut map = self.u32_values.write().unwrap();
422        let atomic = map.entry(addr).or_insert_with(|| AtomicU32::new(0));
423        atomic.store(val, Ordering::SeqCst);
424    }
425}
426
427// ============================================================================
428// MOCK GPU
429// ============================================================================
430
431/// Mock GPU for testing kernel execution.
432pub struct MockGpu {
433    config: MockKernelConfig,
434    atomics: Arc<MockAtomics>,
435}
436
437impl MockGpu {
438    /// Create a new mock GPU.
439    pub fn new(config: MockKernelConfig) -> Self {
440        Self {
441            config,
442            atomics: Arc::new(MockAtomics::new()),
443        }
444    }
445
446    /// Get configuration.
447    pub fn config(&self) -> &MockKernelConfig {
448        &self.config
449    }
450
451    /// Get atomics.
452    pub fn atomics(&self) -> &MockAtomics {
453        &self.atomics
454    }
455
456    /// Dispatch kernel execution sequentially.
457    ///
458    /// Executes the kernel function for each thread in order.
459    /// Useful for deterministic testing.
460    pub fn dispatch<F>(&self, kernel: F)
461    where
462        F: Fn(&MockThread),
463    {
464        for bz in 0..self.config.grid_dim.2 {
465            for by in 0..self.config.grid_dim.1 {
466                for bx in 0..self.config.grid_dim.0 {
467                    for tz in 0..self.config.block_dim.2 {
468                        for ty in 0..self.config.block_dim.1 {
469                            for tx in 0..self.config.block_dim.0 {
470                                let thread =
471                                    MockThread::new((tx, ty, tz), (bx, by, bz), &self.config);
472                                kernel(&thread);
473                            }
474                        }
475                    }
476                }
477            }
478        }
479    }
480
481    /// Dispatch with block synchronization.
482    ///
483    /// Provides a barrier for `sync_threads()` simulation within blocks.
484    pub fn dispatch_with_sync<F>(&self, kernel: F)
485    where
486        F: Fn(&MockThread, &Barrier) + Send + Sync,
487    {
488        let threads_per_block = self.config.threads_per_block() as usize;
489
490        for bz in 0..self.config.grid_dim.2 {
491            for by in 0..self.config.grid_dim.1 {
492                for bx in 0..self.config.grid_dim.0 {
493                    // Each block runs in parallel threads
494                    let barrier = Arc::new(Barrier::new(threads_per_block));
495                    std::thread::scope(|s| {
496                        for tz in 0..self.config.block_dim.2 {
497                            for ty in 0..self.config.block_dim.1 {
498                                for tx in 0..self.config.block_dim.0 {
499                                    let barrier = Arc::clone(&barrier);
500                                    let config = &self.config;
501                                    let kernel_ref = &kernel;
502                                    s.spawn(move || {
503                                        let thread =
504                                            MockThread::new((tx, ty, tz), (bx, by, bz), config);
505                                        kernel_ref(&thread, &barrier);
506                                    });
507                                }
508                            }
509                        }
510                    });
511                }
512            }
513        }
514    }
515}
516
517// ============================================================================
518// MOCK WARP OPERATIONS
519// ============================================================================
520
521/// Mock warp operations for testing warp-level primitives.
522pub struct MockWarp {
523    /// Lane values (up to 64 lanes for AMD).
524    lane_values: Vec<u32>,
525    /// Warp size.
526    warp_size: u32,
527}
528
529impl MockWarp {
530    /// Create a new mock warp.
531    pub fn new(warp_size: u32) -> Self {
532        Self {
533            lane_values: vec![0; warp_size as usize],
534            warp_size,
535        }
536    }
537
538    /// Set lane value.
539    pub fn set_lane(&mut self, lane: u32, value: u32) {
540        if (lane as usize) < self.lane_values.len() {
541            self.lane_values[lane as usize] = value;
542        }
543    }
544
545    /// Simulate warp shuffle.
546    pub fn shuffle(&self, src_lane: u32) -> u32 {
547        self.lane_values
548            .get(src_lane as usize)
549            .copied()
550            .unwrap_or(0)
551    }
552
553    /// Simulate warp shuffle XOR.
554    pub fn shuffle_xor(&self, lane_id: u32, mask: u32) -> u32 {
555        let src = lane_id ^ mask;
556        self.shuffle(src)
557    }
558
559    /// Simulate warp shuffle up.
560    pub fn shuffle_up(&self, lane_id: u32, delta: u32) -> u32 {
561        if lane_id >= delta {
562            self.shuffle(lane_id - delta)
563        } else {
564            self.lane_values[lane_id as usize]
565        }
566    }
567
568    /// Simulate warp shuffle down.
569    pub fn shuffle_down(&self, lane_id: u32, delta: u32) -> u32 {
570        if lane_id + delta < self.warp_size {
571            self.shuffle(lane_id + delta)
572        } else {
573            self.lane_values[lane_id as usize]
574        }
575    }
576
577    /// Simulate warp ballot.
578    pub fn ballot(&self, predicate: impl Fn(u32) -> bool) -> u64 {
579        let mut result = 0u64;
580        for lane in 0..self.warp_size {
581            if predicate(lane) {
582                result |= 1 << lane;
583            }
584        }
585        result
586    }
587
588    /// Simulate warp any.
589    pub fn any(&self, predicate: impl Fn(u32) -> bool) -> bool {
590        (0..self.warp_size).any(predicate)
591    }
592
593    /// Simulate warp all.
594    pub fn all(&self, predicate: impl Fn(u32) -> bool) -> bool {
595        (0..self.warp_size).all(predicate)
596    }
597
598    /// Simulate warp reduction (sum).
599    pub fn reduce_sum(&self) -> u32 {
600        self.lane_values.iter().sum()
601    }
602
603    /// Simulate warp prefix sum (exclusive).
604    pub fn prefix_sum_exclusive(&self) -> Vec<u32> {
605        let mut result = Vec::with_capacity(self.warp_size as usize);
606        let mut sum = 0;
607        for &v in &self.lane_values {
608            result.push(sum);
609            sum += v;
610        }
611        result
612    }
613}
614
615// ============================================================================
616// TESTS
617// ============================================================================
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622
623    #[test]
624    fn test_mock_config() {
625        let config = MockKernelConfig::new()
626            .with_grid_size(4, 4, 1)
627            .with_block_size(32, 8, 1);
628
629        assert_eq!(config.total_blocks(), 16);
630        assert_eq!(config.threads_per_block(), 256);
631        assert_eq!(config.total_threads(), 4096);
632    }
633
634    #[test]
635    fn test_mock_thread_intrinsics() {
636        let config = MockKernelConfig::new()
637            .with_grid_size(2, 2, 1)
638            .with_block_size(16, 16, 1);
639
640        let thread = MockThread::new((5, 3, 0), (1, 0, 0), &config);
641
642        assert_eq!(thread.thread_idx_x(), 5);
643        assert_eq!(thread.thread_idx_y(), 3);
644        assert_eq!(thread.block_idx_x(), 1);
645        assert_eq!(thread.block_dim_x(), 16);
646        assert_eq!(thread.global_x(), 21); // 1*16 + 5
647        assert_eq!(thread.global_y(), 3); // 0*16 + 3
648    }
649
650    #[test]
651    fn test_mock_shared_memory() {
652        let shmem = MockSharedMemory::new(1024);
653
654        shmem.write::<f32>(0, 3.125);
655        shmem.write::<f32>(4, 2.75);
656
657        assert!((shmem.read::<f32>(0) - 3.125).abs() < 0.001);
658        assert!((shmem.read::<f32>(4) - 2.75).abs() < 0.001);
659
660        shmem.write_slice::<u32>(100, &[1, 2, 3, 4]);
661        let slice = shmem.as_slice::<u32>(100, 4);
662        assert_eq!(slice, vec![1, 2, 3, 4]);
663    }
664
665    #[test]
666    fn test_mock_atomics() {
667        let atomics = MockAtomics::new();
668
669        let old = atomics.atomic_add_u32(0, 5);
670        assert_eq!(old, 0);
671
672        let old = atomics.atomic_add_u32(0, 3);
673        assert_eq!(old, 5);
674
675        assert_eq!(atomics.load_u32(0), 8);
676    }
677
678    #[test]
679    fn test_mock_gpu_dispatch() {
680        let config = MockKernelConfig::new()
681            .with_grid_size(2, 1, 1)
682            .with_block_size(4, 1, 1);
683
684        let gpu = MockGpu::new(config);
685        let counter = Arc::new(AtomicU32::new(0));
686
687        let c = Arc::clone(&counter);
688        gpu.dispatch(move |_thread| {
689            c.fetch_add(1, Ordering::SeqCst);
690        });
691
692        assert_eq!(counter.load(Ordering::SeqCst), 8); // 2 blocks * 4 threads
693    }
694
695    #[test]
696    fn test_mock_warp_shuffle() {
697        let mut warp = MockWarp::new(32);
698
699        // Set lane values
700        for i in 0..32 {
701            warp.set_lane(i, i * 2);
702        }
703
704        // Test shuffle
705        assert_eq!(warp.shuffle(5), 10);
706        assert_eq!(warp.shuffle(15), 30);
707
708        // Test shuffle XOR
709        assert_eq!(warp.shuffle_xor(0, 1), 2); // lane 0 XOR 1 = lane 1 value
710        assert_eq!(warp.shuffle_xor(2, 1), 6); // lane 2 XOR 1 = lane 3 value
711    }
712
713    #[test]
714    fn test_mock_warp_ballot() {
715        let warp = MockWarp::new(32);
716
717        // Ballot: all even lanes
718        let ballot = warp.ballot(|lane| lane % 2 == 0);
719        assert_eq!(ballot, 0x55555555); // Even bits set
720    }
721
722    #[test]
723    fn test_mock_warp_reduce() {
724        let mut warp = MockWarp::new(4);
725
726        warp.set_lane(0, 1);
727        warp.set_lane(1, 2);
728        warp.set_lane(2, 3);
729        warp.set_lane(3, 4);
730
731        assert_eq!(warp.reduce_sum(), 10);
732
733        let prefix = warp.prefix_sum_exclusive();
734        assert_eq!(prefix, vec![0, 1, 3, 6]);
735    }
736
737    #[test]
738    fn test_thread_global_id() {
739        let config = MockKernelConfig::new()
740            .with_grid_size(2, 2, 1)
741            .with_block_size(4, 4, 1);
742
743        // Thread (0,0) in block (0,0) -> global ID 0
744        let t1 = MockThread::new((0, 0, 0), (0, 0, 0), &config);
745        assert_eq!(t1.global_id(), 0);
746
747        // Thread (0,0) in block (1,0) -> global ID 16 (one block worth)
748        let t2 = MockThread::new((0, 0, 0), (1, 0, 0), &config);
749        assert_eq!(t2.global_id(), 16);
750
751        // Thread (3,3) in block (0,0) -> linear ID 15
752        let t3 = MockThread::new((3, 3, 0), (0, 0, 0), &config);
753        assert_eq!(t3.global_id(), 15);
754    }
755}