use crate::error::*;
use errno;
use libc::{self, c_int};
use std::fmt;
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use tracing::{debug, error};
pub type IoResult<T> = io::Result<T>;
fn to_io_result(ret: c_int) -> IoResult<()> {
match ret {
0 => Ok(()),
_ => Err(io::Error::last_os_error()),
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum TerminalFlag {
Echo,
EchoNewlines,
}
impl TerminalFlag {
fn to_value(&self) -> libc::tcflag_t {
match *self {
TerminalFlag::Echo => libc::ECHO,
TerminalFlag::EchoNewlines => libc::ECHONL,
}
}
}
pub trait AbstractTerminalAttributes {
fn enable(&mut self, flag: TerminalFlag);
fn disable(&mut self, flag: TerminalFlag);
}
pub struct TerminalAttributes {
inner: libc::termios,
}
impl TerminalAttributes {
fn new(fd: c_int) -> IoResult<Self> {
let mut attrs = MaybeUninit::uninit();
to_io_result(unsafe { libc::tcgetattr(fd, attrs.as_mut_ptr()) })?;
Ok(TerminalAttributes {
inner: unsafe { attrs.assume_init() },
})
}
pub fn new_empty() -> Self {
TerminalAttributes {
inner: unsafe { std::mem::zeroed() },
}
}
fn apply(&self, fd: c_int) -> IoResult<()> {
to_io_result(unsafe { libc::tcsetattr(fd, libc::TCSANOW, &self.inner) })
}
pub fn is_enabled(&self, flag: TerminalFlag) -> bool {
self.inner.c_lflag & flag.to_value() != 0
}
}
impl PartialEq for TerminalAttributes {
fn eq(&self, other: &Self) -> bool {
self.inner.c_iflag == other.inner.c_iflag
&& self.inner.c_oflag == other.inner.c_oflag
&& self.inner.c_cflag == other.inner.c_cflag
&& self.inner.c_lflag == other.inner.c_lflag
&& self.inner.c_line == other.inner.c_line
&& self.inner.c_cc == other.inner.c_cc
&& self.inner.c_ispeed == other.inner.c_ispeed
&& self.inner.c_ospeed == other.inner.c_ospeed
}
}
impl Eq for TerminalAttributes {}
fn debug_format_flag_field(
v: libc::tcflag_t,
fs: &'static [(&'static str, libc::tcflag_t)],
) -> std::result::Result<String, fmt::Error> {
use fmt::Write;
let mut remaining_v: libc::tcflag_t = v;
let mut s = String::new();
for &(fname, fvalue) in fs {
if (v & fvalue) != 0 {
let was_empty = s.is_empty();
write!(&mut s, "{}{}", if was_empty { "" } else { " | " }, fname)?;
remaining_v &= !fvalue;
}
}
if remaining_v != 0 {
let was_empty = s.is_empty();
write!(
&mut s,
"{}(extra: {:x})",
if was_empty { "" } else { " " },
remaining_v
)?;
}
Ok(s)
}
fn debug_format_c_cc_field(c_cc: &[libc::cc_t; 32]) -> std::result::Result<String, fmt::Error> {
use fmt::Write;
const INDICES: &'static [(&'static str, usize)] = &[
("VDISCARD", libc::VDISCARD),
("VEOF", libc::VEOF),
("VEOL", libc::VEOL),
("VEOL2", libc::VEOL2),
("VERASE", libc::VERASE),
("VINTR", libc::VINTR),
("VKILL", libc::VKILL),
("VLNEXT", libc::VLNEXT),
("VMIN", libc::VMIN),
("VQUIT", libc::VQUIT),
("VREPRINT", libc::VREPRINT),
("VSTART", libc::VSTART),
("VSTOP", libc::VSTOP),
("VSUSP", libc::VSUSP),
("VSWTC", libc::VSWTC),
("VTIME", libc::VTIME),
("VWERASE", libc::VWERASE),
];
let mut s = String::new();
for &(name, idx) in INDICES {
let was_empty = s.is_empty();
write!(
&mut s,
"{}{}:{}",
if was_empty { "" } else { ", " },
name,
c_cc[idx]
)?;
}
Ok(s)
}
impl fmt::Debug for TerminalAttributes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TerminalAttributes")
.field(
"c_iflag",
&debug_format_flag_field(
self.inner.c_iflag,
&[
("IGNBRK", libc::IGNBRK),
("BRKINT", libc::BRKINT),
("IGNPAR", libc::IGNPAR),
("PARMRK", libc::PARMRK),
("INPCK", libc::INPCK),
("ISTRIP", libc::ISTRIP),
("INLCR", libc::INLCR),
("IGNCR", libc::IGNCR),
("ICRNL", libc::ICRNL),
("IXON", libc::IXON),
("IXANY", libc::IXANY),
("IXOFF", libc::IXOFF),
("IMAXBEL", libc::IMAXBEL),
("IUTF8", libc::IUTF8),
],
)?,
)
.field(
"c_oflag",
&debug_format_flag_field(
self.inner.c_oflag,
&[
("OPOST", libc::OPOST),
("OLCUC", libc::OLCUC),
("ONLCR", libc::ONLCR),
("ONOCR", libc::ONOCR),
("ONLRET", libc::ONLRET),
("OFILL", libc::OFILL),
("OFDEL", libc::OFDEL),
("NLDLY", libc::NLDLY),
("CRDLY", libc::CRDLY),
("TABDLY", libc::TABDLY),
("BSDLY", libc::BSDLY),
("VTDLY", libc::VTDLY),
("FFDLY", libc::FFDLY),
],
)?,
)
.field(
"c_cflag",
&debug_format_flag_field(
self.inner.c_cflag,
&[
("CBAUD", libc::CBAUD),
("CBAUDEX", libc::CBAUDEX),
("CSIZE", libc::CSIZE),
("CSTOPB", libc::CSTOPB),
("CREAD", libc::CREAD),
("PARENB", libc::PARENB),
("PARODD", libc::PARODD),
("HUPCL", libc::HUPCL),
("CLOCAL", libc::CLOCAL),
("CIBAUD", libc::CIBAUD),
("CMSPAR", libc::CMSPAR),
("CRTSCTS", libc::CRTSCTS),
],
)?,
)
.field(
"c_lflag",
&debug_format_flag_field(
self.inner.c_lflag,
&[
("ISIG", libc::ISIG),
("ICANON", libc::ICANON),
("ECHO", libc::ECHO),
("ECHOE", libc::ECHOE),
("ECHOK", libc::ECHOK),
("ECHONL", libc::ECHONL),
("ECHOCTL", libc::ECHOCTL),
("ECHOPRT", libc::ECHOPRT),
("ECHOKE", libc::ECHOKE),
("FLUSHO", libc::FLUSHO),
("NOFLSH", libc::NOFLSH),
("TOSTOP", libc::TOSTOP),
("PENDIN", libc::PENDIN),
("IEXTEN", libc::IEXTEN),
],
)?,
)
.field("c_cc", &debug_format_c_cc_field(&self.inner.c_cc)?)
.field("c_ispeed", &unsafe { libc::cfgetispeed(&self.inner) })
.field("c_ospeed", &unsafe { libc::cfgetospeed(&self.inner) })
.finish()
}
}
impl AbstractTerminalAttributes for TerminalAttributes {
fn enable(&mut self, flag: TerminalFlag) {
self.inner.c_lflag |= flag.to_value();
}
fn disable(&mut self, flag: TerminalFlag) {
self.inner.c_lflag &= !flag.to_value();
}
}
pub trait AbstractStream {
type Attributes: AbstractTerminalAttributes + fmt::Debug;
fn isatty(&self) -> bool;
fn get_attributes(&self) -> IoResult<Self::Attributes>;
fn set_attributes(&mut self, attributes: &Self::Attributes) -> IoResult<()>;
fn as_reader(&self) -> Option<Box<dyn Read>>;
fn as_writer(&self) -> Option<Box<dyn Write>>;
}
#[derive(Debug)]
pub enum Stream {
Stdout,
Stderr,
Stdin,
}
impl Stream {
fn to_fd(&self) -> c_int {
match *self {
Stream::Stdout => libc::STDOUT_FILENO,
Stream::Stderr => libc::STDERR_FILENO,
Stream::Stdin => libc::STDIN_FILENO,
}
}
}
impl AbstractStream for Stream {
type Attributes = TerminalAttributes;
fn isatty(&self) -> bool {
let ret = unsafe { libc::isatty(self.to_fd()) };
let error: i32 = errno::errno().into();
match ret {
1 => true,
0 => match error {
libc::EBADF => false,
libc::ENOTTY => false,
_ => {
debug!(
"Unrecognized isatty errno: {}; assuming {:?} is not a TTY",
error, *self
);
false
}
},
_ => {
debug!(
"Unrecognized isatty return code: {}; assuming {:?} is not a TTY",
ret, *self
);
false
}
}
}
fn get_attributes(&self) -> IoResult<Self::Attributes> {
TerminalAttributes::new(self.to_fd())
}
fn set_attributes(&mut self, attributes: &Self::Attributes) -> IoResult<()> {
attributes.apply(self.to_fd())?;
let applied = Self::Attributes::new(self.to_fd())?;
if applied != *attributes {
return Err(io::Error::new(
io::ErrorKind::Other,
"tcsetattr did not fully apply the requested attributes",
));
}
Ok(())
}
fn as_reader(&self) -> Option<Box<dyn Read>> {
match *self {
Stream::Stdin => Some(Box::new(io::stdin())),
_ => None,
}
}
fn as_writer(&self) -> Option<Box<dyn Write>> {
match *self {
Stream::Stdout => Some(Box::new(io::stdout())),
Stream::Stderr => Some(Box::new(io::stderr())),
_ => None,
}
}
}
struct DisableEcho<'s, S: AbstractStream> {
stream: &'s mut S,
initial_attributes: S::Attributes,
}
impl<'s, S: AbstractStream> DisableEcho<'s, S> {
fn new(stream: &'s mut S) -> Result<Self> {
let initial_attributes = stream.get_attributes()?;
debug!("Initial stream attributes: {:#?}", initial_attributes);
let mut attributes = stream.get_attributes()?;
attributes.disable(TerminalFlag::Echo);
attributes.enable(TerminalFlag::EchoNewlines);
debug!("Setting attributes to: {:#?}", attributes);
stream.set_attributes(&attributes)?;
Ok(DisableEcho {
stream,
initial_attributes,
})
}
}
impl<'s, S: AbstractStream> Drop for DisableEcho<'s, S> {
fn drop(&mut self) {
if let Err(e) = self.stream.set_attributes(&self.initial_attributes) {
error!("failed to restore terminal attributes: {}", e);
}
}
}
fn require_isatty<S: AbstractStream>(s: &mut S) -> Result<()> {
if !s.isatty() {
Err(Error::Precondition(
"cannot prompt interactively when the I/O streams are not TTYs".to_string(),
))
} else {
Ok(())
}
}
fn build_input_reader<IS: AbstractStream>(
input_stream: &mut IS,
) -> Result<io::BufReader<Box<dyn Read>>> {
require_isatty(input_stream)?;
Ok(io::BufReader::new(match input_stream.as_reader() {
None => {
return Err(Error::Precondition(
"the given input stream must support `Read`".to_string(),
));
}
Some(r) => r,
}))
}
fn remove_newline(mut s: String) -> Result<String> {
if !s.ends_with('\n') {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected end of input").into());
}
s.pop();
if s.ends_with('\r') {
s.pop();
}
Ok(s)
}
fn prompt_for_string_impl<IS: AbstractStream, OS: AbstractStream>(
input_stream: &mut IS,
input_reader: &mut io::BufReader<Box<dyn Read>>,
output_stream: &mut OS,
prompt: &str,
is_sensitive: bool,
) -> Result<String> {
use io::BufRead;
require_isatty(output_stream)?;
let mut writer = match output_stream.as_writer() {
None => {
return Err(Error::Precondition(
"the given output stream must support `Write`".to_string(),
));
}
Some(w) => w,
};
write!(writer, "{}", prompt)?;
writer.flush()?;
Ok({
let _disable_echo = if is_sensitive {
Some(DisableEcho::new(input_stream)?)
} else {
None
};
let mut ret = String::new();
input_reader.read_line(&mut ret)?;
remove_newline(ret)?
})
}
pub fn prompt_for_string<IS: AbstractStream, OS: AbstractStream>(
mut input_stream: IS,
mut output_stream: OS,
prompt: &str,
is_sensitive: bool,
) -> Result<String> {
let mut input_reader = build_input_reader(&mut input_stream)?;
prompt_for_string_impl(
&mut input_stream,
&mut input_reader,
&mut output_stream,
prompt,
is_sensitive,
)
}
fn prompt_for_string_confirm_impl<IS: AbstractStream, OS: AbstractStream>(
input_stream: &mut IS,
input_reader: &mut io::BufReader<Box<dyn Read>>,
output_stream: &mut OS,
prompt: &str,
is_sensitive: bool,
) -> Result<String> {
loop {
let string = prompt_for_string_impl(
input_stream,
input_reader,
output_stream,
prompt,
is_sensitive,
)?;
if string
== prompt_for_string_impl(
input_stream,
input_reader,
output_stream,
"Confirm: ",
is_sensitive,
)?
{
return Ok(string);
}
if let Some(mut writer) = output_stream.as_writer() {
let _ = writeln!(writer, "Entries did not match; please try again.");
let _ = writer.flush();
}
}
}
pub fn prompt_for_string_confirm<IS: AbstractStream, OS: AbstractStream>(
mut input_stream: IS,
mut output_stream: OS,
prompt: &str,
is_sensitive: bool,
) -> Result<String> {
let mut input_reader = build_input_reader(&mut input_stream)?;
prompt_for_string_confirm_impl(
&mut input_stream,
&mut input_reader,
&mut output_stream,
prompt,
is_sensitive,
)
}
pub struct MaybePromptedString {
value: String,
was_provided: bool,
}
impl MaybePromptedString {
pub fn new<IS: AbstractStream, OS: AbstractStream>(
provided: Option<&str>,
mut input_stream: IS,
mut output_stream: OS,
prompt: &str,
is_sensitive: bool,
confirm: bool,
) -> Result<Self> {
let (value, was_provided) = match provided {
Some(s) => (s.to_owned(), true),
None => {
let mut input_reader = build_input_reader(&mut input_stream)?;
let prompted = if confirm {
prompt_for_string_confirm_impl(
&mut input_stream,
&mut input_reader,
&mut output_stream,
prompt,
is_sensitive,
)?
} else {
prompt_for_string_impl(
&mut input_stream,
&mut input_reader,
&mut output_stream,
prompt,
is_sensitive,
)?
};
(prompted, false)
}
};
Ok(MaybePromptedString {
value,
was_provided,
})
}
pub fn was_provided(&self) -> bool {
self.was_provided
}
pub fn into_inner(self) -> String {
self.value
}
}
pub fn continue_confirmation<IS: AbstractStream, OS: AbstractStream>(
mut input_stream: IS,
mut output_stream: OS,
description: &str,
) -> Result<bool> {
let mut input_reader = build_input_reader(&mut input_stream)?;
let prompt = format!("{}Continue? [Yes/No] ", description);
loop {
let original_response = prompt_for_string_impl(
&mut input_stream,
&mut input_reader,
&mut output_stream,
prompt.as_str(),
false,
)?;
let response = original_response.trim().to_lowercase();
if response == "y" || response == "yes" {
return Ok(true);
} else if response == "n" || response == "no" {
return Ok(false);
} else {
let mut writer = match output_stream.as_writer() {
None => {
return Err(Error::Precondition(
"the given output stream must support `Write`".to_string(),
));
}
Some(w) => w,
};
let sanitized: String = original_response.escape_debug().collect();
writeln!(writer, "Invalid response '{}'.", sanitized)?;
writer.flush()?;
}
}
}