retour-utils 0.2.1

Utility crate for creating hooks with `retour`
pub mod error;

use std::iter;

pub use error::Error;
/// Macro used to hook multiple `retour::StaticDetour`s at once
/// 
/// Reads a `mod` block and generating static detours from #[hook] macros. 
/// #[hook] is not its own macro that gets expanded, it is read and removed by the [`hook_module`] 
/// 
/// # `#[hook] syntax
/// Hook based off symbol
/// ```
/// #[hook(<unsafe> <extern> <ABI> DETOUR_NAME, symbol = "SYMBOL_NAME")]
/// ```
/// Hook based off offset
/// ```
/// #[hook(<unsafe> <extern> <ABI> DETOUR_NAME, offset = 0xDEAD_BEEF)]
/// ```
/// Examples:
/// ```
/// #[hook(unsafe extern "system" MessageBoxA_Detour, symbol = "MessageBoxA")]
/// fn messageboxw_detour(hwnd: HWND, text: PCWSTR, _caption: PCWSTR, u_type: u32) -> i32 { ... }
/// 
/// #[hook(Add_Detour, symbol = "add")]
/// fn add(a: i32, b: i32) -> i32 { ... }
/// 
/// #[hook(lua_newstate_Detour, offset = 0x4321)]
/// fn add(a: i32, b: i32) -> i32 { ... }
/// ```
/// 
/// 
/// # Example
/// ```
/// use retour_utils::hook_module;
///
/// #[hook_module("lua52.dll")]
/// mod lua {
///     // Creates a StaticDetour called Lua_newstate with the same function type as our function 
///     // (minus abi/unsafe to work with retour crate)
///     #[hook(unsafe extern "C" Lua_newstate, symbol = "Lua_newstate")]
///     pub fn newstate(f: *mut lua_Alloc, ud: *mut std::ffi::c_void) -> *mut lua_State {
///         unsafe {
///             Lua_newstate.call(f, ud)
///         }
///     }
/// 
///     // == Generated by macro: ==
///     // #[hook_module] will create `MODULE_NAME`:
///     const MODULE_NAME: &str = "lua52.dll"
///     // and init_detours function
///     pub unsafe init_detours() -> crate::Result<()> {..}
///     // which will initialize all the StaticDetours generated by the macro inside this module
/// }
/// ```
pub use retour_utils_impl::hook_module;
use windows::{
    core::{PCSTR, PCWSTR},
    Win32::{
        Foundation::HMODULE,
        System::LibraryLoader::{GetModuleHandleW, GetProcAddress},
    },
};

type Result<T> = std::result::Result<T, error::Error>;

pub enum LookupData {
    Offset {
        module: &'static str,
        offset: usize,
    },
    Symbol {
        module: &'static str,
        symbol: &'static str,
    },
}

impl LookupData {
    pub const fn from_offset(module: &'static str, offset: usize) -> Self {
        Self::Offset { module, offset }
    }

    pub const fn from_symbol(module: &'static str, symbol: &'static str) -> Self {
        Self::Symbol { module, symbol }
    }
    fn get_module(&self) -> &str {
        match self {
            Self::Offset { module, .. } => module,
            Self::Symbol { module, .. } => module,
        }
    }
    #[cfg(windows)]
    /// From a Windows Handle, get the address of a function in memory
    /// 
    fn address_from_handle(&self, handle: HMODULE) -> Option<*const ()> {
        use std::ffi::CString;

        match self {
            LookupData::Offset { offset, .. } => {
                // On Windows, HINSTANCE is the start address of the library,
                //  so we just add the offset to get the address
                Some((handle.0 as usize + offset) as *const ())
            }
            LookupData::Symbol { symbol, .. } => {
                let c_symbol = CString::new(*symbol).ok()?;
                let wrapped_ptr = PCSTR::from_raw(c_symbol.as_ptr() as *const u8);
                unsafe { GetProcAddress(handle, wrapped_ptr) }.map(|func_ptr| func_ptr as *const ())
            }
        }
    }
}

/// Initialize detour by passing the address of original function to `init_detour_fn`
/// 
/// This is called by `init_detours`, which is generated by the [`hook_module`] macro
/// 
/// ## Support
/// Only works on Windows by calling `GetModuleHandleW` to get the process' handle
pub unsafe fn init_detour(
    lookup_data: LookupData,
    init_detour_fn: fn(*const ()) -> retour::Result<()>,
) -> Result<()> {
    let module = lookup_data.get_module().to_string();
    let module_w_ptr = module
        .encode_utf16()
        .chain(iter::once(0))
        .collect::<Vec<u16>>()
        .as_ptr();
    let wrapped_ptr = PCWSTR::from_raw(module_w_ptr);

    // Get handle to module (aka dll)
    if let Ok(handle) = unsafe { GetModuleHandleW(wrapped_ptr) } {
        let Some(addr) = lookup_data.address_from_handle(handle) else {
            return Err(Error::ModuleNotLoaded);
        };
        init_detour_fn(addr)?;
    }
    Ok(())
}