use std::ffi::OsString;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use std::sync::mpsc;
use std::thread;
use std::time::{Duration, Instant};
#[cfg(unix)]
use std::os::unix::process::CommandExt;
use crate::{CallbackResponse, DispatchEnvelope, FailureClass, PayloadEnvelope, ValidationError};
use super::failure_mapping::{TransportError, validate_receipt_eligible};
use super::plan::RoutingPlan;
use super::seams::CallbackInvoker;
use super::validation::RouteError;
#[derive(Debug, Clone)]
pub struct SubprocessInvokerConfig {
program: PathBuf,
args: Vec<OsString>,
timeout: Duration,
}
impl SubprocessInvokerConfig {
pub fn new(program: impl Into<PathBuf>, timeout: Duration) -> Self {
Self {
program: program.into(),
args: Vec::new(),
timeout,
}
}
pub fn arg(mut self, arg: impl Into<OsString>) -> Self {
self.args.push(arg.into());
self
}
pub fn args<I, A>(mut self, args: I) -> Self
where
I: IntoIterator<Item = A>,
A: Into<OsString>,
{
self.args.extend(args.into_iter().map(Into::into));
self
}
pub fn program(&self) -> &PathBuf {
&self.program
}
pub fn timeout(&self) -> Duration {
self.timeout
}
}
#[derive(Debug)]
pub enum SubprocessInvokerError {
ReceiptEmittedRejected(RouteError),
Spawn(std::io::Error),
WriteRequest(std::io::Error),
SerializeRequest(serde_json::Error),
ReadResponse(std::io::Error),
ParseResponse(serde_json::Error),
InvalidResponse(ValidationError),
NonZeroExit { code: Option<i32>, stderr: String },
Timeout,
}
impl std::fmt::Display for SubprocessInvokerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ReceiptEmittedRejected(e) => {
write!(f, "subprocess invoker refused receipt.emitted plan: {e}")
}
Self::Spawn(e) => write!(f, "failed to spawn callback subprocess: {e}"),
Self::WriteRequest(e) => write!(f, "failed to write request to subprocess stdin: {e}"),
Self::SerializeRequest(e) => write!(f, "failed to serialize CallbackRequest: {e}"),
Self::ReadResponse(e) => write!(f, "failed to read subprocess stdout: {e}"),
Self::ParseResponse(e) => write!(
f,
"subprocess stdout was not a valid JSON CallbackResponse: {e}"
),
Self::InvalidResponse(e) => write!(
f,
"subprocess returned a CallbackResponse that failed validation: {e}"
),
Self::NonZeroExit { code, stderr } => match code {
Some(c) => write!(f, "callback subprocess exited with code {c}: {stderr}"),
None => write!(f, "callback subprocess terminated by signal: {stderr}"),
},
Self::Timeout => f.write_str("callback subprocess exceeded configured timeout"),
}
}
}
impl std::error::Error for SubprocessInvokerError {}
pub fn failure_class_for_subprocess_error(err: &SubprocessInvokerError) -> FailureClass {
use SubprocessInvokerError as E;
match err {
E::ReceiptEmittedRejected(_) => FailureClass::InvalidRequest,
E::Spawn(_) | E::WriteRequest(_) | E::ReadResponse(_) | E::NonZeroExit { .. } => {
FailureClass::TransportError
}
E::SerializeRequest(_) => FailureClass::InternalError,
E::ParseResponse(_) | E::InvalidResponse(_) => FailureClass::InvalidRequest,
E::Timeout => FailureClass::Timeout,
}
}
impl From<&SubprocessInvokerError> for FailureClass {
fn from(err: &SubprocessInvokerError) -> Self {
failure_class_for_subprocess_error(err)
}
}
pub fn transport_error_for(err: &SubprocessInvokerError) -> Option<TransportError> {
use SubprocessInvokerError as E;
match err {
E::Spawn(e) | E::WriteRequest(e) | E::ReadResponse(e) => {
Some(TransportError::Io(e.to_string()))
}
E::NonZeroExit { code, stderr } => Some(TransportError::Io(match code {
Some(c) => format!("exit code {c}: {stderr}"),
None => format!("terminated by signal: {stderr}"),
})),
E::Timeout => Some(TransportError::Timeout),
E::SerializeRequest(e) => Some(TransportError::Internal(e.to_string())),
E::ReceiptEmittedRejected(_) | E::ParseResponse(_) | E::InvalidResponse(_) => None,
}
}
const MAX_STDOUT_BYTES: u64 = 16 * 1024 * 1024;
const MAX_STDERR_BYTES: u64 = 256 * 1024;
#[derive(Debug, Clone)]
pub struct SubprocessCallbackInvoker {
config: SubprocessInvokerConfig,
}
impl SubprocessCallbackInvoker {
pub fn new(config: SubprocessInvokerConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &SubprocessInvokerConfig {
&self.config
}
fn invoke_inner(
&self,
plan: &RoutingPlan,
payloads: &[PayloadEnvelope],
) -> Result<CallbackResponse, SubprocessInvokerError> {
validate_receipt_eligible(plan).map_err(SubprocessInvokerError::ReceiptEmittedRejected)?;
let request = super::callbacks::synthesize_request(plan);
let envelope = DispatchEnvelope::new(request, payloads.to_vec());
let request_bytes =
serde_json::to_vec(&envelope).map_err(SubprocessInvokerError::SerializeRequest)?;
let mut command = Command::new(&self.config.program);
command
.args(&self.config.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
#[cfg(unix)]
{
command.process_group(0);
}
let mut child = command.spawn().map_err(SubprocessInvokerError::Spawn)?;
let stdin = child.stdin.take().expect("stdin piped");
let stdout = child.stdout.take().expect("stdout piped");
let stderr = child.stderr.take().expect("stderr piped");
let (writer_tx, writer_rx) = mpsc::channel::<std::io::Result<()>>();
thread::spawn({
let bytes = request_bytes;
move || {
let mut stdin = stdin;
let result = stdin.write_all(&bytes).and_then(|()| stdin.flush());
let _ = writer_tx.send(result);
}
});
let (stdout_tx, stdout_rx) = mpsc::channel::<std::io::Result<Vec<u8>>>();
thread::spawn(move || {
let mut s = stdout;
let result = read_to_end_limited(&mut s, MAX_STDOUT_BYTES, "stdout");
let _ = stdout_tx.send(result);
});
let (stderr_tx, stderr_rx) = mpsc::channel::<Vec<u8>>();
thread::spawn(move || {
let mut s = stderr;
let buf = read_to_end_truncated(&mut s, MAX_STDERR_BYTES).unwrap_or_default();
let _ = stderr_tx.send(buf);
});
let deadline = Instant::now() + self.config.timeout;
let exit_status = match wait_with_deadline(&mut child, deadline) {
Ok(status) => status,
Err(WaitError::Timeout) => {
terminate_child_tree(&mut child);
return Err(SubprocessInvokerError::Timeout);
}
Err(WaitError::Io(e)) => {
terminate_child_tree(&mut child);
return Err(SubprocessInvokerError::ReadResponse(e));
}
};
let grace = Duration::from_millis(100);
let join_timeout = remaining(deadline).max(grace);
match writer_rx.recv_timeout(join_timeout) {
Ok(Ok(())) => {}
Ok(Err(e)) => {
if e.kind() == std::io::ErrorKind::BrokenPipe && !exit_status.success() {
let stderr_bytes = stderr_rx.recv_timeout(join_timeout).unwrap_or_default();
let stderr_text = String::from_utf8_lossy(&stderr_bytes).into_owned();
return Err(SubprocessInvokerError::NonZeroExit {
code: exit_status.code(),
stderr: stderr_text,
});
}
return Err(SubprocessInvokerError::WriteRequest(e));
}
Err(mpsc::RecvTimeoutError::Timeout) => {
terminate_child_tree(&mut child);
return Err(SubprocessInvokerError::Timeout);
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
return Err(SubprocessInvokerError::WriteRequest(std::io::Error::other(
"writer thread disconnected before reporting result",
)));
}
}
let stdout_bytes = match stdout_rx.recv_timeout(join_timeout) {
Ok(Ok(buf)) => buf,
Ok(Err(e)) => return Err(SubprocessInvokerError::ReadResponse(e)),
Err(mpsc::RecvTimeoutError::Timeout) => {
terminate_child_tree(&mut child);
return Err(SubprocessInvokerError::Timeout);
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
return Err(SubprocessInvokerError::ReadResponse(std::io::Error::other(
"stdout reader thread disconnected before reporting result",
)));
}
};
let stderr_bytes = match stderr_rx.recv_timeout(join_timeout) {
Ok(buf) => buf,
Err(mpsc::RecvTimeoutError::Timeout) => {
terminate_child_tree(&mut child);
Vec::new()
}
Err(mpsc::RecvTimeoutError::Disconnected) => Vec::new(),
};
let stderr_text = String::from_utf8_lossy(&stderr_bytes).into_owned();
if !exit_status.success() {
return Err(SubprocessInvokerError::NonZeroExit {
code: exit_status.code(),
stderr: stderr_text,
});
}
let response: CallbackResponse =
serde_json::from_slice(&stdout_bytes).map_err(SubprocessInvokerError::ParseResponse)?;
response
.validate()
.map_err(SubprocessInvokerError::InvalidResponse)?;
Ok(response)
}
}
impl CallbackInvoker for SubprocessCallbackInvoker {
type Error = SubprocessInvokerError;
fn invoke(
&self,
plan: &RoutingPlan,
payloads: &[PayloadEnvelope],
) -> Result<CallbackResponse, Self::Error> {
self.invoke_inner(plan, payloads)
}
}
fn read_to_end_limited<R: Read>(
reader: &mut R,
max_bytes: u64,
stream_name: &'static str,
) -> std::io::Result<Vec<u8>> {
let mut buf = Vec::new();
reader.take(max_bytes + 1).read_to_end(&mut buf)?;
if buf.len() as u64 > max_bytes {
return Err(std::io::Error::other(format!(
"subprocess {stream_name} exceeded {max_bytes} bytes"
)));
}
Ok(buf)
}
fn read_to_end_truncated<R: Read>(reader: &mut R, max_bytes: u64) -> std::io::Result<Vec<u8>> {
let mut buf = Vec::new();
reader.take(max_bytes + 1).read_to_end(&mut buf)?;
if buf.len() as u64 > max_bytes {
buf.truncate(max_bytes as usize);
buf.extend_from_slice(b"\n[stderr truncated]\n");
}
Ok(buf)
}
fn terminate_child_tree(child: &mut Child) {
#[cfg(unix)]
terminate_process_group(child.id());
let _ = child.kill();
let _ = child.wait();
}
#[cfg(unix)]
fn terminate_process_group(child_pid: u32) {
let pgid = format!("-{child_pid}");
let _ = Command::new("kill")
.args(["-TERM", "--", &pgid])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status();
thread::sleep(Duration::from_millis(20));
let _ = Command::new("kill")
.args(["-KILL", "--", &pgid])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status();
}
enum WaitError {
Timeout,
Io(std::io::Error),
}
fn remaining(deadline: Instant) -> Duration {
deadline.saturating_duration_since(Instant::now())
}
fn wait_with_deadline(
child: &mut Child,
deadline: Instant,
) -> Result<std::process::ExitStatus, WaitError> {
let mut interval = Duration::from_millis(2);
let cap = Duration::from_millis(50);
loop {
match child.try_wait() {
Ok(Some(status)) => return Ok(status),
Ok(None) => {}
Err(e) => return Err(WaitError::Io(e)),
}
let now = Instant::now();
if now >= deadline {
match child.try_wait() {
Ok(Some(status)) => return Ok(status),
Ok(None) => return Err(WaitError::Timeout),
Err(e) => return Err(WaitError::Io(e)),
}
}
let remaining = deadline.saturating_duration_since(now);
thread::sleep(interval.min(remaining));
if interval < cap {
interval = (interval * 2).min(cap);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn failure_class_mapping_is_deterministic() {
let cases: Vec<(SubprocessInvokerError, FailureClass)> = vec![
(
SubprocessInvokerError::ReceiptEmittedRejected(RouteError::InvalidEventEnvelope {
detail: "x".into(),
}),
FailureClass::InvalidRequest,
),
(
SubprocessInvokerError::Spawn(std::io::Error::other("nope")),
FailureClass::TransportError,
),
(
SubprocessInvokerError::WriteRequest(std::io::Error::other("epipe")),
FailureClass::TransportError,
),
(
SubprocessInvokerError::ReadResponse(std::io::Error::other("eof")),
FailureClass::TransportError,
),
(
SubprocessInvokerError::NonZeroExit {
code: Some(1),
stderr: "bang".into(),
},
FailureClass::TransportError,
),
(SubprocessInvokerError::Timeout, FailureClass::Timeout),
(
SubprocessInvokerError::ParseResponse(
serde_json::from_str::<serde_json::Value>("not json").unwrap_err(),
),
FailureClass::InvalidRequest,
),
];
for (err, expected) in cases {
let fc = failure_class_for_subprocess_error(&err);
assert_eq!(fc, expected, "subprocess err -> failure class: {err}");
let via_from: FailureClass = (&err).into();
assert_eq!(via_from, fc);
}
}
#[test]
fn transport_error_for_distinguishes_retryable_shapes() {
assert!(matches!(
transport_error_for(&SubprocessInvokerError::Timeout),
Some(TransportError::Timeout)
));
assert!(matches!(
transport_error_for(&SubprocessInvokerError::Spawn(std::io::Error::other("x"))),
Some(TransportError::Io(_))
));
assert!(
transport_error_for(&SubprocessInvokerError::ReceiptEmittedRejected(
RouteError::InvalidEventEnvelope { detail: "x".into() }
))
.is_none()
);
}
#[test]
fn read_to_end_limited_allows_exact_limit_and_rejects_overflow() {
let mut exact = Cursor::new(b"abcd".to_vec());
assert_eq!(
read_to_end_limited(&mut exact, 4, "stdout").unwrap(),
b"abcd"
);
let mut over = Cursor::new(b"abcde".to_vec());
let err = read_to_end_limited(&mut over, 4, "stdout").unwrap_err();
assert!(
err.to_string()
.contains("subprocess stdout exceeded 4 bytes")
);
}
#[test]
fn read_to_end_truncated_marks_only_over_limit_stderr() {
let mut exact = Cursor::new(b"abcd".to_vec());
assert_eq!(read_to_end_truncated(&mut exact, 4).unwrap(), b"abcd");
let mut over = Cursor::new(b"abcde".to_vec());
let mut expected = b"abcd".to_vec();
expected.extend_from_slice(b"\n[stderr truncated]\n");
assert_eq!(read_to_end_truncated(&mut over, 4).unwrap(), expected);
}
#[test]
fn remaining_reports_future_duration_and_saturates_past_deadline() {
let future = Instant::now() + Duration::from_secs(60);
assert!(remaining(future) > Duration::from_secs(59));
let past = Instant::now() - Duration::from_millis(1);
assert_eq!(remaining(past), Duration::ZERO);
}
#[test]
fn config_is_chainable() {
let cfg = SubprocessInvokerConfig::new("/bin/cat", Duration::from_secs(1))
.arg("-")
.args(["--flag"]);
assert_eq!(cfg.program(), &PathBuf::from("/bin/cat"));
assert_eq!(cfg.timeout(), Duration::from_secs(1));
}
}