use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::process::Stdio;
#[derive(Debug)]
pub enum FileDescriptor {
File(File),
Duplicate(RawFd),
Closed,
}
impl FileDescriptor {
pub fn try_clone(&self) -> Result<Self, String> {
match self {
FileDescriptor::File(f) => {
let new_file = f
.try_clone()
.map_err(|e| format!("Failed to clone file: {}", e))?;
Ok(FileDescriptor::File(new_file))
}
FileDescriptor::Duplicate(fd) => Ok(FileDescriptor::Duplicate(*fd)),
FileDescriptor::Closed => Ok(FileDescriptor::Closed),
}
}
}
#[derive(Debug)]
pub struct FileDescriptorTable {
fds: HashMap<i32, FileDescriptor>,
saved_fds: HashMap<i32, RawFd>,
}
impl FileDescriptorTable {
pub fn new() -> Self {
Self {
fds: HashMap::new(),
saved_fds: HashMap::new(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn open_fd(
&mut self,
fd_num: i32,
path: &str,
read: bool,
write: bool,
append: bool,
truncate: bool,
create_new: bool,
) -> Result<(), String> {
let mut opts = OpenOptions::new();
if create_new {
opts.create_new(true); } else if truncate {
opts.create(true).truncate(true);
}
if !(0..=1024).contains(&fd_num) {
return Err(format!("Invalid file descriptor number: {}", fd_num));
}
let file = OpenOptions::new()
.read(read)
.write(write)
.append(append)
.truncate(truncate)
.create(write || append)
.open(path)
.map_err(|e| format!("Cannot open {}: {}", path, e))?;
self.fds.insert(fd_num, FileDescriptor::File(file));
Ok(())
}
pub fn duplicate_fd(&mut self, source_fd: i32, target_fd: i32) -> Result<(), String> {
if !(0..=1024).contains(&source_fd) {
return Err(format!("Invalid source file descriptor: {}", source_fd));
}
if !(0..=1024).contains(&target_fd) {
return Err(format!("Invalid target file descriptor: {}", target_fd));
}
if source_fd == target_fd {
return Ok(());
}
let raw_fd = match self.get_raw_fd(source_fd) {
Some(fd) => fd,
None => {
return Err(format!(
"File descriptor {} is not open or is closed",
source_fd
));
}
};
self.fds
.insert(target_fd, FileDescriptor::Duplicate(raw_fd));
Ok(())
}
pub fn close_fd(&mut self, fd_num: i32) -> Result<(), String> {
if !(0..=1024).contains(&fd_num) {
return Err(format!("Invalid file descriptor number: {}", fd_num));
}
self.fds.insert(fd_num, FileDescriptor::Closed);
Ok(())
}
pub fn save_fd(&mut self, fd_num: i32) -> Result<(), String> {
if !(0..=1024).contains(&fd_num) {
return Err(format!("Invalid file descriptor number: {}", fd_num));
}
let saved_fd = unsafe {
let raw_fd = fd_num as RawFd;
libc::dup(raw_fd)
};
if saved_fd < 0 {
return Err(format!("Failed to save file descriptor {}", fd_num));
}
self.saved_fds.insert(fd_num, saved_fd);
Ok(())
}
pub fn restore_fd(&mut self, fd_num: i32) -> Result<(), String> {
if !(0..=1024).contains(&fd_num) {
return Err(format!("Invalid file descriptor number: {}", fd_num));
}
if let Some(saved_fd) = self.saved_fds.remove(&fd_num) {
unsafe {
let result = libc::dup2(saved_fd, fd_num as RawFd);
libc::close(saved_fd);
if result < 0 {
return Err(format!("Failed to restore file descriptor {}", fd_num));
}
}
self.fds.remove(&fd_num);
}
Ok(())
}
pub fn deep_clone(&self) -> Result<Self, String> {
let mut new_fds = HashMap::new();
for (fd, descriptor) in &self.fds {
new_fds.insert(*fd, descriptor.try_clone()?);
}
Ok(Self {
fds: new_fds,
saved_fds: self.saved_fds.clone(),
})
}
pub fn save_all_fds(&mut self) -> Result<(), String> {
let fd_nums: Vec<i32> = self.fds.keys().copied().collect();
for fd_num in fd_nums {
self.save_fd(fd_num)?;
}
for fd in 0..=2 {
if !self.fds.contains_key(&fd) {
let _ = self.save_fd(fd);
}
}
Ok(())
}
pub fn restore_all_fds(&mut self) -> Result<(), String> {
let fd_nums: Vec<i32> = self.saved_fds.keys().copied().collect();
for fd_num in fd_nums {
self.restore_fd(fd_num)?;
}
Ok(())
}
#[allow(dead_code)]
pub fn get_stdio(&self, fd_num: i32) -> Option<Stdio> {
match self.fds.get(&fd_num) {
Some(FileDescriptor::File(file)) => {
let raw_fd = file.as_raw_fd();
let dup_fd = unsafe { libc::dup(raw_fd) };
if dup_fd >= 0 {
let file = unsafe { File::from_raw_fd(dup_fd) };
Some(Stdio::from(file))
} else {
None
}
}
Some(FileDescriptor::Duplicate(raw_fd)) => {
let dup_fd = unsafe { libc::dup(*raw_fd) };
if dup_fd >= 0 {
let file = unsafe { File::from_raw_fd(dup_fd) };
Some(Stdio::from(file))
} else {
None
}
}
Some(FileDescriptor::Closed) | None => None,
}
}
pub fn get_raw_fd(&self, fd_num: i32) -> Option<RawFd> {
match self.fds.get(&fd_num) {
Some(FileDescriptor::File(file)) => Some(file.as_raw_fd()),
Some(FileDescriptor::Duplicate(raw_fd)) => Some(*raw_fd),
Some(FileDescriptor::Closed) => None,
None => {
if (0..=2).contains(&fd_num) {
Some(fd_num as RawFd)
} else {
None
}
}
}
}
pub fn is_open(&self, fd_num: i32) -> bool {
matches!(
self.fds.get(&fd_num),
Some(FileDescriptor::File(_)) | Some(FileDescriptor::Duplicate(_))
)
}
pub fn is_closed(&self, fd_num: i32) -> bool {
matches!(self.fds.get(&fd_num), Some(FileDescriptor::Closed))
}
pub fn clear(&mut self) {
self.fds.clear();
self.saved_fds.clear();
}
}
impl Default for FileDescriptorTable {
fn default() -> Self {
Self::new()
}
}