mxsh 0.2.0

Embeddable POSIX-style shell parser and runtime
Documentation
use std::io;
use std::os::fd::RawFd;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct FileDescriptor(i32);

impl FileDescriptor {
    pub const STDIN: Self = Self(0);
    pub const STDOUT: Self = Self(1);
    pub const STDERR: Self = Self(2);
    pub const INVALID: Self = Self(-1);

    pub const fn new(fd: i32) -> Self {
        Self(fd)
    }

    pub const fn as_i32(self) -> i32 {
        self.0
    }

    pub const fn is_valid(self) -> bool {
        self.0 >= 0
    }

    #[cfg(feature = "embed")]
    pub(crate) fn is_open(self) -> bool {
        self.is_valid() && unsafe { libc::fcntl(self.into_raw_fd(), libc::F_GETFD) } >= 0
    }

    pub fn is_terminal(self) -> bool {
        self.is_valid() && unsafe { libc::isatty(self.into_raw_fd()) == 1 }
    }

    #[cfg(unix)]
    pub const fn into_raw_fd(self) -> RawFd {
        self.0
    }

    /// Read one line from this fd. Returns None at EOF. The trailing newline is stripped.
    pub fn read_line(self) -> Result<Option<String>, io::Error> {
        read_line_fd_inner(self).map(|line| line.map(|(line, _terminated_by_newline)| line))
    }

    /// Read one line from this fd, reporting whether it ended with a newline.
    ///
    /// Returns `Ok(Some((line, terminated)))` where `terminated` is true when
    /// the line ended with `\n`, or `Ok(None)` at EOF.
    pub fn read_line_with_status(self) -> Result<Option<(String, bool)>, io::Error> {
        read_line_fd_inner(self)
    }

