mxsh 0.2.0

Embeddable POSIX-style shell parser and runtime
Documentation
use std::fs;
use std::io;
use std::path::{Path, PathBuf};

#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
use crate::args::InitArgs;
#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
use crate::parser::{ParseError, Parser};
use crate::sys::Runtime;

use super::path::resolve_against_shell_cwd;
#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
use super::shell_errln;
use super::{OPT_NOEXEC, ShellState, maybe_warn_vi_unsupported, shell_expand, sync_monitor_mode};
#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
use super::{run_string, run_string_with_source};

pub(super) fn read_source_file(path: &Path) -> io::Result<Option<fs::File>> {
    match fs::File::open(path) {
        Ok(file) => Ok(Some(file)),
        Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
        Err(err) => Err(err),
    }
}

fn source_path<R: Runtime>(
    state: &mut ShellState,
    runtime: &mut R,
    path: &Path,
) -> io::Result<bool> {
    let Some(file) = read_source_file(path)? else {
        return Ok(false);
    };
    let source_name = path.to_string_lossy();
    let _ = super::run::run_source_reader(state, runtime, file, Some(source_name.as_ref()));
    Ok(true)
}

#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
fn load_source_with_cwd(
    state: &ShellState,
    init: &InitArgs,
) -> io::Result<Option<(fs::File, Option<PathBuf>)>> {
    if let Some(path) = &init.command_file {
        let resolved = resolve_against_shell_cwd(state, Path::new(path));
        return fs::File::open(&resolved).map(|file| Some((file, Some(resolved))));
    }
    Ok(None)
}

fn source_profile<R: Runtime>(state: &mut ShellState, runtime: &mut R) {
    if let Some(system_profile) = state
        .definition
        .startup_sources
        .system_profile()
        .map(Path::to_path_buf)
    {
        let _ = source_path(state, runtime, &system_profile);
    }
    if let Some(user_profile) = state.definition.startup_sources.user_profile() {
        let user_profile = if user_profile.is_absolute() {
            Some(user_profile.to_path_buf())
        } else {
            state
                .env_get("HOME")
                .map(PathBuf::from)
                .map(|home| home.join(user_profile))
        };
        if let Some(user_profile) = user_profile {
            let _ = source_path(state, runtime, &user_profile);
        }
    }
}

fn source_env<R: Runtime>(state: &mut ShellState, runtime: &mut R) {
    // Interactive callers install helper functions and prompt customizations
    // through a configured env hook variable.
    if let Some(env_file) = state
        .definition
        .startup_sources
        .env_file_var()
        .and_then(|name| state.env_get(name))
        .map(String::from)
    {
        let expanded = shell_expand::expand_tilde(state, &env_file);
        let resolved = resolve_against_shell_cwd(state, Path::new(&expanded));
        let _ = source_path(state, runtime, &resolved);
    }
}

#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
fn parse_error_needs_more_input(buffer: &str, err: &ParseError) -> bool {
    err.range.end.offset >= buffer.len()
}

#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
fn run_streaming_stdin<R: Runtime>(state: &mut ShellState, runtime: &mut R) -> i32 {
    let mut pending = String::new();
    let mut status = 0;
    loop {
        match state.stdin_fd.read_line() {
            Ok(Some(line)) => {
                pending.push_str(&line);
                pending.push('\n');
            }
            Ok(None) => {
                if pending.is_empty() {
                    return status;
                }
                return run_string(state, runtime, &pending);
            }
            Err(_) => return 1,
        }
        if pending.is_empty() {
            continue;
        }
        let mut parser = Parser::from_string(&pending);
        crate::parser::configure_parser_for_language(&mut parser, &state.definition.language);
        let aliases = state.aliases_snapshot();
        parser.set_alias_func(crate::parser::AliasFn::new(move |name| {
            aliases.get(name).cloned()
        }));
        match parser.parse_program() {
            Ok(_) => {
                status = run_string(state, runtime, &pending);
                pending.clear();
                if state.exit_code >= 0 {
                    return status;
                }
            }
            Err(err) if parse_error_needs_more_input(&pending, &err) => {}
            Err(_) => return run_string(state, runtime, &pending),
        }
    }
}

pub(crate) fn initialize_shell_session<R: Runtime>(state: &mut ShellState, runtime: &mut R) {
    sync_monitor_mode(state);
    if state.env_get("IFS").is_none() {
        state.populate_env();
    }
    maybe_warn_vi_unsupported(state);

    if state.definition.startup_policy.should_source_profile() {
        source_profile(state, runtime);
    }
    // Startup env hooks are configured so embedders can opt into their own rc
    // naming without forking the CLI.
    if state.interactive
        && !state.has_option(OPT_NOEXEC)
        && state.definition.startup_policy.should_source_env()
    {
        source_env(state, runtime);
    }
}

