mxsh 0.1.0

Embeddable POSIX-style shell parser and runtime
Documentation
use std::collections::HashMap;
use std::io;
use std::os::fd::RawFd;
use std::sync::{Arc, Mutex, OnceLock};

#[derive(Default)]
struct LineBuffer {
    bytes: Vec<u8>,
    start: usize,
}

impl LineBuffer {
    fn is_empty(&self) -> bool {
        self.start >= self.bytes.len()
    }

    fn clear(&mut self) {
        self.bytes.clear();
        self.start = 0;
    }

    fn take_line(&mut self) -> Option<(String, bool)> {
        let start = self.start;
        let newline = self.bytes[start..].iter().position(|byte| *byte == b'\n')?;
        let end = start + newline;
        let line = String::from_utf8_lossy(&self.bytes[start..end]).into_owned();
        self.start = end + 1;
        if self.start == self.bytes.len() {
            self.clear();
        }
        Some((line, true))
    }

    fn take_remainder(&mut self) -> Option<String> {
        if self.is_empty() {
            self.clear();
            return None;
        }
        let line = String::from_utf8_lossy(&self.bytes[self.start..]).into_owned();
        self.clear();
        Some(line)
    }

    fn compact_before_append(&mut self, incoming_len: usize) {
        if self.start == 0 {
            return;
        }
        if self.is_empty() {
            self.clear();
            return;
        }
        if self.start >= incoming_len || self.start * 2 >= self.bytes.len() {
            self.bytes.copy_within(self.start.., 0);
            self.bytes.truncate(self.bytes.len() - self.start);
            self.start = 0;
        }
    }

    fn append(&mut self, chunk: &[u8]) {
        self.compact_before_append(chunk.len());
        self.bytes.extend_from_slice(chunk);
    }
}

type SharedLineBuffer = Arc<Mutex<LineBuffer>>;
type LineBufferMap = Mutex<HashMap<RawFd, SharedLineBuffer>>;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct FileDescriptor(i32);

impl FileDescriptor {
    pub const STDIN: Self = Self(0);
    pub const STDOUT: Self = Self(1);
    pub const STDERR: Self = Self(2);
    pub const INVALID: Self = Self(-1);

    pub const fn new(fd: i32) -> Self {
        Self(fd)
    }

    pub const fn as_i32(self) -> i32 {
        self.0
    }

    pub const fn is_valid(self) -> bool {
        self.0 >= 0
    }

    #[cfg(unix)]
    pub const fn into_raw_fd(self) -> RawFd {
        self.0
    }

    /// Read one line from this fd. Returns None at EOF. The trailing newline is stripped.
    pub fn read_line(self) -> Result<Option<String>, io::Error> {
        read_line_fd_inner(self).map(|line| line.map(|(line, _terminated_by_newline)| line))
    }

    /// Read one line from this fd, reporting whether it ended with a newline.
    ///
    /// Returns `Ok(Some((line, terminated)))` where `terminated` is true when
    /// the line ended with `\n`, or `Ok(None)` at EOF.
    pub fn read_line_with_status(self) -> Result<Option<(String, bool)>, io::Error> {
        read_line_fd_inner(self)
    }

