use crate::error::{Error, Result};
use crate::{Function, GenericDetour};
use std::marker::Tuple;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::{mem, ptr};
pub struct StaticDetour<T: Function> {
closure: AtomicPtr<Box<dyn Fn<T::Arguments, Output = T::Output>>>,
detour: AtomicPtr<GenericDetour<T>>,
ffi: T,
}
impl<T: Function> StaticDetour<T> {
#[doc(hidden)]
pub const fn __new(ffi: T) -> Self {
StaticDetour {
closure: AtomicPtr::new(ptr::null_mut()),
detour: AtomicPtr::new(ptr::null_mut()),
ffi,
}
}
pub unsafe fn initialize<D>(&self, target: T, closure: D) -> Result<&Self>
where
D: Fn<T::Arguments, Output = T::Output> + Send + 'static,
<T as Function>::Arguments: Tuple,
{
let mut detour = Box::new(GenericDetour::new(target, self.ffi)?);
if self
.detour
.compare_exchange(
ptr::null_mut(),
&mut *detour,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_err()
{
Err(Error::AlreadyInitialized)?;
}
self.set_detour(closure);
mem::forget(detour);
Ok(self)
}
pub unsafe fn enable(&self) -> Result<()> {
self
.detour
.load(Ordering::SeqCst)
.as_ref()
.ok_or(Error::NotInitialized)?
.enable()
}
pub unsafe fn disable(&self) -> Result<()> {
self
.detour
.load(Ordering::SeqCst)
.as_ref()
.ok_or(Error::NotInitialized)?
.disable()
}
pub fn is_enabled(&self) -> bool {
unsafe { self.detour.load(Ordering::SeqCst).as_ref() }
.map(|detour| detour.is_enabled())
.unwrap_or(false)
}
pub fn set_detour<C>(&self, closure: C)
where
C: Fn<T::Arguments, Output = T::Output> + Send + 'static,
<T as Function>::Arguments: Tuple,
{
let previous = self
.closure
.swap(Box::into_raw(Box::new(Box::new(closure))), Ordering::SeqCst);
if !previous.is_null() {
mem::drop(unsafe { Box::from_raw(previous) });
}
}
pub(crate) fn trampoline(&self) -> Result<&()> {
Ok(
unsafe { self.detour.load(Ordering::SeqCst).as_ref() }
.ok_or(Error::NotInitialized)?
.trampoline(),
)
}
#[doc(hidden)]
pub fn __detour(&self) -> &dyn Fn<T::Arguments, Output = T::Output>
where
<T as Function>::Arguments: Tuple,
{
unsafe { self.closure.load(Ordering::SeqCst).as_ref() }
.ok_or(Error::NotInitialized)
.expect("retrieving detour closure")
}
}
impl<T: Function> Drop for StaticDetour<T> {
fn drop(&mut self) {
let previous = self.closure.swap(ptr::null_mut(), Ordering::Relaxed);
if !previous.is_null() {
mem::drop(unsafe { Box::from_raw(previous) });
}
let previous = self.detour.swap(ptr::null_mut(), Ordering::Relaxed);
if !previous.is_null() {
unsafe { let _ = Box::from_raw(previous); };
}
}
}