use crate::addr::{ShimPhysAddr, ShimVirtAddr, SHIM_VIRT_OFFSET};
use crate::exec::NEXT_MMAP_RWLOCK;
use crate::hostcall::{HostCall, SHIM_LOCAL_STORAGE};
use crate::paging::SHIM_PAGETABLE;
use crate::snp::{get_cbit_mask, pvalidate, snp_active, PvalidateSize};
use crate::spin::Locked;
use core::alloc::{GlobalAlloc, Layout};
use core::cmp::{max, min};
use core::convert::TryFrom;
use core::mem::{align_of, size_of};
use core::ptr::NonNull;
use goblin::elf::header::header64::Header;
use goblin::elf::header::ELFMAG;
use goblin::elf::program_header::program_header64::*;
use linked_list_allocator::Heap;
use lset::{Line, Span};
use nbytes::bytes;
use primordial::{Address, Page as Page4KiB};
use sallyport::guest::Handler;
use spinning::Lazy;
use x86_64::instructions::tlb::flush_all;
use x86_64::structures::paging::mapper::{MapToError, UnmapError};
use x86_64::structures::paging::{
self, Mapper, Page, PageTableFlags, PhysFrame, Size1GiB, Size2MiB, Size4KiB,
};
use x86_64::{align_down, align_up, PhysAddr, VirtAddr};
#[derive(Copy, Clone)]
#[repr(C, align(0x200000))]
#[allow(clippy::integer_arithmetic)]
pub struct Page2MiB([u8; bytes![2; MiB]]);
pub static ALLOCATOR: Lazy<Locked<EnarxAllocator>> =
Lazy::new(|| Locked::new(unsafe { EnarxAllocator::new() }));
pub struct EnarxAllocator {
last_alloc: usize,
max_alloc: usize,
end_of_mem: PhysAddr,
allocator: Heap,
}
impl core::fmt::Debug for EnarxAllocator {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
f.debug_struct("EnarxAllocator")
.field("last_alloc", &self.last_alloc)
.field("max_alloc", &self.max_alloc)
.finish()
}
}
#[inline]
#[allow(clippy::integer_arithmetic)]
fn msb(val: usize) -> u32 {
let mut val = val;
let mut r = 0;
loop {
val >>= 1;
if val == 0 {
return r;
}
r += 1;
}
}
#[derive(Debug)]
pub enum AllocateError {
NotAligned,
OutOfMemory,
ZeroSize,
PageAlreadyMapped,
ParentEntryHugePage,
}
impl EnarxAllocator {
unsafe fn new() -> Self {
let mut tls = SHIM_LOCAL_STORAGE.write();
let meminfo = {
let mut host_call = HostCall::try_new(&mut tls).unwrap();
host_call.mem_info().unwrap()
};
drop(tls);
const MIN_EXP: u32 = 24; let c_bit_mask = get_cbit_mask();
let target_exp: u32 = if c_bit_mask > 0 {
u32::min(47, msb(c_bit_mask as _).checked_sub(1).unwrap()) } else {
47 };
debug_assert!(
meminfo > (target_exp.checked_sub(MIN_EXP).unwrap()) as _,
"Not enough memory slots available"
);
let log_rest = msb(meminfo
.checked_sub(target_exp.checked_sub(MIN_EXP).unwrap() as usize)
.unwrap());
let max_exp = target_exp
.checked_sub(log_rest.min(target_exp.checked_sub(MIN_EXP).unwrap()))
.unwrap();
let next_alloc = (2usize).checked_pow(MIN_EXP).unwrap();
let max_alloc = (2usize).checked_pow(max_exp).unwrap();
let mem_start: PhysAddr = {
let shim_virt = ShimVirtAddr::from(&crate::_ENARX_MEM_START);
PhysAddr::new(ShimPhysAddr::try_from(shim_virt).unwrap().raw().raw())
};
let code_size = {
let header: &Header = &crate::_ENARX_EXEC_START;
let header_ptr = header as *const _;
if !header.e_ident[..ELFMAG.len()].eq(ELFMAG) {
panic!("Not valid ELF");
}
let headers: &[ProgramHeader] = core::slice::from_raw_parts(
(header_ptr as usize as *const u8).offset(header.e_phoff as _)
as *const ProgramHeader,
header.e_phnum as _,
);
let region = Span::from(
headers
.iter()
.filter(|ph| ph.p_type == PT_LOAD)
.map(|x| {
Line::from(
x.p_vaddr as usize
..(x.p_vaddr as usize)
.checked_add(x.p_memsz as usize)
.unwrap(),
)
})
.fold(
Line {
start: usize::MAX,
end: usize::MIN,
},
|l, r| Line::from(min(l.start, r.start)..max(l.end, r.end)),
),
);
assert!(
(&crate::_ENARX_EXEC_END as *const _ as usize)
.checked_sub(&crate::_ENARX_EXEC_START as *const _ as usize)
.unwrap()
> region.count
);
align_up(region.count as u64, Page4KiB::SIZE as u64) as usize
};
let mem_size = align_up(
(&crate::_ENARX_EXEC_START as *const _ as u64)
.checked_sub(&crate::_ENARX_MEM_START as *const _ as u64)
.unwrap()
.checked_add(code_size as u64)
.unwrap(),
Page4KiB::SIZE as u64,
) as usize;
let end_of_mem = mem_start + mem_size;
*NEXT_MMAP_RWLOCK.write() =
(*NEXT_MMAP_RWLOCK.read() + code_size).align_up(Page::<Size4KiB>::SIZE);
let allocator = Heap::empty();
EnarxAllocator {
last_alloc: next_alloc.checked_div(2).unwrap(),
max_alloc,
end_of_mem,
allocator,
}
}
fn balloon(&mut self) -> bool {
let mut last_size: usize = self.last_alloc;
loop {
let new_size: usize = 2u64
.checked_mul(last_size as u64)
.unwrap_or(last_size as u64) as _;
let new_size = new_size.min(self.max_alloc);
let num_pages = new_size.checked_div(Page4KiB::SIZE as _).unwrap();
let mut tls = SHIM_LOCAL_STORAGE.write();
let ret = HostCall::try_new(&mut tls).unwrap().balloon_memory(
12, num_pages,
self.end_of_mem.as_u64() as _,
);
drop(tls);
match ret {
Ok(_) => {
let shim_phys_page = ShimPhysAddr::<u8>::try_from(self.end_of_mem).unwrap();
let free_start: *mut u8 = ShimVirtAddr::from(shim_phys_page).into();
self.end_of_mem += new_size;
if snp_active() {
let virt_region = Span::new(free_start as usize, new_size);
let virt_line = Line::from(virt_region);
for addr in
(virt_line.start..virt_line.end).step_by(Page::<Size4KiB>::SIZE as _)
{
let va = VirtAddr::new(addr as _);
pvalidate(va, PvalidateSize::Size4K, true).unwrap();
}
}
unsafe {
if self.allocator.size() > 0 {
self.allocator.extend(new_size);
} else {
self.allocator.init(free_start as _, new_size);
}
}
self.last_alloc = new_size;
return true;
}
Err(_) => {
last_size = last_size.checked_div(2).unwrap();
if last_size < Page4KiB::SIZE {
return false;
}
}
}
}
}
fn try_alloc_half(&mut self, mut size: usize) -> (*mut u8, usize) {
assert!(size >= size_of::<Page4KiB>());
loop {
let p = self.alloc_pages(size);
if let Ok(p) = p {
unsafe {
core::ptr::write_bytes(p.as_ptr(), 0, size);
}
return (p.as_ptr(), size);
}
if size == size_of::<Page4KiB>() {
return (core::ptr::null_mut(), size);
}
size = size.checked_div(2).unwrap();
}
}
fn alloc_pages(&mut self, size: usize) -> Result<NonNull<u8>, ()> {
self.allocator
.allocate_first_fit(Layout::from_size_align(size, align_of::<Page4KiB>()).unwrap())
}
unsafe fn dealloc_pages(&mut self, ptr: *mut u8, size: usize) {
self.allocator.deallocate(
NonNull::new_unchecked(ptr),
Layout::from_size_align_unchecked(size, align_of::<Page4KiB>()),
)
}
pub fn allocate_and_map_memory(
&mut self,
map_to: VirtAddr,
size: usize,
flags: PageTableFlags,
parent_flags: PageTableFlags,
) -> Result<&'static mut [u8], AllocateError> {
if size == 0 {
return Err(AllocateError::ZeroSize);
}
if !map_to.is_aligned(align_of::<Page4KiB>() as u64) {
return Err(AllocateError::NotAligned);
}
if size != align_down(size as _, Page::<Size4KiB>::SIZE) as usize {
return Err(AllocateError::NotAligned);
}
let curr_size = (2usize).checked_pow(msb(size)).unwrap();
let (first_half, first_half_size) = {
while self.allocator.free() < curr_size {
self.balloon();
}
let (chunk, chunk_size) = self.try_alloc_half(curr_size);
if chunk.is_null() {
self.balloon();
self.try_alloc_half(curr_size)
} else {
(chunk, chunk_size)
}
};
if first_half.is_null() {
return Err(AllocateError::OutOfMemory);
}
let second_half_size = size.checked_sub(first_half_size).unwrap();
if second_half_size > 0 {
if let Err(e) = self.allocate_and_map_memory(
map_to + first_half_size,
second_half_size,
flags,
parent_flags,
) {
unsafe {
self.dealloc_pages(first_half, first_half_size);
}
return Err(e);
}
}
let phys = shim_virt_to_enc_phys(first_half);
if let Err(e) = self.map_memory(phys, map_to, first_half_size, flags, parent_flags) {
unsafe {
self.dealloc_pages(first_half, first_half_size);
}
let _ = self.unmap_memory(map_to + first_half_size, second_half_size);
return Err(e);
}
Ok(unsafe { core::slice::from_raw_parts_mut(map_to.as_mut_ptr() as *mut u8, size) })
}
pub fn map_memory(
&mut self,
map_from: PhysAddr,
map_to: VirtAddr,
size: usize,
flags: PageTableFlags,
parent_flags: PageTableFlags,
) -> Result<(), AllocateError> {
if size == 0 {
return Err(AllocateError::ZeroSize);
}
let frame_range_from = {
let start = map_from;
let end = start + size - 1u64;
let start_frame = PhysFrame::<Size4KiB>::containing_address(start);
let end_frame = PhysFrame::<Size4KiB>::containing_address(end);
PhysFrame::range_inclusive(start_frame, end_frame)
};
let page_range_to = {
let start = map_to;
let end = start + size - 1u64;
let start_page = Page::<Size4KiB>::containing_address(start);
let end_page = Page::<Size4KiB>::containing_address(end);
Page::range_inclusive(start_page, end_page)
};
for (frame_from, page_to) in frame_range_from.zip(page_range_to) {
unsafe {
SHIM_PAGETABLE
.write()
.map_to_with_table_flags(page_to, frame_from, flags, parent_flags, self)
.map_err(|e| match e {
MapToError::FrameAllocationFailed => AllocateError::OutOfMemory,
MapToError::ParentEntryHugePage => AllocateError::ParentEntryHugePage,
MapToError::PageAlreadyMapped(_) => AllocateError::PageAlreadyMapped,
})?
.ignore();
}
}
flush_all();
Ok(())
}
pub fn unmap_memory(&mut self, virt_addr: VirtAddr, size: usize) -> Result<(), UnmapError> {
if size == 0 {
return Ok(());
}
let page_range_to = {
let start = virt_addr;
let end = start + size - 1u64;
let start_page = Page::<Size4KiB>::containing_address(start);
let end_page = Page::<Size4KiB>::containing_address(end);
Page::range_inclusive(start_page, end_page)
};
for frame_from in page_range_to {
let phys = {
let (phys_frame, _) = SHIM_PAGETABLE.write().unmap(frame_from)?;
phys_frame.start_address()
};
let free_start_phys = Address::<usize, _>::from(phys.as_u64() as *const u8);
let shim_phys_page = ShimPhysAddr::from(free_start_phys);
let shim_virt: *mut u8 = ShimVirtAddr::from(shim_phys_page).into();
unsafe {
self.dealloc_pages(shim_virt, Page::<Size4KiB>::SIZE as usize);
}
}
flush_all();
Ok(())
}
pub fn try_alloc(&mut self, layout: Layout) -> Option<NonNull<u8>> {
let b = self.allocator.allocate_first_fit(layout).ok();
if b.is_none() && self.balloon() {
self.allocator.allocate_first_fit(layout).ok()
} else {
b
}
}
pub unsafe fn deallocate(&mut self, ptr: *mut u8, layout: Layout) {
self.allocator
.deallocate(NonNull::new(ptr).unwrap(), layout);
}
pub fn free(&self) -> usize {
self.allocator.free()
}
}
unsafe impl paging::FrameAllocator<Size4KiB> for EnarxAllocator {
#[allow(unused_unsafe)]
fn allocate_frame(&mut self) -> Option<PhysFrame<Size4KiB>> {
self.try_alloc(unsafe {
Layout::from_size_align_unchecked(
Page::<Size4KiB>::SIZE as _,
Page::<Size4KiB>::SIZE as _,
)
})
.map(|a| a.as_ptr())
.map(shim_virt_to_enc_phys)
.map(PhysFrame::containing_address)
}
}
unsafe impl paging::FrameAllocator<Size2MiB> for EnarxAllocator {
#[allow(unused_unsafe)]
fn allocate_frame(&mut self) -> Option<PhysFrame<Size2MiB>> {
self.try_alloc(unsafe {
Layout::from_size_align_unchecked(
Page::<Size2MiB>::SIZE as _,
Page::<Size2MiB>::SIZE as _,
)
})
.map(|a| a.as_ptr())
.map(shim_virt_to_enc_phys)
.map(PhysFrame::containing_address)
}
}
unsafe impl paging::FrameAllocator<Size1GiB> for EnarxAllocator {
#[allow(unused_unsafe)]
fn allocate_frame(&mut self) -> Option<PhysFrame<Size1GiB>> {
self.try_alloc(unsafe {
Layout::from_size_align_unchecked(
Page::<Size1GiB>::SIZE as _,
Page::<Size1GiB>::SIZE as _,
)
})
.map(|a| a.as_ptr())
.map(shim_virt_to_enc_phys)
.map(PhysFrame::containing_address)
}
}
#[inline]
fn shim_virt_to_enc_phys<T>(p: *mut T) -> PhysAddr {
let virt = VirtAddr::from_ptr(p);
debug_assert!(virt.as_u64() > SHIM_VIRT_OFFSET);
PhysAddr::new(virt.as_u64().checked_sub(SHIM_VIRT_OFFSET).unwrap() | get_cbit_mask())
}
unsafe impl GlobalAlloc for Locked<EnarxAllocator> {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let mut this = self.lock();
this.try_alloc(layout)
.map_or(core::ptr::null_mut(), |p| p.as_ptr())
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let mut this = self.lock();
this.deallocate(ptr, layout);
}
}