use super::detector::{HookInfo, HookType};
use std::collections::HashMap;
use std::sync::Mutex;
static GLOBAL_TRACKER: Mutex<Option<HookTracker>> = Mutex::new(None);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookState {
Active,
Removed,
Restored,
}
#[derive(Debug, Clone)]
pub struct TrackedHook {
pub info: HookInfo,
pub state: HookState,
pub detected_at: std::time::Instant,
pub last_changed: std::time::Instant,
}
impl TrackedHook {
fn new(info: HookInfo) -> Self {
let now = std::time::Instant::now();
Self {
info,
state: HookState::Active,
detected_at: now,
last_changed: now,
}
}
}
pub struct HookTracker {
hooks: HashMap<usize, TrackedHook>,
by_module: HashMap<String, Vec<usize>>,
}
impl HookTracker {
pub fn new() -> Self {
Self {
hooks: HashMap::new(),
by_module: HashMap::new(),
}
}
pub fn register(&mut self, info: HookInfo) {
let addr = info.function_address;
let module = info.module_name.clone();
self.hooks.insert(addr, TrackedHook::new(info));
self.by_module
.entry(module)
.or_insert_with(Vec::new)
.push(addr);
}
pub fn register_all(&mut self, hooks: impl IntoIterator<Item = HookInfo>) {
for hook in hooks {
self.register(hook);
}
}
pub fn mark_removed(&mut self, address: usize) {
if let Some(tracked) = self.hooks.get_mut(&address) {
tracked.state = HookState::Removed;
tracked.last_changed = std::time::Instant::now();
}
}
pub fn mark_restored(&mut self, address: usize) {
if let Some(tracked) = self.hooks.get_mut(&address) {
tracked.state = HookState::Restored;
tracked.last_changed = std::time::Instant::now();
}
}
pub fn get(&self, address: usize) -> Option<&TrackedHook> {
self.hooks.get(&address)
}
pub fn get_by_module(&self, module_name: &str) -> Vec<&TrackedHook> {
self.by_module
.get(module_name)
.map(|addrs| {
addrs
.iter()
.filter_map(|&addr| self.hooks.get(&addr))
.collect()
})
.unwrap_or_default()
}
pub fn active_hooks(&self) -> Vec<&TrackedHook> {
self.hooks
.values()
.filter(|h| h.state == HookState::Active)
.collect()
}
pub fn removed_hooks(&self) -> Vec<&TrackedHook> {
self.hooks
.values()
.filter(|h| h.state == HookState::Removed)
.collect()
}
pub fn get_by_type(&self, hook_type: HookType) -> Vec<&TrackedHook> {
self.hooks
.values()
.filter(|h| h.info.hook_type == hook_type)
.collect()
}
pub fn count(&self) -> usize {
self.hooks.len()
}
pub fn active_count(&self) -> usize {
self.hooks
.values()
.filter(|h| h.state == HookState::Active)
.count()
}
pub fn removed_count(&self) -> usize {
self.hooks
.values()
.filter(|h| h.state == HookState::Removed)
.count()
}
pub fn is_tracked(&self, address: usize) -> bool {
self.hooks.contains_key(&address)
}
pub fn unregister(&mut self, address: usize) -> Option<TrackedHook> {
if let Some(hook) = self.hooks.remove(&address) {
if let Some(addrs) = self.by_module.get_mut(&hook.info.module_name) {
addrs.retain(|&a| a != address);
}
Some(hook)
} else {
None
}
}
pub fn clear(&mut self) {
self.hooks.clear();
self.by_module.clear();
}
pub fn modules(&self) -> Vec<&str> {
self.by_module.keys().map(|s| s.as_str()).collect()
}
pub fn stats(&self) -> HookStats {
let mut stats = HookStats::default();
for hook in self.hooks.values() {
match hook.state {
HookState::Active => stats.active += 1,
HookState::Removed => stats.removed += 1,
HookState::Restored => stats.restored += 1,
}
match hook.info.hook_type {
HookType::JmpRel32 => stats.jmp_rel32 += 1,
HookType::JmpIndirect => stats.jmp_indirect += 1,
HookType::MovJmpRax => stats.mov_jmp_rax += 1,
HookType::PushRet => stats.push_ret += 1,
HookType::Breakpoint => stats.breakpoints += 1,
HookType::Unknown => stats.unknown += 1,
}
}
stats.total = self.hooks.len();
stats.modules = self.by_module.len();
stats
}
}
impl Default for HookTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone)]
pub struct HookStats {
pub total: usize,
pub active: usize,
pub removed: usize,
pub restored: usize,
pub modules: usize,
pub jmp_rel32: usize,
pub jmp_indirect: usize,
pub mov_jmp_rax: usize,
pub push_ret: usize,
pub breakpoints: usize,
pub unknown: usize,
}
impl std::fmt::Display for HookStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Hook Statistics:")?;
writeln!(f, " Total: {}", self.total)?;
writeln!(f, " Active: {}", self.active)?;
writeln!(f, " Removed: {}", self.removed)?;
writeln!(f, " Restored: {}", self.restored)?;
writeln!(f, " Modules: {}", self.modules)?;
writeln!(f, " By type:")?;
writeln!(f, " jmp rel32: {}", self.jmp_rel32)?;
writeln!(f, " jmp indirect: {}", self.jmp_indirect)?;
writeln!(f, " mov rax; jmp rax: {}", self.mov_jmp_rax)?;
writeln!(f, " push; ret: {}", self.push_ret)?;
writeln!(f, " breakpoints: {}", self.breakpoints)?;
writeln!(f, " unknown: {}", self.unknown)
}
}
pub fn init_global_tracker() -> bool {
match GLOBAL_TRACKER.lock() {
Ok(mut guard) => {
if guard.is_none() {
*guard = Some(HookTracker::new());
}
true
}
Err(poisoned) => {
let mut guard = poisoned.into_inner();
if guard.is_none() {
*guard = Some(HookTracker::new());
}
true
}
}
}
pub fn global_tracker() -> Option<std::sync::MutexGuard<'static, Option<HookTracker>>> {
match GLOBAL_TRACKER.lock() {
Ok(guard) => Some(guard),
Err(poisoned) => {
Some(poisoned.into_inner())
}
}
}
pub fn with_global_tracker<F, R>(f: F) -> Option<R>
where
F: FnOnce(&mut HookTracker) -> R,
{
let mut guard = global_tracker()?;
guard.as_mut().map(f)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_hook(name: &str, addr: usize) -> HookInfo {
HookInfo {
function_name: name.to_string(),
function_address: addr,
hook_type: HookType::JmpRel32,
hook_destination: Some(0xDEADBEEF),
original_bytes: vec![0x90; 5],
hooked_bytes: vec![0xE9, 0x00, 0x00, 0x00, 0x00],
module_name: "test.dll".to_string(),
}
}
#[test]
fn test_tracker_basic() {
let mut tracker = HookTracker::new();
tracker.register(dummy_hook("NtReadVirtualMemory", 0x1000));
tracker.register(dummy_hook("NtWriteVirtualMemory", 0x2000));
assert_eq!(tracker.count(), 2);
assert_eq!(tracker.active_count(), 2);
tracker.mark_removed(0x1000);
assert_eq!(tracker.active_count(), 1);
assert_eq!(tracker.removed_count(), 1);
}
#[test]
fn test_stats() {
let mut tracker = HookTracker::new();
tracker.register(dummy_hook("Func1", 0x1000));
tracker.register(dummy_hook("Func2", 0x2000));
tracker.mark_removed(0x1000);
let stats = tracker.stats();
assert_eq!(stats.total, 2);
assert_eq!(stats.active, 1);
assert_eq!(stats.removed, 1);
assert_eq!(stats.jmp_rel32, 2);
}
}