use std::fs::{File, OpenOptions};
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
pub const SHM_MAGIC: &[u8; 8] = b"RDBSHM01";
pub const SHM_VERSION: u32 = 1;
pub const SHM_HEADER_SIZE: usize = 64;
pub const SHM_FILE_SIZE: u64 = 4096;
static SHM_POLICY: AtomicU8 = AtomicU8::new(0);
pub fn set_shm_provisioning_enabled(enabled: bool) {
SHM_POLICY.store(if enabled { 1 } else { 2 }, Ordering::Relaxed);
}
pub fn shm_provisioning_enabled() -> bool {
match SHM_POLICY.load(Ordering::Relaxed) {
1 => true,
2 => false,
_ => std::env::var("REDDB_SHM_PROVISION")
.ok()
.map(|v| matches!(v.as_str(), "1" | "true" | "TRUE" | "yes" | "on"))
.unwrap_or(false),
}
}
pub fn shm_path_for(data_path: &Path) -> PathBuf {
let file_name = data_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| "data.rdb".to_string());
let shm_file = format!("{file_name}-shm");
match data_path.parent() {
Some(parent) if !parent.as_os_str().is_empty() => parent.join(shm_file),
_ => PathBuf::from(shm_file),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShmProvisionState {
Created,
AttachedToLiveOwner,
RecoveredFromCrash,
HealedCorruptHeader,
}
#[derive(Debug, Clone)]
pub struct ShmHeader {
pub version: u32,
pub owner_pid: u32,
pub generation: u64,
pub reader_count: u64,
pub last_heartbeat_ms: u64,
}
impl ShmHeader {
fn encode(&self) -> [u8; SHM_HEADER_SIZE] {
let mut buf = [0u8; SHM_HEADER_SIZE];
buf[0..8].copy_from_slice(SHM_MAGIC);
buf[8..12].copy_from_slice(&self.version.to_le_bytes());
buf[12..16].copy_from_slice(&self.owner_pid.to_le_bytes());
buf[16..24].copy_from_slice(&self.generation.to_le_bytes());
buf[24..32].copy_from_slice(&self.reader_count.to_le_bytes());
buf[32..40].copy_from_slice(&self.last_heartbeat_ms.to_le_bytes());
let checksum = fold_checksum(&buf[..56]);
buf[56..64].copy_from_slice(&checksum.to_le_bytes());
buf
}
fn decode(buf: &[u8; SHM_HEADER_SIZE]) -> io::Result<Self> {
if &buf[0..8] != SHM_MAGIC {
return Err(io::Error::new(io::ErrorKind::InvalidData, "shm magic mismatch"));
}
let stored_checksum = u64::from_le_bytes(buf[56..64].try_into().unwrap());
let computed = fold_checksum(&buf[..56]);
if stored_checksum != computed {
return Err(io::Error::new(io::ErrorKind::InvalidData, "shm checksum mismatch"));
}
Ok(Self {
version: u32::from_le_bytes(buf[8..12].try_into().unwrap()),
owner_pid: u32::from_le_bytes(buf[12..16].try_into().unwrap()),
generation: u64::from_le_bytes(buf[16..24].try_into().unwrap()),
reader_count: u64::from_le_bytes(buf[24..32].try_into().unwrap()),
last_heartbeat_ms: u64::from_le_bytes(buf[32..40].try_into().unwrap()),
})
}
}
pub struct ShmHandle {
pub path: PathBuf,
pub header: ShmHeader,
pub state: ShmProvisionState,
file: File,
}
impl ShmHandle {
pub fn generation(&self) -> u64 {
self.header.generation
}
pub fn attach_reader(&mut self) -> io::Result<u64> {
self.header.reader_count = self.header.reader_count.saturating_add(1);
self.rewrite_header()?;
Ok(self.header.reader_count)
}
pub fn detach_reader(&mut self) -> io::Result<u64> {
self.header.reader_count = self.header.reader_count.saturating_sub(1);
self.rewrite_header()?;
Ok(self.header.reader_count)
}
pub fn heartbeat(&mut self) -> io::Result<()> {
self.header.last_heartbeat_ms = unix_ms_now();
self.rewrite_header()
}
fn rewrite_header(&mut self) -> io::Result<()> {
let buf = self.header.encode();
self.file.seek(SeekFrom::Start(0))?;
self.file.write_all(&buf)?;
self.file.sync_data()?;
Ok(())
}
}
pub fn provision_shm(data_path: &Path) -> io::Result<ShmHandle> {
let path = shm_path_for(data_path);
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)?;
let metadata = file.metadata()?;
let fresh = metadata.len() == 0;
if fresh {
file.set_len(SHM_FILE_SIZE)?;
let header = ShmHeader {
version: SHM_VERSION,
owner_pid: current_pid(),
generation: 1,
reader_count: 0,
last_heartbeat_ms: unix_ms_now(),
};
file.seek(SeekFrom::Start(0))?;
file.write_all(&header.encode())?;
file.sync_data()?;
return Ok(ShmHandle {
path,
header,
state: ShmProvisionState::Created,
file,
});
}
let mut buf = [0u8; SHM_HEADER_SIZE];
file.seek(SeekFrom::Start(0))?;
let existing = match file.read_exact(&mut buf) {
Ok(()) => ShmHeader::decode(&buf).ok(),
Err(_) => None,
};
let (header, state) = match existing {
Some(prev) if pid_alive(prev.owner_pid) && prev.owner_pid != current_pid() => {
let next = ShmHeader {
version: SHM_VERSION,
owner_pid: prev.owner_pid,
generation: prev.generation,
reader_count: prev.reader_count.saturating_add(1),
last_heartbeat_ms: prev.last_heartbeat_ms,
};
(next, ShmProvisionState::AttachedToLiveOwner)
}
Some(prev) if prev.owner_pid == current_pid() => {
let next = ShmHeader {
version: SHM_VERSION,
owner_pid: prev.owner_pid,
generation: prev.generation,
reader_count: prev.reader_count,
last_heartbeat_ms: unix_ms_now(),
};
(next, ShmProvisionState::AttachedToLiveOwner)
}
Some(prev) => {
let next = ShmHeader {
version: SHM_VERSION,
owner_pid: current_pid(),
generation: prev.generation.saturating_add(1),
reader_count: 0,
last_heartbeat_ms: unix_ms_now(),
};
(next, ShmProvisionState::RecoveredFromCrash)
}
None => {
let next = ShmHeader {
version: SHM_VERSION,
owner_pid: current_pid(),
generation: 1,
reader_count: 0,
last_heartbeat_ms: unix_ms_now(),
};
file.set_len(SHM_FILE_SIZE)?;
(next, ShmProvisionState::HealedCorruptHeader)
}
};
file.seek(SeekFrom::Start(0))?;
file.write_all(&header.encode())?;
file.sync_data()?;
Ok(ShmHandle {
path,
header,
state,
file,
})
}
pub fn read_shm_header(data_path: &Path) -> io::Result<Option<ShmHeader>> {
let path = shm_path_for(data_path);
if !path.exists() {
return Ok(None);
}
let mut file = OpenOptions::new().read(true).open(&path)?;
let mut buf = [0u8; SHM_HEADER_SIZE];
file.read_exact(&mut buf)?;
ShmHeader::decode(&buf).map(Some)
}
fn fold_checksum(bytes: &[u8]) -> u64 {
let mut acc: u64 = 0xcbf29ce484222325;
for &byte in bytes {
acc ^= byte as u64;
acc = acc.wrapping_mul(0x100000001b3);
}
acc
}
fn unix_ms_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
fn current_pid() -> u32 {
std::process::id()
}
#[cfg(unix)]
fn pid_alive(pid: u32) -> bool {
if pid == 0 {
return false;
}
let rc = unsafe { libc::kill(pid as libc::pid_t, 0) };
if rc == 0 {
return true;
}
io::Error::last_os_error()
.raw_os_error()
.map(|e| e == libc::EPERM)
.unwrap_or(false)
}
#[cfg(not(unix))]
fn pid_alive(_pid: u32) -> bool {
true
}