use std::os::fd::RawFd;
use nix::fcntl::{OFlag, open};
use nix::sys::stat::Mode;
use crate::env::ShellEnv;
use crate::expand::expand_word_to_string;
use crate::parser::ast::{Redirect, RedirectKind};
fn raw_dup2(oldfd: RawFd, newfd: RawFd) -> nix::Result<()> {
let res = unsafe { libc::dup2(oldfd, newfd) };
if res == -1 {
Err(nix::errno::Errno::last())
} else {
Ok(())
}
}
#[derive(Default)]
pub struct RedirectState {
saved_fds: Vec<(RawFd, RawFd)>,
}
impl RedirectState {
pub fn new() -> Self {
Self::default()
}
pub fn apply(
&mut self,
redirects: &[Redirect],
env: &mut ShellEnv,
save: bool,
) -> Result<(), String> {
for redirect in redirects {
if let Err(e) = self.apply_one(redirect, env, save) {
self.restore();
return Err(e);
}
}
Ok(())
}
fn apply_one(
&mut self,
redirect: &Redirect,
env: &mut ShellEnv,
save: bool,
) -> Result<(), String> {
match &redirect.kind {
RedirectKind::Input(word) => {
let target_fd = redirect.fd.unwrap_or(0);
let path = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
let fd = open(path.as_str(), OFlag::O_RDONLY, Mode::empty())
.map_err(|e| format!("{}: {}", path, e))?
.into_raw_fd();
if save {
self.save_fd(target_fd)?;
}
if fd != target_fd {
raw_dup2(fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
unsafe { libc::close(fd) };
}
}
RedirectKind::Output(word) => {
let target_fd = redirect.fd.unwrap_or(1);
let path = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
if env.mode.options.noclobber && std::path::Path::new(&path).exists() {
return Err(format!("{}: cannot overwrite existing file", path));
}
let flags = OFlag::O_WRONLY | OFlag::O_CREAT | OFlag::O_TRUNC;
let fd = open(path.as_str(), flags, Mode::from_bits_truncate(0o644))
.map_err(|e| format!("{}: {}", path, e))?
.into_raw_fd();
if save {
self.save_fd(target_fd)?;
}
if fd != target_fd {
raw_dup2(fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
unsafe { libc::close(fd) };
}
}
RedirectKind::OutputClobber(word) => {
let target_fd = redirect.fd.unwrap_or(1);
let path = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
let flags = OFlag::O_WRONLY | OFlag::O_CREAT | OFlag::O_TRUNC;
let fd = open(path.as_str(), flags, Mode::from_bits_truncate(0o644))
.map_err(|e| format!("{}: {}", path, e))?
.into_raw_fd();
if save {
self.save_fd(target_fd)?;
}
if fd != target_fd {
raw_dup2(fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
unsafe { libc::close(fd) };
}
}
RedirectKind::Append(word) => {
let target_fd = redirect.fd.unwrap_or(1);
let path = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
let flags = OFlag::O_WRONLY | OFlag::O_CREAT | OFlag::O_APPEND;
let fd = open(path.as_str(), flags, Mode::from_bits_truncate(0o644))
.map_err(|e| format!("{}: {}", path, e))?
.into_raw_fd();
if save {
self.save_fd(target_fd)?;
}
if fd != target_fd {
raw_dup2(fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
unsafe { libc::close(fd) };
}
}
RedirectKind::DupOutput(word) => {
let target_fd = redirect.fd.unwrap_or(1);
let src = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
if src == "-" {
if save {
self.save_fd(target_fd)?;
}
unsafe { libc::close(target_fd) };
} else {
let src_fd: RawFd = src
.parse()
.map_err(|_| format!("{}: invalid file descriptor", src))?;
if src_fd != target_fd {
if save {
self.save_fd(target_fd)?;
}
raw_dup2(src_fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
}
}
}
RedirectKind::DupInput(word) => {
let target_fd = redirect.fd.unwrap_or(0);
let src = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
if src == "-" {
if save {
self.save_fd(target_fd)?;
}
unsafe { libc::close(target_fd) };
} else {
let src_fd: RawFd = src
.parse()
.map_err(|_| format!("{}: invalid file descriptor", src))?;
if src_fd != target_fd {
if save {
self.save_fd(target_fd)?;
}
raw_dup2(src_fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
}
}
}
RedirectKind::ReadWrite(word) => {
let target_fd = redirect.fd.unwrap_or(0);
let path = expand_word_to_string(env, word).map_err(|e| e.to_string())?;
let flags = OFlag::O_RDWR | OFlag::O_CREAT;
let fd = open(path.as_str(), flags, Mode::from_bits_truncate(0o644))
.map_err(|e| format!("{}: {}", path, e))?
.into_raw_fd();
if save {
self.save_fd(target_fd)?;
}
if fd != target_fd {
raw_dup2(fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
unsafe { libc::close(fd) };
}
}
RedirectKind::HereDoc(heredoc) => {
let target_fd = redirect.fd.unwrap_or(0);
let body = crate::expand::expand_heredoc_body(env, &heredoc.body, heredoc.quoted);
let mut fds: [RawFd; 2] = [0; 2];
if unsafe { libc::pipe(fds.as_mut_ptr()) } == -1 {
return Err(format!("pipe: {}", std::io::Error::last_os_error()));
}
let (read_fd, write_fd) = (fds[0], fds[1]);
{
use std::io::Write;
use std::os::unix::io::FromRawFd;
let mut write_file = unsafe { std::fs::File::from_raw_fd(write_fd) };
let _ = write_file.write_all(body.as_bytes());
}
if save {
self.save_fd(target_fd)?;
}
if read_fd != target_fd {
raw_dup2(read_fd, target_fd).map_err(|e| format!("dup2: {}", e))?;
unsafe { libc::close(read_fd) };
}
}
}
Ok(())
}
fn save_fd(&mut self, fd: RawFd) -> Result<(), String> {
let saved = unsafe { libc::fcntl(fd, libc::F_DUPFD_CLOEXEC, 10) };
if saved == -1 {
return Err(format!("dup: {}", std::io::Error::last_os_error()));
}
self.saved_fds.push((fd, saved));
Ok(())
}
pub fn restore(&mut self) {
for (original, saved) in self.saved_fds.drain(..).rev() {
raw_dup2(saved, original).ok();
unsafe { libc::close(saved) };
}
}
}
impl Drop for RedirectState {
fn drop(&mut self) {
for (_original, saved) in self.saved_fds.drain(..) {
unsafe { libc::close(saved) };
}
}
}
use std::os::unix::io::IntoRawFd;
#[cfg(test)]
mod tests {
use super::*;
use crate::env::ShellEnv;
use crate::parser::ast::{Redirect, RedirectKind, Word};
fn make_env() -> ShellEnv {
ShellEnv::new("yosh", vec![])
}
static FD_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn test_redirect_output_and_restore() {
let _guard = FD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let mut env = make_env();
let tmp = std::env::temp_dir().join("yosh_redirect_test_output.txt");
let path_str = tmp.to_str().unwrap().to_string();
let redirects = vec![Redirect {
fd: Some(1),
kind: RedirectKind::Output(Word::literal(&path_str)),
}];
let mut state = RedirectState::new();
state
.apply(&redirects, &mut env, true)
.expect("apply should succeed");
use std::io::Write;
use std::os::unix::io::FromRawFd;
let mut stdout = unsafe { std::fs::File::from_raw_fd(1) };
write!(stdout, "hello redirect").unwrap();
std::mem::forget(stdout);
state.restore();
let contents = std::fs::read_to_string(&tmp).unwrap_or_default();
assert!(
contents.contains("hello redirect"),
"file should contain written text"
);
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn test_redirect_input() {
let _guard = FD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
use std::io::Read;
use std::os::unix::io::FromRawFd;
let mut env = make_env();
let tmp = std::env::temp_dir().join("yosh_redirect_test_input.txt");
std::fs::write(&tmp, "test input\n").unwrap();
let path_str = tmp.to_str().unwrap().to_string();
let redirects = vec![Redirect {
fd: Some(0),
kind: RedirectKind::Input(Word::literal(&path_str)),
}];
let mut state = RedirectState::new();
state
.apply(&redirects, &mut env, true)
.expect("apply should succeed");
let mut buf = String::new();
let mut stdin = unsafe { std::fs::File::from_raw_fd(0) };
stdin.read_to_string(&mut buf).ok();
std::mem::forget(stdin);
state.restore();
assert!(buf.contains("test input"));
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn test_apply_rolls_back_on_second_redirect_failure() {
let _guard = FD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let mut env = make_env();
let tmp_ok = std::env::temp_dir().join("yosh_apply_rollback_ok.txt");
let _ = std::fs::remove_file(&tmp_ok);
let bad_path = "/no/such/dir/should-not-exist-yosh-test/file.txt";
let redirects = vec![
Redirect {
fd: Some(1),
kind: RedirectKind::Output(Word::literal(tmp_ok.to_str().unwrap())),
},
Redirect {
fd: Some(2),
kind: RedirectKind::Output(Word::literal(bad_path)),
},
];
let orig_stdout = unsafe { libc::dup(1) };
assert!(orig_stdout >= 0, "dup(1) failed");
let mut state = RedirectState::new();
let result = state.apply(&redirects, &mut env, true);
assert!(result.is_err(), "expected apply to fail on the bad path");
assert!(
state.saved_fds.is_empty(),
"saved_fds should be empty after rollback, got {} entries",
state.saved_fds.len()
);
let marker = b"post-rollback-marker\n";
unsafe {
libc::write(1, marker.as_ptr() as *const _, marker.len());
}
let written = std::fs::read_to_string(&tmp_ok).unwrap_or_default();
unsafe {
libc::dup2(orig_stdout, 1);
libc::close(orig_stdout);
}
let _ = std::fs::remove_file(&tmp_ok);
assert!(
!written.contains("post-rollback-marker"),
"fd 1 should not still point at tmp_ok after rollback; tmp_ok contained: {written:?}"
);
}
}