#![cfg(unix)]
use std::ptr;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use crate::NodeId;
use crate::id::NetId64;
use crate::ring::Frame;
use crate::shm::{self, ShmRegion};
const MAGIC: u32 = 0x4F524254; const VERSION: u32 = 1;
pub const PAYLOAD_MAX: usize = 256;
#[repr(C, align(64))]
struct ShmRingHeader {
magic: u32,
version: u32,
kind: u8,
_pad0: [u8; 3],
capacity: u64,
write_pos: AtomicU64,
_reserved: [u8; 64 - 4 - 4 - 1 - 3 - 8 - 8],
}
#[repr(C, align(64))]
struct ShmSlot {
seq: AtomicU64,
id: u64,
kind: u8,
_pad0: [u8; 7],
ver: u64,
payload_len: u32,
_pad1: [u8; 4],
payload: [u8; PAYLOAD_MAX],
}
const SLOT_SIZE: usize = std::mem::size_of::<ShmSlot>();
const HEADER_SIZE: usize = std::mem::size_of::<ShmRingHeader>();
pub fn segment_size_for_capacity(capacity: usize) -> usize {
HEADER_SIZE + capacity * SLOT_SIZE
}
pub struct ShmRing {
region: ShmRegion,
kind: u8,
capacity: usize,
}
impl ShmRing {
pub fn open_or_create(fleet_name: &str, kind: u8, capacity: usize) -> std::io::Result<Self> {
assert!(capacity > 0, "ShmRing capacity must be > 0");
assert!(
capacity.is_power_of_two(),
"ShmRing capacity should be power of two for cheap modulo"
);
let name = shm::ring_segment_name(fleet_name, kind);
let size = segment_size_for_capacity(capacity);
let region = ShmRegion::open_or_create(&name, size)?;
if region.created() {
unsafe {
let header_ptr = region.as_ptr() as *mut ShmRingHeader;
ptr::write(
header_ptr,
ShmRingHeader {
magic: MAGIC,
version: VERSION,
kind,
_pad0: [0; 3],
capacity: capacity as u64,
write_pos: AtomicU64::new(0),
_reserved: [0; 64 - 4 - 4 - 1 - 3 - 8 - 8],
},
);
let slots_ptr = region.as_ptr().add(HEADER_SIZE);
ptr::write_bytes(slots_ptr, 0, capacity * SLOT_SIZE);
}
} else {
let header = unsafe { &*(region.as_ptr() as *const ShmRingHeader) };
if header.magic != MAGIC {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"SHM segment {} has wrong magic 0x{:08X} (expected 0x{:08X})",
name, header.magic, MAGIC
),
));
}
if header.version != VERSION {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"SHM segment {} version {} != local {}",
name, header.version, VERSION
),
));
}
if header.kind != kind {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"SHM segment {} kind {} != requested {}",
name, header.kind, kind
),
));
}
if header.capacity as usize != capacity {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"SHM segment {} capacity {} != requested {}",
name, header.capacity, capacity
),
));
}
}
Ok(Self {
region,
kind,
capacity,
})
}
pub fn created(&self) -> bool {
self.region.created()
}
pub fn unlink(&self) -> std::io::Result<()> {
self.region.unlink()
}
pub fn kind(&self) -> u8 {
self.kind
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn header(&self) -> &ShmRingHeader {
unsafe { &*(self.region.as_ptr() as *const ShmRingHeader) }
}
fn slot_ptr(&self, idx: usize) -> *mut ShmSlot {
debug_assert!(idx < self.capacity);
unsafe {
let base = self.region.as_ptr().add(HEADER_SIZE);
base.add(idx * SLOT_SIZE) as *mut ShmSlot
}
}
pub fn head(&self) -> u64 {
self.header().write_pos.load(Ordering::Acquire)
}
pub fn write(
&self,
node_id: NodeId,
frame_kind: u8,
ver: u64,
payload: Bytes,
) -> std::io::Result<NetId64> {
if payload.len() > PAYLOAD_MAX {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("payload {} > PAYLOAD_MAX {}", payload.len(), PAYLOAD_MAX),
));
}
let counter = self.header().write_pos.fetch_add(1, Ordering::AcqRel);
let id = NetId64::make(self.kind, node_id.get(), counter);
let slot_idx = (counter as usize) & (self.capacity - 1);
let slot_ptr = self.slot_ptr(slot_idx);
unsafe {
let slot = &*slot_ptr;
let mid_seq = counter
.checked_mul(2)
.and_then(|v| v.checked_add(1))
.expect("seq overflow");
let final_seq = mid_seq.wrapping_add(1);
slot.seq.store(mid_seq, Ordering::Release);
let slot_mut = slot_ptr;
ptr::addr_of_mut!((*slot_mut).id).write(id.raw());
ptr::addr_of_mut!((*slot_mut).kind).write(frame_kind);
ptr::addr_of_mut!((*slot_mut).ver).write(ver);
let len = payload.len();
ptr::addr_of_mut!((*slot_mut).payload_len).write(len as u32);
let payload_ptr = ptr::addr_of_mut!((*slot_mut).payload) as *mut u8;
ptr::copy_nonoverlapping(payload.as_ptr(), payload_ptr, len);
slot.seq.store(final_seq, Ordering::Release);
}
Ok(id)
}
pub fn read(&self, id: NetId64) -> Option<Frame> {
if id.kind() != self.kind {
return None;
}
let counter = id.counter();
let slot_idx = (counter as usize) & (self.capacity - 1);
let slot_ptr = self.slot_ptr(slot_idx);
for _ in 0..3 {
let Some(frame) = (unsafe { read_committed_frame(slot_ptr) }) else {
continue;
};
if frame.id.counter() == counter {
return Some(frame);
} else {
return None;
}
}
None
}
pub fn read_head(&self) -> Option<Frame> {
let head = self.head();
if head == 0 {
return None;
}
let counter = head - 1;
let slot_idx = (counter as usize) & (self.capacity - 1);
let slot_ptr = self.slot_ptr(slot_idx);
for _ in 0..3 {
if let Some(frame) = unsafe { read_committed_frame(slot_ptr) } {
return Some(frame);
}
}
None
}
pub fn read_at(&self, counter: u64) -> Option<Frame> {
let slot_idx = (counter as usize) & (self.capacity - 1);
let slot_ptr = self.slot_ptr(slot_idx);
for _ in 0..3 {
if let Some(frame) = unsafe { read_committed_frame(slot_ptr) } {
return Some(frame);
}
}
None
}
pub fn reset(&self) {
unsafe {
let slots_ptr = self.region.as_ptr().add(HEADER_SIZE);
ptr::write_bytes(slots_ptr, 0, self.capacity * SLOT_SIZE);
}
self.header().write_pos.store(0, Ordering::Release);
}
}
pub struct ShmRingRegistry {
fleet_name: String,
capacity: usize,
rings: dashmap::DashMap<u8, std::sync::Arc<ShmRing>>,
}
impl ShmRingRegistry {
pub fn new(fleet_name: impl Into<String>, capacity: usize) -> Self {
Self {
fleet_name: fleet_name.into(),
capacity,
rings: dashmap::DashMap::new(),
}
}
pub fn get_or_create_for(&self, kind: u8) -> std::io::Result<std::sync::Arc<ShmRing>> {
if let Some(entry) = self.rings.get(&kind) {
return Ok(entry.clone());
}
let ring = std::sync::Arc::new(ShmRing::open_or_create(
&self.fleet_name,
kind,
self.capacity,
)?);
let entry = self.rings.entry(kind).or_insert_with(|| ring.clone());
Ok(entry.clone())
}
pub fn lookup(&self, kind: u8) -> Option<std::sync::Arc<ShmRing>> {
self.rings.get(&kind).map(|e| e.clone())
}
}
unsafe fn read_committed_frame(slot_ptr: *mut ShmSlot) -> Option<Frame> {
let slot = unsafe { &*slot_ptr };
let seq_pre = slot.seq.load(Ordering::Acquire);
if seq_pre == 0 {
return None;
}
if seq_pre & 1 == 1 {
return None;
}
let id = NetId64::from_raw(unsafe { ptr::addr_of!((*slot_ptr).id).read() });
let kind = unsafe { ptr::addr_of!((*slot_ptr).kind).read() };
let ver = unsafe { ptr::addr_of!((*slot_ptr).ver).read() };
let payload_len = unsafe { ptr::addr_of!((*slot_ptr).payload_len).read() } as usize;
if payload_len > PAYLOAD_MAX {
return None;
}
let payload_src = unsafe { ptr::addr_of!((*slot_ptr).payload) as *const u8 };
let mut payload_buf = vec![0u8; payload_len];
unsafe { ptr::copy_nonoverlapping(payload_src, payload_buf.as_mut_ptr(), payload_len) };
let seq_post = slot.seq.load(Ordering::Acquire);
if seq_pre != seq_post {
return None;
}
Some(Frame {
id,
kind,
ver,
payload: Bytes::from(payload_buf),
})
}