ask_cli/lib.rs
1#![feature(read_buf)]
2#![feature(core_io_borrowed_buf)]
3
4use std::{
5 io,
6 io::{BorrowedBuf, Read, Write},
7 mem,
8 mem::MaybeUninit,
9 process::{ExitCode, Termination},
10};
11
12/// The answer to a question posed in [ask].
13#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
14pub enum Answer {
15 Yes,
16 No,
17 Unknown,
18}
19
20impl Termination for Answer {
21 fn report(self) -> ExitCode {
22 match self {
23 Self::Yes => ExitCode::SUCCESS,
24 Self::No => ExitCode::FAILURE,
25 Self::Unknown => ExitCode::from(2),
26 }
27 }
28}
29
30enum State {
31 Start,
32 Ask {
33 /// Passthrough for [State::Read].
34 pending_crlf: bool,
35 },
36 /// Continuously reads from stdin until encountering a newline,
37 /// returning the index of its first byte.
38 ///
39 /// This state deals with a number of edge cases:
40 /// - If stdin reaches the EOF, exit the process *after* processing all
41 /// remaining input.
42 /// - If a \r byte is seen on the border of the buffer, fail the input
43 /// (because it will be at a higher index than the max length of all
44 /// possible valid replies and therefore cannot be a valid reply) and
45 /// consume a \n if it is the next byte. Note that this consumption cannot
46 /// happen before printing the question again or we might get blocked on
47 /// stdin if there are no more bytes available.
48 /// - If no newline was found within the buffer, fail the reply since it
49 /// cannot possibly be valid.
50 Read {
51 /// The reply is known to be invalid, but we have not yet seen a
52 /// newline.
53 failed: bool,
54 /// A CRLF might be striding the buffer.
55 pending_crlf: bool,
56 },
57 HandleReply {
58 newline_index: usize,
59 },
60}
61
62/// Ask the user a yes or no question on stdout, reading the reply from stdin.
63///
64/// Replies are delimited by newlines of any kind and must be one of '' (maps to
65/// yes), 'y', 'yes', 'n', 'no', case-insensitive. If the reply fails to parse,
66/// the question will be asked again ad infinitum.
67///
68/// # Examples
69///
70/// ```
71/// # use std::{io, str::from_utf8};
72/// use ask_cli::{ask, Answer};
73///
74/// assert!(matches!(
75/// ask(
76/// "Continue? [Y/n] ",
77/// Answer::Yes,
78/// &mut "y\n".as_bytes(),
79/// &mut io::sink()
80/// ),
81/// Ok(Answer::Yes)
82/// ));
83/// assert!(matches!(
84/// ask(
85/// "Continue? [Y/n] ",
86/// Answer::Yes,
87/// &mut "n\n".as_bytes(),
88/// &mut io::sink()
89/// ),
90/// Ok(Answer::No)
91/// ));
92/// assert!(matches!(
93/// ask(
94/// "Continue? [Y/n] ",
95/// Answer::Yes,
96/// &mut "".as_bytes(),
97/// &mut io::sink()
98/// ),
99/// Ok(Answer::Unknown)
100/// ));
101/// assert!(matches!(
102/// ask(
103/// "Continue? [y/N] ",
104/// Answer::No,
105/// &mut "\n".as_bytes(),
106/// &mut io::sink()
107/// ),
108/// Ok(Answer::No)
109/// ));
110///
111/// // Here we use 3 different kinds of line endings
112/// let mut stdout = Vec::new();
113/// let answer = ask(
114/// "Continue? [Y/n] ",
115/// Answer::Yes,
116/// &mut "a\nb\rc\r\nyes\n".as_bytes(),
117/// &mut stdout,
118/// )
119/// .unwrap();
120/// assert_eq!(
121/// "Continue? [Y/n] Continue? [Y/n] Continue? [Y/n] Continue? [Y/n] ",
122/// from_utf8(&stdout).unwrap()
123/// );
124/// assert!(matches!(answer, Answer::Yes));
125/// ```
126///
127/// # Errors
128///
129/// Underlying I/O errors are bubbled up.
130pub fn ask<Q: AsRef<[u8]>, In: Read, Out: Write>(
131 question: Q,
132 default: Answer,
133 stdin: &mut In,
134 stdout: &mut Out,
135) -> Result<Answer, io::Error> {
136 // max_len(yes, no, y, n) = 3 -> 3 + 2 bytes for new lines
137 const BUF_LEN: usize = 5;
138
139 let (mut buf, mut buf2) = (
140 [MaybeUninit::uninit(); BUF_LEN],
141 [MaybeUninit::uninit(); BUF_LEN],
142 );
143 let (mut buf, mut buf2) = (
144 BorrowedBuf::from(buf.as_mut()),
145 BorrowedBuf::from(buf2.as_mut()),
146 );
147
148 macro_rules! consume_bytes {
149 ($count:expr) => {
150 buf2.clear();
151 buf2.unfilled().append(&buf.filled()[$count..]);
152 mem::swap(&mut buf, &mut buf2);
153 };
154 }
155
156 macro_rules! consume_newline {
157 ($newline_index:expr) => {
158 let newline_index = $newline_index;
159 let is_crlf = buf.filled()[newline_index] == b'\r'
160 && matches!(buf.filled().get(newline_index + 1), Some(b'\n'));
161 let skip = if is_crlf { 2 } else { 1 };
162 consume_bytes!(newline_index + skip);
163 };
164 }
165
166 let mut state = State::Start;
167 loop {
168 state = match state {
169 State::Start => State::Ask {
170 pending_crlf: false,
171 },
172 State::Ask { pending_crlf } => {
173 stdout.write_all(question.as_ref())?;
174 stdout.flush()?;
175 State::Read {
176 failed: false,
177 pending_crlf,
178 }
179 }
180 State::Read {
181 failed,
182 pending_crlf,
183 } => {
184 debug_assert!(buf.len() < buf.capacity());
185
186 let prev_count = buf.len();
187 stdin.read_buf(buf.unfilled())?;
188
189 if pending_crlf && matches!(buf.filled().first(), Some(b'\n')) {
190 consume_bytes!(1);
191 }
192
193 match buf.filled().iter().position(|&b| b == b'\n' || b == b'\r') {
194 Some(newline_index) if newline_index == BUF_LEN - 1 => {
195 let pending_crlf = buf.filled()[newline_index] == b'\r';
196 buf.clear();
197 State::Ask { pending_crlf }
198 }
199 Some(newline_index) if failed => {
200 consume_newline!(newline_index);
201 State::Ask {
202 pending_crlf: false,
203 }
204 }
205 Some(newline_index) => State::HandleReply { newline_index },
206 None if buf.len() == buf.capacity() => {
207 buf.clear();
208 State::Read {
209 failed: true,
210 pending_crlf: false,
211 }
212 }
213 None if !pending_crlf && buf.len() == prev_count => {
214 // Reached EOF
215 return Ok(Answer::Unknown);
216 }
217 None => State::Read {
218 failed,
219 pending_crlf: false,
220 },
221 }
222 }
223 State::HandleReply { newline_index } => {
224 let reply = &mut buf.filled_mut()[..newline_index];
225 reply.make_ascii_lowercase();
226 match &*reply {
227 b"" => return Ok(default),
228 b"y" | b"yes" => return Ok(Answer::Yes),
229 b"n" | b"no" => return Ok(Answer::No),
230 _ => {
231 consume_newline!(newline_index);
232 State::Ask {
233 pending_crlf: false,
234 }
235 }
236 }
237 }
238 }
239 }
240}
241
242#[cfg(kani)]
243#[kani::proof]
244#[kani::unwind(9)]
245fn ask_proof() {
246 let input: [u8; 4] = kani::any();
247 let output = ask("?", &mut input.as_slice(), &mut io::sink());
248
249 output.unwrap();
250}