#[cfg(unix)]
use std::io;
#[cfg(unix)]
use std::os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd};
#[cfg(unix)]
use nix::unistd::{dup, dup2, pipe};
#[cfg(unix)]
use tokio::sync::mpsc;
#[cfg(unix)]
use crate::bridge::protocol::{ControlResponse, LogSource};
#[cfg(unix)]
const CONTROL_STDIN_FD: i32 = 99;
#[cfg(unix)]
const CONTROL_STDOUT_FD: i32 = 100;
#[cfg(unix)]
const WORKER_STDERR_FD: i32 = 101;
#[cfg(unix)]
pub struct ControlChannelFds {
pub stdin_fd: OwnedFd,
pub stdout_fd: OwnedFd,
}
#[cfg(unix)]
pub fn redirect_fds_for_subprocess_isolation(
setup_log_tx: mpsc::Sender<ControlResponse>,
) -> io::Result<ControlChannelFds> {
tracing::debug!("Preserving control channel to high fds");
let control_stdin = unsafe {
let fd = BorrowedFd::borrow_raw(0);
dup(fd)
}
.map_err(|e| io::Error::other(format!("dup(0) failed: {}", e)))?;
let control_stdout = unsafe {
let fd = BorrowedFd::borrow_raw(1);
dup(fd)
}
.map_err(|e| io::Error::other(format!("dup(1) failed: {}", e)))?;
let worker_stderr = unsafe {
let fd = BorrowedFd::borrow_raw(2);
dup(fd)
}
.map_err(|e| io::Error::other(format!("dup(2) failed: {}", e)))?;
tracing::trace!(
control_stdin = control_stdin.as_raw_fd(),
control_stdout = control_stdout.as_raw_fd(),
worker_stderr = worker_stderr.as_raw_fd(),
"Duped original fds"
);
let mut target_stdin = unsafe { OwnedFd::from_raw_fd(CONTROL_STDIN_FD) };
dup2(&control_stdin, &mut target_stdin)
.map_err(|e| io::Error::other(format!("dup2 stdin failed: {}", e)))?;
std::mem::forget(target_stdin);
let mut target_stdout = unsafe { OwnedFd::from_raw_fd(CONTROL_STDOUT_FD) };
dup2(&control_stdout, &mut target_stdout)
.map_err(|e| io::Error::other(format!("dup2 stdout failed: {}", e)))?;
std::mem::forget(target_stdout);
let mut target_stderr = unsafe { OwnedFd::from_raw_fd(WORKER_STDERR_FD) };
dup2(&worker_stderr, &mut target_stderr)
.map_err(|e| io::Error::other(format!("dup2 stderr failed: {}", e)))?;
std::mem::forget(target_stderr);
tracing::trace!(
stdin_fd = CONTROL_STDIN_FD,
stdout_fd = CONTROL_STDOUT_FD,
stderr_fd = WORKER_STDERR_FD,
"Moved control channel to high fds"
);
drop(control_stdin);
drop(control_stdout);
drop(worker_stderr);
tracing::debug!("Creating capture pipes for stdout/stderr");
let (stdout_read, stdout_write) =
pipe().map_err(|e| io::Error::other(format!("pipe failed: {}", e)))?;
let (stderr_read, stderr_write) =
pipe().map_err(|e| io::Error::other(format!("pipe failed: {}", e)))?;
tracing::trace!(
stdout_read = stdout_read.as_raw_fd(),
stdout_write = stdout_write.as_raw_fd(),
stderr_read = stderr_read.as_raw_fd(),
stderr_write = stderr_write.as_raw_fd(),
"Created capture pipes"
);
let mut target_fd1 = unsafe { OwnedFd::from_raw_fd(1) };
dup2(&stdout_write, &mut target_fd1)
.map_err(|e| io::Error::other(format!("dup2(stdout) failed: {}", e)))?;
std::mem::forget(target_fd1);
let mut target_fd2 = unsafe { OwnedFd::from_raw_fd(2) };
dup2(&stderr_write, &mut target_fd2)
.map_err(|e| io::Error::other(format!("dup2(stderr) failed: {}", e)))?;
std::mem::forget(target_fd2);
tracing::trace!("Replaced fd 1/2 with capture pipes");
drop(stdout_write);
drop(stderr_write);
tracing::debug!("Spawning capture threads");
let stdout_tx = setup_log_tx.clone();
let stdout_read_raw = stdout_read.as_raw_fd();
std::thread::spawn(move || {
let mut file = unsafe { std::fs::File::from_raw_fd(stdout_read_raw) };
let mut buf = [0u8; 4096];
loop {
match std::io::Read::read(&mut file, &mut buf) {
Ok(0) => break,
Ok(n) => {
let data = String::from_utf8_lossy(&buf[..n]).to_string();
if stdout_tx
.blocking_send(ControlResponse::Log {
source: LogSource::Stdout,
data,
})
.is_err()
{
break;
}
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
});
std::mem::forget(stdout_read);
let stderr_tx = setup_log_tx;
let stderr_read_raw = stderr_read.as_raw_fd();
std::thread::spawn(move || {
let mut file = unsafe { std::fs::File::from_raw_fd(stderr_read_raw) };
let mut buf = [0u8; 4096];
loop {
match std::io::Read::read(&mut file, &mut buf) {
Ok(0) => break,
Ok(n) => {
let data = String::from_utf8_lossy(&buf[..n]).to_string();
if stderr_tx
.blocking_send(ControlResponse::Log {
source: LogSource::Stderr,
data,
})
.is_err()
{
break;
}
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
});
std::mem::forget(stderr_read);
tracing::info!("File descriptor redirection complete");
Ok(ControlChannelFds {
stdin_fd: unsafe { OwnedFd::from_raw_fd(CONTROL_STDIN_FD) },
stdout_fd: unsafe { OwnedFd::from_raw_fd(CONTROL_STDOUT_FD) },
})
}
#[cfg(not(unix))]
pub struct ControlChannelFds {
pub stdin_fd: std::io::Stdin,
pub stdout_fd: std::io::Stdout,
}
#[cfg(not(unix))]
pub fn redirect_fds_for_subprocess_isolation(
_setup_log_tx: tokio::sync::mpsc::Sender<crate::bridge::protocol::ControlResponse>,
) -> io::Result<ControlChannelFds> {
Ok(ControlChannelFds {
stdin_fd: std::io::stdin(),
stdout_fd: std::io::stdout(),
})
}