#[cfg(any(feature = "frontend", all(test, feature = "test-support")))]
pub(crate) fn run_non_interactive<R: Runtime>(
    state: &mut ShellState,
    runtime: &mut R,
    init: &InitArgs,
) -> i32 {
    if let Some(command) = init.command_str.as_ref() {
        return run_string_with_source(state, runtime, command, None);
    }
    match load_source_with_cwd(state, init) {
        Ok(Some((source, path))) => {
            let source_name = path
                .as_ref()
                .map(|path| path.to_string_lossy().into_owned());
            super::run::run_source_reader(state, runtime, source, source_name.as_deref())
        }
        Ok(None) => run_streaming_stdin(state, runtime),
        Err(err) => {
            if let Some(path) = init.command_file.as_deref() {
                shell_errln(state, &format!("failed to open {path} for reading: {err}"));
            } else {
                shell_errln(state, &format!("{err}"));
            }
            1
        }
    }
}

#[cfg(all(test, feature = "test-support"))]
mod tests {
    use std::fs::File;
    use std::io::Write;
    use std::os::fd::FromRawFd;
    use std::sync::mpsc;
    use std::thread;
    use std::time::{Duration, Instant};

    use crate::sys::{self, DeterministicRuntime, FileDescriptor};

    use super::*;

    fn read_until(fd: FileDescriptor, needle: &str, timeout: Duration) -> String {
        let deadline = Instant::now() + timeout;
        let mut buf = Vec::new();
        let mut chunk = [0_u8; 4096];
        while Instant::now() < deadline {
            let remaining = deadline.saturating_duration_since(Instant::now());
            let wait_ms = remaining.as_millis().min(i32::MAX as u128) as i32;
            let mut poll_fd = libc::pollfd {
                fd: fd.into_raw_fd(),
                events: libc::POLLIN,
                revents: 0,
            };
            let rc = unsafe { libc::poll(&mut poll_fd, 1, wait_ms) };
            assert!(rc >= 0, "poll failed: {}", io::Error::last_os_error());
            if rc == 0 {
                continue;
            }
            let read = unsafe {
                libc::read(
                    fd.into_raw_fd(),
                    chunk.as_mut_ptr() as *mut libc::c_void,
                    chunk.len(),
                )
            };
            assert!(read >= 0, "read failed: {}", io::Error::last_os_error());
            if read == 0 {
                break;
            }
            buf.extend_from_slice(&chunk[..read as usize]);
            let text = String::from_utf8_lossy(&buf);
            if text.contains(needle) {
                return text.into_owned();
            }
        }
        panic!(
            "timed out waiting for {needle:?}; partial output: {:?}",
            String::from_utf8_lossy(&buf)
        );
    }

    #[test]
    fn stdin_programs_execute_before_input_eof() {
        let stdin = sys::OsPipe::new().expect("stdin pipe");
        let stdout = sys::OsPipe::new().expect("stdout pipe");
        let (first_line_written_tx, first_line_written_rx) = mpsc::channel();
        let (allow_finish_tx, allow_finish_rx) = mpsc::channel();
        let stdin_write_fd = stdin.write_fd;
        let writer = thread::spawn(move || {
            let mut file = unsafe { File::from_raw_fd(stdin_write_fd.into_raw_fd()) };
            file.write_all(b"echo first\n").expect("write first line");
            file.flush().expect("flush first line");
            first_line_written_tx
                .send(())
                .expect("signal first line write");
            allow_finish_rx.recv().expect("allow finish");
            file.write_all(b"echo second\n").expect("write second line");
            file.flush().expect("flush second line");
        });
        let stdin_read_fd = stdin.read_fd;
        let stdout_write_fd = stdout.write_fd;
        let shell = thread::spawn(move || {
            let mut state = ShellState::new();
            state.populate_env();
            state.stdin_fd = stdin_read_fd;
            state.stdout_fd = stdout_write_fd;
            let mut runtime = DeterministicRuntime::new();
            let status = run_non_interactive(&mut state, &mut runtime, &InitArgs::default());
            stdin_read_fd.close();
            stdout_write_fd.close();
            status
        });

        first_line_written_rx
            .recv_timeout(Duration::from_secs(1))
            .expect("writer should send first line");
        let first = read_until(stdout.read_fd, "first\n", Duration::from_secs(2));
        assert!(
            first.contains("first\n"),
            "expected first output in {first:?}"
        );

        allow_finish_tx.send(()).expect("allow writer to finish");
        let status = shell.join().expect("shell thread should join");
        assert_eq!(status, 0);

        let rest = stdout.read_fd.read_all();
        let combined = format!("{first}{rest}");
        assert!(
            combined.contains("second\n"),
            "expected second output in {combined:?}"
        );
        writer.join().expect("writer thread should join");
    }
}