#![allow(unsafe_code)]
use std::collections::{HashMap, HashSet, VecDeque};
use std::ptr::NonNull;
use std::time::Duration;
use rand::TryRngCore;
use sha2::{Digest, Sha256};
use zeroize::Zeroize;
use super::memcall::{os_lock, os_protect, os_unlock, page_size, Protection};
use crate::error::{Error, Result};
pub const DEFAULT_SLOT_SIZE: usize = 32;
pub(crate) const SLOT_WAIT_TIMEOUT: Duration = Duration::from_secs(30);
pub(crate) const COFFER_LEFT: usize = 0;
pub(crate) const COFFER_RIGHT: usize = 1;
pub(crate) const FIRST_SHARED_SLOT: usize = 2;
pub struct SecureSlab {
ptr: NonNull<u8>,
page_size: usize,
pub slot_size: usize,
total_slots: usize,
has_coffer: bool,
free: Vec<usize>,
cache_map: HashMap<u64, usize>,
cache_lru: VecDeque<u64>,
transient: HashSet<usize>,
}
unsafe impl Send for SecureSlab {}
impl std::fmt::Debug for SecureSlab {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecureSlab")
.field("slot_size", &self.slot_size)
.field("total_slots", &self.total_slots)
.field("has_coffer", &self.has_coffer)
.field("free", &self.free.len())
.field("cached", &self.cache_map.len())
.field("transient", &self.transient.len())
.finish()
}
}
impl SecureSlab {
pub fn new(slot_size: usize, init_coffer: bool) -> Result<Self> {
let ps = page_size();
if init_coffer && slot_size < 32 {
return Err(Error::Memory(format!(
"SecureSlab: coffer requires slot_size >= 32 (AES-256 key size), got {slot_size}"
)));
}
if slot_size == 0 || slot_size > ps / 3 {
return Err(Error::Memory(format!(
"SecureSlab: slot_size {slot_size} is invalid (must be 1..={})",
ps / 3
)));
}
let total_slots = ps / slot_size;
if total_slots < 3 {
return Err(Error::Memory(
"SecureSlab: page too small for at least 3 slots (2 coffer + 1 usable)".into(),
));
}
let ptr = unsafe { super::memcall::os_alloc(ps) }
.map_err(|e| Error::Memory(format!("SecureSlab alloc: {e}")))?;
if let Err(e) = unsafe { os_lock(ptr.as_ptr(), ps) } {
drop(unsafe { super::memcall::os_free(ptr.as_ptr(), ps) });
return Err(Error::Memory(format!("SecureSlab mlock: {e}")));
}
let free: Vec<usize> = (FIRST_SHARED_SLOT..total_slots).collect();
let mut slab = Self {
ptr,
page_size: ps,
slot_size,
total_slots,
has_coffer: init_coffer,
free,
cache_map: HashMap::new(),
cache_lru: VecDeque::new(),
transient: HashSet::new(),
};
if init_coffer {
slab.init_coffer_slots()?;
}
Ok(slab)
}
fn init_coffer_slots(&mut self) -> Result<()> {
let right_ptr = self.slot_ptr(COFFER_RIGHT);
unsafe {
let right = std::slice::from_raw_parts_mut(right_ptr, self.slot_size);
rand::rngs::OsRng
.try_fill_bytes(right)
.map_err(|e| Error::Memory(format!("SecureSlab coffer right OsRng: {e}")))?;
}
let mut master_key = zeroize::Zeroizing::new(vec![0_u8; self.slot_size]);
rand::rngs::OsRng
.try_fill_bytes(&mut master_key)
.map_err(|e| Error::Memory(format!("SecureSlab coffer master_key OsRng: {e}")))?;
unsafe {
let right = std::slice::from_raw_parts(right_ptr, self.slot_size);
let mut h = zeroize::Zeroizing::new([0_u8; 32]);
let digest: [u8; 32] = Sha256::digest(right).into();
h.copy_from_slice(&digest);
let left = std::slice::from_raw_parts_mut(self.slot_ptr(COFFER_LEFT), self.slot_size);
for i in 0..self.slot_size {
left[i] = master_key[i] ^ h[i % 32];
}
}
Ok(())
}
fn slot_ptr(&self, index: usize) -> *mut u8 {
debug_assert!(
index < self.total_slots,
"slot index {index} out of range (total={})",
self.total_slots
);
unsafe { self.ptr.as_ptr().add(index * self.slot_size) }
}
fn wipe_slot(&self, index: usize) {
unsafe {
std::slice::from_raw_parts_mut(self.slot_ptr(index), self.slot_size).zeroize();
}
}
pub fn slot_raw(&self, index: usize) -> Option<(*mut u8, usize)> {
if index >= self.total_slots {
return None;
}
Some((self.slot_ptr(index), self.slot_size))
}
#[allow(dead_code)]
pub fn slot_size(&self) -> usize {
self.slot_size
}
#[allow(dead_code)]
pub fn total_slots(&self) -> usize {
self.total_slots
}
pub fn acquire_transient(&mut self) -> Option<usize> {
if let Some(idx) = self.free.pop() {
self.transient.insert(idx);
return Some(idx);
}
let evict_id = *self.cache_lru.front()?;
self.cache_evict(evict_id);
let idx = self.free.pop()?;
self.transient.insert(idx);
Some(idx)
}
pub fn release(&mut self, index: usize) {
let was_transient = self.transient.remove(&index);
debug_assert!(
was_transient,
"SecureSlab::release: slot {index} was not transient (double-release?)"
);
self.wipe_slot(index);
self.free.push(index);
}
pub fn cache_get(&mut self, id: u64) -> Option<usize> {
if !self.cache_map.contains_key(&id) {
return None;
}
let cached_idx = *self.cache_map.get(&id)?;
let out_idx = if !self.free.is_empty() {
self.free.pop()?
} else {
let lru_id = *self.cache_lru.front()?;
if lru_id == id {
return None;
}
self.cache_evict(lru_id);
self.free.pop()?
};
unsafe {
let src = std::slice::from_raw_parts(self.slot_ptr(cached_idx), self.slot_size);
let dst = std::slice::from_raw_parts_mut(self.slot_ptr(out_idx), self.slot_size);
dst.copy_from_slice(src);
}
self.transient.insert(out_idx);
if let Some(pos) = self.cache_lru.iter().position(|&x| x == id) {
self.cache_lru.remove(pos);
}
self.cache_lru.push_back(id);
Some(out_idx)
}
pub fn cache_insert(&mut self, id: u64, data: &[u8]) -> bool {
if data.len() != self.slot_size {
return false;
}
if self.cache_map.contains_key(&id) {
self.cache_evict(id);
}
let slot_idx = if let Some(idx) = self.free.pop() {
idx
} else if let Some(&lru_id) = self.cache_lru.front() {
self.cache_evict(lru_id);
match self.free.pop() {
Some(idx) => idx,
None => return false,
}
} else {
return false;
};
unsafe {
let dst = std::slice::from_raw_parts_mut(self.slot_ptr(slot_idx), self.slot_size);
dst.copy_from_slice(data);
}
self.cache_map.insert(id, slot_idx);
self.cache_lru.push_back(id);
true
}
pub fn cache_evict(&mut self, id: u64) {
if let Some(slot_idx) = self.cache_map.remove(&id) {
if let Some(pos) = self.cache_lru.iter().position(|&x| x == id) {
self.cache_lru.remove(pos);
}
self.wipe_slot(slot_idx);
self.free.push(slot_idx);
}
}
pub fn coffer_view(&mut self) -> Option<usize> {
debug_assert!(
self.has_coffer,
"SecureSlab::coffer_view called on non-coffer slab"
);
let out_idx = self.acquire_transient()?;
unsafe {
let left = std::slice::from_raw_parts(self.slot_ptr(COFFER_LEFT), self.slot_size);
let right = std::slice::from_raw_parts(self.slot_ptr(COFFER_RIGHT), self.slot_size);
let mut h = zeroize::Zeroizing::new([0_u8; 32]);
let digest: [u8; 32] = Sha256::digest(right).into();
h.copy_from_slice(&digest);
let out = std::slice::from_raw_parts_mut(self.slot_ptr(out_idx), self.slot_size);
for i in 0..self.slot_size {
out[i] = left[i] ^ h[i % 32];
}
}
Some(out_idx)
}
#[allow(dead_code)]
pub fn usable_slots(&self) -> usize {
self.total_slots.saturating_sub(FIRST_SHARED_SLOT)
}
}
impl Drop for SecureSlab {
fn drop(&mut self) {
unsafe {
drop(os_protect(
self.ptr.as_ptr(),
self.page_size,
Protection::ReadWrite,
));
std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.page_size).zeroize();
drop(os_unlock(self.ptr.as_ptr(), self.page_size));
drop(super::memcall::os_free(self.ptr.as_ptr(), self.page_size));
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn new_default_slot_size() {
let slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
assert_eq!(slab.slot_size(), DEFAULT_SLOT_SIZE);
assert!(slab.total_slots() >= 3);
assert_eq!(slab.usable_slots(), slab.total_slots() - 2);
}
#[test]
fn slot_size_zero_rejected() {
assert!(SecureSlab::new(0, false).is_err());
}
#[test]
fn slot_size_too_large_rejected() {
let ps = page_size();
assert!(SecureSlab::new(ps / 3 + 1, false).is_err());
}
#[test]
fn coffer_slot_size_too_small_rejected() {
let result = SecureSlab::new(16, true);
assert!(
result.is_err(),
"coffer requires slot_size >= 32 (AES-256 key size)"
);
}
#[test]
fn slot_raw_out_of_bounds_returns_none() {
let slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
assert!(slab.slot_raw(slab.total_slots()).is_none());
assert!(slab.slot_raw(usize::MAX).is_none());
}
#[test]
fn coffer_view_result_is_deterministic() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, true).unwrap();
let idx1 = slab.coffer_view().unwrap();
let (ptr1, len1) = slab
.slot_raw(idx1)
.expect("slot_raw: index from coffer_view");
let key1 = unsafe { std::slice::from_raw_parts(ptr1, len1) }.to_vec();
slab.release(idx1);
let idx2 = slab.coffer_view().unwrap();
let (ptr2, len2) = slab
.slot_raw(idx2)
.expect("slot_raw: index from coffer_view");
let key2 = unsafe { std::slice::from_raw_parts(ptr2, len2) }.to_vec();
slab.release(idx2);
assert_eq!(
key1, key2,
"coffer_view must reconstruct the same key each time"
);
assert!(key1.iter().any(|&b| b != 0), "key must not be all zeros");
}
#[test]
fn acquire_and_release() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let idx = slab.acquire_transient().unwrap();
assert!(
idx >= FIRST_SHARED_SLOT,
"coffer slots must not be returned"
);
let idx2 = slab.acquire_transient().unwrap();
assert_ne!(idx, idx2);
slab.release(idx2);
slab.release(idx);
let re = slab.acquire_transient().unwrap();
assert!(re >= FIRST_SHARED_SLOT);
slab.release(re);
}
#[test]
fn out_of_range_slot_raw_is_safe() {
let slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
assert!(slab.slot_raw(slab.total_slots()).is_none());
assert!(slab.slot_raw(usize::MAX).is_none());
let (p0, _) = slab.slot_raw(0).expect("slot 0 is valid");
let (p1, _) = slab.slot_raw(1).expect("slot 1 is valid");
assert_eq!(unsafe { p1.offset_from(p0) } as usize, DEFAULT_SLOT_SIZE);
}
#[test]
fn release_nonexistent_is_noop_in_release_build() {
#[cfg(not(debug_assertions))]
{
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
slab.release(5); }
#[cfg(debug_assertions)]
{
let slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
drop(slab);
}
}
#[test]
fn cache_insert_and_get() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let data = [0x42_u8; DEFAULT_SLOT_SIZE];
assert!(slab.cache_insert(42, &data));
let slot_idx = slab.cache_get(42).unwrap();
let (ptr, len) = slab
.slot_raw(slot_idx)
.expect("slot_raw: index validated by cache_get");
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
assert_eq!(slice, &data);
slab.release(slot_idx);
}
#[test]
fn cache_insert_wrong_size_rejected() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let data = [0x42_u8; DEFAULT_SLOT_SIZE - 1];
assert!(!slab.cache_insert(42, &data));
assert!(slab.cache_get(42).is_none());
}
#[test]
fn cache_evict_removes_entry() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let data = [0x55_u8; DEFAULT_SLOT_SIZE];
slab.cache_insert(99, &data);
slab.cache_evict(99);
assert!(slab.cache_get(99).is_none());
}
#[test]
fn coffer_view_reconstructs_key() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, true).unwrap();
let idx1 = slab.coffer_view().unwrap();
let (ptr1, len1) = slab
.slot_raw(idx1)
.expect("slot_raw: index from coffer_view is valid");
let key1 = unsafe { std::slice::from_raw_parts(ptr1, len1) }.to_vec();
slab.release(idx1);
let idx2 = slab.coffer_view().unwrap();
let (ptr2, len2) = slab
.slot_raw(idx2)
.expect("slot_raw: index from coffer_view is valid");
let key2 = unsafe { std::slice::from_raw_parts(ptr2, len2) }.to_vec();
slab.release(idx2);
assert_eq!(key1, key2, "coffer_view must be deterministic");
assert!(!key1.iter().all(|&b| b == 0_u8));
}
#[test]
fn acquire_transient_evicts_lru_when_full() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let usable = slab.total_slots() - FIRST_SHARED_SLOT;
for id in 0..(usable as u64) {
let data = [id as u8; DEFAULT_SLOT_SIZE];
assert!(slab.cache_insert(id, &data), "insert {id} failed");
}
let idx = slab
.acquire_transient()
.expect("should evict LRU to make room");
assert!(
slab.cache_get(0).is_none(),
"LRU entry (id=0) should have been evicted"
);
slab.release(idx);
}
#[test]
fn cache_get_promotes_to_mru() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let usable = slab.total_slots() - FIRST_SHARED_SLOT;
let count = usable - 1;
for id in 0..(count as u64) {
let data = [id as u8; DEFAULT_SLOT_SIZE];
slab.cache_insert(id, &data);
}
let copy_idx = slab.cache_get(0).unwrap();
slab.release(copy_idx);
slab.cache_insert(count as u64, &[0xFE_u8; DEFAULT_SLOT_SIZE]);
slab.cache_insert(count as u64 + 1, &[0xFF_u8; DEFAULT_SLOT_SIZE]);
assert!(
slab.cache_get(0).is_some() || slab.cache_get(1).is_none(),
"id=1 should be evicted before id=0"
);
if let Some(idx) = slab.cache_get(0) {
slab.release(idx);
}
}
#[test]
fn coffer_slots_not_in_free_list() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, true).unwrap();
let mut acquired = Vec::new();
while let Some(idx) = slab.acquire_transient() {
assert!(
idx >= FIRST_SHARED_SLOT,
"coffer slot {idx} must never appear in the free list"
);
acquired.push(idx);
}
for idx in acquired {
slab.release(idx);
}
}
#[test]
fn free_list_is_lifo_highest_index_first() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let usable = slab.total_slots() - FIRST_SHARED_SLOT;
let mut indices = Vec::new();
while let Some(idx) = slab.acquire_transient() {
indices.push(idx);
}
assert_eq!(indices.len(), usable, "all usable slots should be acquired");
assert_eq!(indices[0], slab.total_slots() - 1);
for idx in indices {
slab.release(idx);
}
}
#[test]
fn released_slot_is_zeroed() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let idx = slab.acquire_transient().unwrap();
let (ptr, len) = slab.slot_raw(idx).unwrap();
unsafe { std::slice::from_raw_parts_mut(ptr, len).fill(0xBE) };
slab.release(idx);
let idx2 = slab.acquire_transient().unwrap();
let (ptr2, len2) = slab.slot_raw(idx2).unwrap();
let slice = unsafe { std::slice::from_raw_parts(ptr2, len2) };
assert!(
slice.iter().all(|&b| b == 0),
"released slot must be zeroed"
);
slab.release(idx2);
}
#[test]
fn double_acquire_returns_different_indices() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let a = slab.acquire_transient().unwrap();
let b = slab.acquire_transient().unwrap();
assert_ne!(a, b, "two consecutive acquires must return different slots");
slab.release(a);
slab.release(b);
}
#[test]
fn acquire_transient_single_entry_self_eviction_guard() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let usable = slab.total_slots() - FIRST_SHARED_SLOT;
let data = [0x42_u8; DEFAULT_SLOT_SIZE];
slab.cache_insert(999, &data);
let mut held = Vec::new();
for _ in 0..(usable - 1) {
if let Some(idx) = slab.acquire_transient() {
held.push(idx);
}
}
let result = slab.acquire_transient();
assert!(
result.is_some(),
"should evict cache entry to give us a slot"
);
assert!(
slab.cache_get(999).is_none(),
"cache entry 999 must be evicted"
);
for idx in held {
slab.release(idx);
}
if let Some(idx) = result {
slab.release(idx);
}
}
#[test]
fn lru_ordering_insert_a_b_c_access_a_evicts_b_next() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let usable = slab.total_slots() - FIRST_SHARED_SLOT;
assert!(
usable >= 4,
"need at least 4 usable slots for this test (A+B+C+D)"
);
let data = [0x11_u8; DEFAULT_SLOT_SIZE];
slab.cache_insert(1, &data);
slab.cache_insert(2, &data);
slab.cache_insert(3, &data);
let copy = slab.cache_get(1).unwrap();
slab.release(copy);
let slots_used = 3; let free_left = usable - slots_used;
let mut held = Vec::new();
for _ in 0..free_left {
if let Some(idx) = slab.acquire_transient() {
held.push(idx);
}
}
slab.cache_insert(4, &data);
for idx in held {
slab.release(idx);
}
assert!(
slab.cache_get(2).is_none(),
"B must be evicted (LRU after A was promoted)"
);
if let Some(idx) = slab.cache_get(1) {
slab.release(idx);
} else {
panic!("A must still be cached (was promoted to MRU)");
}
if let Some(idx) = slab.cache_get(3) {
slab.release(idx);
} else {
panic!("C must still be cached");
}
if let Some(idx) = slab.cache_get(4) {
slab.release(idx);
}
}
#[test]
fn cache_same_id_overwritten() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let data1 = [0x11_u8; DEFAULT_SLOT_SIZE];
let data2 = [0x22_u8; DEFAULT_SLOT_SIZE];
slab.cache_insert(42, &data1);
slab.cache_insert(42, &data2);
let idx = slab.cache_get(42).unwrap();
let (ptr, len) = slab.slot_raw(idx).unwrap();
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
assert_eq!(slice, &data2, "second insert must overwrite first");
slab.release(idx);
}
#[test]
fn coffer_key_not_all_zeros() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, true).unwrap();
let idx = slab.coffer_view().unwrap();
let (ptr, len) = slab.slot_raw(idx).unwrap();
let key = unsafe { std::slice::from_raw_parts(ptr, len) };
assert!(
key.iter().any(|&b| b != 0),
"coffer key must not be all zeros"
);
slab.release(idx);
}
#[test]
fn coffer_key_is_different_from_right_half() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, true).unwrap();
let key_idx = slab.coffer_view().unwrap();
let (kptr, klen) = slab.slot_raw(key_idx).unwrap();
let key = unsafe { std::slice::from_raw_parts(kptr, klen).to_vec() };
slab.release(key_idx);
let (rptr, rlen) = slab.slot_raw(COFFER_RIGHT).unwrap();
let right = unsafe { std::slice::from_raw_parts(rptr, rlen).to_vec() };
assert_ne!(key, right, "coffer key must differ from raw right half");
}
#[test]
fn minimum_slot_size_valid() {
let slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let ps = page_size();
assert_eq!(slab.total_slots(), ps / DEFAULT_SLOT_SIZE);
}
#[test]
fn cache_evicted_entry_slot_is_zeroed() {
let mut slab = SecureSlab::new(DEFAULT_SLOT_SIZE, false).unwrap();
let usable = slab.total_slots() - FIRST_SHARED_SLOT;
let pattern = vec![0xAA_u8; DEFAULT_SLOT_SIZE];
for id in 0..(usable as u64) {
slab.cache_insert(id, &pattern);
}
let evicted_id = 0_u64;
let new_idx = slab.acquire_transient().unwrap();
assert!(slab.cache_get(evicted_id).is_none(), "id=0 must be evicted");
let (ptr, len) = slab.slot_raw(new_idx).unwrap();
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
assert!(
slice.iter().all(|&b| b == 0),
"evicted cache slot must be zeroed"
);
slab.release(new_idx);
}
}