use interprocess::unnamed_pipe::{Recver, Sender};
use nix::errno::Errno;
use nix::fcntl::OFlag;
use nix::sys::signal::{self, Signal};
use nix::unistd::{fork, pipe2, ForkResult, Pid};
use serde::{Deserialize, Serialize};
use std::io::prelude::*;
use std::path::PathBuf;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use std::{fmt, mem};
use thiserror::Error;
pub mod sys {
pub use nix::sys::signal::{self, Signal};
pub use nix::sys::wait::{waitpid, WaitStatus};
pub use nix::unistd::{fork, getpid, ForkResult, Pid};
}
pub mod ser {
use bincode::{deserialize, serialize, Error};
use serde::{Deserialize, Serialize};
pub fn to_vec<T: Serialize>(value: &T) -> Result<Vec<u8>, Error> {
serialize(value)
}
pub fn from_slice<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, Error> {
let val = deserialize(bytes)?;
Ok(val)
}
}
#[derive(Debug)]
pub struct ProcessWrapper<T> {
child_pid: Pid,
start_time: Option<SystemTime>,
receiver: Option<Recver>,
result: Arc<Mutex<Option<Vec<u8>>>>,
_ghost: std::marker::PhantomData<T>,
}
impl<T> fmt::Display for ProcessWrapper<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Process(pid={})", self.child_pid)
}
}
impl<T> ProcessWrapper<T>
where
T: serde::de::DeserializeOwned,
{
pub fn new(child_pid: Pid, receiver: Recver) -> Self {
Self {
child_pid,
start_time: None,
receiver: Some(receiver),
result: Arc::new(Mutex::new(None)),
_ghost: std::marker::PhantomData,
}
}
pub fn wait(&mut self) -> Result<T, ProcessFunError> {
self.ensure_start_time()?;
if let Some(bytes) = self.result.lock().unwrap().take() {
return ser::from_slice(&bytes).map_err(ProcessFunError::from);
}
let receiver = self.receiver.take().ok_or_else(|| {
ProcessFunError::ProcessError("Process already completed".to_string())
})?;
let mut receiver = receiver;
let result_bytes = read_from_pipe(&mut receiver)?;
let result: T = ser::from_slice(&result_bytes)?;
Ok(result)
}
pub fn timeout(&mut self, duration: Duration) -> Result<T, ProcessFunError> {
self.ensure_start_time()?;
let receiver = self.receiver.take().ok_or_else(|| {
ProcessFunError::ProcessError("Process already completed".to_string())
})?;
let (tx, rx) = mpsc::channel();
let result = self.result.clone();
std::thread::spawn(move || {
let mut receiver = receiver;
if let Ok(bytes) = read_from_pipe(&mut receiver) {
*result.lock().unwrap() = Some(bytes);
let _ = tx.send(true); }
});
match rx.recv_timeout(duration) {
Ok(_) => {
if let Some(bytes) = self.result.lock().unwrap().take() {
return ser::from_slice(&bytes).map_err(ProcessFunError::from);
}
Err(ProcessFunError::ProcessError(
"Process result not found".to_string(),
))
}
Err(_) => {
self.abort()?;
Err(ProcessFunError::TimeoutError)
}
}
}
}
#[inline]
pub fn stat_pid_start(pid: Pid) -> Result<SystemTime, ProcessFunError> {
let proc_path = format!("/proc/{}/stat", pid.as_raw());
nix::sys::stat::stat(proc_path.as_str())
.map_err(|e| ProcessFunError::ProcessError(format!("Failed to stat process: {}", e)))
.and_then(|stat| {
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_secs(stat.st_ctime as u64))
.ok_or_else(|| {
ProcessFunError::ProcessError(
"Failed to calculate process start time".to_string(),
)
})
})
}
impl<T> ProcessWrapper<T> {
#[inline]
fn ensure_start_time(&mut self) -> Result<(), ProcessFunError> {
if self.start_time.is_some() {
return Ok(());
}
if let Some(receiver) = &mut self.receiver {
let start_time = read_start_time_from_pipe(receiver)?;
self.start_time = Some(start_time);
Ok(())
} else {
Err(ProcessFunError::ProcessError(
"Process already completed".to_string(),
))
}
}
#[inline]
fn is_same_process(&mut self) -> bool {
if self.ensure_start_time().is_err() {
return false;
}
if let Some(start_time) = self.start_time {
stat_pid_start(self.child_pid)
.map(|stat| stat == start_time)
.unwrap_or(false)
} else {
false
}
}
#[inline]
fn kill(&mut self) -> Result<(), Errno> {
if self.is_same_process() {
match signal::kill(self.child_pid, Signal::SIGKILL) {
Ok(()) => Ok(()),
Err(Errno::ESRCH) => Ok(()), Err(e) => Err(e),
}
} else {
Ok(()) }
}
pub fn abort(&mut self) -> Result<(), ProcessFunError> {
let _ = self.receiver.take();
self.kill().map_err(|e| {
ProcessFunError::ProcessError(format!("Failed to send SIGKILL to process: {}", e))
})?;
Ok(())
}
}
impl<T> Drop for ProcessWrapper<T> {
fn drop(&mut self) {
let _ = self.receiver.take();
let _ = self.kill();
}
}
#[inline]
pub fn create_pipes() -> Result<(Recver, Sender), ProcessFunError> {
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Creating communication pipes");
let (read_fd, write_fd) = pipe2(OFlag::O_CLOEXEC)
.map_err(|e| ProcessFunError::ProcessError(format!("Failed to create pipe: {}", e)))?;
let recver = Recver::from(read_fd);
let sender = Sender::from(write_fd);
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Pipes created successfully");
Ok((recver, sender))
}
const SYSTEM_TIME_SIZE: usize = mem::size_of::<SystemTime>();
#[inline]
fn system_time_to_bytes_unsafe(time: SystemTime) -> [u8; SYSTEM_TIME_SIZE] {
unsafe { mem::transmute::<SystemTime, [u8; SYSTEM_TIME_SIZE]>(time) }
}
#[inline]
fn bytes_to_system_time_unsafe(bytes: [u8; SYSTEM_TIME_SIZE]) -> SystemTime {
unsafe { mem::transmute::<[u8; SYSTEM_TIME_SIZE], SystemTime>(bytes) }
}
#[inline]
pub fn write_time(fd: &mut Sender, time: SystemTime) -> Result<(), ProcessFunError> {
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Writing start time to pipe");
let time_bytes = system_time_to_bytes_unsafe(time);
fd.write_all(&time_bytes)?;
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Successfully wrote start time to pipe");
Ok(())
}
#[inline]
pub fn write_to_pipe(mut fd: Sender, data: &[u8]) -> Result<(), ProcessFunError> {
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Writing {} bytes to pipe", data.len());
fd.write_all(data)
.map_err(|e| ProcessFunError::ProcessError(format!("Failed to write to pipe: {}", e)))?;
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Successfully wrote data to pipe");
Ok(())
}
#[inline]
pub fn read_start_time_from_pipe(fd: &mut Recver) -> Result<SystemTime, ProcessFunError> {
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Reading start time from pipe");
let mut buffer = [0u8; SYSTEM_TIME_SIZE];
fd.read_exact(&mut buffer)?;
let start_time: SystemTime = bytes_to_system_time_unsafe(buffer);
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Read start time from pipe");
Ok(start_time)
}
#[inline]
pub fn read_from_pipe(fd: &mut Recver) -> Result<Vec<u8>, ProcessFunError> {
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Starting to read from pipe");
let mut buffer = vec![];
#[allow(unused_variables)]
let bytes_read = fd
.read_to_end(&mut buffer)
.map_err(|e| ProcessFunError::ProcessError(format!("Failed to read from pipe: {}", e)))?;
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Read {} bytes from pipe", bytes_read);
Ok(buffer)
}
#[inline]
pub fn fork_process() -> Result<ForkResult, ProcessFunError> {
#[cfg(feature = "debug")]
eprintln!("[process-fun-debug] Forking process");
let result = unsafe {
fork().map_err(|e| ProcessFunError::ProcessError(format!("Failed to fork process: {}", e)))
};
#[cfg(feature = "debug")]
if let Ok(fork_result) = &result {
match fork_result {
ForkResult::Parent { child } => {
eprintln!(
"[process-fun-debug] Fork successful - parent process, child pid: {}",
child
);
}
ForkResult::Child => {
eprintln!("[process-fun-debug] Fork successful - child process");
}
}
}
result
}
pub type FunId = PathBuf;
#[derive(Error, Debug, Serialize, Deserialize)]
pub enum ProcessFunError {
#[error("Multiple #[process] attributes found for function '{fun}'")]
MultipleTags { fun: FunId },
#[error("Expected #[process] attribute only on function with implementation but found '{item_text}'")]
BadItemType { item_text: String },
#[error("Failed to read or write file: {0}")]
IoError(String),
#[error("Failed to parse Rust file: {0}")]
ParseError(String),
#[error("Process communication error: {0}")]
ProcessError(String),
#[error("Failed to serialize or deserialize: {0}")]
SerError(String),
#[error("Process execution timed out")]
TimeoutError,
}
impl From<bincode::Error> for ProcessFunError {
fn from(err: bincode::Error) -> Self {
ProcessFunError::SerError(err.to_string())
}
}
impl From<std::io::Error> for ProcessFunError {
fn from(err: std::io::Error) -> Self {
ProcessFunError::IoError(err.to_string())
}
}
impl From<syn::Error> for ProcessFunError {
fn from(err: syn::Error) -> Self {
ProcessFunError::ParseError(err.to_string())
}
}