    /// Write a string to this fd.
    pub fn write_str(self, s: &str) -> Result<(), io::Error> {
        let bytes = s.as_bytes();
        let mut written = 0;
        while written < bytes.len() {
            let n = unsafe {
                libc::write(
                    self.into_raw_fd(),
                    bytes[written..].as_ptr() as *const libc::c_void,
                    bytes.len() - written,
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                return Err(io::Error::new(
                    io::ErrorKind::WriteZero,
                    "failed to write whole buffer",
                ));
            }
            written += n as usize;
        }
        Ok(())
    }

    /// Write a string followed by a newline to this fd.
    pub fn write_line(self, s: &str) -> Result<(), io::Error> {
        self.write_str(s)?;
        self.write_str("\n")
    }

    /// Read all bytes from this fd and return as a String.
    pub fn read_all(self) -> String {
        clear_line_buffer(self);
        let mut output = Vec::new();
        let mut chunk = [0u8; 4096];
        loop {
            let n = unsafe {
                libc::read(
                    self.into_raw_fd(),
                    chunk.as_mut_ptr() as *mut libc::c_void,
                    chunk.len(),
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                break;
            }
            if n == 0 {
                break;
            }
            output.extend_from_slice(&chunk[..n as usize]);
        }
        String::from_utf8_lossy(&output).into_owned()
    }

    /// Read all bytes from this fd and return as a UTF-8 lossy String.
    pub fn read_to_string(self) -> Result<String, io::Error> {
        clear_line_buffer(self);
        let mut output = Vec::new();
        let mut chunk = [0u8; 4096];
        loop {
            let n = unsafe {
                libc::read(
                    self.into_raw_fd(),
                    chunk.as_mut_ptr() as *mut libc::c_void,
                    chunk.len(),
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                break;
            }
            output.extend_from_slice(&chunk[..n as usize]);
        }
        Ok(String::from_utf8_lossy(&output).into_owned())
    }

    /// Write all bytes to this fd without closing it.
    pub fn write_all(self, data: &[u8]) -> Result<(), io::Error> {
        let mut written = 0;
        while written < data.len() {
            let n = unsafe {
                libc::write(
                    self.into_raw_fd(),
                    data[written..].as_ptr() as *const libc::c_void,
                    data.len() - written,
                )
            };
            if n < 0 {
                let err = io::Error::last_os_error();
                if err.raw_os_error() == Some(libc::EINTR) {
                    continue;
                }
                return Err(err);
            }
            if n == 0 {
                return Err(io::Error::new(
                    io::ErrorKind::WriteZero,
                    "failed to write whole buffer",
                ));
            }
            written += n as usize;
        }
        Ok(())
    }

    /// Close this file descriptor.
    pub fn close(self) {
        if self.is_valid() {
            clear_line_buffer(self);
            unsafe { libc::close(self.into_raw_fd()) };
        }
    }

    /// Duplicate this file descriptor.
    pub fn dup(self) -> Result<FileDescriptor, io::Error> {
        let new_fd = unsafe { libc::dup(self.into_raw_fd()) };
        if new_fd < 0 {
            Err(io::Error::last_os_error())
        } else {
            let new_fd = FileDescriptor::from(new_fd);
            clone_line_buffer(self, new_fd);
            Ok(new_fd)
        }
    }

    /// Redirect this fd to `new_fd` via `dup2`.
    pub fn dup2(self, new_fd: FileDescriptor) -> Result<(), io::Error> {
        if unsafe { libc::dup2(self.into_raw_fd(), new_fd.into_raw_fd()) } < 0 {
            Err(io::Error::last_os_error())
        } else {
            clone_line_buffer(self, new_fd);
            Ok(())
        }
    }
}

impl From<i32> for FileDescriptor {
    fn from(value: i32) -> Self {
        Self(value)
    }
}

impl From<FileDescriptor> for i32 {
    fn from(value: FileDescriptor) -> Self {
        value.0
    }
}

/// An OS pipe represented as (read_fd, write_fd).
pub struct OsPipe {
    /// The read end of the pipe.
    pub read_fd: FileDescriptor,
    /// The write end of the pipe.
    pub write_fd: FileDescriptor,
}

impl OsPipe {
    /// Create a new OS pipe.
    pub fn new() -> Result<Self, io::Error> {
        let mut fds = [0i32; 2];
        if unsafe { libc::pipe(fds.as_mut_ptr()) } != 0 {
            return Err(io::Error::last_os_error());
        }
        if let Err(err) = set_cloexec(fds[0]).and_then(|_| set_cloexec(fds[1])) {
            unsafe {
                libc::close(fds[0]);
                libc::close(fds[1]);
            }
            return Err(err);
        }
        Ok(Self {
            read_fd: FileDescriptor::from(fds[0]),
            write_fd: FileDescriptor::from(fds[1]),
        })
    }
}

fn set_cloexec(fd: RawFd) -> Result<(), io::Error> {
    let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
    if flags < 0 {
        return Err(io::Error::last_os_error());
    }
    if unsafe { libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) } < 0 {
        return Err(io::Error::last_os_error());
    }
    Ok(())
}

fn line_buffers() -> &'static LineBufferMap {
    static LINE_BUFFERS: OnceLock<LineBufferMap> = OnceLock::new();
    LINE_BUFFERS.get_or_init(|| Mutex::new(HashMap::new()))
}

fn line_buffer_for_fd(fd: FileDescriptor) -> SharedLineBuffer {
    let mut buffers = line_buffers().lock().unwrap_or_else(|err| err.into_inner());
    buffers
        .entry(fd.into_raw_fd())
        .or_insert_with(|| Arc::new(Mutex::new(LineBuffer::default())))
        .clone()
}

fn clear_line_buffer(fd: FileDescriptor) {
    let mut buffers = line_buffers().lock().unwrap_or_else(|err| err.into_inner());
    buffers.remove(&fd.into_raw_fd());
}

fn clone_line_buffer(old_fd: FileDescriptor, new_fd: FileDescriptor) {
    let mut buffers = line_buffers().lock().unwrap_or_else(|err| err.into_inner());
    let buffer = buffers
        .entry(old_fd.into_raw_fd())
        .or_insert_with(|| Arc::new(Mutex::new(LineBuffer::default())))
        .clone();
    buffers.insert(new_fd.into_raw_fd(), buffer);
}

fn read_line_fd_inner(fd: FileDescriptor) -> Result<Option<(String, bool)>, io::Error> {
    let buffer = line_buffer_for_fd(fd);
    let mut buffer = buffer.lock().unwrap_or_else(|err| err.into_inner());
    let mut chunk = [0u8; 4096];

    loop {
        if let Some(line) = buffer.take_line() {
            return Ok(Some(line));
        }

        let n = unsafe {
            libc::read(
                fd.into_raw_fd(),
                chunk.as_mut_ptr() as *mut libc::c_void,
                chunk.len(),
            )
        };
        if n < 0 {
            let err = io::Error::last_os_error();
            if err.raw_os_error() == Some(libc::EINTR) {
                continue;
            }
            return Err(err);
        }
        if n == 0 {
            return Ok(buffer.take_remainder().map(|line| (line, false)));
        }
        buffer.append(&chunk[..n as usize]);
    }
}

#[cfg(feature = "unix-runtime")]
pub(crate) fn fd_has_cloexec(fd: RawFd) -> Result<bool, io::Error> {
    let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
    if flags < 0 {
        return Err(io::Error::last_os_error());
    }
    Ok((flags & libc::FD_CLOEXEC) != 0)
}