#[allow(unused_imports)]
use crate::injector_core::common::*;
use crate::injector_core::internal::*;
pub use crate::interface::func_ptr::FuncPtr;
pub use crate::interface::macros::__assert_future_output;
pub use crate::interface::macros::__type_id_of_val;
pub use crate::interface::verifier::CallCountVerifier;
use std::future::Future;
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::sync::Mutex;
use std::sync::MutexGuard;
use std::sync::RwLock;
use std::sync::RwLockReadGuard;
use std::sync::RwLockWriteGuard;
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
use crate::injector_core::thread_local_registry::ThreadRegistration;
fn normalize_signature(sig: &str) -> String {
sig.replace("&'_ ", "&")
}
#[allow(dead_code)]
struct NoPoisonMutex<T> {
inner: Mutex<T>,
}
#[allow(dead_code)]
impl<T> NoPoisonMutex<T> {
const fn new(value: T) -> Self {
Self {
inner: Mutex::new(value),
}
}
fn lock(&self) -> MutexGuard<'_, T> {
match self.inner.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
static LOCK_FUNCTION: NoPoisonMutex<()> = NoPoisonMutex::new(());
static GLOBAL_FAKE_LOCK: RwLock<()> = RwLock::new(());
pub struct InjectorPP {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
registrations: Vec<ThreadRegistration>,
guards: Vec<PatchGuard>,
verifiers: Vec<CallCountVerifier>,
_rw_guard: RwGuard,
use_global: bool,
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
_not_send: PhantomData<*const ()>,
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
_lock: MutexGuard<'static, ()>,
}
#[allow(dead_code)]
enum RwGuard {
None,
Read(RwLockReadGuard<'static, ()>),
Write(RwLockWriteGuard<'static, ()>),
}
impl InjectorPP {
pub fn new() -> Self {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let rw_guard = match GLOBAL_FAKE_LOCK.read() {
Ok(g) => g,
Err(e) => e.into_inner(),
};
Self {
registrations: Vec::new(),
guards: Vec::new(),
verifiers: Vec::new(),
_rw_guard: RwGuard::Read(rw_guard),
use_global: false,
_not_send: PhantomData,
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let lock = LOCK_FUNCTION.lock();
let rw_guard = match GLOBAL_FAKE_LOCK.read() {
Ok(g) => g,
Err(e) => e.into_inner(),
};
Self {
guards: Vec::new(),
verifiers: Vec::new(),
_rw_guard: RwGuard::Read(rw_guard),
use_global: false,
_lock: lock,
}
}
}
pub fn new_global() -> Self {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let rw_guard = match GLOBAL_FAKE_LOCK.write() {
Ok(g) => g,
Err(e) => e.into_inner(),
};
Self {
registrations: Vec::new(),
guards: Vec::new(),
verifiers: Vec::new(),
_rw_guard: RwGuard::Write(rw_guard),
use_global: true,
_not_send: PhantomData,
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let lock = LOCK_FUNCTION.lock();
let rw_guard = match GLOBAL_FAKE_LOCK.write() {
Ok(g) => g,
Err(e) => e.into_inner(),
};
Self {
guards: Vec::new(),
verifiers: Vec::new(),
_rw_guard: RwGuard::Write(rw_guard),
use_global: true,
_lock: lock,
}
}
}
pub fn prevent() -> Preventer {
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let lock = LOCK_FUNCTION.lock();
Preventer { _lock: lock }
}
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
Preventer {
_not_send: PhantomData,
}
}
}
pub fn when_called(&mut self, func: FuncPtr) -> WhenCalledBuilder<'_> {
let when = WhenCalled::new(func.func_ptr_internal);
WhenCalledBuilder {
lib: self,
when,
expected_signature: func.signature,
expected_type_id: func.type_id,
}
}
pub unsafe fn when_called_unchecked(&mut self, func: FuncPtr) -> WhenCalledBuilder<'_> {
let when = WhenCalled::new(func.func_ptr_internal);
WhenCalledBuilder {
lib: self,
when,
expected_signature: "",
expected_type_id: None,
}
}
pub fn when_called_async<F, T>(
&mut self,
fake_pair: (Pin<&mut F>, &'static str),
) -> WhenCalledBuilderAsync<'_>
where
F: Future<Output = T>,
{
let poll_fn: fn(Pin<&mut F>, &mut Context<'_>) -> Poll<T> = <F as Future>::poll;
let when = WhenCalled::new(unsafe {
FuncPtr::new(poll_fn as *const (), std::any::type_name_of_val(&poll_fn))
}.func_ptr_internal);
let signature = fake_pair.1;
WhenCalledBuilderAsync {
lib: self,
when,
expected_signature: signature,
expected_type_id: None,
}
}
pub unsafe fn when_called_async_unchecked<F, T>(
&mut self,
_: Pin<&mut F>,
) -> WhenCalledBuilderAsync<'_>
where
F: Future<Output = T>,
{
let poll_fn: fn(Pin<&mut F>, &mut Context<'_>) -> Poll<T> = <F as Future>::poll;
let when = WhenCalled::new(unsafe {
FuncPtr::new(poll_fn as *const (), std::any::type_name_of_val(&poll_fn))
}.func_ptr_internal);
WhenCalledBuilderAsync {
lib: self,
when,
expected_signature: "",
expected_type_id: None,
}
}
}
impl Default for InjectorPP {
fn default() -> Self {
Self::new()
}
}
pub struct Preventer {
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
_lock: MutexGuard<'static, ()>,
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
_not_send: PhantomData<*const ()>,
}
impl Preventer {
pub fn is_active(&self) -> bool {
true
}
}
pub struct WhenCalledBuilder<'a> {
lib: &'a mut InjectorPP,
when: WhenCalled,
expected_signature: &'static str,
expected_type_id: Option<std::any::TypeId>,
}
impl WhenCalledBuilder<'_> {
pub fn will_execute_raw(self, target: FuncPtr) {
match (self.expected_type_id, target.type_id) {
(Some(expected), Some(actual)) if expected != actual => {
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
(None, _) | (_, None) => {
if normalize_signature(target.signature)
!= normalize_signature(self.expected_signature)
{
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
}
_ => {}
}
if self.lib.use_global {
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
} else {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let reg = self.when.will_execute_thread_local(target.func_ptr_internal);
self.lib.registrations.push(reg);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
}
}
}
pub unsafe fn will_execute_raw_unchecked(self, target: FuncPtr) {
if self.lib.use_global {
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
} else {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let reg = self.when.will_execute_thread_local(target.func_ptr_internal);
self.lib.registrations.push(reg);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
}
}
}
pub fn will_execute(self, fake_pair: (FuncPtr, CallCountVerifier)) {
let (fake_func, verifier) = fake_pair;
self.lib.verifiers.push(verifier);
self.will_execute_raw(fake_func);
}
pub fn will_return_boolean(self, value: bool) {
if !self.expected_signature.trim().ends_with("-> bool") {
panic!(
"Signature mismatch: will_return_boolean requires a function returning bool but got {}",
self.expected_signature
);
}
if self.lib.use_global {
let guard = self.when.will_return_boolean_guard(value);
self.lib.guards.push(guard);
} else {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let reg = self.when.will_return_boolean_thread_local(value);
self.lib.registrations.push(reg);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let guard = self.when.will_return_boolean_guard(value);
self.lib.guards.push(guard);
}
}
}
}
pub struct WhenCalledBuilderAsync<'a> {
lib: &'a mut InjectorPP,
when: WhenCalled,
expected_signature: &'static str,
expected_type_id: Option<std::any::TypeId>,
}
impl WhenCalledBuilderAsync<'_> {
pub fn will_return_async(self, target: FuncPtr) {
match (self.expected_type_id, target.type_id) {
(Some(expected), Some(actual)) if expected != actual => {
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
(None, _) | (_, None) => {
if normalize_signature(target.signature)
!= normalize_signature(self.expected_signature)
{
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
}
_ => {}
}
if self.lib.use_global {
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
} else {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let reg = self.when.will_execute_thread_local(target.func_ptr_internal);
self.lib.registrations.push(reg);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
}
}
}
pub unsafe fn will_return_async_unchecked(self, target: FuncPtr) {
if self.lib.use_global {
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
} else {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
{
let reg = self.when.will_execute_thread_local(target.func_ptr_internal);
self.lib.registrations.push(reg);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm")))]
{
let guard = self.when.will_execute_guard(target.func_ptr_internal);
self.lib.guards.push(guard);
}
}
}
}