use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
use crate::{
error::LoadProcedureError,
function::{FunctionPtr, RawFunctionPtr},
process::{
memory::{ProcessMemoryBuffer, RemoteBoxAllocator},
BorrowedProcess, BorrowedProcessModule, ModuleHandle,
},
rpc::{error::PayloadRpcError, RemoteRawProcedure, Truncate},
utils::ArrayOrVecBuf,
ArgAndResultBufInfo, Syringe,
};
#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "rpc-payload")))]
impl Syringe {
#[allow(rustdoc::broken_intra_doc_links)]
pub unsafe fn get_payload_procedure<F: PayloadRpcFunctionPtr>(
&self,
module: BorrowedProcessModule<'_>,
name: &str,
) -> Result<Option<RemotePayloadProcedure<F>>, LoadProcedureError> {
match self.get_procedure_address(module, name) {
Ok(Some(procedure)) => Ok(Some(RemotePayloadProcedure::new(
unsafe { RealPayloadRpcFunctionPtr::from_ptr(procedure) },
self.remote_allocator.clone(),
module.handle(),
))),
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}
}
#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "rpc-payload")))]
pub trait PayloadRpcFunctionPtr: FunctionPtr {}
type RealPayloadRpcFunctionPtr = extern "system" fn(Truncate<*mut ArgAndResultBufInfo>);
#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "rpc-payload")))]
#[derive(Debug)]
pub struct RemotePayloadProcedure<F> {
f: RemoteRawProcedure<RealPayloadRpcFunctionPtr>,
phantom: PhantomData<fn() -> F>,
}
impl<F> RemotePayloadProcedure<F>
where
F: FunctionPtr,
{
pub(crate) fn new(
ptr: RealPayloadRpcFunctionPtr,
remote_allocator: RemoteBoxAllocator,
module_handle: ModuleHandle,
) -> Self {
Self {
f: RemoteRawProcedure::new(ptr, remote_allocator, module_handle),
phantom: PhantomData,
}
}
#[must_use]
pub fn process(&self) -> BorrowedProcess<'_> {
self.f.process()
}
#[must_use]
pub fn as_raw_ptr(&self) -> RawFunctionPtr {
self.f.as_raw_ptr()
}
}
impl<F> RemotePayloadProcedure<F>
where
F: PayloadRpcFunctionPtr,
for<'r> F::RefArgs<'r>: Serialize,
F::Output: DeserializeOwned,
{
fn call_with_args(&self, args: F::RefArgs<'_>) -> Result<F::Output, PayloadRpcError> {
let arg_bytes = bincode::serialized_size(&args)? as usize;
let mut local_arg_buf = ArrayOrVecBuf::<_, 512>::with_capacity(arg_bytes);
bincode::serialize_into(local_arg_buf.spare_writer(), &args)?;
unsafe { local_arg_buf.set_len(arg_bytes) };
let remote_arg_buf = self.f.remote_allocator.alloc_raw(local_arg_buf.len())?;
remote_arg_buf.write_bytes(&local_arg_buf)?;
let parameter_buf = self
.f
.remote_allocator
.alloc_and_copy(&ArgAndResultBufInfo {
data: remote_arg_buf.as_ptr().as_ptr() as u64,
len: remote_arg_buf.len() as u64,
is_error: false,
})?;
self.f.call(Truncate(parameter_buf.as_ptr().as_ptr()))?;
let result_buf_info = parameter_buf.read()?;
let mut local_result_buf = local_arg_buf;
local_result_buf.clear();
let result_buf_len = result_buf_info.len as usize;
local_result_buf.ensure_capacity(result_buf_len);
unsafe { local_result_buf.set_len(result_buf_len) };
if result_buf_info.data == remote_arg_buf.as_ptr().as_ptr() as u64 {
remote_arg_buf.read_bytes(&mut local_result_buf)?;
} else {
let result_memory = unsafe {
ProcessMemoryBuffer::from_raw_parts(
result_buf_info.data as *mut u8,
result_buf_info.len as usize,
self.process(),
)
};
result_memory.read(0, &mut local_result_buf)?;
};
if result_buf_info.is_error {
Err(PayloadRpcError::RemoteProcedure(unsafe {
String::from_utf8_unchecked(local_result_buf.into_vec())
}))
} else {
Ok(bincode::deserialize(&local_result_buf)?)
}
}
}
macro_rules! impl_call {
(@recurse () ($($nm:ident : $ty:ident),*)) => {
impl_call!(@impl_all ($($nm : $ty),*));
};
(@recurse ($hd_nm:ident : $hd_ty:ident $(, $tl_nm:ident : $tl_ty:ident)*) ($($nm:ident : $ty:ident),*)) => {
impl_call!(@impl_all ($($nm : $ty),*));
impl_call!(@recurse ($($tl_nm : $tl_ty),*) ($($nm : $ty,)* $hd_nm : $hd_ty));
};
(@impl_all ($($nm:ident : $ty:ident),*)) => {
impl <$($ty,)* Output> PayloadRpcFunctionPtr for fn($($ty),*) -> Output where $($ty : 'static + Serialize,)* Output: 'static + DeserializeOwned { }
impl <$($ty,)* Output> PayloadRpcFunctionPtr for unsafe fn($($ty),*) -> Output where $($ty : 'static + Serialize,)* Output: 'static + DeserializeOwned { }
impl <$($ty,)* Output> RemotePayloadProcedure<fn($($ty),*) -> Output> where $($ty: 'static + Serialize,)* Output: 'static + DeserializeOwned, {
#[allow(clippy::too_many_arguments)]
pub fn call(&self, $($nm: &$ty),*) -> Result<Output, PayloadRpcError> {
self.call_with_args(($($nm,)*))
}
}
impl <$($ty,)* Output> RemotePayloadProcedure<unsafe fn($($ty),*) -> Output> where $($ty: 'static + Serialize,)* Output: 'static + DeserializeOwned, {
#[allow(clippy::too_many_arguments)]
pub unsafe fn call(&self, $($nm: &$ty),*) -> Result<Output, PayloadRpcError> {
self.call_with_args(($($nm,)*))
}
}
};
(@count ()) => {
0
};
(@count ($hd:tt $($tl:tt)*)) => {
1 + impl_call!(@count ($($tl)*))
};
($($nm:ident : $ty:ident),*) => {
impl_call!(@recurse ($($nm : $ty),*) ());
};
}
impl_call! {
arg0: A, arg1: B, arg2: C, arg3: D, arg4: E, arg5: F,arg6: G,
arg7: H, arg8: I, arg9: J, arg10: K, arg11: L
}