ask_cli/
lib.rs

1#![feature(read_buf)]
2#![feature(core_io_borrowed_buf)]
3
4use std::{
5    io,
6    io::{BorrowedBuf, Read, Write},
7    mem,
8    mem::MaybeUninit,
9    process::{ExitCode, Termination},
10};
11
12/// The answer to a question posed in [ask].
13#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
14pub enum Answer {
15    Yes,
16    No,
17    Unknown,
18}
19
20impl Termination for Answer {
21    fn report(self) -> ExitCode {
22        match self {
23            Self::Yes => ExitCode::SUCCESS,
24            Self::No => ExitCode::FAILURE,
25            Self::Unknown => ExitCode::from(2),
26        }
27    }
28}
29
30enum State {
31    Start,
32    Ask {
33        /// Passthrough for [State::Read].
34        pending_crlf: bool,
35    },
36    /// Continuously reads from stdin until encountering a newline,
37    /// returning the index of its first byte.
38    ///
39    /// This state deals with a number of edge cases:
40    /// - If stdin reaches the EOF, exit the process *after* processing all
41    ///   remaining input.
42    /// - If a \r byte is seen on the border of the buffer, fail the input
43    ///   (because it will be at a higher index than the max length of all
44    ///   possible valid replies and therefore cannot be a valid reply) and
45    ///   consume a \n if it is the next byte. Note that this consumption cannot
46    ///   happen before printing the question again or we might get blocked on
47    ///   stdin if there are no more bytes available.
48    /// - If no newline was found within the buffer, fail the reply since it
49    ///   cannot possibly be valid.
50    Read {
51        /// The reply is known to be invalid, but we have not yet seen a
52        /// newline.
53        failed: bool,
54        /// A CRLF might be striding the buffer.
55        pending_crlf: bool,
56    },
57    HandleReply {
58        newline_index: usize,
59    },
60}
61
62/// Ask the user a yes or no question on stdout, reading the reply from stdin.
63///
64/// Replies are delimited by newlines of any kind and must be one of '' (maps to
65/// yes), 'y', 'yes', 'n', 'no', case-insensitive. If the reply fails to parse,
66/// the question will be asked again ad infinitum.
67///
68/// # Examples
69///
70/// ```
71/// # use std::{io, str::from_utf8};
72/// use ask_cli::{ask, Answer};
73///
74/// assert!(matches!(
75///     ask(
76///         "Continue? [Y/n] ",
77///         Answer::Yes,
78///         &mut "y\n".as_bytes(),
79///         &mut io::sink()
80///     ),
81///     Ok(Answer::Yes)
82/// ));
83/// assert!(matches!(
84///     ask(
85///         "Continue? [Y/n] ",
86///         Answer::Yes,
87///         &mut "n\n".as_bytes(),
88///         &mut io::sink()
89///     ),
90///     Ok(Answer::No)
91/// ));
92/// assert!(matches!(
93///     ask(
94///         "Continue? [Y/n] ",
95///         Answer::Yes,
96///         &mut "".as_bytes(),
97///         &mut io::sink()
98///     ),
99///     Ok(Answer::Unknown)
100/// ));
101/// assert!(matches!(
102///     ask(
103///         "Continue? [y/N] ",
104///         Answer::No,
105///         &mut "\n".as_bytes(),
106///         &mut io::sink()
107///     ),
108///     Ok(Answer::No)
109/// ));
110///
111/// // Here we use 3 different kinds of line endings
112/// let mut stdout = Vec::new();
113/// let answer = ask(
114///     "Continue? [Y/n] ",
115///     Answer::Yes,
116///     &mut "a\nb\rc\r\nyes\n".as_bytes(),
117///     &mut stdout,
118/// )
119/// .unwrap();
120/// assert_eq!(
121///     "Continue? [Y/n] Continue? [Y/n] Continue? [Y/n] Continue? [Y/n] ",
122///     from_utf8(&stdout).unwrap()
123/// );
124/// assert!(matches!(answer, Answer::Yes));
125/// ```
126///
127/// # Errors
128///
129/// Underlying I/O errors are bubbled up.
130pub fn ask<Q: AsRef<[u8]>, In: Read, Out: Write>(
131    question: Q,
132    default: Answer,
133    stdin: &mut In,
134    stdout: &mut Out,
135) -> Result<Answer, io::Error> {
136    // max_len(yes, no, y, n) = 3 -> 3 + 2 bytes for new lines
137    const BUF_LEN: usize = 5;
138
139    let (mut buf, mut buf2) = (
140        [MaybeUninit::uninit(); BUF_LEN],
141        [MaybeUninit::uninit(); BUF_LEN],
142    );
143    let (mut buf, mut buf2) = (
144        BorrowedBuf::from(buf.as_mut()),
145        BorrowedBuf::from(buf2.as_mut()),
146    );
147
148    macro_rules! consume_bytes {
149        ($count:expr) => {
150            buf2.clear();
151            buf2.unfilled().append(&buf.filled()[$count..]);
152            mem::swap(&mut buf, &mut buf2);
153        };
154    }
155
156    macro_rules! consume_newline {
157        ($newline_index:expr) => {
158            let newline_index = $newline_index;
159            let is_crlf = buf.filled()[newline_index] == b'\r'
160                && matches!(buf.filled().get(newline_index + 1), Some(b'\n'));
161            let skip = if is_crlf { 2 } else { 1 };
162            consume_bytes!(newline_index + skip);
163        };
164    }
165
166    let mut state = State::Start;
167    loop {
168        state = match state {
169            State::Start => State::Ask {
170                pending_crlf: false,
171            },
172            State::Ask { pending_crlf } => {
173                stdout.write_all(question.as_ref())?;
174                stdout.flush()?;
175                State::Read {
176                    failed: false,
177                    pending_crlf,
178                }
179            }
180            State::Read {
181                failed,
182                pending_crlf,
183            } => {
184                debug_assert!(buf.len() < buf.capacity());
185
186                let prev_count = buf.len();
187                stdin.read_buf(buf.unfilled())?;
188
189                if pending_crlf && matches!(buf.filled().first(), Some(b'\n')) {
190                    consume_bytes!(1);
191                }
192
193                match buf.filled().iter().position(|&b| b == b'\n' || b == b'\r') {
194                    Some(newline_index) if newline_index == BUF_LEN - 1 => {
195                        let pending_crlf = buf.filled()[newline_index] == b'\r';
196                        buf.clear();
197                        State::Ask { pending_crlf }
198                    }
199                    Some(newline_index) if failed => {
200                        consume_newline!(newline_index);
201                        State::Ask {
202                            pending_crlf: false,
203                        }
204                    }
205                    Some(newline_index) => State::HandleReply { newline_index },
206                    None if buf.len() == buf.capacity() => {
207                        buf.clear();
208                        State::Read {
209                            failed: true,
210                            pending_crlf: false,
211                        }
212                    }
213                    None if !pending_crlf && buf.len() == prev_count => {
214                        // Reached EOF
215                        return Ok(Answer::Unknown);
216                    }
217                    None => State::Read {
218                        failed,
219                        pending_crlf: false,
220                    },
221                }
222            }
223            State::HandleReply { newline_index } => {
224                let reply = &mut buf.filled_mut()[..newline_index];
225                reply.make_ascii_lowercase();
226                match &*reply {
227                    b"" => return Ok(default),
228                    b"y" | b"yes" => return Ok(Answer::Yes),
229                    b"n" | b"no" => return Ok(Answer::No),
230                    _ => {
231                        consume_newline!(newline_index);
232                        State::Ask {
233                            pending_crlf: false,
234                        }
235                    }
236                }
237            }
238        }
239    }
240}
241
242#[cfg(kani)]
243#[kani::proof]
244#[kani::unwind(9)]
245fn ask_proof() {
246    let input: [u8; 4] = kani::any();
247    let output = ask("?", &mut input.as_slice(), &mut io::sink());
248
249    output.unwrap();
250}