memory_rs/internal/
memory.rs

1use crate::error::{Error, ErrorType};
2use crate::wrap_winapi;
3use anyhow::{Context, Result};
4use std::ffi::c_void;
5use std::path::PathBuf;
6use std::ptr::copy_nonoverlapping;
7use windows_sys::Win32::System::Diagnostics::Debug::FlushInstructionCache;
8use windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW;
9use windows_sys::Win32::System::Memory::{
10    VirtualProtect, VirtualQuery, MEMORY_BASIC_INFORMATION, MEM_FREE, PAGE_EXECUTE_READWRITE,
11};
12use windows_sys::Win32::System::Threading::GetCurrentProcess;
13
14pub struct MemProtect {
15    addr: usize,
16    size: usize,
17    prot: u32,
18}
19
20/// Scoped VirtualProtect.
21/// # Safety
22/// The only unsafe bit is the VirtualProtect, which according to msdn
23/// it shouldn't have undefined behavior, so we wrap that function with an
24/// `try_winapi!` macro.
25impl MemProtect {
26    pub fn new(addr: usize, size: usize, prot: Option<u32>) -> Result<Self> {
27        let new_prot = prot.unwrap_or(PAGE_EXECUTE_READWRITE);
28
29        let mut old_prot = 0u32;
30
31        unsafe {
32            wrap_winapi!(
33                VirtualProtect(addr as *const c_void, size, new_prot, &mut old_prot),
34                x == 0
35            )?;
36        }
37
38        Ok(Self {
39            addr,
40            size,
41            prot: old_prot,
42        })
43    }
44}
45
46impl Drop for MemProtect {
47    fn drop(&mut self) {
48        let mut _prot = 0;
49        unsafe {
50            VirtualProtect(self.addr as _, self.size, self.prot, &mut _prot);
51        }
52    }
53}
54
55pub struct MemoryPattern {
56    pub size: usize,
57    pub pattern: fn(&[u8]) -> bool,
58}
59
60impl MemoryPattern {
61    pub fn new(size: usize, pattern: fn(&[u8]) -> bool) -> Self {
62        MemoryPattern { size, pattern }
63    }
64
65    pub fn scan(&self, val: &[u8]) -> bool {
66        (self.pattern)(val)
67    }
68}
69
70/// Write an array of bytes to the desired ptr address.
71/// # Safety
72/// This function can cause the target program to crash due to
73/// incorrect writing, or it could simply make crash the software in case
74/// the virtual protect doesn't succeed.
75pub unsafe fn write_aob(ptr: usize, source: &[u8]) -> Result<()> {
76    let size = source.len();
77
78    let _mp = MemProtect::new(ptr, size, None)?;
79
80    copy_nonoverlapping(source.as_ptr(), ptr as *mut u8, size);
81
82    let ph = GetCurrentProcess();
83    FlushInstructionCache(ph, ptr as *const c_void, size);
84
85    Ok(())
86}
87
88/// Injects a jmp in the target address. The minimum length of it is 12 bytes.
89/// In case the space is bigger than 14 bytes, it'll inject a non-dirty
90/// trampoline, and will nop the rest of the instructions.
91/// # Safety
92/// this function is inherently unsafe since it does a lot of nasty stuff.
93pub unsafe fn hook_function(
94    original_function: usize,
95    new_function: usize,
96    new_function_end: Option<&mut usize>,
97    len: usize,
98) -> Result<()> {
99    assert!(len >= 12, "Not enough space to inject the shellcode");
100
101    let ph = GetCurrentProcess();
102
103    // Unprotect zone we'll write
104    let _mp = MemProtect::new(original_function, len, None)?;
105
106    let nops = vec![0x90; len];
107    write_aob(original_function, &nops).with_context(|| "Couldn't nop original bytes")?;
108
109    // Inject the jmp to the original function
110    // address as an AoB
111    let aob: [u8; std::mem::size_of::<usize>()] = new_function.to_le_bytes();
112
113    let injection = if len < 14 {
114        let mut v = vec![0x48, 0xb8];
115        v.extend_from_slice(&aob);
116        v.extend_from_slice(&[0xff, 0xe0]);
117        v
118    } else {
119        let mut v = if cfg!(target_arch = "x86_64") {
120            vec![0xff, 0x25, 0x00, 0x00, 0x00, 0x00]
121        } else {
122            let mut v = vec![0xFF, 0x25];
123            v.extend_from_slice(&(original_function + 6).to_le_bytes());
124            v
125        };
126        v.extend_from_slice(&aob);
127        v
128    };
129
130    write_aob(original_function, &injection)
131        .with_context(|| "Couldn't write the injection to the original function")?;
132
133    FlushInstructionCache(ph, original_function as *const c_void, injection.len());
134
135    // Inject the jmp back if required
136    if let Some(p) = new_function_end {
137        *p = original_function + len;
138    }
139
140    Ok(())
141}
142
143/// This function will use the WinAPI to check if the region to scan is valid.
144/// A region is not valid when it's free or when VirtualQuery returns an
145/// error at the moment of querying that region.
146pub fn check_valid_region(start_address: usize, len: usize) -> Result<()> {
147    if start_address == 0x0 {
148        return Err(Error::new(ErrorType::Internal, "start_address can't be 0".into()).into());
149    }
150
151    if len == 0x0 {
152        return Err(Error::new(ErrorType::Internal, "len can't be 0".into()).into());
153    }
154
155    let mut region_size = 0_usize;
156    let size_mem_inf = std::mem::size_of::<MEMORY_BASIC_INFORMATION>();
157
158    while region_size < len {
159        let mut information: MEMORY_BASIC_INFORMATION = unsafe { std::mem::zeroed() };
160        unsafe {
161            wrap_winapi!(
162                VirtualQuery(
163                    (start_address + region_size) as *const c_void,
164                    &mut information,
165                    size_mem_inf
166                ),
167                x == 0
168            )?;
169        }
170
171        if information.State == MEM_FREE {
172            return Err(Error::new(
173                ErrorType::Internal,
174                "The region to scan is invalid".to_string(),
175            )
176            .into());
177        }
178
179        region_size += information.RegionSize as usize;
180    }
181
182    Ok(())
183}
184
185/// Get DLL's parent path
186/// # Safety
187/// This function can fail on the
188/// GetModuleFileNameA, everything else is safe
189/// TODO: Find a way to test this one.
190pub unsafe fn resolve_module_path(lib: *const c_void) -> Result<PathBuf> {
191    let mut buf: Vec<u16> = vec![0x0; 255];
192
193    wrap_winapi!(GetModuleFileNameW(lib as _, buf.as_mut_ptr(), 255), x == 0)?;
194    let end_ix = buf
195        .iter()
196        .position(|&x| x == 0)
197        .expect("Invalid utf16 name");
198    let name = String::from_utf16(&buf[..end_ix]).unwrap();
199    let mut path: PathBuf = name.into();
200    path.pop();
201    Ok(path)
202}