use std::alloc::{self, Layout};
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
pub const NUM_BANKS: usize = 32;
pub const BANK_WIDTH_BYTES: usize = 4;
pub struct SharedMemory<T: Send + Sync> {
ptr: NonNull<T>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Send + Sync> Send for SharedMemory<T> {}
unsafe impl<T: Send + Sync> Sync for SharedMemory<T> {}
impl<T: Send + Sync> SharedMemory<T> {
pub fn new(count: usize) -> Self {
assert!(count > 0, "SharedMemory: count must be > 0");
let layout = Layout::array::<T>(count).expect("SharedMemory: layout overflow");
let ptr = if layout.size() > 0 {
let raw = unsafe { alloc::alloc_zeroed(layout) };
NonNull::new(raw as *mut T).expect("SharedMemory: allocation failed")
} else {
NonNull::dangling()
};
Self {
ptr,
len: count,
_marker: PhantomData,
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn get(&self, index: usize) -> &T {
assert!(index < self.len, "SharedMemory: index {index} out of bounds (len={})", self.len);
unsafe { &*self.ptr.as_ptr().add(index) }
}
pub fn get_mut(&mut self, index: usize) -> &mut T {
assert!(index < self.len, "SharedMemory: index {index} out of bounds (len={})", self.len);
unsafe { &mut *self.ptr.as_ptr().add(index) }
}
pub fn as_ptr(&self) -> *const T {
self.ptr.as_ptr() as *const T
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr.as_ptr()
}
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const T, self.len) }
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}
}
impl<T: Send + Sync> Drop for SharedMemory<T> {
fn drop(&mut self) {
if self.len > 0 {
let layout = Layout::array::<T>(self.len)
.expect("SharedMemory::drop: layout overflow");
if layout.size() > 0 {
unsafe {
alloc::dealloc(self.ptr.as_ptr() as *mut u8, layout);
}
}
}
}
}
pub struct DynamicSharedMemory {
ptr: NonNull<u8>,
size_bytes: usize,
}
unsafe impl Send for DynamicSharedMemory {}
unsafe impl Sync for DynamicSharedMemory {}
impl DynamicSharedMemory {
pub fn new(size_bytes: usize) -> Self {
assert!(size_bytes > 0, "DynamicSharedMemory: size must be > 0");
let layout = Layout::from_size_align(size_bytes, 16)
.expect("DynamicSharedMemory: invalid layout");
let ptr = unsafe { alloc::alloc_zeroed(layout) };
let ptr = NonNull::new(ptr).expect("DynamicSharedMemory: allocation failed");
Self { ptr, size_bytes }
}
pub fn size_bytes(&self) -> usize {
self.size_bytes
}
pub fn as_typed_slice<T>(&self) -> &[T] {
let elem_size = std::mem::size_of::<T>();
assert!(elem_size > 0, "DynamicSharedMemory: zero-sized type");
assert!(
self.size_bytes % elem_size == 0,
"DynamicSharedMemory: size {} not a multiple of element size {}",
self.size_bytes,
elem_size
);
assert!(
self.ptr.as_ptr() as usize % std::mem::align_of::<T>() == 0,
"DynamicSharedMemory: alignment mismatch for type"
);
let count = self.size_bytes / elem_size;
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const T, count) }
}
pub fn as_typed_slice_mut<T>(&mut self) -> &mut [T] {
let elem_size = std::mem::size_of::<T>();
assert!(elem_size > 0, "DynamicSharedMemory: zero-sized type");
assert!(
self.size_bytes % elem_size == 0,
"DynamicSharedMemory: size {} not a multiple of element size {}",
self.size_bytes,
elem_size
);
assert!(
self.ptr.as_ptr() as usize % std::mem::align_of::<T>() == 0,
"DynamicSharedMemory: alignment mismatch for type"
);
let count = self.size_bytes / elem_size;
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut T, count) }
}
pub fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr() as *const u8
}
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl Drop for DynamicSharedMemory {
fn drop(&mut self) {
let layout = Layout::from_size_align(self.size_bytes, 16)
.expect("DynamicSharedMemory::drop: invalid layout");
unsafe {
alloc::dealloc(self.ptr.as_ptr(), layout);
}
}
}
pub struct BankConflictDetector {
total_accesses: AtomicUsize,
conflict_count: AtomicUsize,
bank_accesses: [AtomicUsize; NUM_BANKS],
}
impl BankConflictDetector {
pub fn new() -> Self {
const INIT: AtomicUsize = AtomicUsize::new(0);
Self {
total_accesses: AtomicUsize::new(0),
conflict_count: AtomicUsize::new(0),
bank_accesses: [INIT; NUM_BANKS],
}
}
pub fn record_access(&self, byte_address: usize) {
let bank = Self::address_to_bank(byte_address);
let prev = self.bank_accesses[bank].fetch_add(1, Ordering::Relaxed);
self.total_accesses.fetch_add(1, Ordering::Relaxed);
if prev > 0 {
self.conflict_count.fetch_add(1, Ordering::Relaxed);
}
}
pub fn begin_cycle(&self) {
for bank in &self.bank_accesses {
bank.store(0, Ordering::Relaxed);
}
}
pub fn address_to_bank(byte_address: usize) -> usize {
(byte_address / BANK_WIDTH_BYTES) % NUM_BANKS
}
pub fn total_accesses(&self) -> usize {
self.total_accesses.load(Ordering::Relaxed)
}
pub fn conflict_count(&self) -> usize {
self.conflict_count.load(Ordering::Relaxed)
}
pub fn conflict_rate(&self) -> f64 {
let total = self.total_accesses() as f64;
if total == 0.0 {
0.0
} else {
self.conflict_count() as f64 / total
}
}
pub fn reset(&self) {
self.total_accesses.store(0, Ordering::Relaxed);
self.conflict_count.store(0, Ordering::Relaxed);
for bank in &self.bank_accesses {
bank.store(0, Ordering::Relaxed);
}
}
pub fn summary(&self) -> String {
format!(
"Bank conflicts: {} / {} accesses ({:.1}% conflict rate)",
self.conflict_count(),
self.total_accesses(),
self.conflict_rate() * 100.0,
)
}
}
impl Default for BankConflictDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_static_shared_memory_new() {
let smem: SharedMemory<f32> = SharedMemory::new(256);
assert_eq!(smem.len(), 256);
assert!(!smem.is_empty());
}
#[test]
fn test_static_shared_memory_read_write() {
let mut smem: SharedMemory<i32> = SharedMemory::new(16);
*smem.get_mut(0) = 42;
*smem.get_mut(15) = 99;
assert_eq!(*smem.get(0), 42);
assert_eq!(*smem.get(15), 99);
assert_eq!(*smem.get(1), 0);
}
#[test]
fn test_static_shared_memory_slice() {
let mut smem: SharedMemory<f32> = SharedMemory::new(8);
{
let slice = smem.as_mut_slice();
for (i, val) in slice.iter_mut().enumerate() {
*val = i as f32 * 2.0;
}
}
let slice = smem.as_slice();
assert!((slice[3] - 6.0).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "index 16 out of bounds")]
fn test_static_shared_memory_out_of_bounds() {
let smem: SharedMemory<u32> = SharedMemory::new(16);
let _ = smem.get(16);
}
#[test]
fn test_dynamic_shared_memory_new() {
let dsmem = DynamicSharedMemory::new(1024);
assert_eq!(dsmem.size_bytes(), 1024);
}
#[test]
fn test_dynamic_shared_memory_typed_access() {
let mut dsmem = DynamicSharedMemory::new(64);
{
let slice: &mut [f32] = dsmem.as_typed_slice_mut();
assert_eq!(slice.len(), 16);
slice[0] = 3.14;
slice[15] = 2.71;
}
let slice: &[f32] = dsmem.as_typed_slice();
assert!((slice[0] - 3.14).abs() < 1e-6);
assert!((slice[15] - 2.71).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "size must be > 0")]
fn test_dynamic_shared_memory_zero_size() {
let _ = DynamicSharedMemory::new(0);
}
#[test]
fn test_bank_address_mapping() {
assert_eq!(BankConflictDetector::address_to_bank(0), 0);
assert_eq!(BankConflictDetector::address_to_bank(4), 1);
assert_eq!(BankConflictDetector::address_to_bank(128), 0);
assert_eq!(BankConflictDetector::address_to_bank(132), 1);
}
#[test]
fn test_no_bank_conflicts() {
let detector = BankConflictDetector::new();
detector.begin_cycle();
for i in 0..32 {
detector.record_access(i * 4);
}
assert_eq!(detector.total_accesses(), 32);
assert_eq!(detector.conflict_count(), 0);
}
#[test]
fn test_bank_conflicts_detected() {
let detector = BankConflictDetector::new();
detector.begin_cycle();
detector.record_access(0);
detector.record_access(128);
assert_eq!(detector.total_accesses(), 2);
assert_eq!(detector.conflict_count(), 1);
}
#[test]
fn test_bank_conflict_rate() {
let detector = BankConflictDetector::new();
detector.begin_cycle();
detector.record_access(0); detector.record_access(128); detector.record_access(256); detector.record_access(4);
assert_eq!(detector.total_accesses(), 4);
assert_eq!(detector.conflict_count(), 2);
assert!((detector.conflict_rate() - 0.5).abs() < 1e-6);
}
#[test]
fn test_bank_conflict_reset() {
let detector = BankConflictDetector::new();
detector.begin_cycle();
detector.record_access(0);
detector.record_access(128);
detector.reset();
assert_eq!(detector.total_accesses(), 0);
assert_eq!(detector.conflict_count(), 0);
}
#[test]
fn test_bank_conflict_summary() {
let detector = BankConflictDetector::new();
let summary = detector.summary();
assert!(summary.contains("Bank conflicts"));
}
}