use crate::{
emulation::{
runtime::hook::{
core::Hook,
types::{HookContext, HookOutcome, PostHookResult, PreHookResult},
},
EmValue, EmulationError, EmulationThread,
},
Result,
};
use std::cmp::Reverse;
#[derive(Default)]
pub struct HookManager {
hooks: Vec<Hook>,
}
impl HookManager {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, hook: Hook) {
self.hooks.push(hook);
self.hooks.sort_by_key(|h| Reverse(h.priority()));
}
#[must_use]
pub fn find_matching<'a>(
&'a self,
context: &HookContext<'_>,
thread: &EmulationThread,
) -> Option<&'a Hook> {
self.hooks.iter().find(|h| h.matches(context, thread))
}
pub fn execute<F>(
&self,
context: &HookContext<'_>,
thread: &mut EmulationThread,
execute_original: F,
) -> Result<HookOutcome>
where
F: FnOnce(&mut EmulationThread) -> Option<EmValue>,
{
let Some(hook) = self.find_matching(context, thread) else {
return Ok(HookOutcome::NoMatch);
};
let pre_result = hook.execute_pre(context, thread);
match pre_result {
Some(PreHookResult::Bypass(value)) => {
return Ok(HookOutcome::Handled(value));
}
Some(PreHookResult::Error(msg)) => {
return Err(EmulationError::HookError(format!(
"Hook '{}' pre-hook error: {}",
hook.name(),
msg
))
.into());
}
Some(PreHookResult::Continue) | None => {
}
}
let original_result = execute_original(thread);
match hook.execute_post(context, thread, original_result.as_ref()) {
Some(PostHookResult::Replace(new_value)) => Ok(HookOutcome::Handled(new_value)),
Some(PostHookResult::Error(msg)) => Err(EmulationError::HookError(format!(
"Hook '{}' post-hook error: {}",
hook.name(),
msg
))
.into()),
Some(PostHookResult::Keep) => Ok(HookOutcome::Handled(original_result)),
None => {
if original_result.is_none() {
Ok(HookOutcome::NoMatch)
} else {
Ok(HookOutcome::Handled(original_result))
}
}
}
}
#[must_use]
pub fn len(&self) -> usize {
self.hooks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &Hook> {
self.hooks.iter()
}
}
impl std::fmt::Debug for HookManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookManager")
.field("hook_count", &self.hooks.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
emulation::HookPriority,
metadata::{token::Token, typesystem::PointerSize},
test::emulation::create_test_thread,
};
#[test]
fn test_hook_manager_empty() {
let manager = HookManager::new();
assert!(manager.is_empty());
assert_eq!(manager.len(), 0);
}
#[test]
fn test_hook_manager_registration() {
let mut manager = HookManager::new();
manager.register(Hook::new("hook1").match_method_name("Method1"));
manager.register(Hook::new("hook2").match_method_name("Method2"));
assert_eq!(manager.len(), 2);
assert!(!manager.is_empty());
}
#[test]
fn test_hook_manager_priority_sorting() {
let mut manager = HookManager::new();
manager.register(
Hook::new("low")
.with_priority(HookPriority::LOW)
.match_method_name("Test"),
);
manager.register(
Hook::new("high")
.with_priority(HookPriority::HIGH)
.match_method_name("Test"),
);
manager.register(
Hook::new("normal")
.with_priority(HookPriority::NORMAL)
.match_method_name("Test"),
);
let names: Vec<_> = manager.iter().map(|h| h.name()).collect();
assert_eq!(names, vec!["high", "normal", "low"]);
}
#[test]
fn test_execute_no_match() {
let manager = HookManager::new();
let mut thread = create_test_thread();
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Concat",
PointerSize::Bit64,
);
let outcome = manager
.execute(&context, &mut thread, |_| Some(EmValue::I32(100)))
.unwrap();
assert!(matches!(outcome, HookOutcome::NoMatch));
}
#[test]
fn test_execute_pre_hook_bypass() {
let mut manager = HookManager::new();
let mut thread = create_test_thread();
manager.register(
Hook::new("bypass-test")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Bypass(Some(EmValue::I32(42)))),
);
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let original_called = std::sync::atomic::AtomicBool::new(false);
let outcome = manager
.execute(&context, &mut thread, |_| {
original_called.store(true, std::sync::atomic::Ordering::SeqCst);
Some(EmValue::I32(999))
})
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(42)))
));
assert!(!original_called.load(std::sync::atomic::Ordering::SeqCst));
}
#[test]
fn test_execute_pre_hook_continue_then_post_hook() {
let mut manager = HookManager::new();
let mut thread = create_test_thread();
manager.register(
Hook::new("continue-then-modify")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Continue)
.post(|_ctx, _thread, result| {
if let Some(EmValue::I32(v)) = result {
PostHookResult::Replace(Some(EmValue::I32(v * 2)))
} else {
PostHookResult::Keep
}
}),
);
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let outcome = manager
.execute(
&context,
&mut thread,
|_| Some(EmValue::I32(50)), )
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(100)))
));
}
#[test]
fn test_execute_pre_hook_error() {
let mut manager = HookManager::new();
let mut thread = create_test_thread();
manager.register(
Hook::new("error-test")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Error("test error".to_string())),
);
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let result = manager.execute(&context, &mut thread, |_| Some(EmValue::I32(100)));
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("test error"));
}
#[test]
fn test_execute_post_hook_keep() {
let mut manager = HookManager::new();
let mut thread = create_test_thread();
manager.register(
Hook::new("post-keep")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Continue)
.post(|_ctx, _thread, _result| PostHookResult::Keep),
);
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let outcome = manager
.execute(&context, &mut thread, |_| Some(EmValue::I32(123)))
.unwrap();
assert!(matches!(
outcome,
HookOutcome::Handled(Some(EmValue::I32(123)))
));
}
#[test]
fn test_execute_post_hook_error() {
let mut manager = HookManager::new();
let mut thread = create_test_thread();
manager.register(
Hook::new("post-error")
.match_name("System", "String", "Test")
.pre(|_ctx, _thread| PreHookResult::Continue)
.post(|_ctx, _thread, _result| PostHookResult::Error("post error".to_string())),
);
let context = HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Test",
PointerSize::Bit64,
);
let result = manager.execute(&context, &mut thread, |_| Some(EmValue::I32(100)));
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("post error"));
}
}