Skip to main content

ib_hook/process/
module.rs

1/*!
2Process module (EXE/DLL) utilities.
3*/
4use std::{ffi::OsString, os::windows::ffi::OsStringExt, path::PathBuf};
5
6use derive_more::{Deref, From};
7use windows::{
8    Win32::{
9        Foundation::{HMODULE, MAX_PATH},
10        System::LibraryLoader::{
11            GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
12            GetModuleFileNameW, GetModuleHandleExW,
13        },
14    },
15    core::PCWSTR,
16};
17
18/// A process module (EXE/DLL).
19#[derive(Clone, Copy, From, Deref, Debug, Default, PartialEq, Eq)]
20#[repr(transparent)]
21pub struct Module(pub HMODULE);
22
23impl Module {
24    /// Get the handle of the current executable or DLL.
25    ///
26    /// Ref: https://github.com/compio-rs/winio/issues/35
27    pub fn current() -> Self {
28        let mut module = Module::default();
29        _ = unsafe {
30            GetModuleHandleExW(
31                GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS
32                    | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
33                PCWSTR(Module::current as *const _),
34                &mut module.0,
35            )
36        };
37        module
38    }
39
40    /// Get the file path of a module (EXE/DLL).
41    ///
42    /// [GetModuleFileNameW function (libloaderapi.h)](https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-getmodulefilenamew)
43    pub fn get_path(self) -> PathBuf {
44        let hmodule = Some(self.0);
45
46        let mut buf_stack = [0; MAX_PATH as usize];
47        let mut buf = buf_stack.as_mut_slice();
48        let result = unsafe { GetModuleFileNameW(hmodule, buf) };
49
50        let mut buf_heap;
51        let len = if result == 0 {
52            // Error occurred
53            0
54        } else if result == buf.len() as u32 {
55            // Buffer was too small (truncated), try with a larger buffer
56            // Extended path length
57            let mut size = 512;
58            loop {
59                buf_heap = vec![0; size];
60                buf = buf_heap.as_mut_slice();
61                let result = unsafe { GetModuleFileNameW(hmodule, buf) };
62                if result == 0 {
63                    break 0;
64                }
65                if result != size as u32 {
66                    // Success - result is the actual length
67                    break result as usize;
68                }
69                // Still truncated, try larger buffer
70                size *= 2;
71            }
72        } else {
73            // Success - result is the actual length
74            result as usize
75        };
76
77        let path_str = OsString::from_wide(&buf[..len]);
78        PathBuf::from(path_str)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn get_path() {
88        let module = Module::current();
89        let path = module.get_path();
90        println!("Current module path: {:?}", path);
91        assert!(path.exists(), "Module path should exist: {:?}", path);
92    }
93}