use std::{
cmp::Reverse,
collections::HashMap,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock,
},
};
use dashmap::DashMap;
use crate::{
emulation::{
runtime::hook::{
core::Hook,
types::{HookContext, HookOutcome, PostHookResult, PreHookResult},
},
EmValue, EmulationError, EmulationThread,
},
Result,
};
type TypeMethodHookIndex = HashMap<Arc<str>, HashMap<Arc<str>, Vec<HookEntry>>>;
enum IndexRoute {
Full(Arc<str>, Arc<str>, Arc<str>),
TypeLevel(Arc<str>, Arc<str>),
MethodOnly(Arc<str>),
}
struct HookEntry {
hook: Arc<Hook>,
has_runtime_matchers: bool,
has_signature_matchers: bool,
}
impl HookEntry {
fn new(hook: Hook) -> Self {
let has_runtime = hook.matchers().iter().any(|m| m.is_runtime_matcher());
let has_signature = hook.matchers().iter().any(|m| m.is_signature_matcher());
Self {
hook: Arc::new(hook),
has_runtime_matchers: has_runtime,
has_signature_matchers: has_signature,
}
}
}
fn extract_route(hook: &Hook) -> Option<IndexRoute> {
for matcher in hook.matchers() {
if let Some((ns, ty, method)) = matcher.name_components() {
return match (ns, ty, method) {
(Some(ns), Some(ty), Some(m)) => {
Some(IndexRoute::Full(ns.into(), ty.into(), m.into()))
}
(Some(ns), Some(ty), None) => Some(IndexRoute::TypeLevel(ns.into(), ty.into())),
(_, _, Some(m)) => Some(IndexRoute::MethodOnly(m.into())),
_ => None,
};
}
}
None
}
fn insert_sorted(bucket: &mut Vec<HookEntry>, entry: HookEntry) {
bucket.push(entry);
bucket.sort_by_key(|e| Reverse(e.hook.priority()));
}
pub struct HookManager {
full_index: DashMap<Arc<str>, TypeMethodHookIndex>,
type_index: DashMap<Arc<str>, HashMap<Arc<str>, Vec<HookEntry>>>,
method_index: DashMap<Arc<str>, Vec<HookEntry>>,
wildcard_hooks: RwLock<Vec<HookEntry>>,
total_count: AtomicUsize,
}
impl Default for HookManager {
fn default() -> Self {
Self {
full_index: DashMap::new(),
type_index: DashMap::new(),
method_index: DashMap::new(),
wildcard_hooks: RwLock::new(Vec::new()),
total_count: AtomicUsize::new(0),
}
}
}
impl HookManager {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, hook: Hook) -> Result<()> {
let route = extract_route(&hook);
let entry = HookEntry::new(hook);
match route {
Some(IndexRoute::Full(ns, ty, method)) => {
let mut ns_map = self.full_index.entry(ns).or_default();
let bucket = ns_map.entry(ty).or_default().entry(method).or_default();
insert_sorted(bucket, entry);
}
Some(IndexRoute::TypeLevel(ns, ty)) => {
let mut ns_map = self.type_index.entry(ns).or_default();
let bucket = ns_map.entry(ty).or_default();
insert_sorted(bucket, entry);
}
Some(IndexRoute::MethodOnly(method)) => {
let mut bucket = self.method_index.entry(method).or_default();
insert_sorted(&mut bucket, entry);
}
None => {
let mut wildcards =
self.wildcard_hooks
.write()
.map_err(|_| EmulationError::LockPoisoned {
description: "hook manager",
})?;
insert_sorted(&mut wildcards, entry);
}
}
self.total_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn has_potential_match(
&self,
namespace: &str,
type_name: &str,
method_name: &str,
) -> Result<bool> {
let wildcards = self
.wildcard_hooks
.read()
.map_err(|_| EmulationError::LockPoisoned {
description: "hook manager",
})?;
if !wildcards.is_empty() {
return Ok(true);
}
drop(wildcards);
if let Some(type_map) = self.full_index.get(namespace) {
if let Some(method_map) = type_map.get(type_name) {
if method_map.contains_key(method_name) {
return Ok(true);
}
}
}
if let Some(type_map) = self.type_index.get(namespace) {
if type_map.contains_key(type_name) {
return Ok(true);
}
}
Ok(self.method_index.contains_key(method_name))
}
pub fn execute<F>(
&self,
context: &HookContext<'_>,
thread: &mut EmulationThread,
execute_original: F,
) -> Result<HookOutcome>
where
F: FnOnce(&mut EmulationThread) -> Option<EmValue>,
{
let Some(hook) = self.find_matching(context, thread)? else {
return Ok(HookOutcome::NoMatch);
};
let pre_result = hook.execute_pre(context, thread);
match pre_result {
Some(PreHookResult::Bypass(value)) => {
return Ok(HookOutcome::Handled(value));
}
Some(PreHookResult::ReflectionInvoke {
request,
bypass_value,
}) => {
return Ok(HookOutcome::ReflectionInvoke {
request,
bypass_value,
});
}
Some(PreHookResult::Error(msg)) => {
return Err(EmulationError::HookError(format!(
"Hook '{}' pre-hook error: {}",
hook.name(),
msg
))
.into());
}
Some(PreHookResult::Throw {
exception_type,
message,
}) => {
return Ok(HookOutcome::ThrewException {
exception_type,
message,
});
}
Some(PreHookResult::Continue) | None => {
}
}
let original_result = execute_original(thread);
match hook.execute_post(context, thread, original_result.as_ref()) {
Some(PostHookResult::Replace(new_value)) => Ok(HookOutcome::Handled(new_value)),
Some(PostHookResult::Error(msg)) => Err(EmulationError::HookError(format!(
"Hook '{}' post-hook error: {}",
hook.name(),
msg
))
.into()),
Some(PostHookResult::Keep) => Ok(HookOutcome::Handled(original_result)),
None => {
if original_result.is_none() {
Ok(HookOutcome::NoMatch)
} else {
Ok(HookOutcome::Handled(original_result))
}
}
}
}
#[must_use]
pub fn len(&self) -> usize {
self.total_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn find_matching(
&self,
context: &HookContext<'_>,
thread: &EmulationThread,
) -> Result<Option<Arc<Hook>>> {
if let Some(type_map) = self.full_index.get(context.namespace) {
if let Some(method_map) = type_map.get(context.type_name) {
if let Some(candidates) = method_map.get(context.method_name) {
for entry in candidates {
if !entry.has_runtime_matchers
&& !entry.has_signature_matchers
&& entry.hook.matchers().len() == 1
{
return Ok(Some(Arc::clone(&entry.hook)));
}
if entry.hook.matches(context, thread) {
return Ok(Some(Arc::clone(&entry.hook)));
}
}
}
}
}
if let Some(type_map) = self.type_index.get(context.namespace) {
if let Some(candidates) = type_map.get(context.type_name) {
for entry in candidates {
if entry.hook.matches(context, thread) {
return Ok(Some(Arc::clone(&entry.hook)));
}
}
}
}
if let Some(candidates) = self.method_index.get(context.method_name) {
for entry in candidates.iter() {
if entry.hook.matches(context, thread) {
return Ok(Some(Arc::clone(&entry.hook)));
}
}
}
let wildcards = self
.wildcard_hooks
.read()
.map_err(|_| EmulationError::LockPoisoned {
description: "hook manager",
})?;
Ok(wildcards
.iter()
.find(|e| e.hook.matches(context, thread))
.map(|e| Arc::clone(&e.hook)))
}
}
impl std::fmt::Debug for HookManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let wildcard_count = self
.wildcard_hooks
.read()
.map_err(|_| std::fmt::Error)?
.len();
f.debug_struct("HookManager")
.field("hook_count", &self.total_count.load(Ordering::Relaxed))
.field("full_index_namespaces", &self.full_index.len())
.field("type_index_namespaces", &self.type_index.len())
.field("method_index_entries", &self.method_index.len())
.field("wildcard_hooks", &wildcard_count)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
emulation::HookPriority,
metadata::{token::Token, typesystem::PointerSize},
test::emulation::create_test_thread,
};
#[test]
fn test_hook_manager_empty() {
let manager = HookManager::new();
assert!(manager.is_empty());
assert_eq!(manager.len(), 0);
}
#[test]
fn test_hook_manager_registration() {
let manager = HookManager::new();
manager
.register(Hook::new("hook1").match_method_name("Method1"))
.unwrap();
manager
.register(Hook::new("hook2").match_method_name("Method2"))
.unwrap();
assert_eq!(manager.len(), 2);
assert!(!manager.is_empty());
}
#[test]
fn test_hook_manager_priority_sorting() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("low")
.with_priority(HookPriority::LOW)
.match_method_name("Test")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(1)))),
)
.unwrap();
manager
.register(
Hook::new("high")
.with_priority(HookPriority::HIGH)
.match_method_name("Test")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(3)))),
)
.unwrap();
manager
.register(
Hook::new("normal")
.with_priority(HookPriority::NORMAL)
.match_method_name("Test")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(2)))),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let outcome = manager.execute(&context, &mut thread, |_| None).unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(3)))
));
}
#[test]
fn test_has_potential_match() {
let manager = HookManager::new();
manager
.register(Hook::new("string-concat").match_name("System", "String", "Concat"))
.unwrap();
assert!(manager
.has_potential_match("System", "String", "Concat")
.unwrap());
assert!(!manager
.has_potential_match("System", "String", "Replace")
.unwrap());
assert!(!manager
.has_potential_match("System", "Math", "Abs")
.unwrap());
}
#[test]
fn test_has_potential_match_with_wildcard() {
let manager = HookManager::new();
manager
.register(Hook::new("wildcard").match_runtime("always", |_, _| true))
.unwrap();
assert!(manager
.has_potential_match("Any", "Thing", "AtAll")
.unwrap());
}
#[test]
fn test_has_potential_match_type_level() {
let manager = HookManager::new();
manager
.register(Hook::new("type-hook").match_type_name("Console"))
.unwrap();
assert!(manager
.has_potential_match("System", "Console", "WriteLine")
.unwrap());
}
#[test]
fn test_index_full_key_match() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("exact-match")
.match_name("System", "String", "Concat")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(42)))),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Concat",
PointerSize::Bit64,
);
let outcome = manager
.execute(&context, &mut thread, |_| Some(EmValue::I32(100)))
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(42)))
));
}
#[test]
fn test_index_method_only_match() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("method-only")
.match_method_name("Decrypt")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(99)))),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"Custom",
"Obfuscator",
"Decrypt",
PointerSize::Bit64,
);
let outcome = manager.execute(&context, &mut thread, |_| None).unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(99)))
));
}
#[test]
fn test_execute_no_match() {
let manager = HookManager::new();
let mut thread = create_test_thread();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Concat",
PointerSize::Bit64,
);
let outcome = manager
.execute(&context, &mut thread, |_| Some(EmValue::I32(100)))
.unwrap();
assert!(matches!(outcome, HookOutcome::NoMatch));
}
#[test]
fn test_execute_pre_hook_bypass() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("bypass-test")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(42)))),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let original_called = std::sync::atomic::AtomicBool::new(false);
let outcome = manager
.execute(&context, &mut thread, |_| {
original_called.store(true, std::sync::atomic::Ordering::SeqCst);
Some(EmValue::I32(999))
})
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(42)))
));
assert!(!original_called.load(std::sync::atomic::Ordering::SeqCst));
}
#[test]
fn test_execute_pre_hook_continue_then_post_hook() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("continue-then-modify")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Continue)
.post(|_ctx, _thread, result| {
if let Some(EmValue::I32(v)) = result {
PostHookResult::Replace(Some(EmValue::I32(v * 2)))
} else {
PostHookResult::Keep
}
}),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let outcome = manager
.execute(&context, &mut thread, |_| Some(EmValue::I32(50)))
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(100)))
));
}
#[test]
fn test_execute_pre_hook_error() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("error-test")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Error("test error".to_string())),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let result = manager.execute(&context, &mut thread, |_| Some(EmValue::I32(100)));
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("test error"));
}
#[test]
fn test_execute_post_hook_keep() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("post-keep")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Continue)
.post(|_ctx, _thread, _result| PostHookResult::Keep),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let outcome = manager
.execute(&context, &mut thread, |_| Some(EmValue::I32(123)))
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(123)))
));
}
#[test]
fn test_execute_post_hook_error() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("post-error")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Continue)
.post(|_ctx, _thread, _result| PostHookResult::Error("post error".to_string())),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let result = manager.execute(&context, &mut thread, |_| Some(EmValue::I32(100)));
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("post error"));
}
#[test]
fn test_multiple_hooks_same_key_priority_order() {
let manager = HookManager::new();
let mut thread = create_test_thread();
manager
.register(
Hook::new("low-priority")
.with_priority(HookPriority::LOW)
.match_name("System", "String", "Concat")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(1)))),
)
.unwrap();
manager
.register(
Hook::new("high-priority")
.with_priority(HookPriority::HIGH)
.match_name("System", "String", "Concat")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(2)))),
)
.unwrap();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Concat",
PointerSize::Bit64,
);
let outcome = manager.execute(&context, &mut thread, |_| None).unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(2)))
));
}
#[test]
fn test_registration_count() {
let manager = HookManager::new();
manager
.register(Hook::new("hook1").match_name("System", "String", "Concat"))
.unwrap();
manager
.register(Hook::new("hook2").match_name("System", "Math", "Abs"))
.unwrap();
manager
.register(Hook::new("hook3").match_method_name("Decrypt"))
.unwrap();
assert_eq!(manager.len(), 3);
}
}