Skip to main content

memexec/peloader/
mod.rs

1pub mod def;
2pub mod error;
3pub mod winapi;
4
5use crate::peparser::def::*;
6use crate::peparser::PE;
7use def::{DllMain, DLL_PROCESS_ATTACH, MEM_COMMIT, MEM_RESERVE};
8use error::{Error, Result};
9use std::ffi::CStr;
10use std::mem;
11use std::os::raw::c_void;
12use std::ptr;
13
14#[cfg(feature = "hook")]
15use hook::{ProcDesc, Thunk};
16#[cfg(feature = "hook")]
17use std::collections::HashMap;
18#[cfg(feature = "hook")]
19pub mod hook;
20
21unsafe fn patch_reloc_table(pe: &PE, base_addr: *const c_void) -> Result<()> {
22    let reloc_entry = &pe.pe_header.nt_header.get_data_directory()[IMAGE_DIRECTORY_ENTRY_BASERELOC];
23    let image_base_offset = base_addr as isize - pe.pe_header.nt_header.get_image_base() as isize;
24
25    if image_base_offset != 0 && reloc_entry.VirtualAddress != 0 && reloc_entry.Size != 0 {
26        let mut reloc_table_ptr =
27            base_addr.offset(reloc_entry.VirtualAddress as isize) as *const u8;
28
29        loop {
30            let reloc_block =
31                &*mem::transmute::<*const u8, *const IMAGE_BASE_RELOCATION>(reloc_table_ptr);
32            if reloc_block.SizeOfBlock == 0 && reloc_block.VirtualAddress == 0 {
33                break;
34            }
35
36            for i in 0..(reloc_block.SizeOfBlock as isize - 8) / 2 {
37                let item = *(reloc_table_ptr.offset(8 + i * 2) as *const u16);
38                if (item >> 12) == IMAGE_REL_BASED {
39                    let patch_addr = base_addr
40                        .offset(reloc_block.VirtualAddress as isize + (item & 0xfff) as isize)
41                        as *mut isize;
42                    *patch_addr = *patch_addr + image_base_offset;
43                }
44            }
45
46            reloc_table_ptr = reloc_table_ptr.offset(reloc_block.SizeOfBlock as isize);
47        }
48    }
49    Ok(())
50}
51
52unsafe fn resolve_import_symbols(
53    pe: &PE,
54    base_addr: *const c_void,
55    #[cfg(feature = "hook")] hooks: Option<&HashMap<ProcDesc, *const c_void>>,
56) -> Result<()> {
57    let import_entry = &pe.pe_header.nt_header.get_data_directory()[IMAGE_DIRECTORY_ENTRY_IMPORT];
58    if import_entry.Size != 0 && import_entry.VirtualAddress != 0 {
59        let mut import_desc_ptr = base_addr.offset(import_entry.VirtualAddress as isize)
60            as *const IMAGE_IMPORT_DESCRIPTOR;
61        loop {
62            let import_desc = &*import_desc_ptr;
63            if 0 == import_desc.Name
64                && 0 == import_desc.FirstThunk
65                && 0 == import_desc.OriginalFirstThunk
66                && 0 == import_desc.TimeDateStamp
67                && 0 == import_desc.ForwarderChain
68            {
69                break;
70            }
71
72            let dll_name = CStr::from_ptr(base_addr.offset(import_desc.Name as isize) as *const i8)
73                .to_str()?;
74            // TODO: implement loading module by calling self recursively
75            let hmod = winapi::load_library(&dll_name)?;
76
77            // Whether the ILT (called INT in IDA) exists? (some linkers didn't generate the ILT)
78            let (mut iat_ptr, mut ilt_ptr) = if import_desc.OriginalFirstThunk != 0 {
79                (
80                    base_addr.offset(import_desc.FirstThunk as isize) as *mut IMAGE_THUNK_DATA,
81                    base_addr.offset(import_desc.OriginalFirstThunk as isize)
82                        as *const IMAGE_THUNK_DATA,
83                )
84            } else {
85                (
86                    base_addr.offset(import_desc.FirstThunk as isize) as *mut IMAGE_THUNK_DATA,
87                    base_addr.offset(import_desc.FirstThunk as isize) as *const IMAGE_THUNK_DATA,
88                )
89            };
90
91            loop {
92                let thunk_data = *ilt_ptr as isize;
93                if thunk_data == 0 {
94                    break;
95                }
96
97                if thunk_data & IMAGE_ORDINAL_FLAG != 0 {
98                    // Import by ordinal number
99
100                    #[cfg(not(feature = "hook"))]
101                    let proc_addr = winapi::get_proc_address_by_ordinal(hmod, thunk_data & 0xffff)?;
102                    #[cfg(feature = "hook")]
103                    let proc_addr = match hooks {
104                        Some(hooks) => {
105                            match hooks.get(&ProcDesc::new(
106                                dll_name,
107                                Thunk::Ordinal(thunk_data & 0xffff),
108                            )) {
109                                Some(addr) => *addr,
110                                None => {
111                                    winapi::get_proc_address_by_ordinal(hmod, thunk_data & 0xffff)?
112                                }
113                            }
114                        }
115                        None => winapi::get_proc_address_by_ordinal(hmod, thunk_data & 0xffff)?,
116                    };
117
118                    // rust-lang/rust/issues/15701
119                    *iat_ptr = proc_addr as IMAGE_THUNK_DATA;
120                } else {
121                    // TODO: implement resolving proc address by `IMAGE_IMPORT_BY_NAME.Hint`
122                    let hint_name_table = &*mem::transmute::<
123                        *const c_void,
124                        *const IMAGE_IMPORT_BY_NAME,
125                    >(base_addr.offset(thunk_data));
126                    if 0 == hint_name_table.Name {
127                        break;
128                    }
129
130                    #[cfg(not(feature = "hook"))]
131                    let proc_addr = winapi::get_proc_address_by_name(
132                        hmod,
133                        CStr::from_ptr(&hint_name_table.Name as _).to_str()?,
134                    )?;
135                    #[cfg(feature = "hook")]
136                    let proc_addr = match hooks {
137                        Some(hooks) => match hooks.get(&ProcDesc::new(
138                            dll_name,
139                            Thunk::Name(CStr::from_ptr(&hint_name_table.Name as _).to_str()?),
140                        )) {
141                            Some(addr) => *addr,
142                            None => winapi::get_proc_address_by_name(
143                                hmod,
144                                CStr::from_ptr(&hint_name_table.Name as _).to_str()?,
145                            )?,
146                        },
147                        None => winapi::get_proc_address_by_name(
148                            hmod,
149                            CStr::from_ptr(&hint_name_table.Name as _).to_str()?,
150                        )?,
151                    };
152
153                    *iat_ptr = proc_addr as IMAGE_THUNK_DATA;
154                }
155
156                iat_ptr = iat_ptr.offset(1);
157                ilt_ptr = ilt_ptr.offset(1);
158            }
159
160            import_desc_ptr = import_desc_ptr.offset(1);
161        }
162    }
163    Ok(())
164}
165
166unsafe fn call_tls_callback(pe: &PE, base_addr: *const c_void) -> Result<()> {
167    let tls_entry = &pe.pe_header.nt_header.get_data_directory()[IMAGE_DIRECTORY_ENTRY_TLS];
168    if tls_entry.Size != 0 && tls_entry.VirtualAddress != 0 {
169        let tls = &*mem::transmute::<*const c_void, *const IMAGE_TLS_DIRECTORY>(
170            base_addr.offset(tls_entry.VirtualAddress as isize),
171        );
172        let mut tls_callback_addr = tls.AddressOfCallBacks as *const *const c_void;
173
174        loop {
175            if *tls_callback_addr == 0 as _ {
176                break;
177            }
178
179            mem::transmute::<*const c_void, PIMAGE_TLS_CALLBACK>(*tls_callback_addr)(
180                base_addr,
181                DLL_PROCESS_ATTACH,
182                0 as _,
183            );
184            tls_callback_addr = tls_callback_addr.offset(1);
185        }
186    }
187    Ok(())
188}
189
190unsafe fn load_pe_into_mem(
191    pe: &PE,
192    #[cfg(feature = "hook")] hooks: Option<&HashMap<ProcDesc, *const c_void>>,
193) -> Result<*const c_void> {
194    // Step1: allocate memory for image
195    let mut base_addr = pe.pe_header.nt_header.get_image_base();
196    let size = pe.pe_header.nt_header.get_size_of_image();
197
198    // ASLR
199    if winapi::nt_alloc_vm(
200        &base_addr as _,
201        &size as _,
202        MEM_RESERVE | MEM_COMMIT,
203        PAGE_READWRITE,
204    )
205    .is_err()
206    {
207        base_addr = 0 as *const c_void;
208        winapi::nt_alloc_vm(
209            &base_addr as _,
210            &size as _,
211            MEM_RESERVE | MEM_COMMIT,
212            PAGE_READWRITE,
213        )?;
214    }
215
216    // Step2: copy sections
217    for section in pe.section_area.section_table {
218        ptr::copy_nonoverlapping(
219            pe.raw.as_ptr().offset(section.PointerToRawData as isize),
220            base_addr.offset(section.VirtualAddress as isize) as *mut u8,
221            section.SizeOfRawData as usize,
222        );
223    }
224
225    // Step3: handle base relocataion table
226    patch_reloc_table(pe, base_addr)?;
227
228    // Step4: resolve import symbols
229    #[cfg(feature = "hook")]
230    resolve_import_symbols(pe, base_addr, hooks)?;
231    #[cfg(not(feature = "hook"))]
232    resolve_import_symbols(pe, base_addr)?;
233
234    // Step5: restore sections' protection
235    for section in pe.section_area.section_table {
236        let size = section.SizeOfRawData as usize;
237        if size == 0 {
238            continue;
239        }
240
241        winapi::nt_protect_vm(
242            &(base_addr.offset(section.VirtualAddress as isize)) as _,
243            &size as _,
244            section.get_protection(),
245        )?;
246    }
247
248    // Step6: call TLS callback
249    call_tls_callback(pe, base_addr)?;
250
251    Ok(base_addr)
252}
253
254fn check_platform(pe: &PE) -> Result<()> {
255    if (mem::size_of::<usize>() == 4 && pe.is_x86())
256        || (mem::size_of::<usize>() == 8 && pe.is_x64())
257    {
258        Ok(())
259    } else {
260        Err(Error::MismatchedArch)
261    }
262}
263
264pub struct ExeLoader {
265    entry_point_va: *const c_void,
266}
267
268impl ExeLoader {
269    pub unsafe fn new(
270        pe: &PE,
271        #[cfg(feature = "hook")] hooks: Option<&HashMap<ProcDesc, *const c_void>>,
272    ) -> Result<ExeLoader> {
273        check_platform(pe)?;
274        if pe.is_dll() {
275            return Err(Error::MismatchedLoader);
276        }
277
278        if pe.is_dot_net() {
279            return Err(Error::UnsupportedDotNetExecutable);
280        }
281
282        let entry_point = pe.pe_header.nt_header.get_address_of_entry_point();
283        if entry_point == 0 {
284            Err(Error::NoEntryPoint)
285        } else {
286            #[cfg(feature = "hook")]
287            let entry_point_va = load_pe_into_mem(pe, hooks)?.offset(entry_point);
288            #[cfg(not(feature = "hook"))]
289            let entry_point_va = load_pe_into_mem(pe)?.offset(entry_point);
290            Ok(ExeLoader {
291                entry_point_va: entry_point_va,
292            })
293        }
294    }
295
296    pub unsafe fn invoke_entry_point(&self) {
297        mem::transmute::<*const c_void, extern "system" fn()>(self.entry_point_va)()
298    }
299}
300
301pub struct DllLoader {
302    entry_point_va: *const c_void,
303}
304
305impl DllLoader {
306    pub unsafe fn new(
307        pe: &PE,
308        #[cfg(feature = "hook")] hooks: Option<&HashMap<ProcDesc, *const c_void>>,
309    ) -> Result<DllLoader> {
310        check_platform(pe)?;
311        if !pe.is_dll() {
312            return Err(Error::MismatchedLoader);
313        }
314
315        if pe.is_dot_net() {
316            return Err(Error::UnsupportedDotNetExecutable);
317        }
318
319        let entry_point = pe.pe_header.nt_header.get_address_of_entry_point();
320        if entry_point == 0 {
321            Err(Error::NoEntryPoint)
322        } else {
323            #[cfg(feature = "hook")]
324            let entry_point_va = load_pe_into_mem(pe, hooks)?.offset(entry_point);
325            #[cfg(not(feature = "hook"))]
326            let entry_point_va = load_pe_into_mem(pe)?.offset(entry_point);
327            Ok(DllLoader {
328                entry_point_va: entry_point_va,
329            })
330        }
331    }
332
333    pub unsafe fn invoke_entry_point(
334        &self,
335        hmod: *const c_void,
336        reason_for_call: u32,
337        lp_reserved: *const c_void,
338    ) -> bool {
339        mem::transmute::<*const c_void, DllMain>(self.entry_point_va)(
340            hmod,
341            reason_for_call,
342            lp_reserved,
343        )
344    }
345}