Skip to main content

ib_hook/process/
module.rs

1/*!
2Process module (EXE/DLL) utilities.
3*/
4#[cfg(feature = "std")]
5use std::{ffi::OsString, os::windows::ffi::OsStringExt, path::PathBuf};
6
7use derive_more::{Deref, From};
8use windows::{
9    Win32::{
10        Foundation::{HMODULE, MAX_PATH},
11        System::LibraryLoader::{
12            GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
13            GetModuleFileNameW, GetModuleHandleExW,
14        },
15    },
16    core::PCWSTR,
17};
18
19/// A process module (EXE/DLL).
20#[derive(Clone, Copy, From, Deref, Debug, Default, PartialEq, Eq)]
21#[repr(transparent)]
22pub struct Module(pub HMODULE);
23
24impl Module {
25    /// Get the handle of the current executable or DLL.
26    ///
27    /// Ref: https://github.com/compio-rs/winio/issues/35
28    pub fn current() -> Self {
29        let mut module = Module::default();
30        _ = unsafe {
31            GetModuleHandleExW(
32                GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS
33                    | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
34                PCWSTR(Module::current as *const _),
35                &mut module.0,
36            )
37        };
38        module
39    }
40
41    /// Get the file path of a module (EXE/DLL).
42    ///
43    /// [GetModuleFileNameW function (libloaderapi.h)](https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-getmodulefilenamew)
44    #[cfg(feature = "std")]
45    pub fn get_path(self) -> PathBuf {
46        let hmodule = Some(self.0);
47
48        let mut buf_stack = [0; MAX_PATH as usize];
49        let mut buf = buf_stack.as_mut_slice();
50        let result = unsafe { GetModuleFileNameW(hmodule, buf) };
51
52        let mut buf_heap;
53        let len = if result == 0 {
54            // Error occurred
55            0
56        } else if result == buf.len() as u32 {
57            // Buffer was too small (truncated), try with a larger buffer
58            // Extended path length
59            let mut size = 512;
60            loop {
61                buf_heap = vec![0; size];
62                buf = buf_heap.as_mut_slice();
63                let result = unsafe { GetModuleFileNameW(hmodule, buf) };
64                if result == 0 {
65                    break 0;
66                }
67                if result != size as u32 {
68                    // Success - result is the actual length
69                    break result as usize;
70                }
71                // Still truncated, try larger buffer
72                size *= 2;
73            }
74        } else {
75            // Success - result is the actual length
76            result as usize
77        };
78
79        let path_str = OsString::from_wide(&buf[..len]);
80        PathBuf::from(path_str)
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn get_path() {
90        let module = Module::current();
91        let path = module.get_path();
92        println!("Current module path: {:?}", path);
93        assert!(path.exists(), "Module path should exist: {:?}", path);
94    }
95}