#![allow(unsafe_code)]
use std::sync::{Condvar, Mutex, OnceLock};
use super::memcall::page_size;
use super::secure_buffer::SecureBuffer;
use super::slab::{SecureSlab, DEFAULT_SLOT_SIZE, SLOT_WAIT_TIMEOUT};
use crate::error::{Error, Result};
enum PoolSlotOrigin {
Slab {
tier_index: usize,
slot_index: usize,
},
Standalone(SecureBuffer),
}
pub struct PoolSlot {
ptr: *mut u8,
len: usize,
origin: PoolSlotOrigin,
}
unsafe impl Send for PoolSlot {}
impl std::fmt::Debug for PoolSlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoolSlot").field("len", &self.len).finish()
}
}
impl PoolSlot {
pub(crate) fn from_slab(
ptr: *mut u8,
len: usize,
tier_index: usize,
slot_index: usize,
) -> Self {
Self {
ptr,
len,
origin: PoolSlotOrigin::Slab {
tier_index,
slot_index,
},
}
}
fn from_standalone(mut buf: SecureBuffer) -> Self {
drop(buf.melt());
let ptr = buf.bytes().as_mut_ptr();
let len = buf.size();
Self {
ptr,
len,
origin: PoolSlotOrigin::Standalone(buf),
}
}
pub fn bytes(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
pub fn size(&self) -> usize {
self.len
}
#[allow(dead_code)]
pub(crate) fn slab_index(&self) -> Option<usize> {
match &self.origin {
PoolSlotOrigin::Slab { slot_index, .. } => Some(*slot_index),
PoolSlotOrigin::Standalone(_) => None,
}
}
#[allow(dead_code)]
pub(crate) fn tier_index(&self) -> Option<usize> {
match &self.origin {
PoolSlotOrigin::Slab { tier_index, .. } => Some(*tier_index),
PoolSlotOrigin::Standalone(_) => None,
}
}
}
impl Drop for PoolSlot {
fn drop(&mut self) {
match &mut self.origin {
PoolSlotOrigin::Slab {
tier_index,
slot_index,
} => {
unsafe {
use zeroize::Zeroize;
std::slice::from_raw_parts_mut(self.ptr, self.len).zeroize();
}
let pool = global_pool();
if let Ok(mut slab) = pool.tiers[*tier_index].slab.lock() {
slab.release(*slot_index);
}
pool.tiers[*tier_index].cv.notify_one();
}
PoolSlotOrigin::Standalone(buf) => {
drop(buf.melt());
unsafe {
use zeroize::Zeroize;
std::slice::from_raw_parts_mut(self.ptr, self.len).zeroize();
}
}
}
}
}
struct Tier {
slot_size: usize,
slab: Mutex<SecureSlab>,
cv: Condvar,
}
impl std::fmt::Debug for Tier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tier")
.field("slot_size", &self.slot_size)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct TieredPoolConfig {
pub tier_sizes: Vec<usize>,
}
impl Default for TieredPoolConfig {
fn default() -> Self {
Self {
tier_sizes: vec![DEFAULT_SLOT_SIZE],
}
}
}
#[derive(Debug)]
pub struct TieredPool {
tiers: Vec<Tier>,
}
unsafe impl Send for TieredPool {}
unsafe impl Sync for TieredPool {}
impl TieredPool {
pub fn new(config: TieredPoolConfig) -> Result<Self> {
#[cfg(not(test))]
crate::harden_process();
let ps = page_size();
let max_slot = ps / 3;
if config.tier_sizes.is_empty() {
return Err(Error::Memory(
"TieredPoolConfig: tier_sizes must be non-empty".into(),
));
}
let mut sizes = config.tier_sizes;
sizes.sort_unstable();
sizes.dedup();
for &sz in &sizes {
if sz == 0 {
return Err(Error::Memory(
"TieredPoolConfig: tier size 0 is invalid".into(),
));
}
if sz > max_slot {
return Err(Error::Memory(format!(
"TieredPoolConfig: tier size {sz} exceeds page_size/3 ({max_slot})"
)));
}
}
if sizes[0] < 32 {
return Err(Error::Memory(format!(
"TieredPool: first tier slot_size must be >= 32 for coffer, got {}",
sizes[0]
)));
}
let mut tiers = Vec::with_capacity(sizes.len());
for (i, sz) in sizes.into_iter().enumerate() {
let init_coffer = i == 0;
let slab = SecureSlab::new(sz, init_coffer)?;
tiers.push(Tier {
slot_size: sz,
slab: Mutex::new(slab),
cv: Condvar::new(),
});
}
Ok(Self { tiers })
}
fn tier_for_size(&self, size: usize) -> Option<usize> {
self.tiers.iter().position(|t| t.slot_size >= size)
}
pub(crate) fn acquire(&self, size: usize) -> Result<PoolSlot> {
if let Some(tier_idx) = self.tier_for_size(size) {
let deadline = std::time::Instant::now() + SLOT_WAIT_TIMEOUT;
let mut guard = self.tiers[tier_idx]
.slab
.lock()
.unwrap_or_else(|e| e.into_inner());
loop {
if let Some(slot_idx) = guard.acquire_transient() {
let (ptr, len) = guard
.slot_raw(slot_idx)
.expect("slot_raw: index validated by acquire_transient");
drop(guard);
return Ok(PoolSlot::from_slab(ptr, len, tier_idx, slot_idx));
}
let timeout = deadline.saturating_duration_since(std::time::Instant::now());
if timeout.is_zero() {
tracing::warn!(
size,
tier_idx,
"pool acquire: all slab slots exhausted; using standalone SecureBuffer"
);
drop(guard);
break;
}
let result = self.tiers[tier_idx]
.cv
.wait_timeout(guard, timeout)
.unwrap_or_else(|e| e.into_inner());
guard = result.0;
}
}
Ok(PoolSlot::from_standalone(SecureBuffer::new(size)?))
}
pub(crate) fn coffer_view(&self) -> Result<PoolSlot> {
let mut guard = self.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
let slot_idx = guard
.coffer_view()
.ok_or_else(|| Error::Memory("coffer_view: no free slab slot".into()))?;
let (ptr, len) = guard
.slot_raw(slot_idx)
.expect("slot_raw: index validated by coffer_view");
drop(guard);
Ok(PoolSlot::from_slab(ptr, len, 0, slot_idx))
}
pub fn max_slab_slot_size(&self) -> usize {
self.tiers.iter().map(|t| t.slot_size).max().unwrap_or(0)
}
pub fn tier_count(&self) -> usize {
self.tiers.len()
}
pub fn tier_slot_size(&self, i: usize) -> Option<usize> {
self.tiers.get(i).map(|t| t.slot_size)
}
}
static POOL: OnceLock<TieredPool> = OnceLock::new();
pub fn init_pool(config: TieredPoolConfig) -> Result<()> {
let pool = TieredPool::new(config)?;
POOL.set(pool)
.map_err(|_| Error::Memory("pool already initialized".into()))
}
pub(crate) fn global_pool() -> &'static TieredPool {
POOL.get_or_init(|| {
TieredPool::new(TieredPoolConfig::default())
.expect("enclave: default tiered pool init failed — OsRng unavailable")
})
}
pub fn pool_acquire(size: usize) -> Result<PoolSlot> {
global_pool().acquire(size)
}
pub fn pool_release(slot: PoolSlot) {
drop(slot);
}
pub fn coffer_view() -> Result<PoolSlot> {
global_pool().coffer_view()
}
pub(super) fn hot_cache_insert(id: u64, data: &[u8]) {
let pool = global_pool();
let mut slab = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
slab.cache_insert(id, data);
}
pub(super) fn hot_cache_get(id: u64) -> Option<PoolSlot> {
let pool = global_pool();
let mut guard = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
let slot_idx = guard.cache_get(id)?;
let (ptr, len) = guard.slot_raw(slot_idx)?;
drop(guard);
Some(PoolSlot::from_slab(ptr, len, 0, slot_idx))
}
pub(super) fn hot_cache_evict(id: u64) {
let pool = global_pool();
{
let mut slab = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
slab.cache_evict(id);
}
pool.tiers[0].cv.notify_one();
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use std::sync::Mutex;
use super::super::slab::FIRST_SHARED_SLOT;
use super::*;
static TEST_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn pool_acquire_small_uses_slab() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(16).unwrap();
assert!(slot.slab_index().is_some());
assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
}
#[test]
fn pool_acquire_large_uses_standalone() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(8192).unwrap();
assert!(slot.slab_index().is_none());
assert_eq!(slot.size(), 8192);
}
#[test]
fn pool_acquire_zero_uses_slab() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(0).unwrap();
assert!(slot.slab_index().is_some());
}
#[test]
fn pool_slot_write_and_read() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let mut slot = pool_acquire(16).unwrap();
let data = b"test data 12345!";
slot.bytes()[..data.len()].copy_from_slice(data);
assert_eq!(&slot.as_slice()[..data.len()], data);
}
#[test]
fn pool_slot_zeroized_on_drop() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let mut slot = pool_acquire(16).unwrap();
let sz = slot.size();
slot.bytes().iter_mut().for_each(|b| *b = 0xDE);
let slot_idx = slot.slab_index().unwrap();
drop(slot);
let pool = global_pool();
let mut guard = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
let mut acquired = vec![];
while let Some(idx) = guard.acquire_transient() {
acquired.push(idx);
if idx == slot_idx {
break;
}
}
if acquired.last() == Some(&slot_idx) {
let (ptr, _) = guard
.slot_raw(slot_idx)
.expect("slot_raw: index just acquired from slab");
let s = unsafe { std::slice::from_raw_parts(ptr, sz) };
assert!(s.iter().all(|&b| b == 0), "slot must be zeroed after drop");
}
for idx in acquired {
guard.release(idx);
}
}
#[test]
fn coffer_view_returns_key_sized_slot() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = coffer_view().unwrap();
assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
assert_eq!(slot.tier_index(), Some(0));
}
#[test]
fn coffer_view_is_deterministic() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let s1 = coffer_view().unwrap();
let key1 = s1.as_slice().to_vec();
drop(s1);
let s2 = coffer_view().unwrap();
let key2 = s2.as_slice().to_vec();
drop(s2);
assert_eq!(key1, key2, "coffer_view must return same key each call");
assert!(!key1.iter().all(|&b| b == 0));
}
#[test]
fn hot_cache_insert_get_evict() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let data = [0xAB_u8; DEFAULT_SLOT_SIZE];
hot_cache_insert(1001, &data);
let slot = hot_cache_get(1001).unwrap();
assert_eq!(slot.as_slice(), &data);
drop(slot);
hot_cache_evict(1001);
assert!(hot_cache_get(1001).is_none());
}
#[test]
fn hot_cache_get_returns_pool_slot() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let data = [0xCC_u8; DEFAULT_SLOT_SIZE];
hot_cache_insert(2002, &data);
let slot = hot_cache_get(2002).expect("should be a cache hit");
assert_eq!(slot.tier_index(), Some(0));
assert!(slot
.slab_index()
.map(|i| i >= FIRST_SHARED_SLOT)
.unwrap_or(false));
drop(slot);
hot_cache_evict(2002);
}
#[test]
fn tiered_pool_routes_small_to_first_tier() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(16).unwrap();
assert_eq!(
slot.tier_index(),
Some(0),
"should route to tier 0 (32-byte)"
);
assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
}
#[test]
fn tiered_pool_routes_medium_to_second_tier() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(48).unwrap();
assert!(
slot.tier_index().is_none(),
"48-byte request exceeds default tier; should be standalone"
);
assert_eq!(slot.size(), 48);
}
#[test]
fn tiered_pool_routes_large_to_standalone() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(8192).unwrap();
assert!(slot.tier_index().is_none(), "should be standalone");
assert_eq!(slot.size(), 8192);
}
#[test]
fn init_pool_default_config_has_one_tier() {
let pool = global_pool();
assert_eq!(pool.tier_count(), 1);
assert_eq!(pool.tier_slot_size(0), Some(DEFAULT_SLOT_SIZE));
assert_eq!(pool.max_slab_slot_size(), DEFAULT_SLOT_SIZE);
}
#[test]
fn tiered_pool_config_validates_ascending() {
let pool = TieredPool::new(TieredPoolConfig {
tier_sizes: vec![32, 32],
})
.unwrap();
assert_eq!(pool.tier_count(), 1, "duplicates should be deduped");
}
#[test]
fn tiered_pool_config_validates_max_slot_size() {
let ps = page_size();
let too_large = ps / 3 + 1;
let err = TieredPool::new(TieredPoolConfig {
tier_sizes: vec![too_large],
});
assert!(err.is_err(), "slot size > page_size/3 must be rejected");
}
#[test]
fn local_pool_coffer_view_works() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = coffer_view().unwrap();
assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
assert_eq!(slot.tier_index(), Some(0));
}
#[test]
fn tiered_pool_first_tier_must_be_32_bytes() {
let result = TieredPool::new(TieredPoolConfig {
tier_sizes: vec![16],
});
assert!(
result.is_err(),
"first tier < 32 should fail (coffer requires slot_size >= 32)"
);
}
#[test]
fn coffer_view_key_is_32_bytes_and_nonzero() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = coffer_view().unwrap();
assert_eq!(slot.size(), 32);
assert!(
slot.as_slice().iter().any(|&b| b != 0),
"coffer key must not be all zeros"
);
}
#[test]
fn empty_tier_sizes_rejected() {
let result = TieredPool::new(TieredPoolConfig { tier_sizes: vec![] });
assert!(result.is_err(), "empty tier_sizes must be rejected");
}
#[test]
fn tier_sizes_sorted_ascending_internally() {
let pool = TieredPool::new(TieredPoolConfig {
tier_sizes: vec![64, 32],
})
.unwrap();
assert_eq!(pool.tier_count(), 2);
assert_eq!(pool.tier_slot_size(0), Some(32));
assert_eq!(pool.tier_slot_size(1), Some(64));
}
#[test]
fn multi_tier_routing_smallest_fit() {
let pool = TieredPool::new(TieredPoolConfig {
tier_sizes: vec![32, 64],
})
.unwrap();
assert_eq!(pool.tier_count(), 2);
assert_eq!(pool.tier_slot_size(0), Some(32));
assert_eq!(pool.tier_slot_size(1), Some(64));
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(33).unwrap();
assert!(
slot.tier_index().is_none(),
"size 33 exceeds single 32-byte tier → standalone"
);
assert_eq!(slot.size(), 33);
let slot2 = pool_acquire(32).unwrap();
assert_eq!(
slot2.tier_index(),
Some(0),
"size 32 must use tier 0 (32-byte)"
);
drop(slot);
drop(slot2);
}
#[test]
fn pool_slot_tier_index_matches_acquisition_tier() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(16).unwrap();
assert_eq!(slot.tier_index(), Some(0));
assert_eq!(
slot.slab_index().map(|i| i >= FIRST_SHARED_SLOT),
Some(true)
);
}
#[test]
fn standalone_slot_has_no_tier_or_slab_index() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let slot = pool_acquire(9999).unwrap();
assert!(slot.tier_index().is_none());
assert!(slot.slab_index().is_none());
assert_eq!(slot.size(), 9999);
}
#[test]
fn pool_slot_zeroized_on_drop_standalone() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let mut slot = pool_acquire(512).unwrap();
slot.bytes().fill(0xBE);
drop(slot);
}
#[test]
fn hot_cache_not_populated_for_large_plaintext() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let big_data = [0x42_u8; 64];
hot_cache_insert(9876, &big_data);
let result = hot_cache_get(9876);
assert!(result.is_none(), "oversized data must not be cached");
}
#[test]
fn hot_cache_multiple_ids_are_independent() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let data_a = [0xAA_u8; DEFAULT_SLOT_SIZE];
let data_b = [0xBB_u8; DEFAULT_SLOT_SIZE];
hot_cache_insert(100, &data_a);
hot_cache_insert(101, &data_b);
let slot_a = hot_cache_get(100).unwrap();
let slot_b = hot_cache_get(101).unwrap();
assert_eq!(slot_a.as_slice(), &data_a, "id 100 must return data_a");
assert_eq!(slot_b.as_slice(), &data_b, "id 101 must return data_b");
drop(slot_a);
drop(slot_b);
hot_cache_evict(100);
hot_cache_evict(101);
}
#[test]
fn coffer_view_returns_same_key_every_time() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let s1 = coffer_view().unwrap();
let k1 = s1.as_slice().to_vec();
drop(s1);
let s2 = coffer_view().unwrap();
let k2 = s2.as_slice().to_vec();
drop(s2);
let s3 = coffer_view().unwrap();
let k3 = s3.as_slice().to_vec();
drop(s3);
assert_eq!(k1, k2, "coffer key must be same on second call");
assert_eq!(k2, k3, "coffer key must be same on third call");
assert!(
k1.iter().any(|&b| b != 0),
"coffer key must not be all zeros"
);
}
#[test]
fn concurrent_pool_acquire_and_release() {
use std::sync::Arc;
use std::thread;
let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let barrier = Arc::new(std::sync::Barrier::new(8));
let handles: Vec<_> = (0..8_u8)
.map(|i| {
let b = Arc::clone(&barrier);
thread::spawn(move || {
let mut slot = pool_acquire(16).unwrap();
slot.bytes()[0] = i;
b.wait(); assert_eq!(slot.as_slice()[0], i, "thread {i}: slot content must match");
drop(slot);
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
}
}