use std::{
process::Stdio,
sync::atomic::{AtomicBool, Ordering},
};
use bytes::{Bytes, BytesMut};
use dialoguer::FuzzySelect;
use futures::prelude::*;
use indoc::indoc;
use is_root::is_root;
use itertools::{chain, Itertools};
use regex::Regex;
use tap::prelude::*;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
process::Command as Exec,
task::JoinHandle,
};
#[allow(clippy::wildcard_imports)]
use tokio_util::{
codec::{BytesCodec, FramedRead},
compat::*,
either::Either,
};
use which::which;
use crate::{
error::{Error, Result},
print::{println_quoted, prompt, question_theme},
};
#[derive(Copy, Clone, Debug)]
pub(crate) enum Mode {
PrintCmd,
Mute,
CheckAll,
CheckErr,
Prompt,
}
pub(crate) type StatusCode = i32;
fn exit_result(code: Option<StatusCode>, output: Output) -> Result<Output> {
match code {
Some(0) => Ok(output),
Some(code) => Err(Error::CmdStatusCodeError { code, output }),
None => Err(Error::CmdInterruptedError),
}
}
pub(crate) type Output = Vec<u8>;
#[must_use]
#[derive(Debug, Clone, Default)]
pub(crate) struct Cmd {
pub sudo: bool,
pub cmd: Vec<String>,
pub flags: Vec<String>,
pub kws: Vec<String>,
}
impl Cmd {
pub(crate) fn new(cmd: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
Cmd {
cmd: cmd.into_iter().map(|s| s.as_ref().into()).collect(),
..Cmd::default()
}
}
pub(crate) fn with_sudo(cmd: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
Cmd::new(cmd).sudo(true)
}
pub(crate) fn flags(self, flags: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
Cmd {
flags: flags.into_iter().map(|s| s.as_ref().into()).collect(),
..self
}
}
pub(crate) fn kws(self, kws: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
Cmd {
kws: kws.into_iter().map(|s| s.as_ref().into()).collect(),
..self
}
}
pub(crate) fn sudo(self, sudo: bool) -> Self {
Cmd { sudo, ..self }
}
#[must_use]
fn should_sudo(&self) -> bool {
self.sudo && !is_root()
}
#[must_use]
fn build(self) -> Exec {
if self.should_sudo() {
Exec::new("sudo").tap_mut(|builder| {
builder
.arg("-S")
.args(&self.cmd)
.args(&self.flags)
.args(&self.kws);
})
} else {
let (cmd, subcmd) = self
.cmd
.split_first()
.expect("Failed to build Cmd, command is empty");
Exec::new(cmd).tap_mut(|builder| {
builder.args(subcmd).args(&self.flags).args(&self.kws);
})
}
}
}
async fn exec_tee(
src: impl Stream<Item = io::Result<Bytes>>,
out: Option<impl AsyncWrite>,
) -> Result<Vec<u8>> {
let mut buf = Vec::<u8>::new();
let buf_sink = (&mut buf).into_sink();
let sink = if let Some(out) = out {
let out_sink = out.compat_write().into_sink();
buf_sink.fanout(out_sink).left_sink()
} else {
buf_sink.right_sink()
};
src.forward(sink).await?;
Ok(buf)
}
macro_rules! docs_errors_exec {
() => {
indoc! {"
# Errors
This function might return one of the following errors:
- [`Error::CmdJoinError`]
- [`Error::CmdNoHandleError`]
- [`Error::CmdSpawnError`]
- [`Error::CmdWaitError`]
- [`Error::CmdStatusCodeError`]
- [`Error::CmdInterruptedError`]
"}
};
}
impl Cmd {
#[doc = docs_errors_exec!()]
pub(crate) async fn exec(self, mode: Mode) -> Result<Output> {
match mode {
Mode::PrintCmd => {
println_quoted(&*prompt::CANCELED, &self);
Ok(Output::default())
}
Mode::Mute => self.exec_checkall(true).await,
Mode::CheckAll => {
println_quoted(&*prompt::RUNNING, &self);
self.exec_checkall(false).await
}
Mode::CheckErr => {
println_quoted(&*prompt::RUNNING, &self);
self.exec_checkerr(false).await
}
Mode::Prompt => self.exec_prompt(false).await,
}
}
#[doc = docs_errors_exec!()]
async fn exec_check_output(self, mute: bool, merge: bool) -> Result<Output> {
use tokio_stream::StreamExt;
use Error::{CmdJoinError, CmdNoHandleError, CmdSpawnError, CmdWaitError};
fn make_reader(
src: Option<impl AsyncRead>,
name: &str,
) -> Result<impl Stream<Item = io::Result<Bytes>>> {
src.map(into_bytes).ok_or_else(|| CmdNoHandleError {
handle: name.into(),
})
}
let mut child = self
.build()
.stderr(Stdio::piped())
.tap_deref_mut(|cmd| {
if merge {
cmd.stdout(Stdio::piped());
}
})
.spawn()
.map_err(CmdSpawnError)?;
let stderr_reader = make_reader(child.stderr.take(), "stderr")?;
let mut reader = if merge {
let stdout_reader = make_reader(child.stdout.take(), "stdout")?;
StreamExt::merge(stdout_reader, stderr_reader).left_stream()
} else {
stderr_reader.right_stream()
};
let mut out = if merge {
Either::Left(io::stdout())
} else {
Either::Right(io::stderr())
};
let code: JoinHandle<Result<Option<i32>>> = tokio::spawn(async move {
let status = child.wait().await.map_err(CmdWaitError)?;
Ok(status.code())
});
let output = exec_tee(&mut reader, (!mute).then_some(&mut out)).await?;
let code = code.await.map_err(CmdJoinError)??;
exit_result(code, output)
}
#[doc = docs_errors_exec!()]
async fn exec_checkall(self, mute: bool) -> Result<Output> {
self.exec_check_output(mute, true).await
}
#[doc = docs_errors_exec!()]
async fn exec_checkerr(self, mute: bool) -> Result<Output> {
self.exec_check_output(mute, false).await
}
#[doc = docs_errors_exec!()]
async fn exec_prompt(self, mute: bool) -> Result<Output> {
static ALL: AtomicBool = AtomicBool::new(false);
let answer = || -> Result<bool> {
println_quoted(&*prompt::PENDING, &self);
let answer = tokio::task::block_in_place(move || {
prompt(
"Proceed",
"with the previous command?",
&["Yes", "All", "No"],
)
})?;
Ok(match answer {
0 => true,
1 => {
ALL.store(true, Ordering::SeqCst);
true
}
2 => false,
_ => unreachable!(),
})
};
let proceed = ALL.load(Ordering::SeqCst) || answer()?;
if !proceed {
return Ok(Output::default());
}
println_quoted(&*prompt::RUNNING, &self);
self.exec_checkerr(mute).await
}
}
impl std::fmt::Display for Cmd {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let sudo: &str = if self.should_sudo() { "sudo -S " } else { "" };
let cmd = chain!(&self.cmd, &self.flags, &self.kws).join(" ");
write!(f, "{sudo}{cmd}")
}
}
#[allow(clippy::missing_panics_doc)]
fn prompt<'a>(prompt: &str, question: &str, expected: &[&'a str]) -> io::Result<usize> {
FuzzySelect::with_theme(&question_theme(prompt))
.with_prompt(question)
.items(expected)
.default(0)
.interact()
}
macro_rules! docs_errors_grep {
() => {
indoc! {"
# Errors
Returns an [`Error::OtherError`] when any of the
regex patterns is ill-formed.
"}
};
}
#[doc = docs_errors_grep!()]
fn grep<'t>(text: &'t str, patterns: &[&str]) -> Result<Vec<&'t str>> {
let patterns: Vec<Regex> = patterns
.iter()
.map(|pat| {
Regex::new(pat)
.map_err(|_e| Error::OtherError(format!("Pattern `{pat}` is ill-formed")))
})
.try_collect()?;
Ok(text
.lines()
.filter(|line| patterns.iter().all(|pat| pat.is_match(line)))
.collect())
}
#[doc = docs_errors_grep!()]
pub(crate) fn grep_print(text: &str, patterns: &[&str]) -> Result<()> {
grep(text, patterns).map(|lns| lns.iter().for_each(|ln| println!("{ln}")))
}
#[must_use]
pub(crate) fn is_exe(name: &str, path: &str) -> bool {
(!path.is_empty() && which(path).is_ok()) || (!name.is_empty() && which(name).is_ok())
}
fn into_bytes(reader: impl AsyncRead) -> impl Stream<Item = io::Result<Bytes>> {
FramedRead::new(reader, BytesCodec::new()).map_ok(BytesMut::freeze)
}