use std::alloc::Layout;
use std::ptr::NonNull;
#[cfg(not(windows))]
use std::alloc::handle_alloc_error;
#[cfg(not(windows))]
use std::alloc::{alloc, dealloc};
#[cfg(not(windows))]
use std::cmp::max;
#[cfg(windows)]
use windows_sys::Win32::Foundation::{CloseHandle, ERROR_SUCCESS, GetLastError, HANDLE, LUID};
#[cfg(windows)]
use windows_sys::Win32::Security::{
AdjustTokenPrivileges, LUID_AND_ATTRIBUTES, LookupPrivilegeValueA, SE_PRIVILEGE_ENABLED,
TOKEN_ADJUST_PRIVILEGES, TOKEN_PRIVILEGES, TOKEN_QUERY,
};
#[cfg(windows)]
use windows_sys::Win32::System::Memory::{
GetLargePageMinimum, MEM_COMMIT, MEM_LARGE_PAGES, MEM_RELEASE, MEM_RESERVE, PAGE_READWRITE,
VirtualAlloc, VirtualFree,
};
#[cfg(windows)]
use windows_sys::Win32::System::Threading::{GetCurrentProcess, OpenProcessToken};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) enum AllocKind {
LargePages,
#[allow(dead_code)]
Regular,
}
pub(super) struct Allocation {
ptr: NonNull<u8>,
kind: AllocKind,
#[cfg(not(windows))]
layout: Layout,
}
impl Allocation {
pub(super) fn allocate(size: usize, alignment: usize) -> Self {
#[cfg(windows)]
{
debug_assert!(alignment.is_power_of_two(), "alignment must be power of two");
if let Some(alloc) = try_alloc_large_pages(size) {
return alloc;
}
alloc_windows(size, alignment)
}
#[cfg(not(windows))]
{
alloc_unix(size, alignment)
}
}
pub(super) fn ptr(&self) -> NonNull<u8> {
self.ptr
}
pub(super) fn kind(&self) -> AllocKind {
self.kind
}
}
#[cfg(windows)]
fn align_up(value: usize, align: usize) -> usize {
debug_assert!(
value.checked_add(align - 1).is_some(),
"align_up overflow: value={value}, align={align}"
);
value.div_ceil(align) * align
}
#[cfg(windows)]
fn try_alloc_large_pages(size: usize) -> Option<Allocation> {
unsafe {
let large_page_size = GetLargePageMinimum() as usize;
if large_page_size == 0 {
return None;
}
let mut token: HANDLE = std::ptr::null_mut();
if OpenProcessToken(GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &mut token)
== 0
{
return None;
}
let mut luid = LUID {
LowPart: 0,
HighPart: 0,
};
let privilege_name = c"SeLockMemoryPrivilege";
if LookupPrivilegeValueA(std::ptr::null(), privilege_name.as_ptr() as *const u8, &mut luid)
== 0
{
CloseHandle(token);
return None;
}
let tp = TOKEN_PRIVILEGES {
PrivilegeCount: 1,
Privileges: [LUID_AND_ATTRIBUTES {
Luid: luid,
Attributes: SE_PRIVILEGE_ENABLED,
}],
};
let mut prev_tp = TOKEN_PRIVILEGES {
PrivilegeCount: 0,
Privileges: [LUID_AND_ATTRIBUTES {
Luid: LUID {
LowPart: 0,
HighPart: 0,
},
Attributes: 0,
}],
};
let mut prev_len = std::mem::size_of::<TOKEN_PRIVILEGES>() as u32;
if AdjustTokenPrivileges(token, 0, &tp, prev_len, &mut prev_tp, &mut prev_len) == 0
|| GetLastError() != ERROR_SUCCESS
{
CloseHandle(token);
return None;
}
let alloc_size = align_up(size, large_page_size);
let ptr = VirtualAlloc(
std::ptr::null_mut(),
alloc_size,
MEM_RESERVE | MEM_COMMIT | MEM_LARGE_PAGES,
PAGE_READWRITE,
);
AdjustTokenPrivileges(token, 0, &prev_tp, 0, std::ptr::null_mut(), std::ptr::null_mut());
CloseHandle(token);
let ptr = NonNull::new(ptr as *mut u8)?;
Some(Allocation {
ptr,
kind: AllocKind::LargePages,
})
}
}
#[cfg(windows)]
fn alloc_windows(size: usize, alignment: usize) -> Allocation {
unsafe {
let ptr =
VirtualAlloc(std::ptr::null_mut(), size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
let ptr = NonNull::new(ptr as *mut u8).unwrap_or_else(|| {
let align = alignment.max(4096);
std::alloc::handle_alloc_error(Layout::from_size_align(size, align).unwrap())
});
Allocation {
ptr,
kind: AllocKind::Regular,
}
}
}
#[cfg(not(windows))]
fn alloc_unix(size: usize, alignment: usize) -> Allocation {
#[cfg(any(target_os = "linux", target_os = "android"))]
let (page_align, kind) = (2 * 1024 * 1024, AllocKind::LargePages);
#[cfg(not(any(target_os = "linux", target_os = "android")))]
let (page_align, kind) = (4096, AllocKind::Regular);
let alignment = max(alignment, page_align);
let layout = Layout::from_size_align(size, alignment)
.expect("Invalid TT allocation layout")
.pad_to_align();
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
handle_alloc_error(layout);
}
#[cfg(any(target_os = "linux", target_os = "android"))]
unsafe {
let result = libc::madvise(ptr as *mut _, layout.size(), libc::MADV_HUGEPAGE);
#[cfg(debug_assertions)]
if result != 0 {
eprintln!("Warning: madvise MADV_HUGEPAGE failed");
}
#[cfg(not(debug_assertions))]
let _ = result;
}
Allocation {
ptr: NonNull::new(ptr).expect("TT allocation returned null"),
kind,
layout,
}
}
impl Drop for Allocation {
fn drop(&mut self) {
unsafe {
#[cfg(windows)]
{
let ok = VirtualFree(self.ptr.as_ptr() as *mut _, 0, MEM_RELEASE);
if ok == 0 {
eprintln!("Warning: VirtualFree failed with error {}", GetLastError());
debug_assert!(false, "VirtualFree failed");
}
}
#[cfg(not(windows))]
{
dealloc(self.ptr.as_ptr(), self.layout);
}
}
}
}
unsafe impl Send for Allocation {}
unsafe impl Sync for Allocation {}