    /// Write a string to this fd.
    pub fn write_str(self, s: &str) -> Result<(), io::Error> {
        let bytes = s.as_bytes();
        let mut written = 0;
        while written < bytes.len() {
            let n = unsafe {
                libc::write(
                    self.into_raw_fd(),
                    bytes[written..].as_ptr() as *const libc::c_void,
                    bytes.len() - written,
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                return Err(io::Error::new(
                    io::ErrorKind::WriteZero,
                    "failed to write whole buffer",
                ));
            }
            written += n as usize;
        }
        Ok(())
    }

    /// Write a string followed by a newline to this fd.
    pub fn write_line(self, s: &str) -> Result<(), io::Error> {
        self.write_str(s)?;
        self.write_str("\n")
    }

    /// Read all bytes from this fd and return as a String.
    pub fn read_all(self) -> String {
        let mut output = Vec::new();
        let mut chunk = [0u8; 4096];
        loop {
            let n = unsafe {
                libc::read(
                    self.into_raw_fd(),
                    chunk.as_mut_ptr() as *mut libc::c_void,
                    chunk.len(),
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                break;
            }
            if n == 0 {
                break;
            }
            output.extend_from_slice(&chunk[..n as usize]);
        }
        String::from_utf8_lossy(&output).into_owned()
    }

    /// Read all bytes from this fd and return as a UTF-8 lossy String.
    pub fn read_to_string(self) -> Result<String, io::Error> {
        let mut output = Vec::new();
        let mut chunk = [0u8; 4096];
        loop {
            let n = unsafe {
                libc::read(
                    self.into_raw_fd(),
                    chunk.as_mut_ptr() as *mut libc::c_void,
                    chunk.len(),
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                break;
            }
            output.extend_from_slice(&chunk[..n as usize]);
        }
        Ok(String::from_utf8_lossy(&output).into_owned())
    }

    /// Read bytes from this fd and return as a UTF-8 lossy String, stopping if the
    /// payload exceeds `max_bytes`.
    pub fn read_to_string_with_limit(self, max_bytes: usize) -> Result<String, io::Error> {
        let mut output = Vec::new();
        let mut chunk = [0u8; 4096];
        loop {
            let n = unsafe {
                libc::read(
                    self.into_raw_fd(),
                    chunk.as_mut_ptr() as *mut libc::c_void,
                    chunk.len(),
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                break;
            }
            if output.len().saturating_add(n as usize) > max_bytes {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    "payload exceeds maximum size",
                ));
            }
            output.extend_from_slice(&chunk[..n as usize]);
        }
        Ok(String::from_utf8_lossy(&output).into_owned())
    }

    /// Write all bytes to this fd without closing it.
    pub fn write_all(self, data: &[u8]) -> Result<(), io::Error> {
        let mut written = 0;
        while written < data.len() {
            let n = unsafe {
                libc::write(
                    self.into_raw_fd(),
                    data[written..].as_ptr() as *const libc::c_void,
                    data.len() - written,
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                return Err(io::Error::new(
                    io::ErrorKind::WriteZero,
                    "failed to write whole buffer",
                ));
            }
            written += n as usize;
        }
        Ok(())
    }

    /// Close this file descriptor.
    pub fn close(self) {
        if self.is_valid() {
            unsafe { libc::close(self.into_raw_fd()) };
        }
    }

    /// Duplicate this file descriptor.
    pub fn dup(self) -> Result<FileDescriptor, io::Error> {
        let new_fd = unsafe { libc::dup(self.into_raw_fd()) };
        if new_fd < 0 {
            Err(io::Error::last_os_error())
        } else {
            set_cloexec(new_fd)?;
            Ok(FileDescriptor::from(new_fd))
        }
    }

    /// Redirect this fd to `new_fd` via `dup2`.
    pub fn dup2(self, new_fd: FileDescriptor) -> Result<(), io::Error> {
        if unsafe { libc::dup2(self.into_raw_fd(), new_fd.into_raw_fd()) } < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(())
        }
    }
}

impl From<i32> for FileDescriptor {
    fn from(value: i32) -> Self {
        Self(value)
    }
}

impl From<FileDescriptor> for i32 {
    fn from(value: FileDescriptor) -> Self {
        value.0
    }
}

/// An OS pipe represented as (read_fd, write_fd).
pub struct OsPipe {
    /// The read end of the pipe.
    pub read_fd: FileDescriptor,
    /// The write end of the pipe.
    pub write_fd: FileDescriptor,
}

impl OsPipe {
    /// Create a new OS pipe.
    pub fn new() -> Result<Self, io::Error> {
        let mut fds = [0i32; 2];
        if unsafe { libc::pipe(fds.as_mut_ptr()) } != 0 {
            return Err(io::Error::last_os_error());
        }
        if let Err(err) = set_cloexec(fds[0]).and_then(|_| set_cloexec(fds[1])) {
            unsafe {
                libc::close(fds[0]);
                libc::close(fds[1]);
            }
            return Err(err);
        }
        Ok(Self {
            read_fd: FileDescriptor::from(fds[0]),
            write_fd: FileDescriptor::from(fds[1]),
        })
    }
}

fn set_cloexec(fd: RawFd) -> Result<(), io::Error> {
    let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
    if flags < 0 {
        return Err(io::Error::last_os_error());
    }
    if unsafe { libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) } < 0 {
        return Err(io::Error::last_os_error());
    }
    Ok(())
}

fn read_line_fd_inner(fd: FileDescriptor) -> Result<Option<(String, bool)>, io::Error> {
    let mut line = Vec::new();
    let mut byte = [0u8; 1];
    loop {
        let n = unsafe {
            libc::read(
                fd.into_raw_fd(),
                byte.as_mut_ptr() as *mut libc::c_void,
                byte.len(),
            )
        };
        if n < 0 {
            let err = io::Error::last_os_error();
            if err.raw_os_error() == Some(libc::EINTR) {
                continue;
            }
            return Err(err);
        }
        if n == 0 {
            if line.is_empty() {
                return Ok(None);
            }
            return Ok(Some((String::from_utf8_lossy(&line).into_owned(), false)));
        }
        if byte[0] == b'\n' {
            return Ok(Some((String::from_utf8_lossy(&line).into_owned(), true)));
        }
        line.push(byte[0]);
    }
}

#[cfg(feature = "unix-runtime")]
pub(crate) fn fd_has_cloexec(fd: RawFd) -> Result<bool, io::Error> {
    let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
    if flags < 0 {
        return Err(io::Error::last_os_error());
    }
    Ok((flags & libc::FD_CLOEXEC) != 0)
}