askit/
prompt.rs

1use std::io::{self, BufRead, Write};
2use std::str::FromStr;
3
4/// Errors that `askit` may return.
5#[derive(Debug)]
6pub enum Error {
7    Io(io::Error),
8    /// Failed to parse the input into the requested type.
9    Parse {
10        ty: &'static str,
11        cause: String,
12    },
13    /// Input was empty and no default was provided.
14    EmptyNotAllowed,
15    /// All retry attempts were exhausted.
16    RetriesExceeded,
17    /// Validation failed with user-defined message.
18    Validation(String),
19}
20
21impl std::fmt::Display for Error {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Error::Io(e) => write!(f, "I/O error: {e}"),
25            Error::Parse { ty, cause } => write!(f, "Failed to parse as {ty}: {cause}"),
26            Error::EmptyNotAllowed => write!(f, "Empty input (no default provided)"),
27            Error::RetriesExceeded => write!(f, "Maximum retry attempts exceeded"),
28            Error::Validation(msg) => write!(f, "Validation failed: {msg}"),
29        }
30    }
31}
32
33impl std::error::Error for Error {}
34
35impl From<io::Error> for Error {
36    fn from(e: io::Error) -> Self {
37        Error::Io(e)
38    }
39}
40
41/// Untyped builder for reading and parsing CLI input.
42pub struct Prompt<'a> {
43    message: &'a str,
44    default_str: Option<String>,
45    retries: usize,
46    trim_input: bool,
47}
48
49impl<'a> Prompt<'a> {
50    /// Create a new prompt with a message.
51    pub fn new(message: &'a str) -> Self {
52        Self {
53            message,
54            default_str: None,
55            retries: 0,
56            trim_input: true,
57        }
58    }
59
60    /// Provide a default value **as string**. If the user hits ENTER with empty input,
61    /// `default` will be used and parsed as the target type.
62    pub fn default(mut self, default: &str) -> Self {
63        self.default_str = Some(default.to_string());
64        self
65    }
66
67    /// Number of times to retry when parsing fails or input is empty w/o default.
68    pub fn retries(mut self, retries: usize) -> Self {
69        self.retries = retries;
70        self
71    }
72
73    /// Whether to trim whitespace (default: true).
74    pub fn trim(mut self, yes: bool) -> Self {
75        self.trim_input = yes;
76        self
77    }
78
79    /// Convert to a typed builder, enabling `.default_val()` and `.validate()`.
80    pub fn to<T>(self) -> TypedPrompt<'a, T>
81    where
82        T: FromStr,
83        T::Err: std::fmt::Display,
84    {
85        TypedPrompt {
86            base: self,
87            default_val: None,
88            validator: None,
89            validation_msg: None,
90        }
91    }
92
93    /// Read from **stdin**, parse and return the desired type.
94    pub fn get<T>(&self) -> Result<T, Error>
95    where
96        T: FromStr,
97        T::Err: std::fmt::Display,
98    {
99        let stdin = io::stdin();
100        let mut lock = stdin.lock();
101        let mut stdout = io::stdout();
102        self.get_with(&mut lock, &mut stdout)
103    }
104
105    /// Same as `get()`, but allows providing a custom **reader** and **writer**.
106    pub fn get_with<T, R, W>(&self, reader: &mut R, writer: &mut W) -> Result<T, Error>
107    where
108        T: FromStr,
109        T::Err: std::fmt::Display,
110        R: BufRead,
111        W: Write,
112    {
113        let mut attempts_left = self.retries + 1;
114
115        loop {
116            // Render message (with default hint, if any)
117            {
118                let mut msg = String::new();
119                msg.push_str(self.message);
120                if let Some(def) = &self.default_str {
121                    if !self.message.contains('[') && !self.message.contains("(default") {
122                        use std::fmt::Write as _;
123                        let _ = write!(msg, "[default: {}] ", def);
124                    } else {
125                        msg.push(' ');
126                    }
127                }
128                writer.write_all(msg.as_bytes())?;
129                writer.flush()?;
130            }
131
132            let mut line = String::new();
133            let bytes = reader.read_line(&mut line)?;
134            if bytes == 0 {
135                // EOF
136                if let Some(def) = &self.default_str {
137                    return parse_as::<T>(def);
138                }
139                return Err(Error::EmptyNotAllowed);
140            }
141
142            let mut s = line;
143            if self.trim_input {
144                s = s.trim().to_string();
145            }
146
147            // Empty input handling
148            if s.is_empty() {
149                if let Some(def) = &self.default_str {
150                    match parse_as::<T>(def) {
151                        Ok(val) => return Ok(val),
152                        Err(e) => return Err(e), // misconfigured default
153                    }
154                } else {
155                    if self.retries == 0 {
156                        return Err(Error::EmptyNotAllowed);
157                    }
158                    attempts_left -= 1;
159                    if attempts_left == 0 {
160                        return Err(Error::RetriesExceeded);
161                    }
162                    writeln!(writer, "Empty input. Please try again.")?;
163                    continue;
164                }
165            }
166
167            // Try parse
168            match parse_as::<T>(&s) {
169                Ok(val) => return Ok(val),
170                Err(e) => {
171                    attempts_left -= 1;
172                    if attempts_left == 0 {
173                        return Err(Error::RetriesExceeded);
174                    }
175                    writeln!(writer, "{e}")?;
176                    continue;
177                }
178            }
179        }
180    }
181}
182
183/// Typed builder for extras: `.default_val`, `.validate`, `.message`.
184pub struct TypedPrompt<'a, T>
185where
186    T: FromStr,
187    T::Err: std::fmt::Display,
188{
189    pub(crate) base: Prompt<'a>,
190    default_val: Option<T>,
191    validator: Option<Box<dyn Fn(&T) -> bool + 'static>>,
192    validation_msg: Option<String>,
193}
194
195impl<'a, T> TypedPrompt<'a, T>
196where
197    T: FromStr,
198    T::Err: std::fmt::Display,
199{
200    /// Provide a typed default (não precisa parsear).
201    pub fn default_val(mut self, val: T) -> Self {
202        self.default_val = Some(val);
203        self
204    }
205
206    /// function with validation `Fn(&T) -> bool`.
207    pub fn validate<F>(mut self, f: F) -> Self
208    where
209        F: Fn(&T) -> bool + 'static,
210    {
211        self.validator = Some(Box::new(f));
212        self
213    }
214
215    /// Message `.validate` return `false`.
216    pub fn message(mut self, msg: &str) -> Self {
217        self.validation_msg = Some(msg.to_string());
218        self
219    }
220
221    /// Number of times to retry (aplica no `Prompt` base).
222    pub fn retries(mut self, retries: usize) -> Self {
223        self.base.retries = retries;
224        self
225    }
226
227    /// Whether to trim whitespace (default: true).
228    pub fn trim(mut self, yes: bool) -> Self {
229        self.base.trim_input = yes;
230        self
231    }
232
233    /// read and return `T` using stdin/stdout.
234    pub fn get(self) -> Result<T, Error> {
235        let mut reader = io::stdin().lock();
236        let mut writer = io::stdout();
237        self.get_with_io(&mut reader, &mut writer)
238    }
239
240    pub fn get_with_io<R, W>(self, reader: &mut R, writer: &mut W) -> Result<T, Error>
241    where
242        R: BufRead,
243        W: Write,
244    {
245        let mut attempts_left = self.base.retries + 1;
246        let mut default_val = self.default_val;
247        let validator = self.validator;
248        let validation_msg = self.validation_msg;
249        let base = self.base;
250
251        loop {
252            // Render message (com dica de default)
253            {
254                let mut msg = String::new();
255                msg.push_str(base.message);
256                if default_val.is_some() {
257                    if !base.message.contains('[') && !base.message.contains("(default") {
258                        use std::fmt::Write as _;
259                        let _ = write!(msg, "[default set] ");
260                    } else {
261                        msg.push(' ');
262                    }
263                } else if let Some(def) = &base.default_str {
264                    if !base.message.contains('[') && !base.message.contains("(default") {
265                        use std::fmt::Write as _;
266                        let _ = write!(msg, "[default: {}] ", def);
267                    } else {
268                        msg.push(' ');
269                    }
270                }
271                writer.write_all(msg.as_bytes())?;
272                writer.flush()?;
273            }
274
275            let mut line = String::new();
276            let bytes = reader.read_line(&mut line)?;
277            if bytes == 0 {
278                if let Some(v) = default_val.take() {
279                    return Ok(v);
280                } else if let Some(def) = &base.default_str {
281                    return parse_as::<T>(def);
282                } else {
283                    return Err(Error::EmptyNotAllowed);
284                }
285            }
286
287            let mut s = line;
288            if base.trim_input {
289                s = s.trim().to_string();
290            }
291
292            if s.is_empty() {
293                if let Some(v) = default_val.take() {
294                    return Ok(v);
295                } else if let Some(def) = &base.default_str {
296                    match parse_as::<T>(def) {
297                        Ok(val) => return Ok(val),
298                        Err(e) => return Err(e),
299                    }
300                } else {
301                    if base.retries == 0 {
302                        return Err(Error::EmptyNotAllowed);
303                    }
304                    attempts_left -= 1;
305                    if attempts_left == 0 {
306                        return Err(Error::RetriesExceeded);
307                    }
308                    writeln!(writer, "Empty input. Please try again.")?;
309                    continue;
310                }
311            }
312
313            // Parse
314            let val = match parse_as::<T>(&s) {
315                Ok(v) => v,
316                Err(e) => {
317                    attempts_left -= 1;
318                    if attempts_left == 0 {
319                        return Err(Error::RetriesExceeded);
320                    }
321                    writeln!(writer, "{e}")?;
322                    continue;
323                }
324            };
325
326            // Validação
327            if let Some(vf) = &validator {
328                if !vf(&val) {
329                    let msg = validation_msg.clone().unwrap_or_else(|| "Invalid value".to_string());
330                    attempts_left -= 1;
331                    if attempts_left == 0 {
332                        return Err(Error::Validation(msg));
333                    }
334                    writeln!(writer, "{msg}")?;
335                    continue;
336                }
337            }
338
339            return Ok(val);
340        }
341    }
342}
343
344/// parse helper
345fn parse_as<T>(s: &str) -> Result<T, Error>
346where
347    T: FromStr,
348    T::Err: std::fmt::Display,
349{
350    T::from_str(s).map_err(|e| Error::Parse {
351        ty: std::any::type_name::<T>(),
352        cause: e.to_string(),
353    })
354}
355
356/// Entry-point function to create a `Prompt`.
357pub fn prompt(message: &str) -> Prompt<'_> {
358    Prompt::new(message)
359}