Skip to main content

ib_shell_item/hook/
display_name.rs

1use std::{cell::SyncUnsafeCell, ffi::c_void, mem::MaybeUninit, ptr};
2
3use bon::Builder;
4use serde::{Deserialize, Serialize};
5use tracing::{debug, error};
6use widestring::U16CStr;
7use windows::{
8    Win32::UI::Shell::SIGDN,
9    core::{HRESULT, PWSTR},
10};
11
12use crate::{ShellItemDisplayName, hook::HOOK_CONFIG, string};
13
14#[derive(Default, Serialize, Deserialize, Clone, Builder, Debug)]
15#[builder(on(Vec<u16>, into))]
16pub struct DisplayNameHookConfig {
17    /// Mainly for testing.
18    display_prefix: Option<Vec<u16>>,
19    /// Mainly for testing.
20    edit_prefix: Option<Vec<u16>>,
21}
22
23pub(crate) type GetDisplayNameFn =
24    unsafe extern "system" fn(*mut c_void, SIGDN, *mut PWSTR) -> HRESULT;
25
26/// Store original GetDisplayName function pointer (lazy initialized)
27pub(crate) static ORIGINAL_GET_DISPLAY_NAME: SyncUnsafeCell<Option<GetDisplayNameFn>> =
28    SyncUnsafeCell::new(None);
29pub(crate) static TRUE_GET_DISPLAY_NAME: SyncUnsafeCell<MaybeUninit<GetDisplayNameFn>> =
30    SyncUnsafeCell::new(MaybeUninit::uninit());
31
32/// Hooked GetDisplayName function
33pub(crate) unsafe extern "system" fn sh_get_display_name(
34    this: *mut core::ffi::c_void,
35    sigdn_name: SIGDN,
36    ppsz_name: *mut windows::core::PWSTR,
37) -> HRESULT {
38    let true_get_display_name = TRUE_GET_DISPLAY_NAME.get();
39    let real = || unsafe { (*true_get_display_name).assume_init()(this, sigdn_name, ppsz_name) };
40
41    let config = HOOK_CONFIG.read().unwrap();
42    let Some(config) = &config.display_name else {
43        return real();
44    };
45
46    // Call original function
47    let result = real();
48
49    // Log the display name
50    if result.is_ok() {
51        let name = (unsafe { *ppsz_name }).0;
52        match ShellItemDisplayName::try_from(sigdn_name.0) {
53            Ok(ShellItemDisplayName::FileSystemPath) => (),
54            Ok(d) if d.is_for_display() | d.is_for_edit() => {
55                let name = unsafe { U16CStr::from_ptr_str(name) };
56                debug!(?d, ?name, "GetDisplayName for display");
57                if let Some(prefix) = &config.display_prefix {
58                    let new_name = string::prefix_u16cstr_ptr(name, prefix);
59                    unsafe { *ppsz_name = new_name };
60                }
61            }
62            Ok(d) if d.is_for_edit() => {
63                let name = unsafe { U16CStr::from_ptr_str(name) };
64                debug!(?d, ?name, "GetDisplayName for edit");
65                if let Some(prefix) = &config.edit_prefix {
66                    let new_name = string::prefix_u16cstr_ptr(name, prefix);
67                    unsafe { *ppsz_name = new_name };
68                }
69            }
70            Ok(d) => {
71                let name = unsafe { U16CStr::from_ptr_str(name) };
72                debug!(?d, ?name, "GetDisplayName for parse");
73            }
74            Err(_) => {
75                let name = unsafe { U16CStr::from_ptr_str(name) };
76                debug!(?name, "GetDisplayName unknown");
77            }
78        }
79    }
80
81    result
82}
83
84fn hook(enable: bool) -> windows::core::Result<()> {
85    let res = unsafe {
86        slim_detours_sys::SlimDetoursInlineHook(
87            enable as _,
88            TRUE_GET_DISPLAY_NAME.get().cast(),
89            sh_get_display_name as _,
90        )
91    };
92    windows::core::HRESULT(res).ok()
93}
94
95pub(crate) fn enable_hook(get_display_name: GetDisplayNameFn) -> windows::core::Result<()> {
96    match unsafe { *ORIGINAL_GET_DISPLAY_NAME.get() } {
97        Some(f) if ptr::fn_addr_eq(f, get_display_name) => Ok(()),
98        None => {
99            // Not yet hooked, store original and hook
100            unsafe { *ORIGINAL_GET_DISPLAY_NAME.get() = Some(get_display_name) };
101            unsafe { (*TRUE_GET_DISPLAY_NAME.get()).write(get_display_name) };
102            debug!(?get_display_name, "Hooking GetDisplayName");
103            hook(true)
104        }
105        // Some(f) if ptr::fn_addr_eq(f, sh_get_display_name as GetDisplayNameFn) => Ok(()),
106        Some(f) => {
107            // TODO
108            error!(?f, ?get_display_name, "Multi GetDisplayName");
109            windows::core::HRESULT(1).ok()
110        }
111    }
112}
113
114pub(crate) fn disable_hook() -> windows::core::Result<()> {
115    if unsafe { *ORIGINAL_GET_DISPLAY_NAME.get() }.is_some() {
116        // Unhook and restore original
117        debug!("Unhooking GetDisplayName");
118        hook(false)
119    } else {
120        Ok(())
121    }
122}