use std::io;
use zeroize::{Zeroize, ZeroizeOnDrop};
const CANARY: [u8; 18] = *b"~!GUARDED_CANARY!~";
const MAX_SECRET_SIZE: usize = 1024 * 1024;
pub struct ProtectedRegion {
base: *mut u8,
total_len: usize,
data: *mut u8,
data_pages_len: usize,
usable_len: usize,
protection: Protection,
}
unsafe impl Send for ProtectedRegion {}
impl ProtectedRegion {
pub fn new(size: usize) -> Result<Self, ProtectionError> {
if size == 0 || size > MAX_SECRET_SIZE {
return Err(ProtectionError::InvalidSize);
}
let page_size = page_size();
let needed = size + CANARY.len();
let data_pages_len = needed.next_multiple_of(page_size);
let total_len = page_size + data_pages_len + page_size;
let base = mmap_anon(total_len)?;
let data = unsafe { base.add(page_size) };
mprotect(data, data_pages_len, Protection::ReadWrite)?;
if unsafe { libc::mlock(data.cast(), data_pages_len) } != 0 {
unsafe { libc::munmap(base.cast(), total_len) };
return Err(ProtectionError::Mlock(io::Error::last_os_error()));
}
unsafe {
std::ptr::copy_nonoverlapping(CANARY.as_ptr(), data.add(size), CANARY.len());
}
mprotect(data, data_pages_len, Protection::NoAccess)?;
Ok(Self {
base,
total_len,
data,
data_pages_len,
usable_len: size,
protection: Protection::NoAccess,
})
}
#[must_use]
#[expect(clippy::len_without_is_empty, reason = "regions are always non-empty")]
pub fn len(&self) -> usize {
self.usable_len
}
pub fn with_read<T>(&mut self, f: impl FnOnce(&[u8]) -> T) -> Result<T, ProtectionError> {
self.verify_canary()?;
let slice = unsafe { std::slice::from_raw_parts(self.data, self.usable_len) };
let result = f(slice);
self.mprotect_noaccess()?;
Ok(result)
}
pub fn with_write<T>(&mut self, f: impl FnOnce(&mut [u8]) -> T) -> Result<T, ProtectionError> {
mprotect(self.data, self.data_pages_len, Protection::ReadWrite)?;
self.protection = Protection::ReadWrite;
let slice = unsafe { std::slice::from_raw_parts_mut(self.data, self.usable_len) };
let result = f(slice);
self.mprotect_noaccess()?;
Ok(result)
}
fn mprotect_noaccess(&mut self) -> Result<(), ProtectionError> {
mprotect(self.data, self.data_pages_len, Protection::NoAccess)?;
self.protection = Protection::NoAccess;
Ok(())
}
fn verify_canary(&mut self) -> Result<(), ProtectionError> {
mprotect(self.data, self.data_pages_len, Protection::ReadOnly)?;
self.protection = Protection::ReadOnly;
let canary_ptr = unsafe { self.data.add(self.usable_len) };
let stored = unsafe { std::slice::from_raw_parts(canary_ptr, CANARY.len()) };
if stored != CANARY {
return Err(ProtectionError::CanaryCorrupted);
}
Ok(())
}
}
impl Zeroize for ProtectedRegion {
fn zeroize(&mut self) {
let _ = mprotect(self.data, self.data_pages_len, Protection::ReadWrite);
self.protection = Protection::ReadWrite;
unsafe {
std::ptr::write_bytes(self.data, 0, self.data_pages_len);
}
}
}
impl Drop for ProtectedRegion {
fn drop(&mut self) {
self.zeroize();
unsafe {
libc::munlock(self.data.cast(), self.data_pages_len);
libc::munmap(self.base.cast(), self.total_len);
}
}
}
impl ZeroizeOnDrop for ProtectedRegion {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Protection {
NoAccess,
ReadOnly,
ReadWrite,
}
#[derive(Debug, thiserror::Error)]
pub enum ProtectionError {
#[error("requested size must be 1..=1Mb")]
InvalidSize,
#[error("mmap failed: {0}")]
Mmap(io::Error),
#[error("mprotect failed: {0}")]
Mprotect(io::Error),
#[error("mlock failed: {0}")]
Mlock(io::Error),
#[error("canary corrupted. potential buffer overflow")]
CanaryCorrupted,
}
fn page_size() -> usize {
let ps = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
usize::try_from(ps).unwrap_or(4096)
}
fn mmap_anon(len: usize) -> Result<*mut u8, ProtectionError> {
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
len,
libc::PROT_NONE,
libc::MAP_ANON | libc::MAP_PRIVATE,
-1,
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(ProtectionError::Mmap(io::Error::last_os_error()));
}
Ok(ptr.cast())
}
fn mprotect(addr: *mut u8, len: usize, prot: Protection) -> Result<(), ProtectionError> {
let flags = match prot {
Protection::NoAccess => libc::PROT_NONE,
Protection::ReadOnly => libc::PROT_READ,
Protection::ReadWrite => libc::PROT_READ | libc::PROT_WRITE,
};
if unsafe { libc::mprotect(addr.cast(), len, flags) } != 0 {
return Err(ProtectionError::Mprotect(io::Error::last_os_error()));
}
Ok(())
}
#[cfg(test)]
#[expect(clippy::unwrap_used, reason = "tests")]
mod tests {
use super::*;
#[test]
fn basic_write_read_cycle() {
let mut region = ProtectedRegion::new(32).unwrap();
region
.with_write(|buf| {
buf.copy_from_slice(&[0xAB; 32]);
})
.unwrap();
let data = region.with_read(<[u8]>::to_vec).unwrap();
assert_eq!(data, vec![0xAB; 32]);
}
#[test]
fn mprotect_noaccess_sets_state() {
let mut region = ProtectedRegion::new(16).unwrap();
region.mprotect_noaccess().unwrap();
assert_eq!(region.protection, Protection::NoAccess);
}
#[test]
fn canary_verified_on_read() {
let mut region = ProtectedRegion::new(8).unwrap();
region
.with_write(|buf| buf.copy_from_slice(&[1; 8]))
.unwrap();
let result = region.with_read(<[u8]>::to_vec);
assert!(result.is_ok());
}
#[test]
fn zero_length_rejected() {
assert!(ProtectedRegion::new(0).is_err());
}
#[test]
fn oversized_rejected() {
assert!(ProtectedRegion::new(MAX_SECRET_SIZE + 1).is_err());
}
#[test]
fn test_dropping() {
let region = ProtectedRegion::new(64).unwrap();
drop(region);
}
#[test]
fn zeroize_wipes_canary() {
let mut region = ProtectedRegion::new(32).unwrap();
region
.with_write(|buf| buf.copy_from_slice(&[0xAB; 32]))
.unwrap();
region.zeroize();
let err = region.with_read(<[u8]>::to_vec).unwrap_err();
assert!(matches!(err, ProtectionError::CanaryCorrupted));
}
#[test]
fn zeroize_is_idempotent() {
let mut region = ProtectedRegion::new(16).unwrap();
region
.with_write(|buf| buf.copy_from_slice(&[0xFF; 16]))
.unwrap();
region.zeroize();
region.zeroize();
drop(region);
}
#[test]
fn multiple_read_write_cycles() {
let mut region = ProtectedRegion::new(16).unwrap();
for i in 0u8..5 {
region
.with_write(|buf| {
for b in buf.iter_mut() {
*b = i;
}
})
.unwrap();
let data = region.with_read(|buf| buf[0]).unwrap();
assert_eq!(data, i);
}
}
#[test]
fn zeroize_writes_zeros_across_entire_data_region() {
let mut region = ProtectedRegion::new(64).unwrap();
region
.with_write(|buf| buf.copy_from_slice(&[0xAA; 64]))
.unwrap();
region.zeroize();
let observed = unsafe { std::slice::from_raw_parts(region.data, region.data_pages_len) };
let first_nonzero = observed.iter().position(|&b| b != 0);
assert!(
first_nonzero.is_none(),
"byte at offset {first_nonzero:?} was not zeroed"
);
}
#[test]
fn canary_corruption_is_detected_on_next_read() {
let mut region = ProtectedRegion::new(64).unwrap();
region
.with_write(|buf| buf.copy_from_slice(&[0xCC; 64]))
.unwrap();
mprotect(region.data, region.data_pages_len, Protection::ReadWrite).unwrap();
region.protection = Protection::ReadWrite;
unsafe {
*region.data.add(region.usable_len) ^= 0xFF;
}
region.mprotect_noaccess().unwrap();
let err = region.with_read(<[u8]>::to_vec).unwrap_err();
assert!(matches!(err, ProtectionError::CanaryCorrupted));
}
#[cfg(unix)]
fn assert_child_segfaults(child_work: impl FnOnce()) {
let pid = unsafe { libc::fork() };
assert!(pid >= 0, "fork failed");
if pid == 0 {
child_work();
unsafe { libc::_exit(0) };
}
let mut status: libc::c_int = 0;
let waited = unsafe { libc::waitpid(pid, &raw mut status, 0) };
assert_eq!(waited, pid, "waitpid did not return our child");
assert!(
libc::WIFSIGNALED(status),
"child exited cleanly (status {status}); expected to be killed by a signal"
);
let sig = libc::WTERMSIG(status);
assert!(
sig == libc::SIGSEGV || sig == libc::SIGBUS,
"expected SIGSEGV or SIGBUS, got signal {sig}"
);
}
#[cfg(unix)]
#[test]
fn underflow_into_front_guard_page_segfaults() {
assert_child_segfaults(|| {
let region = ProtectedRegion::new(64).unwrap();
let guard = unsafe { region.data.sub(1) };
let _ = unsafe { std::ptr::read_volatile(guard) };
});
}
#[cfg(unix)]
#[test]
fn overflow_past_back_guard_page_segfaults() {
assert_child_segfaults(|| {
let region = ProtectedRegion::new(64).unwrap();
let overflow = unsafe { region.data.add(region.data_pages_len) };
let _ = unsafe { std::ptr::read_volatile(overflow) };
});
}
}