use std::sync::Arc;
use crate::{
emulation::{
runtime::hook::{
matcher::{
HookMatcher, InternalMethodMatcher, NameMatcher, RuntimeMatcher, SignatureMatcher,
},
types::{
HookContext, HookPriority, PostHookFn, PostHookResult, PreHookFn, PreHookResult,
},
},
EmValue, EmulationThread,
},
metadata::typesystem::CilFlavor,
};
pub struct Hook {
name: String,
priority: HookPriority,
matchers: Vec<Box<dyn HookMatcher>>,
pre_fn: Option<PreHookFn>,
post_fn: Option<PostHookFn>,
}
impl Hook {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
priority: HookPriority::NORMAL,
matchers: Vec::new(),
pre_fn: None,
post_fn: None,
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn priority(&self) -> HookPriority {
self.priority
}
#[must_use]
pub fn with_priority(mut self, priority: HookPriority) -> Self {
self.priority = priority;
self
}
#[must_use]
pub fn add_matcher<M: HookMatcher + 'static>(mut self, matcher: M) -> Self {
self.matchers.push(Box::new(matcher));
self
}
#[must_use]
pub fn match_name(
self,
namespace: impl Into<String>,
type_name: impl Into<String>,
method_name: impl Into<String>,
) -> Self {
self.add_matcher(NameMatcher::full(namespace, type_name, method_name))
}
#[must_use]
pub fn match_method_name(self, method_name: impl Into<String>) -> Self {
self.add_matcher(NameMatcher::new().method_name(method_name))
}
#[must_use]
pub fn match_type_name(self, type_name: impl Into<String>) -> Self {
self.add_matcher(NameMatcher::new().type_name(type_name))
}
#[must_use]
pub fn match_internal_method(self) -> Self {
self.add_matcher(InternalMethodMatcher)
}
#[must_use]
pub fn match_native(self, dll: impl Into<String>, function: impl Into<String>) -> Self {
self.add_matcher(super::matcher::NativeMethodMatcher::full(dll, function))
}
#[must_use]
pub fn match_native_dll(self, dll: impl Into<String>) -> Self {
self.add_matcher(super::matcher::NativeMethodMatcher::new().dll(dll))
}
#[must_use]
pub fn match_signature(self, params: Vec<CilFlavor>, return_type: Option<CilFlavor>) -> Self {
let mut matcher = SignatureMatcher::new().params(params);
if let Some(ret) = return_type {
matcher = matcher.returns(ret);
}
self.add_matcher(matcher)
}
#[must_use]
pub fn match_runtime<F>(self, description: impl Into<String>, predicate: F) -> Self
where
F: Fn(&HookContext<'_>, &EmulationThread) -> bool + Send + Sync + 'static,
{
self.add_matcher(RuntimeMatcher::new(description, predicate))
}
#[must_use]
pub fn pre<F>(mut self, handler: F) -> Self
where
F: Fn(&HookContext<'_>, &mut EmulationThread) -> PreHookResult + Send + Sync + 'static,
{
self.pre_fn = Some(Arc::new(handler));
self
}
#[must_use]
pub fn post<F>(mut self, handler: F) -> Self
where
F: Fn(&HookContext<'_>, &mut EmulationThread, Option<&EmValue>) -> PostHookResult
+ Send
+ Sync
+ 'static,
{
self.post_fn = Some(Arc::new(handler));
self
}
#[must_use]
pub fn matches(&self, context: &HookContext<'_>, thread: &EmulationThread) -> bool {
if self.matchers.is_empty() {
return false;
}
self.matchers.iter().all(|m| m.matches(context, thread))
}
pub fn execute_pre(
&self,
context: &HookContext<'_>,
thread: &mut EmulationThread,
) -> Option<PreHookResult> {
self.pre_fn.as_ref().map(|hook| hook(context, thread))
}
pub fn execute_post(
&self,
context: &HookContext<'_>,
thread: &mut EmulationThread,
result: Option<&EmValue>,
) -> Option<PostHookResult> {
self.post_fn
.as_ref()
.map(|hook| hook(context, thread, result))
}
#[must_use]
pub fn has_pre_hook(&self) -> bool {
self.pre_fn.is_some()
}
#[must_use]
pub fn has_post_hook(&self) -> bool {
self.post_fn.is_some()
}
}
impl std::fmt::Debug for Hook {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Hook")
.field("name", &self.name)
.field("priority", &self.priority)
.field("matcher_count", &self.matchers.len())
.field("has_pre_hook", &self.pre_fn.is_some())
.field("has_post_hook", &self.post_fn.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_builder() {
let hook = Hook::new("test-hook")
.with_priority(HookPriority::HIGH)
.match_name("System", "String", "Concat");
assert_eq!(hook.name(), "test-hook");
assert_eq!(hook.priority(), HookPriority::HIGH);
assert!(!hook.has_pre_hook());
assert!(!hook.has_post_hook());
}
#[test]
fn test_hook_with_pre_handler() {
let hook = Hook::new("test-hook")
.match_method_name("Test")
.pre(|_ctx, _thread| PreHookResult::Continue);
assert!(hook.has_pre_hook());
assert!(!hook.has_post_hook());
}
#[test]
fn test_hook_with_post_handler() {
let hook = Hook::new("test-hook")
.match_method_name("Test")
.post(|_ctx, _thread, _result| PostHookResult::Keep);
assert!(!hook.has_pre_hook());
assert!(hook.has_post_hook());
}
#[test]
fn test_empty_matchers_dont_match() {
let hook = Hook::new("empty");
assert!(hook.matchers.is_empty());
}
}