cuda_rust_wasm/kernel/
shared_memory.rs1use std::alloc::{self, Layout};
13use std::marker::PhantomData;
14use std::ptr::NonNull;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17pub const NUM_BANKS: usize = 32;
19
20pub const BANK_WIDTH_BYTES: usize = 4;
22
23pub struct SharedMemory<T: Send + Sync> {
31 ptr: NonNull<T>,
33 len: usize,
35 _marker: PhantomData<T>,
37}
38
39unsafe impl<T: Send + Sync> Send for SharedMemory<T> {}
41unsafe impl<T: Send + Sync> Sync for SharedMemory<T> {}
42
43impl<T: Send + Sync> SharedMemory<T> {
44 pub fn new(count: usize) -> Self {
49 assert!(count > 0, "SharedMemory: count must be > 0");
50 let layout = Layout::array::<T>(count).expect("SharedMemory: layout overflow");
51
52 let ptr = if layout.size() > 0 {
54 let raw = unsafe { alloc::alloc_zeroed(layout) };
55 NonNull::new(raw as *mut T).expect("SharedMemory: allocation failed")
56 } else {
57 NonNull::dangling()
58 };
59
60 Self {
61 ptr,
62 len: count,
63 _marker: PhantomData,
64 }
65 }
66
67 pub fn len(&self) -> usize {
69 self.len
70 }
71
72 pub fn is_empty(&self) -> bool {
74 self.len == 0
75 }
76
77 pub fn get(&self, index: usize) -> &T {
82 assert!(index < self.len, "SharedMemory: index {index} out of bounds (len={})", self.len);
83 unsafe { &*self.ptr.as_ptr().add(index) }
84 }
85
86 pub fn get_mut(&mut self, index: usize) -> &mut T {
95 assert!(index < self.len, "SharedMemory: index {index} out of bounds (len={})", self.len);
96 unsafe { &mut *self.ptr.as_ptr().add(index) }
97 }
98
99 pub fn as_ptr(&self) -> *const T {
101 self.ptr.as_ptr() as *const T
102 }
103
104 pub fn as_mut_ptr(&mut self) -> *mut T {
106 self.ptr.as_ptr()
107 }
108
109 pub fn as_slice(&self) -> &[T] {
111 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const T, self.len) }
112 }
113
114 pub fn as_mut_slice(&mut self) -> &mut [T] {
119 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
120 }
121}
122
123impl<T: Send + Sync> Drop for SharedMemory<T> {
124 fn drop(&mut self) {
125 if self.len > 0 {
126 let layout = Layout::array::<T>(self.len)
127 .expect("SharedMemory::drop: layout overflow");
128 if layout.size() > 0 {
129 unsafe {
130 alloc::dealloc(self.ptr.as_ptr() as *mut u8, layout);
131 }
132 }
133 }
134 }
135}
136
137pub struct DynamicSharedMemory {
149 ptr: NonNull<u8>,
151 size_bytes: usize,
153}
154
155unsafe impl Send for DynamicSharedMemory {}
157unsafe impl Sync for DynamicSharedMemory {}
158
159impl DynamicSharedMemory {
160 pub fn new(size_bytes: usize) -> Self {
165 assert!(size_bytes > 0, "DynamicSharedMemory: size must be > 0");
166
167 let layout = Layout::from_size_align(size_bytes, 16)
169 .expect("DynamicSharedMemory: invalid layout");
170
171 let ptr = unsafe { alloc::alloc_zeroed(layout) };
172 let ptr = NonNull::new(ptr).expect("DynamicSharedMemory: allocation failed");
173
174 Self { ptr, size_bytes }
175 }
176
177 pub fn size_bytes(&self) -> usize {
179 self.size_bytes
180 }
181
182 pub fn as_typed_slice<T>(&self) -> &[T] {
188 let elem_size = std::mem::size_of::<T>();
189 assert!(elem_size > 0, "DynamicSharedMemory: zero-sized type");
190 assert!(
191 self.size_bytes % elem_size == 0,
192 "DynamicSharedMemory: size {} not a multiple of element size {}",
193 self.size_bytes,
194 elem_size
195 );
196 assert!(
197 self.ptr.as_ptr() as usize % std::mem::align_of::<T>() == 0,
198 "DynamicSharedMemory: alignment mismatch for type"
199 );
200
201 let count = self.size_bytes / elem_size;
202 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const T, count) }
203 }
204
205 pub fn as_typed_slice_mut<T>(&mut self) -> &mut [T] {
210 let elem_size = std::mem::size_of::<T>();
211 assert!(elem_size > 0, "DynamicSharedMemory: zero-sized type");
212 assert!(
213 self.size_bytes % elem_size == 0,
214 "DynamicSharedMemory: size {} not a multiple of element size {}",
215 self.size_bytes,
216 elem_size
217 );
218 assert!(
219 self.ptr.as_ptr() as usize % std::mem::align_of::<T>() == 0,
220 "DynamicSharedMemory: alignment mismatch for type"
221 );
222
223 let count = self.size_bytes / elem_size;
224 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut T, count) }
225 }
226
227 pub fn as_ptr(&self) -> *const u8 {
229 self.ptr.as_ptr() as *const u8
230 }
231
232 pub fn as_mut_ptr(&mut self) -> *mut u8 {
234 self.ptr.as_ptr()
235 }
236}
237
238impl Drop for DynamicSharedMemory {
239 fn drop(&mut self) {
240 let layout = Layout::from_size_align(self.size_bytes, 16)
241 .expect("DynamicSharedMemory::drop: invalid layout");
242 unsafe {
243 alloc::dealloc(self.ptr.as_ptr(), layout);
244 }
245 }
246}
247
248pub struct BankConflictDetector {
258 total_accesses: AtomicUsize,
260 conflict_count: AtomicUsize,
262 bank_accesses: [AtomicUsize; NUM_BANKS],
264}
265
266impl BankConflictDetector {
267 pub fn new() -> Self {
269 const INIT: AtomicUsize = AtomicUsize::new(0);
270 Self {
271 total_accesses: AtomicUsize::new(0),
272 conflict_count: AtomicUsize::new(0),
273 bank_accesses: [INIT; NUM_BANKS],
274 }
275 }
276
277 pub fn record_access(&self, byte_address: usize) {
287 let bank = Self::address_to_bank(byte_address);
288 let prev = self.bank_accesses[bank].fetch_add(1, Ordering::Relaxed);
289 self.total_accesses.fetch_add(1, Ordering::Relaxed);
290
291 if prev > 0 {
293 self.conflict_count.fetch_add(1, Ordering::Relaxed);
294 }
295 }
296
297 pub fn begin_cycle(&self) {
300 for bank in &self.bank_accesses {
301 bank.store(0, Ordering::Relaxed);
302 }
303 }
304
305 pub fn address_to_bank(byte_address: usize) -> usize {
309 (byte_address / BANK_WIDTH_BYTES) % NUM_BANKS
310 }
311
312 pub fn total_accesses(&self) -> usize {
314 self.total_accesses.load(Ordering::Relaxed)
315 }
316
317 pub fn conflict_count(&self) -> usize {
319 self.conflict_count.load(Ordering::Relaxed)
320 }
321
322 pub fn conflict_rate(&self) -> f64 {
325 let total = self.total_accesses() as f64;
326 if total == 0.0 {
327 0.0
328 } else {
329 self.conflict_count() as f64 / total
330 }
331 }
332
333 pub fn reset(&self) {
335 self.total_accesses.store(0, Ordering::Relaxed);
336 self.conflict_count.store(0, Ordering::Relaxed);
337 for bank in &self.bank_accesses {
338 bank.store(0, Ordering::Relaxed);
339 }
340 }
341
342 pub fn summary(&self) -> String {
344 format!(
345 "Bank conflicts: {} / {} accesses ({:.1}% conflict rate)",
346 self.conflict_count(),
347 self.total_accesses(),
348 self.conflict_rate() * 100.0,
349 )
350 }
351}
352
353impl Default for BankConflictDetector {
354 fn default() -> Self {
355 Self::new()
356 }
357}
358
359#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_static_shared_memory_new() {
368 let smem: SharedMemory<f32> = SharedMemory::new(256);
369 assert_eq!(smem.len(), 256);
370 assert!(!smem.is_empty());
371 }
372
373 #[test]
374 fn test_static_shared_memory_read_write() {
375 let mut smem: SharedMemory<i32> = SharedMemory::new(16);
376 *smem.get_mut(0) = 42;
377 *smem.get_mut(15) = 99;
378 assert_eq!(*smem.get(0), 42);
379 assert_eq!(*smem.get(15), 99);
380 assert_eq!(*smem.get(1), 0);
382 }
383
384 #[test]
385 fn test_static_shared_memory_slice() {
386 let mut smem: SharedMemory<f32> = SharedMemory::new(8);
387 {
388 let slice = smem.as_mut_slice();
389 for (i, val) in slice.iter_mut().enumerate() {
390 *val = i as f32 * 2.0;
391 }
392 }
393 let slice = smem.as_slice();
394 assert!((slice[3] - 6.0).abs() < 1e-6);
395 }
396
397 #[test]
398 #[should_panic(expected = "index 16 out of bounds")]
399 fn test_static_shared_memory_out_of_bounds() {
400 let smem: SharedMemory<u32> = SharedMemory::new(16);
401 let _ = smem.get(16);
402 }
403
404 #[test]
405 fn test_dynamic_shared_memory_new() {
406 let dsmem = DynamicSharedMemory::new(1024);
407 assert_eq!(dsmem.size_bytes(), 1024);
408 }
409
410 #[test]
411 fn test_dynamic_shared_memory_typed_access() {
412 let mut dsmem = DynamicSharedMemory::new(64); {
415 let slice: &mut [f32] = dsmem.as_typed_slice_mut();
416 assert_eq!(slice.len(), 16);
417 slice[0] = 3.14;
418 slice[15] = 2.71;
419 }
420
421 let slice: &[f32] = dsmem.as_typed_slice();
422 assert!((slice[0] - 3.14).abs() < 1e-6);
423 assert!((slice[15] - 2.71).abs() < 1e-6);
424 }
425
426 #[test]
427 #[should_panic(expected = "size must be > 0")]
428 fn test_dynamic_shared_memory_zero_size() {
429 let _ = DynamicSharedMemory::new(0);
430 }
431
432 #[test]
433 fn test_bank_address_mapping() {
434 assert_eq!(BankConflictDetector::address_to_bank(0), 0);
436 assert_eq!(BankConflictDetector::address_to_bank(4), 1);
438 assert_eq!(BankConflictDetector::address_to_bank(128), 0);
440 assert_eq!(BankConflictDetector::address_to_bank(132), 1);
442 }
443
444 #[test]
445 fn test_no_bank_conflicts() {
446 let detector = BankConflictDetector::new();
447 detector.begin_cycle();
448
449 for i in 0..32 {
451 detector.record_access(i * 4);
452 }
453
454 assert_eq!(detector.total_accesses(), 32);
455 assert_eq!(detector.conflict_count(), 0);
456 }
457
458 #[test]
459 fn test_bank_conflicts_detected() {
460 let detector = BankConflictDetector::new();
461 detector.begin_cycle();
462
463 detector.record_access(0);
465 detector.record_access(128);
466
467 assert_eq!(detector.total_accesses(), 2);
468 assert_eq!(detector.conflict_count(), 1);
469 }
470
471 #[test]
472 fn test_bank_conflict_rate() {
473 let detector = BankConflictDetector::new();
474 detector.begin_cycle();
475
476 detector.record_access(0); detector.record_access(128); detector.record_access(256); detector.record_access(4); assert_eq!(detector.total_accesses(), 4);
483 assert_eq!(detector.conflict_count(), 2);
484 assert!((detector.conflict_rate() - 0.5).abs() < 1e-6);
485 }
486
487 #[test]
488 fn test_bank_conflict_reset() {
489 let detector = BankConflictDetector::new();
490 detector.begin_cycle();
491 detector.record_access(0);
492 detector.record_access(128);
493
494 detector.reset();
495 assert_eq!(detector.total_accesses(), 0);
496 assert_eq!(detector.conflict_count(), 0);
497 }
498
499 #[test]
500 fn test_bank_conflict_summary() {
501 let detector = BankConflictDetector::new();
502 let summary = detector.summary();
503 assert!(summary.contains("Bank conflicts"));
504 }
505}