use crate::memory::MemoryError;
use windows::Win32::Foundation::HANDLE;
use windows::Win32::System::Memory::{
VirtualProtectEx, PAGE_EXECUTE_READWRITE, PAGE_PROTECTION_FLAGS,
};
#[derive(Clone, Copy)]
pub struct SendableHandle(pub HANDLE);
unsafe impl Send for SendableHandle {}
unsafe impl Sync for SendableHandle {}
pub struct ProtectionGuard {
handle: SendableHandle,
address: *mut std::ffi::c_void,
size: usize,
old_protection: PAGE_PROTECTION_FLAGS,
}
impl ProtectionGuard {
pub fn new(handle: HANDLE, address: usize, size: usize) -> Result<Self, MemoryError> {
let mut old_protection = PAGE_PROTECTION_FLAGS(0);
unsafe {
let result = VirtualProtectEx(
handle,
address as *mut std::ffi::c_void,
size,
PAGE_EXECUTE_READWRITE,
&mut old_protection,
);
if result.is_err() {
let error_code = result.unwrap_err();
return Err(MemoryError::WriteFailed(
format!(
"Failed to change memory protection at 0x{:X} (size: {} bytes). \
Windows error: {:?}. \
This may happen if the process has exited or the memory region is not accessible.",
address, size, error_code
)
));
}
}
Ok(Self {
handle: SendableHandle(handle),
address: address as *mut std::ffi::c_void,
size,
old_protection,
})
}
pub fn restore(mut self) -> Result<(), MemoryError> {
self.do_restore()
}
fn do_restore(&mut self) -> Result<(), MemoryError> {
unsafe {
let mut _dummy = PAGE_PROTECTION_FLAGS(0);
let result = VirtualProtectEx(
self.handle.0,
self.address,
self.size,
self.old_protection,
&mut _dummy,
);
if result.is_err() {
return Err(MemoryError::WriteFailed(
"Failed to restore memory protection".to_string(),
));
}
}
Ok(())
}
}
impl Drop for ProtectionGuard {
fn drop(&mut self) {
let _ = self.do_restore();
}
}
pub struct MemoryProtector;
impl MemoryProtector {
pub fn protect_multiple(
handle: HANDLE,
regions: &[(usize, usize)],
) -> Result<Vec<ProtectionGuard>, MemoryError> {
let mut guards = Vec::new();
for &(address, size) in regions {
let guard = ProtectionGuard::new(handle, address, size)?;
guards.push(guard);
}
Ok(guards)
}
}
#[cfg(test)]
mod tests {
pub fn is_valid_user_address(address: usize, is_64bit: bool) -> bool {
if is_64bit {
address < 0x0000800000000000
} else {
address < 0x80000000
}
}
pub fn align_to_page(address: usize, size: usize, page_size: usize) -> (usize, usize) {
let aligned_addr = address & !(page_size - 1);
let end_addr = address + size;
let aligned_end = (end_addr + page_size - 1) & !(page_size - 1);
let aligned_size = aligned_end - aligned_addr;
(aligned_addr, aligned_size)
}
#[test]
fn test_is_valid_user_address_32bit() {
assert!(is_valid_user_address(0x00400000, false));
assert!(is_valid_user_address(0x7FFFFFFF, false));
assert!(!is_valid_user_address(0x80000000, false));
assert!(!is_valid_user_address(0xFFFFFFFF, false));
}
#[test]
fn test_is_valid_user_address_64bit() {
assert!(is_valid_user_address(0x00007FFF00000000, true));
assert!(is_valid_user_address(0x00007FFFFFFFFFFF, true));
assert!(!is_valid_user_address(0xFFFF800000000000, true));
}
#[test]
fn test_align_to_page() {
let page_size = 4096;
let (addr, size) = align_to_page(0x1000, 4096, page_size);
assert_eq!(addr, 0x1000);
assert_eq!(size, 4096);
let (addr, size) = align_to_page(0x1005, 100, page_size);
assert_eq!(addr, 0x1000);
assert_eq!(size, 4096); }
}