memory_rs/internal/
memory.rs1use crate::error::{Error, ErrorType};
2use crate::wrap_winapi;
3use anyhow::{Context, Result};
4use std::ffi::c_void;
5use std::path::PathBuf;
6use std::ptr::copy_nonoverlapping;
7use windows_sys::Win32::System::Diagnostics::Debug::FlushInstructionCache;
8use windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW;
9use windows_sys::Win32::System::Memory::{
10 VirtualProtect, VirtualQuery, MEMORY_BASIC_INFORMATION, MEM_FREE, PAGE_EXECUTE_READWRITE,
11};
12use windows_sys::Win32::System::Threading::GetCurrentProcess;
13
14pub struct MemProtect {
15 addr: usize,
16 size: usize,
17 prot: u32,
18}
19
20impl MemProtect {
26 pub fn new(addr: usize, size: usize, prot: Option<u32>) -> Result<Self> {
27 let new_prot = prot.unwrap_or(PAGE_EXECUTE_READWRITE);
28
29 let mut old_prot = 0u32;
30
31 unsafe {
32 wrap_winapi!(
33 VirtualProtect(addr as *const c_void, size, new_prot, &mut old_prot),
34 x == 0
35 )?;
36 }
37
38 Ok(Self {
39 addr,
40 size,
41 prot: old_prot,
42 })
43 }
44}
45
46impl Drop for MemProtect {
47 fn drop(&mut self) {
48 let mut _prot = 0;
49 unsafe {
50 VirtualProtect(self.addr as _, self.size, self.prot, &mut _prot);
51 }
52 }
53}
54
55pub struct MemoryPattern {
56 pub size: usize,
57 pub pattern: fn(&[u8]) -> bool,
58}
59
60impl MemoryPattern {
61 pub fn new(size: usize, pattern: fn(&[u8]) -> bool) -> Self {
62 MemoryPattern { size, pattern }
63 }
64
65 pub fn scan(&self, val: &[u8]) -> bool {
66 (self.pattern)(val)
67 }
68}
69
70pub unsafe fn write_aob(ptr: usize, source: &[u8]) -> Result<()> {
76 let size = source.len();
77
78 let _mp = MemProtect::new(ptr, size, None)?;
79
80 copy_nonoverlapping(source.as_ptr(), ptr as *mut u8, size);
81
82 let ph = GetCurrentProcess();
83 FlushInstructionCache(ph, ptr as *const c_void, size);
84
85 Ok(())
86}
87
88pub unsafe fn hook_function(
94 original_function: usize,
95 new_function: usize,
96 new_function_end: Option<&mut usize>,
97 len: usize,
98) -> Result<()> {
99 assert!(len >= 12, "Not enough space to inject the shellcode");
100
101 let ph = GetCurrentProcess();
102
103 let _mp = MemProtect::new(original_function, len, None)?;
105
106 let nops = vec![0x90; len];
107 write_aob(original_function, &nops).with_context(|| "Couldn't nop original bytes")?;
108
109 let aob: [u8; std::mem::size_of::<usize>()] = new_function.to_le_bytes();
112
113 let injection = if len < 14 {
114 let mut v = vec![0x48, 0xb8];
115 v.extend_from_slice(&aob);
116 v.extend_from_slice(&[0xff, 0xe0]);
117 v
118 } else {
119 let mut v = if cfg!(target_arch = "x86_64") {
120 vec![0xff, 0x25, 0x00, 0x00, 0x00, 0x00]
121 } else {
122 let mut v = vec![0xFF, 0x25];
123 v.extend_from_slice(&(original_function + 6).to_le_bytes());
124 v
125 };
126 v.extend_from_slice(&aob);
127 v
128 };
129
130 write_aob(original_function, &injection)
131 .with_context(|| "Couldn't write the injection to the original function")?;
132
133 FlushInstructionCache(ph, original_function as *const c_void, injection.len());
134
135 if let Some(p) = new_function_end {
137 *p = original_function + len;
138 }
139
140 Ok(())
141}
142
143pub fn check_valid_region(start_address: usize, len: usize) -> Result<()> {
147 if start_address == 0x0 {
148 return Err(Error::new(ErrorType::Internal, "start_address can't be 0".into()).into());
149 }
150
151 if len == 0x0 {
152 return Err(Error::new(ErrorType::Internal, "len can't be 0".into()).into());
153 }
154
155 let mut region_size = 0_usize;
156 let size_mem_inf = std::mem::size_of::<MEMORY_BASIC_INFORMATION>();
157
158 while region_size < len {
159 let mut information: MEMORY_BASIC_INFORMATION = unsafe { std::mem::zeroed() };
160 unsafe {
161 wrap_winapi!(
162 VirtualQuery(
163 (start_address + region_size) as *const c_void,
164 &mut information,
165 size_mem_inf
166 ),
167 x == 0
168 )?;
169 }
170
171 if information.State == MEM_FREE {
172 return Err(Error::new(
173 ErrorType::Internal,
174 "The region to scan is invalid".to_string(),
175 )
176 .into());
177 }
178
179 region_size += information.RegionSize as usize;
180 }
181
182 Ok(())
183}
184
185pub unsafe fn resolve_module_path(lib: *const c_void) -> Result<PathBuf> {
191 let mut buf: Vec<u16> = vec![0x0; 255];
192
193 wrap_winapi!(GetModuleFileNameW(lib as _, buf.as_mut_ptr(), 255), x == 0)?;
194 let end_ix = buf
195 .iter()
196 .position(|&x| x == 0)
197 .expect("Invalid utf16 name");
198 let name = String::from_utf16(&buf[..end_ix]).unwrap();
199 let mut path: PathBuf = name.into();
200 path.pop();
201 Ok(path)
202}