use crate::error::{Error, Result};
use std::ptr::NonNull;
use windows::Win32::Foundation::HANDLE;
use windows::Win32::System::Memory::{
GetProcessHeap, HeapAlloc, HeapCreate, HeapDestroy, HeapFree, HeapReAlloc, HeapSize,
VirtualAlloc, VirtualFree, VirtualLock, VirtualProtect, VirtualQuery, VirtualUnlock, HEAP_NONE,
MEMORY_BASIC_INFORMATION, MEM_COMMIT, MEM_DECOMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_EXECUTE,
PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, PAGE_NOACCESS, PAGE_PROTECTION_FLAGS, PAGE_READONLY,
PAGE_READWRITE,
};
use windows::Win32::System::SystemInformation::{
GetSystemInfo, GlobalMemoryStatusEx, MEMORYSTATUSEX, SYSTEM_INFO,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protection {
NoAccess,
ReadOnly,
ReadWrite,
Execute,
ExecuteRead,
ExecuteReadWrite,
}
impl Protection {
fn to_flags(self) -> PAGE_PROTECTION_FLAGS {
match self {
Protection::NoAccess => PAGE_NOACCESS,
Protection::ReadOnly => PAGE_READONLY,
Protection::ReadWrite => PAGE_READWRITE,
Protection::Execute => PAGE_EXECUTE,
Protection::ExecuteRead => PAGE_EXECUTE_READ,
Protection::ExecuteReadWrite => PAGE_EXECUTE_READWRITE,
}
}
}
pub struct VirtualMemory {
ptr: NonNull<u8>,
size: usize,
}
impl VirtualMemory {
pub fn alloc(size: usize, protection: Protection) -> Result<Self> {
Self::alloc_at(None, size, protection)
}
pub fn alloc_at(address: Option<*mut u8>, size: usize, protection: Protection) -> Result<Self> {
let ptr = unsafe {
VirtualAlloc(
address.map(|p| p as *const _),
size,
MEM_COMMIT | MEM_RESERVE,
protection.to_flags(),
)
};
if ptr.is_null() {
return Err(crate::error::last_error());
}
Ok(Self {
ptr: NonNull::new(ptr as *mut u8).unwrap(),
size,
})
}
pub fn reserve(size: usize) -> Result<Self> {
let ptr = unsafe { VirtualAlloc(None, size, MEM_RESERVE, PAGE_NOACCESS) };
if ptr.is_null() {
return Err(crate::error::last_error());
}
Ok(Self {
ptr: NonNull::new(ptr as *mut u8).unwrap(),
size,
})
}
pub fn commit(&self, offset: usize, size: usize, protection: Protection) -> Result<()> {
if offset + size > self.size {
return Err(Error::custom("Commit region exceeds allocation size"));
}
let ptr = unsafe {
VirtualAlloc(
Some(self.ptr.as_ptr().add(offset) as *const _),
size,
MEM_COMMIT,
protection.to_flags(),
)
};
if ptr.is_null() {
return Err(crate::error::last_error());
}
Ok(())
}
pub fn decommit(&self, offset: usize, size: usize) -> Result<()> {
if offset + size > self.size {
return Err(Error::custom("Decommit region exceeds allocation size"));
}
unsafe {
VirtualFree(self.ptr.as_ptr().add(offset) as *mut _, size, MEM_DECOMMIT)?;
}
Ok(())
}
pub fn protect(
&self,
offset: usize,
size: usize,
protection: Protection,
) -> Result<Protection> {
if offset + size > self.size {
return Err(Error::custom("Protect region exceeds allocation size"));
}
let mut old_protect = PAGE_PROTECTION_FLAGS(0);
unsafe {
VirtualProtect(
self.ptr.as_ptr().add(offset) as *const _,
size,
protection.to_flags(),
&mut old_protect,
)?;
}
let old = match old_protect {
PAGE_NOACCESS => Protection::NoAccess,
PAGE_READONLY => Protection::ReadOnly,
PAGE_READWRITE => Protection::ReadWrite,
PAGE_EXECUTE => Protection::Execute,
PAGE_EXECUTE_READ => Protection::ExecuteRead,
PAGE_EXECUTE_READWRITE => Protection::ExecuteReadWrite,
_ => Protection::NoAccess,
};
Ok(old)
}
pub fn lock(&self, offset: usize, size: usize) -> Result<()> {
if offset + size > self.size {
return Err(Error::custom("Lock region exceeds allocation size"));
}
unsafe {
VirtualLock(self.ptr.as_ptr().add(offset) as *const _, size)?;
}
Ok(())
}
pub fn unlock(&self, offset: usize, size: usize) -> Result<()> {
if offset + size > self.size {
return Err(Error::custom("Unlock region exceeds allocation size"));
}
unsafe {
VirtualUnlock(self.ptr.as_ptr().add(offset) as *const _, size)?;
}
Ok(())
}
pub fn as_ptr(&self) -> *mut u8 {
self.ptr.as_ptr()
}
pub fn size(&self) -> usize {
self.size
}
pub unsafe fn as_slice(&self) -> &[u8] {
std::slice::from_raw_parts(self.ptr.as_ptr(), self.size)
}
pub unsafe fn as_mut_slice(&mut self) -> &mut [u8] {
std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size)
}
}
impl Drop for VirtualMemory {
fn drop(&mut self) {
unsafe {
let _ = VirtualFree(self.ptr.as_ptr() as *mut _, 0, MEM_RELEASE);
}
}
}
#[derive(Debug)]
pub struct MemoryInfo {
pub base_address: *mut u8,
pub region_size: usize,
pub protection: Protection,
pub is_committed: bool,
pub is_reserved: bool,
pub is_free: bool,
}
pub fn query_memory(address: *const u8) -> Result<MemoryInfo> {
let mut info = MEMORY_BASIC_INFORMATION::default();
let result = unsafe {
VirtualQuery(
Some(address as *const _),
&mut info,
std::mem::size_of::<MEMORY_BASIC_INFORMATION>(),
)
};
if result == 0 {
return Err(crate::error::last_error());
}
let protection = match info.Protect {
PAGE_NOACCESS => Protection::NoAccess,
PAGE_READONLY => Protection::ReadOnly,
PAGE_READWRITE => Protection::ReadWrite,
PAGE_EXECUTE => Protection::Execute,
PAGE_EXECUTE_READ => Protection::ExecuteRead,
PAGE_EXECUTE_READWRITE => Protection::ExecuteReadWrite,
_ => Protection::NoAccess,
};
Ok(MemoryInfo {
base_address: info.BaseAddress as *mut u8,
region_size: info.RegionSize,
protection,
is_committed: info.State.0 & MEM_COMMIT.0 != 0,
is_reserved: info.State.0 & MEM_RESERVE.0 != 0,
is_free: info.State.0 == 0x10000, })
}
pub struct Heap {
handle: HANDLE,
owned: bool,
}
impl Heap {
pub fn new() -> Result<Self> {
let handle = unsafe { HeapCreate(HEAP_NONE, 0, 0)? };
Ok(Self {
handle,
owned: true,
})
}
pub fn with_size(initial_size: usize, max_size: usize) -> Result<Self> {
let handle = unsafe { HeapCreate(HEAP_NONE, initial_size, max_size)? };
Ok(Self {
handle,
owned: true,
})
}
pub fn process_heap() -> Result<Self> {
let handle = unsafe { GetProcessHeap()? };
Ok(Self {
handle,
owned: false, })
}
pub fn alloc(&self, size: usize) -> Result<NonNull<u8>> {
let ptr = unsafe { HeapAlloc(self.handle, HEAP_NONE, size) };
if ptr.is_null() {
return Err(crate::error::last_error());
}
Ok(NonNull::new(ptr as *mut u8).unwrap())
}
pub fn alloc_zeroed(&self, size: usize) -> Result<NonNull<u8>> {
use windows::Win32::System::Memory::HEAP_ZERO_MEMORY;
let ptr = unsafe { HeapAlloc(self.handle, HEAP_ZERO_MEMORY, size) };
if ptr.is_null() {
return Err(crate::error::last_error());
}
Ok(NonNull::new(ptr as *mut u8).unwrap())
}
pub unsafe fn realloc(&self, ptr: NonNull<u8>, new_size: usize) -> Result<NonNull<u8>> {
let new_ptr = HeapReAlloc(
self.handle,
HEAP_NONE,
Some(ptr.as_ptr() as *const _),
new_size,
);
if new_ptr.is_null() {
return Err(crate::error::last_error());
}
Ok(NonNull::new(new_ptr as *mut u8).unwrap())
}
pub unsafe fn free(&self, ptr: NonNull<u8>) -> Result<()> {
HeapFree(self.handle, HEAP_NONE, Some(ptr.as_ptr() as *const _))?;
Ok(())
}
pub unsafe fn size(&self, ptr: NonNull<u8>) -> Result<usize> {
let size = HeapSize(self.handle, HEAP_NONE, ptr.as_ptr() as *const _);
if size == usize::MAX {
return Err(crate::error::last_error());
}
Ok(size)
}
}
impl Drop for Heap {
fn drop(&mut self) {
if self.owned {
unsafe {
let _ = HeapDestroy(self.handle);
}
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStatus {
pub memory_load: u32,
pub total_physical: u64,
pub available_physical: u64,
pub total_page_file: u64,
pub available_page_file: u64,
pub total_virtual: u64,
pub available_virtual: u64,
}
pub fn memory_status() -> Result<MemoryStatus> {
let mut status = MEMORYSTATUSEX {
dwLength: std::mem::size_of::<MEMORYSTATUSEX>() as u32,
..Default::default()
};
unsafe {
GlobalMemoryStatusEx(&mut status)?;
}
Ok(MemoryStatus {
memory_load: status.dwMemoryLoad,
total_physical: status.ullTotalPhys,
available_physical: status.ullAvailPhys,
total_page_file: status.ullTotalPageFile,
available_page_file: status.ullAvailPageFile,
total_virtual: status.ullTotalVirtual,
available_virtual: status.ullAvailVirtual,
})
}
#[derive(Debug, Clone)]
pub struct SystemMemoryInfo {
pub page_size: u32,
pub allocation_granularity: u32,
pub minimum_address: *const u8,
pub maximum_address: *const u8,
pub processor_count: u32,
}
pub fn system_info() -> SystemMemoryInfo {
let mut info = SYSTEM_INFO::default();
unsafe {
GetSystemInfo(&mut info);
}
SystemMemoryInfo {
page_size: info.dwPageSize,
allocation_granularity: info.dwAllocationGranularity,
minimum_address: info.lpMinimumApplicationAddress as *const u8,
maximum_address: info.lpMaximumApplicationAddress as *const u8,
processor_count: info.dwNumberOfProcessors,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_virtual_memory_alloc() {
let mut mem = VirtualMemory::alloc(4096, Protection::ReadWrite).unwrap();
assert!(!mem.as_ptr().is_null());
assert_eq!(mem.size(), 4096);
unsafe {
let slice = mem.as_mut_slice();
slice[0] = 42;
slice[1] = 43;
assert_eq!(slice[0], 42);
assert_eq!(slice[1], 43);
}
}
#[test]
fn test_virtual_memory_reserve_commit() {
let mem = VirtualMemory::reserve(65536).unwrap();
mem.commit(0, 4096, Protection::ReadWrite).unwrap();
unsafe {
let ptr = mem.as_ptr();
*ptr = 42;
assert_eq!(*ptr, 42);
}
}
#[test]
fn test_heap() {
let heap = Heap::new().unwrap();
let ptr = heap.alloc(1024).unwrap();
unsafe {
*ptr.as_ptr() = 42;
assert_eq!(*ptr.as_ptr(), 42);
}
let size = unsafe { heap.size(ptr).unwrap() };
assert!(size >= 1024);
unsafe {
heap.free(ptr).unwrap();
}
}
#[test]
fn test_memory_status() {
let status = memory_status().unwrap();
assert!(status.total_physical > 0);
assert!(status.available_physical > 0);
assert!(status.memory_load <= 100);
}
#[test]
fn test_system_info() {
let info = system_info();
assert!(info.page_size > 0);
assert!(info.processor_count > 0);
}
#[test]
fn test_query_memory() {
let mem = VirtualMemory::alloc(4096, Protection::ReadWrite).unwrap();
let info = query_memory(mem.as_ptr()).unwrap();
assert!(info.is_committed);
assert!(!info.is_free);
assert!(info.region_size >= 4096);
}
}