Skip to main content

cuda_rust_wasm/kernel/
shared_memory.rs

1//! Shared memory management for CUDA kernel emulation
2//!
3//! Emulates CUDA shared memory (`__shared__`) on the CPU, including:
4//! - Static shared memory allocation (known size at compile time)
5//! - Dynamic (extern) shared memory allocation (size provided at launch)
6//! - Bank conflict detection for profiling/debugging
7//!
8//! In CUDA, shared memory is per-block SRAM accessible by all threads in a
9//! block. On the CPU we emulate this with heap-allocated buffers shared among
10//! threads in the same block.
11
12use std::alloc::{self, Layout};
13use std::marker::PhantomData;
14use std::ptr::NonNull;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17/// Number of memory banks (matches NVIDIA GPU shared memory banks).
18pub const NUM_BANKS: usize = 32;
19
20/// Size of each bank in bytes (4 bytes = 32 bits, matching CUDA).
21pub const BANK_WIDTH_BYTES: usize = 4;
22
23/// Static shared memory allocation.
24///
25/// Represents a fixed-size shared memory buffer that is known at compile time.
26/// Analogous to `__shared__ T data[N]` in CUDA.
27///
28/// # Type Parameters
29/// - `T`: Element type (must be `Send + Sync` since shared across threads)
30pub struct SharedMemory<T: Send + Sync> {
31    /// Pointer to the allocated memory
32    ptr: NonNull<T>,
33    /// Number of elements
34    len: usize,
35    /// Alignment requirement in bytes
36    _marker: PhantomData<T>,
37}
38
39// Safety: SharedMemory is explicitly designed for cross-thread sharing.
40unsafe impl<T: Send + Sync> Send for SharedMemory<T> {}
41unsafe impl<T: Send + Sync> Sync for SharedMemory<T> {}
42
43impl<T: Send + Sync> SharedMemory<T> {
44    /// Allocate a new shared memory buffer with `count` elements, all zeroed.
45    ///
46    /// # Panics
47    /// Panics if the allocation fails or if `count * size_of::<T>()` overflows.
48    pub fn new(count: usize) -> Self {
49        assert!(count > 0, "SharedMemory: count must be > 0");
50        let layout = Layout::array::<T>(count).expect("SharedMemory: layout overflow");
51
52        // Safety: layout has non-zero size (count > 0, T has non-zero size for most types)
53        let ptr = if layout.size() > 0 {
54            let raw = unsafe { alloc::alloc_zeroed(layout) };
55            NonNull::new(raw as *mut T).expect("SharedMemory: allocation failed")
56        } else {
57            NonNull::dangling()
58        };
59
60        Self {
61            ptr,
62            len: count,
63            _marker: PhantomData,
64        }
65    }
66
67    /// Returns the number of elements.
68    pub fn len(&self) -> usize {
69        self.len
70    }
71
72    /// Returns true if the buffer is empty (always false after construction).
73    pub fn is_empty(&self) -> bool {
74        self.len == 0
75    }
76
77    /// Get a reference to the element at `index`.
78    ///
79    /// # Panics
80    /// Panics if `index >= len`.
81    pub fn get(&self, index: usize) -> &T {
82        assert!(index < self.len, "SharedMemory: index {index} out of bounds (len={})", self.len);
83        unsafe { &*self.ptr.as_ptr().add(index) }
84    }
85
86    /// Get a mutable reference to the element at `index`.
87    ///
88    /// # Safety
89    /// The caller must ensure no other thread is reading or writing the same
90    /// index concurrently (or use appropriate synchronization).
91    ///
92    /// # Panics
93    /// Panics if `index >= len`.
94    pub fn get_mut(&mut self, index: usize) -> &mut T {
95        assert!(index < self.len, "SharedMemory: index {index} out of bounds (len={})", self.len);
96        unsafe { &mut *self.ptr.as_ptr().add(index) }
97    }
98
99    /// Get the raw pointer to the underlying buffer.
100    pub fn as_ptr(&self) -> *const T {
101        self.ptr.as_ptr() as *const T
102    }
103
104    /// Get a mutable raw pointer to the underlying buffer.
105    pub fn as_mut_ptr(&mut self) -> *mut T {
106        self.ptr.as_ptr()
107    }
108
109    /// Get a slice view of the shared memory.
110    pub fn as_slice(&self) -> &[T] {
111        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const T, self.len) }
112    }
113
114    /// Get a mutable slice view of the shared memory.
115    ///
116    /// # Safety
117    /// Caller must ensure exclusive access.
118    pub fn as_mut_slice(&mut self) -> &mut [T] {
119        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
120    }
121}
122
123impl<T: Send + Sync> Drop for SharedMemory<T> {
124    fn drop(&mut self) {
125        if self.len > 0 {
126            let layout = Layout::array::<T>(self.len)
127                .expect("SharedMemory::drop: layout overflow");
128            if layout.size() > 0 {
129                unsafe {
130                    alloc::dealloc(self.ptr.as_ptr() as *mut u8, layout);
131                }
132            }
133        }
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Dynamic (extern) shared memory
139// ---------------------------------------------------------------------------
140
141/// Dynamic shared memory allocation.
142///
143/// Represents a shared memory buffer whose size is determined at kernel launch
144/// time. Analogous to `extern __shared__ T data[]` in CUDA where the size is
145/// passed as a launch parameter.
146///
147/// The buffer is untyped (byte-level) and callers can reinterpret as needed.
148pub struct DynamicSharedMemory {
149    /// Raw byte buffer
150    ptr: NonNull<u8>,
151    /// Size in bytes
152    size_bytes: usize,
153}
154
155// Safety: Same as SharedMemory - designed for cross-thread sharing.
156unsafe impl Send for DynamicSharedMemory {}
157unsafe impl Sync for DynamicSharedMemory {}
158
159impl DynamicSharedMemory {
160    /// Allocate dynamic shared memory of the given size in bytes.
161    ///
162    /// # Panics
163    /// Panics if `size_bytes` is 0 or allocation fails.
164    pub fn new(size_bytes: usize) -> Self {
165        assert!(size_bytes > 0, "DynamicSharedMemory: size must be > 0");
166
167        // Align to 16 bytes for SIMD compatibility
168        let layout = Layout::from_size_align(size_bytes, 16)
169            .expect("DynamicSharedMemory: invalid layout");
170
171        let ptr = unsafe { alloc::alloc_zeroed(layout) };
172        let ptr = NonNull::new(ptr).expect("DynamicSharedMemory: allocation failed");
173
174        Self { ptr, size_bytes }
175    }
176
177    /// Returns the size of the buffer in bytes.
178    pub fn size_bytes(&self) -> usize {
179        self.size_bytes
180    }
181
182    /// Reinterpret the buffer as a typed slice of `T`.
183    ///
184    /// # Panics
185    /// Panics if the buffer size is not a multiple of `size_of::<T>()` or if
186    /// the alignment is insufficient.
187    pub fn as_typed_slice<T>(&self) -> &[T] {
188        let elem_size = std::mem::size_of::<T>();
189        assert!(elem_size > 0, "DynamicSharedMemory: zero-sized type");
190        assert!(
191            self.size_bytes % elem_size == 0,
192            "DynamicSharedMemory: size {} not a multiple of element size {}",
193            self.size_bytes,
194            elem_size
195        );
196        assert!(
197            self.ptr.as_ptr() as usize % std::mem::align_of::<T>() == 0,
198            "DynamicSharedMemory: alignment mismatch for type"
199        );
200
201        let count = self.size_bytes / elem_size;
202        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const T, count) }
203    }
204
205    /// Reinterpret the buffer as a mutable typed slice of `T`.
206    ///
207    /// # Safety
208    /// Caller must ensure exclusive access and correct typing.
209    pub fn as_typed_slice_mut<T>(&mut self) -> &mut [T] {
210        let elem_size = std::mem::size_of::<T>();
211        assert!(elem_size > 0, "DynamicSharedMemory: zero-sized type");
212        assert!(
213            self.size_bytes % elem_size == 0,
214            "DynamicSharedMemory: size {} not a multiple of element size {}",
215            self.size_bytes,
216            elem_size
217        );
218        assert!(
219            self.ptr.as_ptr() as usize % std::mem::align_of::<T>() == 0,
220            "DynamicSharedMemory: alignment mismatch for type"
221        );
222
223        let count = self.size_bytes / elem_size;
224        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut T, count) }
225    }
226
227    /// Get the raw byte pointer.
228    pub fn as_ptr(&self) -> *const u8 {
229        self.ptr.as_ptr() as *const u8
230    }
231
232    /// Get a mutable raw byte pointer.
233    pub fn as_mut_ptr(&mut self) -> *mut u8 {
234        self.ptr.as_ptr()
235    }
236}
237
238impl Drop for DynamicSharedMemory {
239    fn drop(&mut self) {
240        let layout = Layout::from_size_align(self.size_bytes, 16)
241            .expect("DynamicSharedMemory::drop: invalid layout");
242        unsafe {
243            alloc::dealloc(self.ptr.as_ptr(), layout);
244        }
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Bank conflict detection (profiling)
250// ---------------------------------------------------------------------------
251
252/// Tracks shared memory access patterns to detect bank conflicts.
253///
254/// In CUDA, shared memory is divided into banks. Simultaneous accesses to the
255/// same bank by different threads cause serialisation (bank conflicts). This
256/// profiler counts such conflicts to help developers optimise access patterns.
257pub struct BankConflictDetector {
258    /// Total accesses recorded
259    total_accesses: AtomicUsize,
260    /// Number of bank conflicts detected
261    conflict_count: AtomicUsize,
262    /// Per-bank access counters for the current "cycle"
263    bank_accesses: [AtomicUsize; NUM_BANKS],
264}
265
266impl BankConflictDetector {
267    /// Create a new bank conflict detector.
268    pub fn new() -> Self {
269        const INIT: AtomicUsize = AtomicUsize::new(0);
270        Self {
271            total_accesses: AtomicUsize::new(0),
272            conflict_count: AtomicUsize::new(0),
273            bank_accesses: [INIT; NUM_BANKS],
274        }
275    }
276
277    /// Record an access to a shared memory address.
278    ///
279    /// Computes which bank the byte address maps to and counts conflicts
280    /// when multiple threads in the same warp access the same bank in one
281    /// cycle (represented by a batch of `record_access` calls between
282    /// `begin_cycle` / `end_cycle`).
283    ///
284    /// # Arguments
285    /// * `byte_address` - The byte offset into shared memory
286    pub fn record_access(&self, byte_address: usize) {
287        let bank = Self::address_to_bank(byte_address);
288        let prev = self.bank_accesses[bank].fetch_add(1, Ordering::Relaxed);
289        self.total_accesses.fetch_add(1, Ordering::Relaxed);
290
291        // If this bank was already accessed in the current cycle, it is a conflict
292        if prev > 0 {
293            self.conflict_count.fetch_add(1, Ordering::Relaxed);
294        }
295    }
296
297    /// Begin a new access cycle (e.g., a new warp instruction).
298    /// Resets the per-bank counters.
299    pub fn begin_cycle(&self) {
300        for bank in &self.bank_accesses {
301            bank.store(0, Ordering::Relaxed);
302        }
303    }
304
305    /// Compute which bank a byte address maps to.
306    ///
307    /// Bank index = `(byte_address / BANK_WIDTH_BYTES) % NUM_BANKS`
308    pub fn address_to_bank(byte_address: usize) -> usize {
309        (byte_address / BANK_WIDTH_BYTES) % NUM_BANKS
310    }
311
312    /// Get the total number of accesses recorded.
313    pub fn total_accesses(&self) -> usize {
314        self.total_accesses.load(Ordering::Relaxed)
315    }
316
317    /// Get the number of bank conflicts detected.
318    pub fn conflict_count(&self) -> usize {
319        self.conflict_count.load(Ordering::Relaxed)
320    }
321
322    /// Get the conflict rate (conflicts / total accesses).
323    /// Returns 0.0 if no accesses have been recorded.
324    pub fn conflict_rate(&self) -> f64 {
325        let total = self.total_accesses() as f64;
326        if total == 0.0 {
327            0.0
328        } else {
329            self.conflict_count() as f64 / total
330        }
331    }
332
333    /// Reset all counters.
334    pub fn reset(&self) {
335        self.total_accesses.store(0, Ordering::Relaxed);
336        self.conflict_count.store(0, Ordering::Relaxed);
337        for bank in &self.bank_accesses {
338            bank.store(0, Ordering::Relaxed);
339        }
340    }
341
342    /// Returns a human-readable summary of bank conflict statistics.
343    pub fn summary(&self) -> String {
344        format!(
345            "Bank conflicts: {} / {} accesses ({:.1}% conflict rate)",
346            self.conflict_count(),
347            self.total_accesses(),
348            self.conflict_rate() * 100.0,
349        )
350    }
351}
352
353impl Default for BankConflictDetector {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359// ---------------------------------------------------------------------------
360// Tests
361// ---------------------------------------------------------------------------
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_static_shared_memory_new() {
368        let smem: SharedMemory<f32> = SharedMemory::new(256);
369        assert_eq!(smem.len(), 256);
370        assert!(!smem.is_empty());
371    }
372
373    #[test]
374    fn test_static_shared_memory_read_write() {
375        let mut smem: SharedMemory<i32> = SharedMemory::new(16);
376        *smem.get_mut(0) = 42;
377        *smem.get_mut(15) = 99;
378        assert_eq!(*smem.get(0), 42);
379        assert_eq!(*smem.get(15), 99);
380        // Zeroed elements
381        assert_eq!(*smem.get(1), 0);
382    }
383
384    #[test]
385    fn test_static_shared_memory_slice() {
386        let mut smem: SharedMemory<f32> = SharedMemory::new(8);
387        {
388            let slice = smem.as_mut_slice();
389            for (i, val) in slice.iter_mut().enumerate() {
390                *val = i as f32 * 2.0;
391            }
392        }
393        let slice = smem.as_slice();
394        assert!((slice[3] - 6.0).abs() < 1e-6);
395    }
396
397    #[test]
398    #[should_panic(expected = "index 16 out of bounds")]
399    fn test_static_shared_memory_out_of_bounds() {
400        let smem: SharedMemory<u32> = SharedMemory::new(16);
401        let _ = smem.get(16);
402    }
403
404    #[test]
405    fn test_dynamic_shared_memory_new() {
406        let dsmem = DynamicSharedMemory::new(1024);
407        assert_eq!(dsmem.size_bytes(), 1024);
408    }
409
410    #[test]
411    fn test_dynamic_shared_memory_typed_access() {
412        let mut dsmem = DynamicSharedMemory::new(64); // 16 f32s
413
414        {
415            let slice: &mut [f32] = dsmem.as_typed_slice_mut();
416            assert_eq!(slice.len(), 16);
417            slice[0] = 3.14;
418            slice[15] = 2.71;
419        }
420
421        let slice: &[f32] = dsmem.as_typed_slice();
422        assert!((slice[0] - 3.14).abs() < 1e-6);
423        assert!((slice[15] - 2.71).abs() < 1e-6);
424    }
425
426    #[test]
427    #[should_panic(expected = "size must be > 0")]
428    fn test_dynamic_shared_memory_zero_size() {
429        let _ = DynamicSharedMemory::new(0);
430    }
431
432    #[test]
433    fn test_bank_address_mapping() {
434        // Address 0 -> bank 0
435        assert_eq!(BankConflictDetector::address_to_bank(0), 0);
436        // Address 4 -> bank 1
437        assert_eq!(BankConflictDetector::address_to_bank(4), 1);
438        // Address 128 -> bank 0 (128 / 4 = 32 % 32 = 0)
439        assert_eq!(BankConflictDetector::address_to_bank(128), 0);
440        // Address 132 -> bank 1
441        assert_eq!(BankConflictDetector::address_to_bank(132), 1);
442    }
443
444    #[test]
445    fn test_no_bank_conflicts() {
446        let detector = BankConflictDetector::new();
447        detector.begin_cycle();
448
449        // Each access goes to a different bank: addresses 0, 4, 8, 12, ...
450        for i in 0..32 {
451            detector.record_access(i * 4);
452        }
453
454        assert_eq!(detector.total_accesses(), 32);
455        assert_eq!(detector.conflict_count(), 0);
456    }
457
458    #[test]
459    fn test_bank_conflicts_detected() {
460        let detector = BankConflictDetector::new();
461        detector.begin_cycle();
462
463        // Two accesses to the same bank (bank 0): address 0 and address 128
464        detector.record_access(0);
465        detector.record_access(128);
466
467        assert_eq!(detector.total_accesses(), 2);
468        assert_eq!(detector.conflict_count(), 1);
469    }
470
471    #[test]
472    fn test_bank_conflict_rate() {
473        let detector = BankConflictDetector::new();
474        detector.begin_cycle();
475
476        // 4 accesses, 2 conflicts (same bank hit 3 times -> 2 conflicts)
477        detector.record_access(0);   // bank 0, first
478        detector.record_access(128); // bank 0, conflict
479        detector.record_access(256); // bank 0, conflict
480        detector.record_access(4);   // bank 1, first
481
482        assert_eq!(detector.total_accesses(), 4);
483        assert_eq!(detector.conflict_count(), 2);
484        assert!((detector.conflict_rate() - 0.5).abs() < 1e-6);
485    }
486
487    #[test]
488    fn test_bank_conflict_reset() {
489        let detector = BankConflictDetector::new();
490        detector.begin_cycle();
491        detector.record_access(0);
492        detector.record_access(128);
493
494        detector.reset();
495        assert_eq!(detector.total_accesses(), 0);
496        assert_eq!(detector.conflict_count(), 0);
497    }
498
499    #[test]
500    fn test_bank_conflict_summary() {
501        let detector = BankConflictDetector::new();
502        let summary = detector.summary();
503        assert!(summary.contains("Bank conflicts"));
504    }
505}