use std::env;
use std::ffi::{OsStr, OsString};
use std::io;
use std::mem;
use std::panic;
use std::process;
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "safe-shared-libraries")]
use findshlibs::{Avma, IterationControl, Segment, SharedLibrary};
use ipc_channel::ipc::{self, IpcReceiver, IpcSender, OpaqueIpcReceiver, OpaqueIpcSender};
use ipc_channel::ErrorKind as IpcErrorKind;
use serde::{Deserialize, Serialize};
use crate::error::PanicInfo;
use crate::panic::{init_panic_hook, reset_panic_info, take_panic, BacktraceCapture};
use crate::serde::with_ipc_mode;
pub const ENV_NAME: &str = "__PROCSPAWN_CONTENT_PROCESS_ID";
static INITIALIZED: AtomicBool = AtomicBool::new(false);
static PASS_ARGS: AtomicBool = AtomicBool::new(false);
#[cfg(not(feature = "safe-shared-libraries"))]
static ALLOW_UNSAFE_SPAWN: AtomicBool = AtomicBool::new(false);
pub unsafe fn assert_spawn_is_safe() {
#[cfg(not(feature = "safe-shared-libraries"))]
{
ALLOW_UNSAFE_SPAWN.store(true, Ordering::SeqCst);
}
}
pub struct ProcConfig {
callback: Option<Box<dyn FnOnce()>>,
panic_handling: bool,
pass_args: bool,
#[cfg(feature = "backtrace")]
capture_backtraces: bool,
#[cfg(feature = "backtrace")]
resolve_backtraces: bool,
}
impl Default for ProcConfig {
fn default() -> ProcConfig {
ProcConfig {
callback: None,
panic_handling: true,
pass_args: true,
#[cfg(feature = "backtrace")]
capture_backtraces: true,
#[cfg(feature = "backtrace")]
resolve_backtraces: true,
}
}
}
pub fn mark_initialized() {
INITIALIZED.store(true, Ordering::SeqCst);
}
pub fn should_pass_args() -> bool {
PASS_ARGS.load(Ordering::SeqCst)
}
fn find_shared_library_offset_by_name(name: &OsStr) -> isize {
#[cfg(feature = "safe-shared-libraries")]
{
let mut result = None;
findshlibs::TargetSharedLibrary::each(|shlib| {
if shlib.name() == name {
result = Some(
shlib
.segments()
.next()
.map_or(0, |x| x.actual_virtual_memory_address(shlib).0 as isize),
);
return IterationControl::Break;
}
IterationControl::Continue
});
match result {
Some(rv) => rv,
None => panic!("Unable to locate shared library {:?} in subprocess", name),
}
}
#[cfg(not(feature = "safe-shared-libraries"))]
{
let _ = name;
init as *const () as isize
}
}
fn find_library_name_and_offset(f: *const u8) -> (OsString, isize) {
#[cfg(feature = "safe-shared-libraries")]
{
let mut result = None;
findshlibs::TargetSharedLibrary::each(|shlib| {
let start = shlib
.segments()
.next()
.map_or(0, |x| x.actual_virtual_memory_address(shlib).0 as isize);
for seg in shlib.segments() {
if seg.contains_avma(shlib, Avma(f as usize)) {
result = Some((shlib.name().to_owned(), start));
return IterationControl::Break;
}
}
IterationControl::Continue
});
result.expect("Unable to locate function pointer in loaded image")
}
#[cfg(not(feature = "safe-shared-libraries"))]
{
let _ = f;
(OsString::new(), init as *const () as isize)
}
}
impl ProcConfig {
pub fn new() -> ProcConfig {
ProcConfig::default()
}
pub fn config_callback<F: FnOnce() + 'static>(&mut self, f: F) -> &mut Self {
self.callback = Some(Box::new(f));
self
}
pub fn pass_args(&mut self, enabled: bool) -> &mut Self {
self.pass_args = enabled;
self
}
pub fn panic_handling(&mut self, enabled: bool) -> &mut Self {
self.panic_handling = enabled;
self
}
#[cfg(feature = "backtrace")]
pub fn capture_backtraces(&mut self, enabled: bool) -> &mut Self {
self.capture_backtraces = enabled;
self
}
#[cfg(feature = "backtrace")]
pub fn resolve_backtraces(&mut self, enabled: bool) -> &mut Self {
self.resolve_backtraces = enabled;
self
}
pub fn init(&mut self) {
mark_initialized();
PASS_ARGS.store(self.pass_args, Ordering::SeqCst);
if let Ok(token) = env::var(ENV_NAME) {
std::env::remove_var(ENV_NAME);
if let Some(callback) = self.callback.take() {
callback();
}
bootstrap_ipc(token, self);
}
}
fn backtrace_capture(&self) -> BacktraceCapture {
#[cfg(feature = "backtrace")]
{
match (self.capture_backtraces, self.resolve_backtraces) {
(false, _) => BacktraceCapture::No,
(true, true) => BacktraceCapture::Resolved,
(true, false) => BacktraceCapture::Unresolved,
}
}
#[cfg(not(feature = "backtrace"))]
{
BacktraceCapture::No
}
}
}
pub fn init() {
ProcConfig::default().init()
}
#[inline]
pub fn assert_spawn_okay() {
if !INITIALIZED.load(Ordering::SeqCst) {
panic!("procspawn was not initialized");
}
#[cfg(not(feature = "safe-shared-libraries"))]
{
if !ALLOW_UNSAFE_SPAWN.load(Ordering::SeqCst) {
panic!(
"spawn() prevented because safe-shared-library feature was \
disabled and assert_no_shared_libraries was not invoked."
);
}
}
}
fn is_benign_bootstrap_error(err: &io::Error) -> bool {
err.kind() == io::ErrorKind::Other && err.to_string() == "Unknown Mach error: 44e"
}
fn bootstrap_ipc(token: String, config: &ProcConfig) {
if config.panic_handling {
init_panic_hook(config.backtrace_capture());
}
{
let connection_bootstrap: IpcSender<IpcSender<MarshalledCall>> =
match IpcSender::connect(token) {
Ok(sender) => sender,
Err(err) => {
if !is_benign_bootstrap_error(&err) {
panic!("could not bootstrap ipc connection: {:?}", err);
}
process::exit(1);
}
};
let (tx, rx) = ipc::channel().unwrap();
connection_bootstrap.send(tx).unwrap();
let marshalled_call = rx.recv().unwrap();
marshalled_call.call(config.panic_handling);
}
process::exit(0);
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MarshalledCall {
pub lib_name: OsString,
pub fn_offset: isize,
pub wrapper_offset: isize,
pub args_receiver: OpaqueIpcReceiver,
pub return_sender: OpaqueIpcSender,
}
impl MarshalledCall {
pub fn marshal<A, R>(
f: fn(A) -> R,
args_receiver: IpcReceiver<A>,
return_sender: IpcSender<Result<R, PanicInfo>>,
) -> MarshalledCall
where
A: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let (lib_name, offset) = find_library_name_and_offset(f as *const () as *const u8);
let init_loc = init as *const () as isize;
let fn_offset = f as *const () as isize - offset;
let wrapper_offset = run_func::<A, R> as *const () as isize - init_loc;
MarshalledCall {
lib_name,
fn_offset,
wrapper_offset,
args_receiver: args_receiver.to_opaque(),
return_sender: return_sender.to_opaque(),
}
}
pub fn call(self, panic_handling: bool) {
unsafe {
let ptr = self.wrapper_offset + init as *const () as isize;
let func: fn(&OsStr, isize, OpaqueIpcReceiver, OpaqueIpcSender, bool) =
mem::transmute(ptr);
func(
&self.lib_name,
self.fn_offset,
self.args_receiver,
self.return_sender,
panic_handling,
);
}
}
}
unsafe fn run_func<A, R>(
lib_name: &OsStr,
fn_offset: isize,
args_recv: OpaqueIpcReceiver,
sender: OpaqueIpcSender,
panic_handling: bool,
) where
A: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let lib_offset = find_shared_library_offset_by_name(lib_name);
let function: fn(A) -> R = mem::transmute(fn_offset + lib_offset as *const () as isize);
let args = with_ipc_mode(|| args_recv.to().recv().unwrap());
let rv = if panic_handling {
reset_panic_info();
match panic::catch_unwind(panic::AssertUnwindSafe(|| function(args))) {
Ok(rv) => Ok(rv),
Err(panic) => Err(take_panic(&*panic)),
}
} else {
Ok(function(args))
};
if let Err(err) = with_ipc_mode(|| sender.to().send(rv)) {
if let IpcErrorKind::Io(ref io) = *err {
if io.kind() == io::ErrorKind::NotFound || io.kind() == io::ErrorKind::ConnectionReset {
}
} else {
panic!("could not send event over ipc channel: {:?}", err);
}
}
}