use std::sync::atomic::Ordering;
use shared_memory::{Shmem, ShmemConf, ShmemError};
use crate::config::SlotBusConfig;
use crate::error::SlotBusError;
use crate::types::*;
pub struct ShmRegion {
_shmem: Shmem,
ptr: *mut u8,
len: usize,
name: String,
num_slots: usize,
heap_offset: usize,
heap_size: usize,
}
unsafe impl Send for ShmRegion {}
unsafe impl Sync for ShmRegion {}
impl ShmRegion {
pub fn create(name: &str, size: usize) -> Result<Self, SlotBusError> {
let shmem = ShmemConf::new()
.os_id(name)
.size(size)
.create()
.map_err(|e| SlotBusError::SharedMemory(format!("create '{name}': {e}")))?;
let ptr = shmem.as_ptr();
let len = shmem.len();
Ok(Self {
_shmem: shmem,
ptr,
len,
name: name.to_string(),
num_slots: 0,
heap_offset: 0,
heap_size: 0,
})
}
pub fn open(name: &str) -> Result<Self, SlotBusError> {
let shmem = ShmemConf::new()
.os_id(name)
.open()
.map_err(|e| SlotBusError::SharedMemory(format!("open '{name}': {e}")))?;
let ptr = shmem.as_ptr();
let len = shmem.len();
Ok(Self {
_shmem: shmem,
ptr,
len,
name: name.to_string(),
num_slots: 0,
heap_offset: 0,
heap_size: 0,
})
}
pub fn create_or_open(name: &str, size: usize) -> Result<Self, SlotBusError> {
match ShmemConf::new().os_id(name).size(size).create() {
Ok(shmem) => {
let ptr = shmem.as_ptr();
let len = shmem.len();
Ok(Self {
_shmem: shmem,
ptr,
len,
name: name.to_string(),
num_slots: 0,
heap_offset: 0,
heap_size: 0,
})
}
Err(ShmemError::MappingIdExists) => Self::open(name),
Err(e) => Err(SlotBusError::SharedMemory(format!(
"create_or_open '{name}': {e}"
))),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_ptr(&self) -> *mut u8 {
self.ptr
}
pub fn num_slots(&self) -> usize {
self.num_slots
}
pub unsafe fn header(&self) -> &ShmHeader {
&*(self.ptr as *const ShmHeader)
}
pub unsafe fn slot(&self, index: usize) -> &SlotMeta {
debug_assert!(index < self.num_slots);
let offset = SHM_HEADER_SIZE + index * SLOT_META_SIZE;
&*(self.ptr.add(offset) as *const SlotMeta)
}
fn heap_ptr(&self) -> *mut u8 {
unsafe { self.ptr.add(self.heap_offset) }
}
pub unsafe fn heap_read(&self, offset: u32, len: usize) -> &[u8] {
let p = self.heap_ptr().add(offset as usize);
std::slice::from_raw_parts(p, len)
}
pub unsafe fn heap_write(&self, offset: u32, data: &[u8]) {
let p = self.heap_ptr().add(offset as usize);
std::ptr::copy_nonoverlapping(data.as_ptr(), p, data.len());
}
pub fn alloc_heap(&self, size: usize) -> Option<u32> {
let aligned = (size + 7) & !7; let header = unsafe { self.header() };
let heap_size = self.heap_size;
loop {
let head = header.alloc_head.load(Ordering::Acquire);
let new_head = head as usize + aligned;
if new_head > heap_size {
return None; }
if header
.alloc_head
.compare_exchange(head, new_head as u32, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Some(head);
}
}
}
pub fn reset_heap(&self) {
let header = unsafe { self.header() };
header.alloc_head.store(0, Ordering::Release);
}
pub fn has_inflight_slots(&self) -> bool {
for i in 0..self.num_slots {
let slot = unsafe { self.slot(i) };
if slot.status.load(Ordering::Acquire) != SLOT_FREE {
return true;
}
}
false
}
pub fn try_reset_heap(&self) {
if !self.has_inflight_slots() {
self.reset_heap();
}
}
pub fn init_control(&mut self, config: &SlotBusConfig) {
let (heap_offset, heap_size) = compute_layout(config.num_slots, config.region_size);
self.num_slots = config.num_slots;
self.heap_offset = heap_offset;
self.heap_size = heap_size;
unsafe {
std::ptr::write_bytes(self.ptr, 0, self.len);
}
unsafe {
let h = self.ptr as *mut u32;
h.write(SHM_MAGIC);
h.add(1).write(SHM_VERSION);
h.add(2).write(config.num_slots as u32);
h.add(3).write(heap_offset as u32);
h.add(4).write(heap_size as u32);
}
let header = unsafe { self.header() };
header.alloc_head.store(0, Ordering::Release);
}
pub fn validate_control(&mut self) -> Result<(), SlotBusError> {
unsafe {
let h = self.ptr as *const u32;
let magic = h.read();
let version = h.add(1).read();
let num_slots = h.add(2).read() as usize;
let heap_offset = h.add(3).read() as usize;
let heap_size = h.add(4).read() as usize;
if magic != SHM_MAGIC {
return Err(SlotBusError::InvalidRegion(format!(
"bad magic: expected 0x{SHM_MAGIC:08X}, got 0x{magic:08X}"
)));
}
if version != SHM_VERSION {
return Err(SlotBusError::InvalidRegion(format!(
"bad version: expected {SHM_VERSION}, got {version}"
)));
}
self.num_slots = num_slots;
self.heap_offset = heap_offset;
self.heap_size = heap_size;
}
Ok(())
}
pub fn create_overflow(name: &str, data: &[u8]) -> Result<Self, SlotBusError> {
let size = (data.len() + 4095) & !4095;
let size = size.max(4096);
let region = Self::create_or_open(name, size)?;
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), region.ptr, data.len());
}
Ok(region)
}
pub fn read_overflow(name: &str, len: usize) -> Result<Vec<u8>, SlotBusError> {
let region = Self::open(name)?;
if len > region.len {
return Err(SlotBusError::SharedMemory(format!(
"overflow region '{name}' too small: need {len}, have {}",
region.len
)));
}
let data = unsafe { std::slice::from_raw_parts(region.ptr, len) };
Ok(data.to_vec())
}
}
impl std::fmt::Debug for ShmRegion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShmRegion")
.field("name", &self.name)
.field("len", &self.len)
.field("num_slots", &self.num_slots)
.finish()
}
}
pub fn claim_free_slot(region: &ShmRegion) -> Option<usize> {
for i in 0..region.num_slots() {
let slot = unsafe { region.slot(i) };
if slot
.status
.compare_exchange(SLOT_FREE, SLOT_WRITING, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return Some(i);
}
}
None
}
#[deprecated(note = "Use claim_free_slot() which atomically reserves the slot")]
pub fn find_free_slot(region: &ShmRegion) -> Option<usize> {
for i in 0..region.num_slots() {
let slot = unsafe { region.slot(i) };
if slot.status.load(Ordering::Acquire) == SLOT_FREE {
return Some(i);
}
}
None
}
pub fn write_request(
region: &ShmRegion,
slot_index: usize,
req_id: &str,
method: u8,
meta_bytes: &[u8],
body: &[u8],
config: &SlotBusConfig,
) -> Result<Option<ShmRegion>, SlotBusError> {
match write_request_inner(region, slot_index, req_id, method, meta_bytes, body, config) {
Ok(overflow) => Ok(overflow),
Err(e) => {
let slot = unsafe { region.slot(slot_index) };
slot.status.store(SLOT_FREE, Ordering::Release);
Err(e)
}
}
}
fn write_request_inner(
region: &ShmRegion,
slot_index: usize,
req_id: &str,
method: u8,
meta_bytes: &[u8],
body: &[u8],
config: &SlotBusConfig,
) -> Result<Option<ShmRegion>, SlotBusError> {
let slot = unsafe { region.slot(slot_index) };
let id_bytes = req_id.as_bytes();
let id_len = id_bytes.len().min(36);
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
let req_id_ptr = slot_ptr.add(4);
std::ptr::write_bytes(req_id_ptr, 0, 36);
std::ptr::copy_nonoverlapping(id_bytes.as_ptr(), req_id_ptr, id_len);
slot_ptr.add(40).write(method);
}
let meta_offset = region
.alloc_heap(meta_bytes.len())
.ok_or_else(|| SlotBusError::SharedMemory("heap full for request meta".into()))?;
unsafe {
region.heap_write(meta_offset, meta_bytes);
}
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(44) as *mut u32).write(meta_offset);
(slot_ptr.add(48) as *mut u16).write(meta_bytes.len() as u16);
}
let mut overflow_region = None;
if body.is_empty() {
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(52) as *mut u32).write(0);
(slot_ptr.add(56) as *mut u32).write(0);
slot_ptr.add(60).write(0);
}
} else if let Some(body_offset) = region.alloc_heap(body.len()) {
unsafe {
region.heap_write(body_offset, body);
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(52) as *mut u32).write(body_offset);
(slot_ptr.add(56) as *mut u32).write(body.len() as u32);
slot_ptr.add(60).write(0);
}
} else {
let name = config.request_overflow_name(slot_index);
let ovf = ShmRegion::create_overflow(&name, body)?;
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(52) as *mut u32).write(0);
(slot_ptr.add(56) as *mut u32).write(body.len() as u32);
slot_ptr.add(60).write(1);
}
overflow_region = Some(ovf);
}
slot.status.store(SLOT_READY, Ordering::Release);
Ok(overflow_region)
}
pub fn write_response(
region: &ShmRegion,
slot_index: usize,
status: u16,
meta_bytes: &[u8],
body: &[u8],
config: &SlotBusConfig,
) -> Result<Option<ShmRegion>, SlotBusError> {
let slot = unsafe { region.slot(slot_index) };
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(64) as *mut u16).write(status);
}
let meta_offset = region
.alloc_heap(meta_bytes.len())
.ok_or_else(|| SlotBusError::SharedMemory("heap full for response meta".into()))?;
unsafe {
region.heap_write(meta_offset, meta_bytes);
}
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(68) as *mut u32).write(meta_offset);
(slot_ptr.add(72) as *mut u16).write(meta_bytes.len() as u16);
}
let mut overflow_region = None;
if body.is_empty() {
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(76) as *mut u32).write(0);
(slot_ptr.add(80) as *mut u32).write(0);
slot_ptr.add(84).write(0);
}
} else if let Some(body_offset) = region.alloc_heap(body.len()) {
unsafe {
region.heap_write(body_offset, body);
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(76) as *mut u32).write(body_offset);
(slot_ptr.add(80) as *mut u32).write(body.len() as u32);
slot_ptr.add(84).write(0);
}
} else {
let name = config.response_overflow_name(slot_index);
let ovf = ShmRegion::create_overflow(&name, body)?;
unsafe {
let slot_ptr = region
.as_ptr()
.add(SHM_HEADER_SIZE + slot_index * SLOT_META_SIZE);
(slot_ptr.add(76) as *mut u32).write(0);
(slot_ptr.add(80) as *mut u32).write(body.len() as u32);
slot_ptr.add(84).write(1);
}
overflow_region = Some(ovf);
}
slot.status.store(SLOT_DONE, Ordering::Release);
Ok(overflow_region)
}
pub fn read_request(
region: &ShmRegion,
slot_index: usize,
config: &SlotBusConfig,
) -> Result<(String, u8, RequestMeta, Vec<u8>), SlotBusError> {
let slot = unsafe { region.slot(slot_index) };
let req_id = {
let raw = &slot.req_id;
let end = raw.iter().position(|&b| b == 0).unwrap_or(36);
String::from_utf8_lossy(&raw[..end]).to_string()
};
let method = slot.method;
let meta: RequestMeta = unsafe {
let meta_bytes = region.heap_read(slot.meta_offset, slot.meta_len as usize);
postcard::from_bytes(meta_bytes)?
};
let body_len = slot.body_len as usize;
let body = if body_len == 0 {
Vec::new()
} else if slot.body_overflow == 0 {
let bytes = unsafe { region.heap_read(slot.body_offset, body_len) };
bytes.to_vec()
} else {
let name = config.request_overflow_name(slot_index);
ShmRegion::read_overflow(&name, body_len)?
};
Ok((req_id, method, meta, body))
}
pub fn read_response(
region: &ShmRegion,
slot_index: usize,
config: &SlotBusConfig,
) -> Result<(u16, ResponseMeta, Vec<u8>), SlotBusError> {
let slot = unsafe { region.slot(slot_index) };
let status = slot.resp_status;
let meta: ResponseMeta = unsafe {
let meta_bytes = region.heap_read(slot.resp_meta_offset, slot.resp_meta_len as usize);
postcard::from_bytes(meta_bytes)?
};
let body_len = slot.resp_body_len as usize;
let body = if body_len == 0 {
Vec::new()
} else if slot.resp_body_overflow == 0 {
let bytes = unsafe { region.heap_read(slot.resp_body_offset, body_len) };
bytes.to_vec()
} else {
let name = config.response_overflow_name(slot_index);
ShmRegion::read_overflow(&name, body_len)?
};
Ok((status, meta, body))
}