use core::alloc::Layout;
use core::ffi::c_void;
use core::mem::{align_of, size_of};
use core::ptr;
use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode;
use hyperlight_guest::exit::abort_with_code;
const DEFAULT_ALIGN: usize = align_of::<u128>();
const HEADER_LEN: usize = size_of::<Header>();
#[repr(transparent)]
struct Header(Layout);
unsafe fn alloc_helper(size: usize, alignment: usize, zero: bool) -> *mut c_void {
if size == 0 {
return ptr::null_mut();
}
let actual_align = alignment.max(align_of::<Header>());
let data_offset = HEADER_LEN.next_multiple_of(actual_align);
let Some(total_size) = data_offset.checked_add(size) else {
abort_with_code(&[ErrorCode::MallocFailed as u8]);
};
let layout =
Layout::from_size_align(total_size, actual_align).expect("Invalid layout parameters");
unsafe {
let raw_ptr = match zero {
true => alloc::alloc::alloc_zeroed(layout),
false => alloc::alloc::alloc(layout),
};
if raw_ptr.is_null() {
abort_with_code(&[ErrorCode::MallocFailed as u8]);
}
let header_ptr = raw_ptr.add(data_offset - HEADER_LEN).cast::<Header>();
header_ptr.write(Header(layout));
raw_ptr.add(data_offset) as *mut c_void
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn malloc(size: usize) -> *mut c_void {
unsafe { alloc_helper(size, DEFAULT_ALIGN, false) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn calloc(nmemb: usize, size: usize) -> *mut c_void {
unsafe {
let total_size = nmemb
.checked_mul(size)
.expect("nmemb * size should not overflow in calloc");
alloc_helper(total_size, DEFAULT_ALIGN, true)
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn aligned_alloc(alignment: usize, size: usize) -> *mut c_void {
if alignment == 0 || (alignment & (alignment - 1)) != 0 {
return ptr::null_mut();
}
unsafe { alloc_helper(size, alignment, false) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn free(ptr: *mut c_void) {
if ptr.is_null() {
return;
}
let user_ptr = ptr as *const u8;
unsafe {
let header_ptr = user_ptr.sub(HEADER_LEN).cast::<Header>();
let layout = header_ptr.read().0;
let offset = HEADER_LEN.next_multiple_of(layout.align());
let raw_ptr = user_ptr.sub(offset) as *mut u8;
alloc::alloc::dealloc(raw_ptr, layout);
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn realloc(ptr: *mut c_void, size: usize) -> *mut c_void {
if ptr.is_null() {
return unsafe { malloc(size) };
}
if size == 0 {
unsafe {
free(ptr);
}
return ptr::null_mut();
}
let user_ptr = ptr as *const u8;
unsafe {
let header_ptr = user_ptr.sub(HEADER_LEN).cast::<Header>();
let old_layout = header_ptr.read().0;
let old_offset = HEADER_LEN.next_multiple_of(old_layout.align());
let old_user_size = old_layout.size() - old_offset;
let new_ptr = alloc_helper(size, old_layout.align(), false);
if new_ptr.is_null() {
return ptr::null_mut();
}
let copy_size = old_user_size.min(size);
ptr::copy_nonoverlapping(user_ptr, new_ptr as *mut u8, copy_size);
free(ptr);
new_ptr
}
}