wraith/manipulation/hooks/
tracker.rs

1//! Hook state tracking
2//!
3//! Tracks hooks that have been detected and/or removed, allowing
4//! for continuous monitoring and selective re-hooking if needed.
5
6use super::detector::{HookInfo, HookType};
7use std::collections::HashMap;
8use std::sync::Mutex;
9
10/// global hook tracker instance
11static GLOBAL_TRACKER: Mutex<Option<HookTracker>> = Mutex::new(None);
12
13/// state of a tracked hook
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum HookState {
16    /// hook is currently active
17    Active,
18    /// hook was removed by us
19    Removed,
20    /// hook was restored (re-installed)
21    Restored,
22}
23
24/// tracked hook entry
25#[derive(Debug, Clone)]
26pub struct TrackedHook {
27    /// original hook information
28    pub info: HookInfo,
29    /// current state
30    pub state: HookState,
31    /// when this hook was first detected
32    pub detected_at: std::time::Instant,
33    /// when state last changed
34    pub last_changed: std::time::Instant,
35}
36
37impl TrackedHook {
38    fn new(info: HookInfo) -> Self {
39        let now = std::time::Instant::now();
40        Self {
41            info,
42            state: HookState::Active,
43            detected_at: now,
44            last_changed: now,
45        }
46    }
47}
48
49/// tracks hooks across the process
50pub struct HookTracker {
51    /// hooks by function address
52    hooks: HashMap<usize, TrackedHook>,
53    /// hooks indexed by module name
54    by_module: HashMap<String, Vec<usize>>,
55}
56
57impl HookTracker {
58    /// create new hook tracker
59    pub fn new() -> Self {
60        Self {
61            hooks: HashMap::new(),
62            by_module: HashMap::new(),
63        }
64    }
65
66    /// register a detected hook
67    pub fn register(&mut self, info: HookInfo) {
68        let addr = info.function_address;
69        let module = info.module_name.clone();
70
71        self.hooks.insert(addr, TrackedHook::new(info));
72
73        self.by_module
74            .entry(module)
75            .or_insert_with(Vec::new)
76            .push(addr);
77    }
78
79    /// register multiple hooks
80    pub fn register_all(&mut self, hooks: impl IntoIterator<Item = HookInfo>) {
81        for hook in hooks {
82            self.register(hook);
83        }
84    }
85
86    /// mark hook as removed
87    pub fn mark_removed(&mut self, address: usize) {
88        if let Some(tracked) = self.hooks.get_mut(&address) {
89            tracked.state = HookState::Removed;
90            tracked.last_changed = std::time::Instant::now();
91        }
92    }
93
94    /// mark hook as restored
95    pub fn mark_restored(&mut self, address: usize) {
96        if let Some(tracked) = self.hooks.get_mut(&address) {
97            tracked.state = HookState::Restored;
98            tracked.last_changed = std::time::Instant::now();
99        }
100    }
101
102    /// get hook info by address
103    pub fn get(&self, address: usize) -> Option<&TrackedHook> {
104        self.hooks.get(&address)
105    }
106
107    /// get all hooks for a module
108    pub fn get_by_module(&self, module_name: &str) -> Vec<&TrackedHook> {
109        self.by_module
110            .get(module_name)
111            .map(|addrs| {
112                addrs
113                    .iter()
114                    .filter_map(|&addr| self.hooks.get(&addr))
115                    .collect()
116            })
117            .unwrap_or_default()
118    }
119
120    /// get all active hooks
121    pub fn active_hooks(&self) -> Vec<&TrackedHook> {
122        self.hooks
123            .values()
124            .filter(|h| h.state == HookState::Active)
125            .collect()
126    }
127
128    /// get all removed hooks
129    pub fn removed_hooks(&self) -> Vec<&TrackedHook> {
130        self.hooks
131            .values()
132            .filter(|h| h.state == HookState::Removed)
133            .collect()
134    }
135
136    /// get hooks by type
137    pub fn get_by_type(&self, hook_type: HookType) -> Vec<&TrackedHook> {
138        self.hooks
139            .values()
140            .filter(|h| h.info.hook_type == hook_type)
141            .collect()
142    }
143
144    /// total number of tracked hooks
145    pub fn count(&self) -> usize {
146        self.hooks.len()
147    }
148
149    /// number of active hooks
150    pub fn active_count(&self) -> usize {
151        self.hooks
152            .values()
153            .filter(|h| h.state == HookState::Active)
154            .count()
155    }
156
157    /// number of removed hooks
158    pub fn removed_count(&self) -> usize {
159        self.hooks
160            .values()
161            .filter(|h| h.state == HookState::Removed)
162            .count()
163    }
164
165    /// check if an address is tracked
166    pub fn is_tracked(&self, address: usize) -> bool {
167        self.hooks.contains_key(&address)
168    }
169
170    /// unregister a hook (remove from tracking)
171    pub fn unregister(&mut self, address: usize) -> Option<TrackedHook> {
172        if let Some(hook) = self.hooks.remove(&address) {
173            // remove from module index
174            if let Some(addrs) = self.by_module.get_mut(&hook.info.module_name) {
175                addrs.retain(|&a| a != address);
176            }
177            Some(hook)
178        } else {
179            None
180        }
181    }
182
183    /// clear all tracked hooks
184    pub fn clear(&mut self) {
185        self.hooks.clear();
186        self.by_module.clear();
187    }
188
189    /// get modules with tracked hooks
190    pub fn modules(&self) -> Vec<&str> {
191        self.by_module.keys().map(|s| s.as_str()).collect()
192    }
193
194    /// get statistics
195    pub fn stats(&self) -> HookStats {
196        let mut stats = HookStats::default();
197
198        for hook in self.hooks.values() {
199            match hook.state {
200                HookState::Active => stats.active += 1,
201                HookState::Removed => stats.removed += 1,
202                HookState::Restored => stats.restored += 1,
203            }
204
205            match hook.info.hook_type {
206                HookType::JmpRel32 => stats.jmp_rel32 += 1,
207                HookType::JmpIndirect => stats.jmp_indirect += 1,
208                HookType::MovJmpRax => stats.mov_jmp_rax += 1,
209                HookType::PushRet => stats.push_ret += 1,
210                HookType::Breakpoint => stats.breakpoints += 1,
211                HookType::Unknown => stats.unknown += 1,
212            }
213        }
214
215        stats.total = self.hooks.len();
216        stats.modules = self.by_module.len();
217
218        stats
219    }
220}
221
222impl Default for HookTracker {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228/// hook statistics
229#[derive(Debug, Default, Clone)]
230pub struct HookStats {
231    pub total: usize,
232    pub active: usize,
233    pub removed: usize,
234    pub restored: usize,
235    pub modules: usize,
236    pub jmp_rel32: usize,
237    pub jmp_indirect: usize,
238    pub mov_jmp_rax: usize,
239    pub push_ret: usize,
240    pub breakpoints: usize,
241    pub unknown: usize,
242}
243
244impl std::fmt::Display for HookStats {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        writeln!(f, "Hook Statistics:")?;
247        writeln!(f, "  Total: {}", self.total)?;
248        writeln!(f, "  Active: {}", self.active)?;
249        writeln!(f, "  Removed: {}", self.removed)?;
250        writeln!(f, "  Restored: {}", self.restored)?;
251        writeln!(f, "  Modules: {}", self.modules)?;
252        writeln!(f, "  By type:")?;
253        writeln!(f, "    jmp rel32: {}", self.jmp_rel32)?;
254        writeln!(f, "    jmp indirect: {}", self.jmp_indirect)?;
255        writeln!(f, "    mov rax; jmp rax: {}", self.mov_jmp_rax)?;
256        writeln!(f, "    push; ret: {}", self.push_ret)?;
257        writeln!(f, "    breakpoints: {}", self.breakpoints)?;
258        writeln!(f, "    unknown: {}", self.unknown)
259    }
260}
261
262// global tracker functions
263
264/// initialize global tracker
265///
266/// returns false if mutex is poisoned (previous panic while holding lock)
267pub fn init_global_tracker() -> bool {
268    match GLOBAL_TRACKER.lock() {
269        Ok(mut guard) => {
270            if guard.is_none() {
271                *guard = Some(HookTracker::new());
272            }
273            true
274        }
275        Err(poisoned) => {
276            // recover from poisoned mutex by taking the guard anyway
277            let mut guard = poisoned.into_inner();
278            if guard.is_none() {
279                *guard = Some(HookTracker::new());
280            }
281            true
282        }
283    }
284}
285
286/// get reference to global tracker (returns MutexGuard)
287///
288/// returns None if mutex is poisoned and recovery fails
289pub fn global_tracker() -> Option<std::sync::MutexGuard<'static, Option<HookTracker>>> {
290    match GLOBAL_TRACKER.lock() {
291        Ok(guard) => Some(guard),
292        Err(poisoned) => {
293            // recover from poisoned mutex
294            Some(poisoned.into_inner())
295        }
296    }
297}
298
299/// modify global tracker
300///
301/// returns None if tracker is not initialized or mutex is poisoned
302pub fn with_global_tracker<F, R>(f: F) -> Option<R>
303where
304    F: FnOnce(&mut HookTracker) -> R,
305{
306    let mut guard = global_tracker()?;
307    guard.as_mut().map(f)
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    fn dummy_hook(name: &str, addr: usize) -> HookInfo {
315        HookInfo {
316            function_name: name.to_string(),
317            function_address: addr,
318            hook_type: HookType::JmpRel32,
319            hook_destination: Some(0xDEADBEEF),
320            original_bytes: vec![0x90; 5],
321            hooked_bytes: vec![0xE9, 0x00, 0x00, 0x00, 0x00],
322            module_name: "test.dll".to_string(),
323        }
324    }
325
326    #[test]
327    fn test_tracker_basic() {
328        let mut tracker = HookTracker::new();
329
330        tracker.register(dummy_hook("NtReadVirtualMemory", 0x1000));
331        tracker.register(dummy_hook("NtWriteVirtualMemory", 0x2000));
332
333        assert_eq!(tracker.count(), 2);
334        assert_eq!(tracker.active_count(), 2);
335
336        tracker.mark_removed(0x1000);
337        assert_eq!(tracker.active_count(), 1);
338        assert_eq!(tracker.removed_count(), 1);
339    }
340
341    #[test]
342    fn test_stats() {
343        let mut tracker = HookTracker::new();
344
345        tracker.register(dummy_hook("Func1", 0x1000));
346        tracker.register(dummy_hook("Func2", 0x2000));
347        tracker.mark_removed(0x1000);
348
349        let stats = tracker.stats();
350        assert_eq!(stats.total, 2);
351        assert_eq!(stats.active, 1);
352        assert_eq!(stats.removed, 1);
353        assert_eq!(stats.jmp_rel32, 2);
354    }
355}