use std::sync::Arc;
use crate::{
emulation::{runtime::hook::types::HookContext, EmulationThread},
metadata::typesystem::CilFlavor,
};
pub type RuntimePredicate = dyn Fn(&HookContext<'_>, &EmulationThread) -> bool + Send + Sync;
pub trait HookMatcher: Send + Sync {
fn matches(&self, context: &HookContext<'_>, thread: &EmulationThread) -> bool;
fn description(&self) -> String;
}
#[derive(Clone, Debug, Default)]
pub struct NameMatcher {
namespace: Option<String>,
type_name: Option<String>,
method_name: Option<String>,
}
impl NameMatcher {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn namespace(mut self, ns: impl Into<String>) -> Self {
self.namespace = Some(ns.into());
self
}
#[must_use]
pub fn type_name(mut self, name: impl Into<String>) -> Self {
self.type_name = Some(name.into());
self
}
#[must_use]
pub fn method_name(mut self, name: impl Into<String>) -> Self {
self.method_name = Some(name.into());
self
}
#[must_use]
pub fn full(
namespace: impl Into<String>,
type_name: impl Into<String>,
method_name: impl Into<String>,
) -> Self {
Self {
namespace: Some(namespace.into()),
type_name: Some(type_name.into()),
method_name: Some(method_name.into()),
}
}
}
impl HookMatcher for NameMatcher {
fn matches(&self, context: &HookContext<'_>, _thread: &EmulationThread) -> bool {
let ns_matches = self
.namespace
.as_ref()
.is_none_or(|ns| ns == context.namespace);
let type_matches = self
.type_name
.as_ref()
.is_none_or(|t| t == context.type_name);
let method_matches = self
.method_name
.as_ref()
.is_none_or(|m| m == context.method_name);
ns_matches && type_matches && method_matches
}
fn description(&self) -> String {
let mut parts = Vec::new();
if let Some(ns) = &self.namespace {
parts.push(format!("namespace={ns}"));
}
if let Some(t) = &self.type_name {
parts.push(format!("type={t}"));
}
if let Some(m) = &self.method_name {
parts.push(format!("method={m}"));
}
if parts.is_empty() {
"any".to_string()
} else {
parts.join(", ")
}
}
}
#[derive(Clone, Debug, Default)]
pub struct InternalMethodMatcher;
impl HookMatcher for InternalMethodMatcher {
fn matches(&self, context: &HookContext<'_>, _thread: &EmulationThread) -> bool {
context.is_internal
}
fn description(&self) -> String {
"internal method only".to_string()
}
}
#[derive(Clone, Debug, Default)]
pub struct SignatureMatcher {
param_types: Option<Vec<CilFlavor>>,
return_type: Option<CilFlavor>,
}
impl SignatureMatcher {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn params(mut self, types: Vec<CilFlavor>) -> Self {
self.param_types = Some(types);
self
}
#[must_use]
pub fn returns(mut self, return_type: CilFlavor) -> Self {
self.return_type = Some(return_type);
self
}
}
impl HookMatcher for SignatureMatcher {
fn matches(&self, context: &HookContext<'_>, _thread: &EmulationThread) -> bool {
if let Some(expected_params) = &self.param_types {
match context.param_types {
Some(actual_params) => {
if expected_params.len() != actual_params.len() {
return false;
}
for (expected, actual) in expected_params.iter().zip(actual_params.iter()) {
if expected != actual {
return false;
}
}
}
None => return false, }
}
if let Some(expected_ret) = &self.return_type {
match &context.return_type {
Some(actual_ret) => {
if expected_ret != actual_ret {
return false;
}
}
None => return false,
}
}
true
}
fn description(&self) -> String {
let mut parts = Vec::new();
if let Some(params) = &self.param_types {
parts.push(format!("params={params:?}"));
}
if let Some(ret) = &self.return_type {
parts.push(format!("returns={ret:?}"));
}
if parts.is_empty() {
"any signature".to_string()
} else {
parts.join(", ")
}
}
}
pub struct RuntimeMatcher {
predicate: Arc<RuntimePredicate>,
description: String,
}
impl RuntimeMatcher {
pub fn new<F>(description: impl Into<String>, predicate: F) -> Self
where
F: Fn(&HookContext<'_>, &EmulationThread) -> bool + Send + Sync + 'static,
{
Self {
predicate: Arc::new(predicate),
description: description.into(),
}
}
}
impl HookMatcher for RuntimeMatcher {
fn matches(&self, context: &HookContext<'_>, thread: &EmulationThread) -> bool {
(self.predicate)(context, thread)
}
fn description(&self) -> String {
self.description.clone()
}
}
#[derive(Clone, Debug, Default)]
pub struct NativeMethodMatcher {
dll_name: Option<String>,
function_name: Option<String>,
}
impl NativeMethodMatcher {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn dll(mut self, dll: impl Into<String>) -> Self {
let dll = dll.into().to_lowercase();
let normalized = dll
.trim_end_matches(".dll")
.trim_end_matches(".DLL")
.to_string();
self.dll_name = Some(normalized);
self
}
#[must_use]
pub fn function(mut self, function: impl Into<String>) -> Self {
self.function_name = Some(function.into());
self
}
#[must_use]
pub fn full(dll: impl Into<String>, function: impl Into<String>) -> Self {
Self::new().dll(dll).function(function)
}
}
impl HookMatcher for NativeMethodMatcher {
fn matches(&self, context: &HookContext<'_>, _thread: &EmulationThread) -> bool {
if !context.is_native {
return false;
}
if let Some(expected_dll) = &self.dll_name {
let actual_dll = context
.dll_name
.map(|d| {
d.to_lowercase()
.trim_end_matches(".dll")
.trim_end_matches(".DLL")
.to_string()
})
.unwrap_or_default();
if expected_dll != &actual_dll {
return false;
}
}
if let Some(expected_fn) = &self.function_name {
if expected_fn != context.method_name {
return false;
}
}
true
}
fn description(&self) -> String {
let mut parts = Vec::new();
parts.push("native".to_string());
if let Some(dll) = &self.dll_name {
parts.push(format!("dll={dll}"));
}
if let Some(func) = &self.function_name {
parts.push(format!("function={func}"));
}
parts.join(", ")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::{token::Token, typesystem::PointerSize};
fn create_test_context<'a>() -> HookContext<'a> {
HookContext::new(
Token::new(0x06000001),
"System",
"String",
"Concat",
PointerSize::Bit64,
)
}
#[test]
fn test_name_matcher_full() {
let matcher = NameMatcher::full("System", "String", "Concat");
assert_eq!(
matcher.description(),
"namespace=System, type=String, method=Concat"
);
}
#[test]
fn test_name_matcher_partial() {
let matcher = NameMatcher::new().method_name("Decrypt");
assert_eq!(matcher.description(), "method=Decrypt");
}
#[test]
fn test_name_matcher_empty() {
let matcher = NameMatcher::new();
assert_eq!(matcher.description(), "any");
}
#[test]
fn test_internal_method_matcher_description() {
let matcher = InternalMethodMatcher;
assert_eq!(matcher.description(), "internal method only");
}
#[test]
fn test_signature_matcher_description() {
let matcher = SignatureMatcher::new()
.params(vec![CilFlavor::I4, CilFlavor::I4])
.returns(CilFlavor::I4);
let desc = matcher.description();
assert!(desc.contains("params="));
assert!(desc.contains("returns="));
}
}