use crate::error::{IpcError, Result};
use crate::shm::SharedMemory;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
const HEADER_SIZE: usize = 64;
const MAGIC: u32 = 0x524C_4B21;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum ResourceKind {
SharedMemory = 0,
MappedFile = 1,
}
impl TryFrom<u8> for ResourceKind {
type Error = IpcError;
fn try_from(v: u8) -> Result<Self> {
match v {
0 => Ok(Self::SharedMemory),
1 => Ok(Self::MappedFile),
_ => Err(IpcError::Other(format!("unknown ResourceKind byte {v}"))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceLinkInfo {
pub key: String,
pub len: usize,
pub payload_len: usize,
pub kind: ResourceKind,
pub created_at: SystemTime,
pub ttl: Option<Duration>,
pub refcount: u32,
}
pub struct ResourceLink {
shm: SharedMemory,
key: String,
kind: ResourceKind,
ttl: Option<Duration>,
}
fn read_magic(shm: &SharedMemory) -> Result<u32> {
let bytes = shm.read(0, 4)?;
Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
unsafe fn refcount_ptr(shm: &SharedMemory) -> *const AtomicU32 {
shm.as_ptr().add(4) as *const AtomicU32
}
fn load_refcount(shm: &SharedMemory) -> u32 {
unsafe { (*refcount_ptr(shm)).load(Ordering::SeqCst) }
}
fn increment_refcount(shm: &SharedMemory) -> u32 {
unsafe { (*refcount_ptr(shm)).fetch_add(1, Ordering::SeqCst) + 1 }
}
fn decrement_refcount(shm: &SharedMemory) -> u32 {
unsafe { (*refcount_ptr(shm)).fetch_sub(1, Ordering::SeqCst) - 1 }
}
fn read_created_at_secs(shm: &SharedMemory) -> Result<u64> {
let bytes = shm.read(8, 8)?;
Ok(u64::from_le_bytes(bytes.try_into().unwrap()))
}
fn read_kind(shm: &SharedMemory) -> Result<ResourceKind> {
let byte = shm.read(24, 1)?[0];
ResourceKind::try_from(byte)
}
fn write_header(shm: &mut SharedMemory, payload_len: usize, kind: ResourceKind) -> Result<()> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
shm.write(0, &MAGIC.to_le_bytes())?;
shm.write(4, &1u32.to_le_bytes())?;
shm.write(8, &now.to_le_bytes())?;
shm.write(16, &(payload_len as u64).to_le_bytes())?;
shm.write(24, &[kind as u8])?;
Ok(())
}
impl ResourceLink {
pub fn create(
key: &str,
payload_size: usize,
kind: ResourceKind,
ttl: Option<Duration>,
) -> Result<Self> {
let total = payload_size + HEADER_SIZE;
let mut shm = SharedMemory::create(key, total)?;
write_header(&mut shm, payload_size, kind)?;
Ok(Self {
shm,
key: key.to_string(),
kind,
ttl,
})
}
pub fn acquire(key: &str) -> Result<Self> {
let shm = SharedMemory::open(key)?;
if read_magic(&shm)? != MAGIC {
return Err(IpcError::Other(format!(
"ResourceLink: segment '{key}' has invalid magic — not a ResourceLink segment"
)));
}
let kind = read_kind(&shm)?;
increment_refcount(&shm);
Ok(Self {
shm,
key: key.to_string(),
kind,
ttl: None,
})
}
pub fn refcount(&self) -> u32 {
load_refcount(&self.shm)
}
pub fn key(&self) -> &str {
&self.key
}
pub fn len(&self) -> usize {
self.shm.size()
}
pub fn is_empty(&self) -> bool {
self.shm.size() <= HEADER_SIZE
}
pub fn payload_len(&self) -> usize {
self.shm.size().saturating_sub(HEADER_SIZE)
}
pub fn kind(&self) -> ResourceKind {
self.kind
}
pub fn created_at(&self) -> Result<SystemTime> {
let secs = read_created_at_secs(&self.shm)?;
Ok(UNIX_EPOCH + Duration::from_secs(secs))
}
pub fn ttl(&self) -> Option<Duration> {
self.ttl
}
pub fn is_expired(&self) -> bool {
let Some(ttl) = self.ttl else { return false };
self.created_at()
.ok()
.and_then(|t| t.elapsed().ok())
.is_some_and(|age| age > ttl)
}
pub fn info(&self) -> Result<ResourceLinkInfo> {
Ok(ResourceLinkInfo {
key: self.key.clone(),
len: self.shm.size(),
payload_len: self.payload_len(),
kind: self.kind,
created_at: self.created_at()?,
ttl: self.ttl,
refcount: self.refcount(),
})
}
pub fn write_payload(&mut self, data: &[u8]) -> Result<()> {
if data.len() > self.payload_len() {
return Err(IpcError::BufferTooSmall {
needed: data.len(),
got: self.payload_len(),
});
}
self.shm.write(HEADER_SIZE, data)
}
pub fn read_payload(&self, payload_offset: usize, len: usize) -> Result<Vec<u8>> {
self.shm.read(HEADER_SIZE + payload_offset, len)
}
pub fn gc_orphans(max_age: Duration) -> usize {
#[cfg(target_os = "linux")]
{
gc_orphans_unix("/dev/shm", max_age)
}
#[cfg(target_os = "macos")]
{
gc_orphans_unix("/tmp", max_age)
}
#[cfg(windows)]
{
let _ = max_age;
0
}
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
{
let _ = max_age;
0
}
}
}
impl Drop for ResourceLink {
fn drop(&mut self) {
let remaining = decrement_refcount(&self.shm);
let _ = remaining;
}
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn gc_orphans_unix(shm_dir: &str, max_age: Duration) -> usize {
use std::ffi::CString;
let dir = match std::fs::read_dir(shm_dir) {
Ok(d) => d,
Err(_) => return 0,
};
let now = SystemTime::now();
let mut removed = 0;
for entry in dir.flatten() {
let path = entry.path();
let fname = match path.file_name().and_then(|n| n.to_str()) {
Some(n) => n.to_string(),
None => continue,
};
let shm = match SharedMemory::open(&fname) {
Ok(s) => s,
Err(_) => continue,
};
if read_magic(&shm).ok() != Some(MAGIC) {
continue;
}
if load_refcount(&shm) > 0 {
continue;
}
let age_ok = read_created_at_secs(&shm)
.ok()
.map(|secs| UNIX_EPOCH + Duration::from_secs(secs))
.and_then(|created| now.duration_since(created).ok())
.is_some_and(|age| age > max_age);
if !age_ok {
continue;
}
#[cfg(unix)]
{
let c_name = match CString::new(format!("/{}", fname)) {
Ok(n) => n,
Err(_) => continue,
};
unsafe {
libc::shm_unlink(c_name.as_ptr());
}
removed += 1;
}
}
removed
}
#[cfg(test)]
mod tests {
use super::*;
fn unique_key(tag: &str) -> String {
format!(
"rl_test_{}_{}_{}",
tag,
std::process::id(),
std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.subsec_nanos()
)
}
#[test]
fn test_create_and_read_payload() {
let key = unique_key("crp");
let payload = b"Hello, ResourceLink!";
let mut link = ResourceLink::create(&key, 256, ResourceKind::SharedMemory, None).unwrap();
link.write_payload(payload).unwrap();
let read_back = link.read_payload(0, payload.len()).unwrap();
assert_eq!(read_back, payload);
}
#[test]
fn test_refcount_increments_on_acquire() {
let key = unique_key("rcia");
let link = ResourceLink::create(&key, 64, ResourceKind::SharedMemory, None).unwrap();
assert_eq!(link.refcount(), 1);
let consumer = ResourceLink::acquire(&key).unwrap();
assert_eq!(link.refcount(), 2);
assert_eq!(consumer.refcount(), 2);
drop(consumer);
assert_eq!(link.refcount(), 1);
}
#[test]
fn test_payload_too_large_returns_error() {
let key = unique_key("ptl");
let mut link = ResourceLink::create(&key, 8, ResourceKind::SharedMemory, None).unwrap();
let result = link.write_payload(&[0u8; 100]);
assert!(result.is_err());
}
#[test]
fn test_kind_round_trip() {
let key = unique_key("krt");
let link = ResourceLink::create(&key, 16, ResourceKind::SharedMemory, None).unwrap();
assert_eq!(link.kind(), ResourceKind::SharedMemory);
}
#[test]
fn test_info_snapshot() {
let key = unique_key("info");
let link = ResourceLink::create(
&key,
128,
ResourceKind::SharedMemory,
Some(Duration::from_secs(60)),
)
.unwrap();
let info = link.info().unwrap();
assert_eq!(info.key, key);
assert_eq!(info.payload_len, 128);
assert_eq!(info.kind, ResourceKind::SharedMemory);
assert_eq!(info.refcount, 1);
assert_eq!(info.ttl, Some(Duration::from_secs(60)));
}
#[test]
fn test_acquire_invalid_magic_fails() {
let key = unique_key("bad");
let _raw = SharedMemory::create(&key, 64).unwrap();
}
#[test]
fn test_ttl_not_expired() {
let key = unique_key("ttlok");
let link = ResourceLink::create(
&key,
32,
ResourceKind::SharedMemory,
Some(Duration::from_secs(3600)),
)
.unwrap();
assert!(!link.is_expired());
}
#[test]
fn test_no_ttl_never_expired() {
let key = unique_key("nottl");
let link = ResourceLink::create(&key, 32, ResourceKind::SharedMemory, None).unwrap();
assert!(!link.is_expired());
}
#[test]
fn test_gc_orphans_returns_usize() {
let _ = ResourceLink::gc_orphans(Duration::from_secs(1));
}
}