use core::fmt;
use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
use parking_lot::Mutex;
use zng_clone_move::{async_clmv, clmv};
use zng_txt::Txt;
use zng_unique_id::IdMap;
use zng_unit::TimeUnits as _;
use crate::{
TaskPanicError,
channel::{self, ChannelError, IpcReceiver, IpcSender, IpcValue, NamedIpcSender},
};
const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
const WORKER_TIMEOUT: &str = "ZNG_TASK_WORKER_TIMEOUT";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub struct Worker<I: IpcValue, O: IpcValue> {
running: Option<(std::thread::JoinHandle<()>, std::process::Child)>,
sender: IpcSender<(RequestId, Request<I>)>,
requests: Arc<Mutex<IdMap<RequestId, channel::Sender<O>>>>,
_p: PhantomData<fn(I) -> O>,
crash: Option<WorkerCrashError>,
}
impl<I: IpcValue, O: IpcValue> Worker<I, O> {
pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
Self::start_impl(worker_name.into(), std::env::current_exe()?, &[], &[]).await
}
pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
Self::start_impl(worker_name.into(), std::env::current_exe()?, env_vars, args).await
}
pub async fn start_other(
worker_name: impl Into<Txt>,
worker_exe: impl Into<PathBuf>,
env_vars: &[(&str, &str)],
args: &[&str],
) -> std::io::Result<Self> {
Self::start_impl(worker_name.into(), worker_exe.into(), env_vars, args).await
}
async fn start_impl(worker_name: Txt, exe: PathBuf, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
let chan_sender = NamedIpcSender::<WorkerInit<I, O>>::new()?;
let mut worker = std::process::Command::new(dunce::canonicalize(exe)?);
for (key, value) in env_vars {
worker.env(key, value);
}
for arg in args {
worker.arg(arg);
}
worker
.env(WORKER_VERSION, crate::process::worker::VERSION)
.env(WORKER_SERVER, chan_sender.name())
.env(WORKER_NAME, worker_name)
.env("RUST_BACKTRACE", "full");
let mut worker = blocking::unblock(move || worker.spawn()).await?;
let timeout = match std::env::var(WORKER_TIMEOUT) {
Ok(t) if !t.is_empty() => match t.parse::<u64>() {
Ok(t) => t.max(1),
Err(e) => {
tracing::error!("invalid {WORKER_TIMEOUT:?} value, {e}");
10
}
},
_ => 10,
};
let (request_sender, mut response_receiver) = match Self::connect_worker(chan_sender, timeout).await {
Ok(r) => r,
Err(ce) => {
let cleanup = blocking::unblock(move || {
worker.kill()?;
worker.wait()
});
match cleanup.await {
Ok(status) => {
let code = status.code().unwrap_or(0);
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("worker process did not connect in {timeout}s\nworker exit code: {code}\nchannel error: {ce}"),
));
}
Err(e) => {
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("worker process did not connect in {timeout}s\ncannot kill worker process, {e}\nchannel error: {ce}"),
));
}
}
}
};
let requests = Arc::new(Mutex::new(IdMap::<RequestId, channel::Sender<O>>::new()));
let receiver = std::thread::Builder::new()
.name("task-ipc-recv".into())
.stack_size(256 * 1024)
.spawn(clmv!(requests, || {
loop {
match response_receiver.recv_blocking() {
Ok((id, r)) => match requests.lock().remove(&id) {
Some(s) => match r {
Response::Out(r) => {
let _ = s.send_blocking(r);
}
},
None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
},
Err(e) => match e {
ChannelError::Disconnected { .. } => {
requests.lock().clear();
break;
}
e => {
tracing::error!("worker response error, will shutdown, {e}");
break;
}
},
}
}
}))
.expect("failed to spawn thread");
Ok(Self {
running: Some((receiver, worker)),
sender: request_sender,
_p: PhantomData,
crash: None,
requests,
})
}
async fn connect_worker(
chan_sender: NamedIpcSender<WorkerInit<I, O>>,
timeout: u64,
) -> Result<(IpcSender<(RequestId, Request<I>)>, IpcReceiver<(RequestId, Response<O>)>), ChannelError> {
let mut chan_sender = chan_sender.connect_deadline(timeout.secs()).await?;
let (request_sender, request_receiver) =
channel::ipc_unbounded::<(RequestId, Request<I>)>().map_err(ChannelError::disconnected_by)?;
let (response_sender, response_receiver) =
channel::ipc_unbounded::<(RequestId, Response<O>)>().map_err(ChannelError::disconnected_by)?;
chan_sender.send_blocking((request_receiver, response_sender))?;
Ok((request_sender, response_receiver))
}
pub async fn shutdown(mut self) -> std::io::Result<()> {
if let Some((receiver, mut process)) = self.running.take() {
while !self.requests.lock().is_empty() {
crate::deadline(100.ms()).await;
}
let r = blocking::unblock(move || process.kill()).await;
match crate::with_deadline(blocking::unblock(move || receiver.join()), 1.secs()).await {
Ok(r) => {
if let Err(p) = r {
tracing::error!(
"worker receiver thread exited panicked, {}",
TaskPanicError::new(p).panic_str().unwrap_or("")
);
}
}
Err(_) => {
if r.is_ok() {
panic!("worker receiver thread did not exit after worker process did");
}
}
}
r
} else {
Ok(())
}
}
pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
self.run_request(Request::Run(input))
}
fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
if self.crash_error().is_some() {
return Box::pin(std::future::ready(Err(RunError::Disconnected)));
}
let id = RequestId::new_unique();
let (sx, rx) = channel::bounded(1);
let requests = self.requests.clone();
requests.lock().insert(id, sx);
let mut sender = self.sender.clone();
let send_r = blocking::unblock(move || sender.send_blocking((id, request)));
Box::pin(async move {
if let Err(e) = send_r.await {
requests.lock().remove(&id);
return Err(RunError::Other(Arc::new(e)));
}
match rx.recv().await {
Ok(r) => Ok(r),
Err(e) => match e {
ChannelError::Disconnected { .. } => {
requests.lock().remove(&id);
Err(RunError::Disconnected)
}
_ => unreachable!(),
},
}
})
}
pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
if let Some((t, _)) = &self.running
&& t.is_finished()
{
let (t, mut p) = self.running.take().unwrap();
if let Err(e) = t.join() {
tracing::error!(
"panic in worker receiver thread, {}",
TaskPanicError::new(e).panic_str().unwrap_or("")
);
}
if let Err(e) = p.kill() {
tracing::error!("error killing worker process after receiver exit, {e}");
}
match p.wait() {
Ok(o) => {
self.crash = Some(WorkerCrashError { status: o });
}
Err(e) => tracing::error!("error reading crashed worker output, {e}"),
}
}
self.crash.as_ref()
}
}
impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
fn drop(&mut self) {
if let Some((receiver, mut process)) = self.running.take() {
if !receiver.is_finished() {
tracing::error!("dropped worker without shutdown");
}
if let Err(e) = process.kill() {
tracing::error!("failed to kill worker process on drop, {e}");
}
}
}
}
pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
where
I: IpcValue,
O: IpcValue,
F: Future<Output = O> + Send + Sync + 'static,
{
let name = worker_name.into();
if let Some(server_name) = run_worker_server(&name) {
zng_env::init_process_name(zng_txt::formatx!("worker-process ({name}, {})", std::process::id()));
let mut chan_recv = IpcReceiver::<WorkerInit<I, O>>::connect(server_name)
.unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
let (mut request_receiver, response_sender) = chan_recv
.recv_blocking()
.unwrap_or_else(|e| panic!("failed to connect initial channels, {e}"));
let handler = Arc::new(handler);
loop {
match request_receiver.recv_blocking() {
Ok((id, input)) => match input {
Request::Run(r) => crate::spawn(async_clmv!(handler, mut response_sender, {
let output = handler(RequestArgs { request: r }).await;
let _ = response_sender.send_blocking((id, Response::Out(output)));
})),
},
Err(e) => match e {
ChannelError::Disconnected { .. } => break,
ChannelError::Timeout => unreachable!(),
},
}
}
zng_env::exit(0);
}
}
fn run_worker_server(worker_name: &str) -> Option<String> {
if let Ok(w_name) = std::env::var(WORKER_NAME)
&& let Ok(version) = std::env::var(WORKER_VERSION)
&& let Ok(server_name) = std::env::var(WORKER_SERVER)
{
if w_name != worker_name {
return None;
}
if version != VERSION {
eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
zng_env::exit(i32::from_le_bytes(*b"vapi"));
}
Some(server_name)
} else {
None
}
}
#[non_exhaustive]
pub struct RequestArgs<I: IpcValue> {
pub request: I,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RunError {
Disconnected,
Other(Arc<dyn std::error::Error + Send + Sync>),
}
impl fmt::Display for RunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RunError::Disconnected => write!(f, "worker process disconnected"),
RunError::Other(e) => write!(f, "run error, {e}"),
}
}
}
impl std::error::Error for RunError {}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct WorkerCrashError {
pub status: std::process::ExitStatus,
}
impl fmt::Display for WorkerCrashError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.status)
}
}
impl std::error::Error for WorkerCrashError {}
#[derive(serde::Serialize, serde::Deserialize)]
enum Request<I> {
Run(I),
}
#[derive(serde::Serialize, serde::Deserialize)]
enum Response<O> {
Out(O),
}
type WorkerInit<I, O> = (
channel::IpcReceiver<(RequestId, Request<I>)>,
channel::IpcSender<(RequestId, Response<O>)>,
);
zng_unique_id::unique_id_64! {
#[derive(serde::Serialize, serde::Deserialize)]
struct RequestId;
}