1#![allow(unsafe_code)]
12
13use std::sync::{Condvar, Mutex, OnceLock};
14
15use super::memcall::page_size;
16use super::secure_buffer::SecureBuffer;
17use super::slab::{SecureSlab, DEFAULT_SLOT_SIZE, SLOT_WAIT_TIMEOUT};
18use crate::error::{Error, Result};
19
20enum PoolSlotOrigin {
23 Slab {
25 tier_index: usize,
26 slot_index: usize,
27 },
28 Standalone(SecureBuffer),
30}
31
32pub struct PoolSlot {
52 ptr: *mut u8,
53 len: usize,
54 origin: PoolSlotOrigin,
55}
56
57unsafe impl Send for PoolSlot {}
63
64impl std::fmt::Debug for PoolSlot {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("PoolSlot").field("len", &self.len).finish()
67 }
68}
69
70impl PoolSlot {
71 pub(crate) fn from_slab(
72 ptr: *mut u8,
73 len: usize,
74 tier_index: usize,
75 slot_index: usize,
76 ) -> Self {
77 Self {
78 ptr,
79 len,
80 origin: PoolSlotOrigin::Slab {
81 tier_index,
82 slot_index,
83 },
84 }
85 }
86
87 fn from_standalone(mut buf: SecureBuffer) -> Self {
88 drop(buf.melt());
90 let ptr = buf.bytes().as_mut_ptr();
91 let len = buf.size();
92 Self {
93 ptr,
94 len,
95 origin: PoolSlotOrigin::Standalone(buf),
96 }
97 }
98
99 pub fn bytes(&mut self) -> &mut [u8] {
101 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
105 }
106
107 pub fn as_slice(&self) -> &[u8] {
109 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
112 }
113
114 pub fn size(&self) -> usize {
116 self.len
117 }
118
119 #[allow(dead_code)]
121 pub(crate) fn slab_index(&self) -> Option<usize> {
122 match &self.origin {
123 PoolSlotOrigin::Slab { slot_index, .. } => Some(*slot_index),
124 PoolSlotOrigin::Standalone(_) => None,
125 }
126 }
127
128 #[allow(dead_code)]
130 pub(crate) fn tier_index(&self) -> Option<usize> {
131 match &self.origin {
132 PoolSlotOrigin::Slab { tier_index, .. } => Some(*tier_index),
133 PoolSlotOrigin::Standalone(_) => None,
134 }
135 }
136}
137
138impl Drop for PoolSlot {
139 fn drop(&mut self) {
140 match &mut self.origin {
141 PoolSlotOrigin::Slab {
142 tier_index,
143 slot_index,
144 } => {
145 unsafe {
150 use zeroize::Zeroize;
151 std::slice::from_raw_parts_mut(self.ptr, self.len).zeroize();
152 }
153 let pool = global_pool();
154 if let Ok(mut slab) = pool.tiers[*tier_index].slab.lock() {
155 slab.release(*slot_index);
156 }
157 pool.tiers[*tier_index].cv.notify_one();
159 }
160 PoolSlotOrigin::Standalone(buf) => {
161 drop(buf.melt());
163 unsafe {
166 use zeroize::Zeroize;
167 std::slice::from_raw_parts_mut(self.ptr, self.len).zeroize();
168 }
169 }
171 }
172 }
173}
174
175struct Tier {
179 slot_size: usize,
180 slab: Mutex<SecureSlab>,
181 cv: Condvar,
183}
184
185impl std::fmt::Debug for Tier {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("Tier")
188 .field("slot_size", &self.slot_size)
189 .finish()
190 }
191}
192
193#[derive(Debug, Clone)]
199pub struct TieredPoolConfig {
200 pub tier_sizes: Vec<usize>,
202}
203
204impl Default for TieredPoolConfig {
205 fn default() -> Self {
207 Self {
208 tier_sizes: vec![DEFAULT_SLOT_SIZE],
209 }
210 }
211}
212
213#[derive(Debug)]
222pub struct TieredPool {
223 tiers: Vec<Tier>,
224}
225
226unsafe impl Send for TieredPool {}
231unsafe impl Sync for TieredPool {}
232
233impl TieredPool {
234 pub fn new(config: TieredPoolConfig) -> Result<Self> {
239 #[cfg(not(test))]
244 crate::internal::core::process::harden_process();
245
246 let ps = page_size();
247 let max_slot = ps / 3;
248
249 if config.tier_sizes.is_empty() {
250 return Err(Error::Memory(
251 "TieredPoolConfig: tier_sizes must be non-empty".into(),
252 ));
253 }
254
255 let mut sizes = config.tier_sizes;
257 sizes.sort_unstable();
258 sizes.dedup();
259
260 for &sz in &sizes {
262 if sz == 0 {
263 return Err(Error::Memory(
264 "TieredPoolConfig: tier size 0 is invalid".into(),
265 ));
266 }
267 if sz > max_slot {
268 return Err(Error::Memory(format!(
269 "TieredPoolConfig: tier size {sz} exceeds page_size/3 ({max_slot})"
270 )));
271 }
272 }
273
274 if sizes[0] < 32 {
276 return Err(Error::Memory(format!(
277 "TieredPool: first tier slot_size must be >= 32 for coffer, got {}",
278 sizes[0]
279 )));
280 }
281
282 let mut tiers = Vec::with_capacity(sizes.len());
285 for (i, sz) in sizes.into_iter().enumerate() {
286 let init_coffer = i == 0;
287 let slab = SecureSlab::new(sz, init_coffer)?;
288 tiers.push(Tier {
289 slot_size: sz,
290 slab: Mutex::new(slab),
291 cv: Condvar::new(),
292 });
293 }
294
295 Ok(Self { tiers })
296 }
297
298 fn tier_for_size(&self, size: usize) -> Option<usize> {
300 self.tiers.iter().position(|t| t.slot_size >= size)
301 }
302
303 pub(crate) fn acquire(&self, size: usize) -> Result<PoolSlot> {
314 if let Some(tier_idx) = self.tier_for_size(size) {
315 let deadline = std::time::Instant::now() + SLOT_WAIT_TIMEOUT;
316 let mut guard = self.tiers[tier_idx]
317 .slab
318 .lock()
319 .unwrap_or_else(|e| e.into_inner());
320 loop {
321 if let Some(slot_idx) = guard.acquire_transient() {
322 let (ptr, len) = guard
323 .slot_raw(slot_idx)
324 .expect("slot_raw: index validated by acquire_transient");
325 drop(guard);
326 return Ok(PoolSlot::from_slab(ptr, len, tier_idx, slot_idx));
327 }
328 let timeout = deadline.saturating_duration_since(std::time::Instant::now());
329 if timeout.is_zero() {
330 tracing::warn!(
331 size,
332 tier_idx,
333 "pool acquire: all slab slots exhausted; using standalone SecureBuffer"
334 );
335 drop(guard);
336 break;
337 }
338 let result = self.tiers[tier_idx]
339 .cv
340 .wait_timeout(guard, timeout)
341 .unwrap_or_else(|e| e.into_inner());
342 guard = result.0;
343 }
344 }
345
346 Ok(PoolSlot::from_standalone(SecureBuffer::new(size)?))
348 }
349
350 pub(crate) fn coffer_view(&self) -> Result<PoolSlot> {
356 let mut guard = self.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
357 let slot_idx = guard
358 .coffer_view()
359 .ok_or_else(|| Error::Memory("coffer_view: no free slab slot".into()))?;
360 let (ptr, len) = guard
361 .slot_raw(slot_idx)
362 .expect("slot_raw: index validated by coffer_view");
363 drop(guard);
364 Ok(PoolSlot::from_slab(ptr, len, 0, slot_idx))
365 }
366
367 pub fn max_slab_slot_size(&self) -> usize {
369 self.tiers.iter().map(|t| t.slot_size).max().unwrap_or(0)
370 }
371
372 pub fn tier_count(&self) -> usize {
374 self.tiers.len()
375 }
376
377 pub fn tier_slot_size(&self, i: usize) -> Option<usize> {
379 self.tiers.get(i).map(|t| t.slot_size)
380 }
381}
382
383static POOL: OnceLock<TieredPool> = OnceLock::new();
386
387pub fn init_pool(config: TieredPoolConfig) -> Result<()> {
393 let pool = TieredPool::new(config)?;
394 POOL.set(pool)
395 .map_err(|_| Error::Memory("pool already initialized".into()))
396}
397
398pub(crate) fn global_pool() -> &'static TieredPool {
399 POOL.get_or_init(|| {
400 TieredPool::new(TieredPoolConfig::default())
401 .expect("enclave: default tiered pool init failed — OsRng unavailable")
402 })
403}
404
405pub fn pool_acquire(size: usize) -> Result<PoolSlot> {
410 global_pool().acquire(size)
411}
412
413pub fn pool_release(slot: PoolSlot) {
416 drop(slot);
417}
418
419pub fn coffer_view() -> Result<PoolSlot> {
422 global_pool().coffer_view()
423}
424
425pub(super) fn hot_cache_insert(id: u64, data: &[u8]) {
430 let pool = global_pool();
431 let mut slab = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
432 slab.cache_insert(id, data);
433}
434
435pub(super) fn hot_cache_get(id: u64) -> Option<PoolSlot> {
438 let pool = global_pool();
439 let mut guard = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
440 let slot_idx = guard.cache_get(id)?;
441 let (ptr, len) = guard.slot_raw(slot_idx)?;
443 drop(guard);
444 Some(PoolSlot::from_slab(ptr, len, 0, slot_idx))
445}
446
447pub(super) fn hot_cache_evict(id: u64) {
449 let pool = global_pool();
450 {
451 let mut slab = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
452 slab.cache_evict(id);
453 }
454 pool.tiers[0].cv.notify_one();
455}
456
457#[cfg(test)]
458#[allow(clippy::unwrap_used, clippy::panic)]
459mod tests {
460 use std::sync::Mutex;
461
462 use super::super::slab::FIRST_SHARED_SLOT;
463 use super::*;
464
465 static TEST_LOCK: Mutex<()> = Mutex::new(());
467
468 #[test]
471 fn pool_acquire_small_uses_slab() {
472 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
473 let slot = pool_acquire(16).unwrap();
474 assert!(slot.slab_index().is_some());
475 assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
476 }
477
478 #[test]
479 fn pool_acquire_large_uses_standalone() {
480 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
481 let slot = pool_acquire(8192).unwrap();
482 assert!(slot.slab_index().is_none());
483 assert_eq!(slot.size(), 8192);
484 }
485
486 #[test]
487 fn pool_acquire_zero_uses_slab() {
488 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
489 let slot = pool_acquire(0).unwrap();
491 assert!(slot.slab_index().is_some());
492 }
493
494 #[test]
495 fn pool_slot_write_and_read() {
496 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
497 let mut slot = pool_acquire(16).unwrap();
498 let data = b"test data 12345!";
499 slot.bytes()[..data.len()].copy_from_slice(data);
500 assert_eq!(&slot.as_slice()[..data.len()], data);
501 }
502
503 #[test]
504 fn pool_slot_zeroized_on_drop() {
505 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
506 let mut slot = pool_acquire(16).unwrap();
508 let sz = slot.size();
509 slot.bytes().iter_mut().for_each(|b| *b = 0xDE);
510 let slot_idx = slot.slab_index().unwrap();
512 drop(slot);
513 let pool = global_pool();
515 let mut guard = pool.tiers[0].slab.lock().unwrap_or_else(|e| e.into_inner());
516 let mut acquired = vec![];
518 while let Some(idx) = guard.acquire_transient() {
519 acquired.push(idx);
520 if idx == slot_idx {
521 break;
522 }
523 }
524 if acquired.last() == Some(&slot_idx) {
525 let (ptr, _) = guard
526 .slot_raw(slot_idx)
527 .expect("slot_raw: index just acquired from slab");
528 let s = unsafe { std::slice::from_raw_parts(ptr, sz) };
529 assert!(s.iter().all(|&b| b == 0), "slot must be zeroed after drop");
530 }
531 for idx in acquired {
532 guard.release(idx);
533 }
534 }
535
536 #[test]
537 fn coffer_view_returns_key_sized_slot() {
538 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
539 let slot = coffer_view().unwrap();
540 assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
541 assert_eq!(slot.tier_index(), Some(0));
543 }
544
545 #[test]
546 fn coffer_view_is_deterministic() {
547 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
548 let s1 = coffer_view().unwrap();
549 let key1 = s1.as_slice().to_vec();
550 drop(s1);
551 let s2 = coffer_view().unwrap();
552 let key2 = s2.as_slice().to_vec();
553 drop(s2);
554 assert_eq!(key1, key2, "coffer_view must return same key each call");
555 assert!(!key1.iter().all(|&b| b == 0));
556 }
557
558 #[test]
559 fn hot_cache_insert_get_evict() {
560 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
561 let data = [0xAB_u8; DEFAULT_SLOT_SIZE];
562 hot_cache_insert(1001, &data);
563 let slot = hot_cache_get(1001).unwrap();
564 assert_eq!(slot.as_slice(), &data);
565 drop(slot);
566 hot_cache_evict(1001);
567 assert!(hot_cache_get(1001).is_none());
568 }
569
570 #[test]
571 fn hot_cache_get_returns_pool_slot() {
572 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
573 let data = [0xCC_u8; DEFAULT_SLOT_SIZE];
574 hot_cache_insert(2002, &data);
575 let slot = hot_cache_get(2002).expect("should be a cache hit");
576 assert_eq!(slot.tier_index(), Some(0));
578 assert!(slot
579 .slab_index()
580 .map(|i| i >= FIRST_SHARED_SLOT)
581 .unwrap_or(false));
582 drop(slot);
583 hot_cache_evict(2002);
584 }
585
586 #[test]
589 fn tiered_pool_routes_small_to_first_tier() {
590 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
592 let slot = pool_acquire(16).unwrap();
593 assert_eq!(
594 slot.tier_index(),
595 Some(0),
596 "should route to tier 0 (32-byte)"
597 );
598 assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
599 }
600
601 #[test]
602 fn tiered_pool_routes_medium_to_second_tier() {
603 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
605 let slot = pool_acquire(48).unwrap();
606 assert!(
607 slot.tier_index().is_none(),
608 "48-byte request exceeds default tier; should be standalone"
609 );
610 assert_eq!(slot.size(), 48);
611 }
612
613 #[test]
614 fn tiered_pool_routes_large_to_standalone() {
615 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
616 let slot = pool_acquire(8192).unwrap();
617 assert!(slot.tier_index().is_none(), "should be standalone");
618 assert_eq!(slot.size(), 8192);
619 }
620
621 #[test]
622 fn init_pool_default_config_has_one_tier() {
623 let pool = global_pool();
625 assert_eq!(pool.tier_count(), 1);
626 assert_eq!(pool.tier_slot_size(0), Some(DEFAULT_SLOT_SIZE));
627 assert_eq!(pool.max_slab_slot_size(), DEFAULT_SLOT_SIZE);
628 }
629
630 #[test]
631 fn tiered_pool_config_validates_ascending() {
632 let pool = TieredPool::new(TieredPoolConfig {
634 tier_sizes: vec![32, 32],
635 })
636 .unwrap();
637 assert_eq!(pool.tier_count(), 1, "duplicates should be deduped");
638 }
639
640 #[test]
641 fn tiered_pool_config_validates_max_slot_size() {
642 let ps = page_size();
643 let too_large = ps / 3 + 1;
644 let err = TieredPool::new(TieredPoolConfig {
645 tier_sizes: vec![too_large],
646 });
647 assert!(err.is_err(), "slot size > page_size/3 must be rejected");
648 }
649
650 #[test]
651 fn local_pool_coffer_view_works() {
652 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
654 let slot = coffer_view().unwrap();
655 assert_eq!(slot.size(), DEFAULT_SLOT_SIZE);
656 assert_eq!(slot.tier_index(), Some(0));
657 }
658
659 #[test]
662 fn tiered_pool_first_tier_must_be_32_bytes() {
663 let result = TieredPool::new(TieredPoolConfig {
665 tier_sizes: vec![16],
666 });
667 assert!(
668 result.is_err(),
669 "first tier < 32 should fail (coffer requires slot_size >= 32)"
670 );
671 }
672
673 #[test]
674 fn coffer_view_key_is_32_bytes_and_nonzero() {
675 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
676 let slot = coffer_view().unwrap();
677 assert_eq!(slot.size(), 32);
678 assert!(
680 slot.as_slice().iter().any(|&b| b != 0),
681 "coffer key must not be all zeros"
682 );
683 }
684
685 #[test]
686 fn empty_tier_sizes_rejected() {
687 let result = TieredPool::new(TieredPoolConfig { tier_sizes: vec![] });
688 assert!(result.is_err(), "empty tier_sizes must be rejected");
689 }
690
691 #[test]
692 fn tier_sizes_sorted_ascending_internally() {
693 let pool = TieredPool::new(TieredPoolConfig {
695 tier_sizes: vec![64, 32],
696 })
697 .unwrap();
698 assert_eq!(pool.tier_count(), 2);
699 assert_eq!(pool.tier_slot_size(0), Some(32));
700 assert_eq!(pool.tier_slot_size(1), Some(64));
701 }
702
703 #[test]
704 fn multi_tier_routing_smallest_fit() {
705 let pool = TieredPool::new(TieredPoolConfig {
708 tier_sizes: vec![32, 64],
709 })
710 .unwrap();
711 assert_eq!(pool.tier_count(), 2);
712 assert_eq!(pool.tier_slot_size(0), Some(32));
714 assert_eq!(pool.tier_slot_size(1), Some(64));
715 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
717 let slot = pool_acquire(33).unwrap();
718 assert!(
719 slot.tier_index().is_none(),
720 "size 33 exceeds single 32-byte tier → standalone"
721 );
722 assert_eq!(slot.size(), 33);
723 let slot2 = pool_acquire(32).unwrap();
725 assert_eq!(
726 slot2.tier_index(),
727 Some(0),
728 "size 32 must use tier 0 (32-byte)"
729 );
730 drop(slot);
731 drop(slot2);
732 }
733
734 #[test]
735 fn pool_slot_tier_index_matches_acquisition_tier() {
736 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
737 let slot = pool_acquire(16).unwrap();
739 assert_eq!(slot.tier_index(), Some(0));
740 assert_eq!(
741 slot.slab_index().map(|i| i >= FIRST_SHARED_SLOT),
742 Some(true)
743 );
744 }
745
746 #[test]
747 fn standalone_slot_has_no_tier_or_slab_index() {
748 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
749 let slot = pool_acquire(9999).unwrap();
750 assert!(slot.tier_index().is_none());
751 assert!(slot.slab_index().is_none());
752 assert_eq!(slot.size(), 9999);
753 }
754
755 #[test]
756 fn pool_slot_zeroized_on_drop_standalone() {
757 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
759 let mut slot = pool_acquire(512).unwrap();
760 slot.bytes().fill(0xBE);
761 drop(slot);
763 }
765
766 #[test]
767 fn hot_cache_not_populated_for_large_plaintext() {
768 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
769 let big_data = [0x42_u8; 64];
772 hot_cache_insert(9876, &big_data);
773 let result = hot_cache_get(9876);
775 assert!(result.is_none(), "oversized data must not be cached");
776 }
777
778 #[test]
779 fn hot_cache_multiple_ids_are_independent() {
780 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
781 let data_a = [0xAA_u8; DEFAULT_SLOT_SIZE];
782 let data_b = [0xBB_u8; DEFAULT_SLOT_SIZE];
783 hot_cache_insert(100, &data_a);
784 hot_cache_insert(101, &data_b);
785 let slot_a = hot_cache_get(100).unwrap();
786 let slot_b = hot_cache_get(101).unwrap();
787 assert_eq!(slot_a.as_slice(), &data_a, "id 100 must return data_a");
788 assert_eq!(slot_b.as_slice(), &data_b, "id 101 must return data_b");
789 drop(slot_a);
790 drop(slot_b);
791 hot_cache_evict(100);
792 hot_cache_evict(101);
793 }
794
795 #[test]
796 fn coffer_view_returns_same_key_every_time() {
797 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
798 let s1 = coffer_view().unwrap();
799 let k1 = s1.as_slice().to_vec();
800 drop(s1);
801 let s2 = coffer_view().unwrap();
802 let k2 = s2.as_slice().to_vec();
803 drop(s2);
804 let s3 = coffer_view().unwrap();
805 let k3 = s3.as_slice().to_vec();
806 drop(s3);
807 assert_eq!(k1, k2, "coffer key must be same on second call");
808 assert_eq!(k2, k3, "coffer key must be same on third call");
809 assert!(
810 k1.iter().any(|&b| b != 0),
811 "coffer key must not be all zeros"
812 );
813 }
814
815 #[test]
816 fn concurrent_pool_acquire_and_release() {
817 use std::sync::Arc;
818 use std::thread;
819 let _guard = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
820 let barrier = Arc::new(std::sync::Barrier::new(8));
822 let handles: Vec<_> = (0..8_u8)
823 .map(|i| {
824 let b = Arc::clone(&barrier);
825 thread::spawn(move || {
826 let mut slot = pool_acquire(16).unwrap();
827 slot.bytes()[0] = i;
828 b.wait(); assert_eq!(slot.as_slice()[0], i, "thread {i}: slot content must match");
830 drop(slot);
831 })
832 })
833 .collect();
834 for h in handles {
835 h.join().expect("thread panicked");
836 }
837 }
838}