1use std::io::{self, BufRead, Write};
2use std::str::FromStr;
3
4#[derive(Debug)]
6pub enum Error {
7 Io(io::Error),
8 Parse {
10 ty: &'static str,
11 cause: String,
12 },
13 EmptyNotAllowed,
15 RetriesExceeded,
17 Validation(String),
19}
20
21impl std::fmt::Display for Error {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Error::Io(e) => write!(f, "I/O error: {e}"),
25 Error::Parse { ty, cause } => write!(f, "Failed to parse as {ty}: {cause}"),
26 Error::EmptyNotAllowed => write!(f, "Empty input (no default provided)"),
27 Error::RetriesExceeded => write!(f, "Maximum retry attempts exceeded"),
28 Error::Validation(msg) => write!(f, "Validation failed: {msg}"),
29 }
30 }
31}
32
33impl std::error::Error for Error {}
34
35impl From<io::Error> for Error {
36 fn from(e: io::Error) -> Self {
37 Error::Io(e)
38 }
39}
40
41pub struct Prompt<'a> {
43 message: &'a str,
44 default_str: Option<String>,
45 retries: usize,
46 trim_input: bool,
47}
48
49impl<'a> Prompt<'a> {
50 pub fn new(message: &'a str) -> Self {
52 Self {
53 message,
54 default_str: None,
55 retries: 0,
56 trim_input: true,
57 }
58 }
59
60 pub fn default(mut self, default: &str) -> Self {
63 self.default_str = Some(default.to_string());
64 self
65 }
66
67 pub fn retries(mut self, retries: usize) -> Self {
69 self.retries = retries;
70 self
71 }
72
73 pub fn trim(mut self, yes: bool) -> Self {
75 self.trim_input = yes;
76 self
77 }
78
79 pub fn to<T>(self) -> TypedPrompt<'a, T>
81 where
82 T: FromStr,
83 T::Err: std::fmt::Display,
84 {
85 TypedPrompt {
86 base: self,
87 default_val: None,
88 validator: None,
89 validation_msg: None,
90 }
91 }
92
93 pub fn get<T>(&self) -> Result<T, Error>
95 where
96 T: FromStr,
97 T::Err: std::fmt::Display,
98 {
99 let stdin = io::stdin();
100 let mut lock = stdin.lock();
101 let mut stdout = io::stdout();
102 self.get_with(&mut lock, &mut stdout)
103 }
104
105 pub fn get_with<T, R, W>(&self, reader: &mut R, writer: &mut W) -> Result<T, Error>
107 where
108 T: FromStr,
109 T::Err: std::fmt::Display,
110 R: BufRead,
111 W: Write,
112 {
113 let mut attempts_left = self.retries + 1;
114
115 loop {
116 {
118 let mut msg = String::new();
119 msg.push_str(self.message);
120 if let Some(def) = &self.default_str {
121 if !self.message.contains('[') && !self.message.contains("(default") {
122 use std::fmt::Write as _;
123 let _ = write!(msg, "[default: {}] ", def);
124 } else {
125 msg.push(' ');
126 }
127 }
128 writer.write_all(msg.as_bytes())?;
129 writer.flush()?;
130 }
131
132 let mut line = String::new();
133 let bytes = reader.read_line(&mut line)?;
134 if bytes == 0 {
135 if let Some(def) = &self.default_str {
137 return parse_as::<T>(def);
138 }
139 return Err(Error::EmptyNotAllowed);
140 }
141
142 let mut s = line;
143 if self.trim_input {
144 s = s.trim().to_string();
145 }
146
147 if s.is_empty() {
149 if let Some(def) = &self.default_str {
150 match parse_as::<T>(def) {
151 Ok(val) => return Ok(val),
152 Err(e) => return Err(e), }
154 } else {
155 if self.retries == 0 {
156 return Err(Error::EmptyNotAllowed);
157 }
158 attempts_left -= 1;
159 if attempts_left == 0 {
160 return Err(Error::RetriesExceeded);
161 }
162 writeln!(writer, "Empty input. Please try again.")?;
163 continue;
164 }
165 }
166
167 match parse_as::<T>(&s) {
169 Ok(val) => return Ok(val),
170 Err(e) => {
171 attempts_left -= 1;
172 if attempts_left == 0 {
173 return Err(Error::RetriesExceeded);
174 }
175 writeln!(writer, "{e}")?;
176 continue;
177 }
178 }
179 }
180 }
181}
182
183pub struct TypedPrompt<'a, T>
185where
186 T: FromStr,
187 T::Err: std::fmt::Display,
188{
189 pub(crate) base: Prompt<'a>,
190 default_val: Option<T>,
191 validator: Option<Box<dyn Fn(&T) -> bool + 'static>>,
192 validation_msg: Option<String>,
193}
194
195impl<'a, T> TypedPrompt<'a, T>
196where
197 T: FromStr,
198 T::Err: std::fmt::Display,
199{
200 pub fn default_val(mut self, val: T) -> Self {
202 self.default_val = Some(val);
203 self
204 }
205
206 pub fn validate<F>(mut self, f: F) -> Self
208 where
209 F: Fn(&T) -> bool + 'static,
210 {
211 self.validator = Some(Box::new(f));
212 self
213 }
214
215 pub fn message(mut self, msg: &str) -> Self {
217 self.validation_msg = Some(msg.to_string());
218 self
219 }
220
221 pub fn retries(mut self, retries: usize) -> Self {
223 self.base.retries = retries;
224 self
225 }
226
227 pub fn trim(mut self, yes: bool) -> Self {
229 self.base.trim_input = yes;
230 self
231 }
232
233 pub fn get(self) -> Result<T, Error> {
235 let mut reader = io::stdin().lock();
236 let mut writer = io::stdout();
237 self.get_with_io(&mut reader, &mut writer)
238 }
239
240 pub fn get_with_io<R, W>(self, reader: &mut R, writer: &mut W) -> Result<T, Error>
241 where
242 R: BufRead,
243 W: Write,
244 {
245 let mut attempts_left = self.base.retries + 1;
246 let mut default_val = self.default_val;
247 let validator = self.validator;
248 let validation_msg = self.validation_msg;
249 let base = self.base;
250
251 loop {
252 {
254 let mut msg = String::new();
255 msg.push_str(base.message);
256 if default_val.is_some() {
257 if !base.message.contains('[') && !base.message.contains("(default") {
258 use std::fmt::Write as _;
259 let _ = write!(msg, "[default set] ");
260 } else {
261 msg.push(' ');
262 }
263 } else if let Some(def) = &base.default_str {
264 if !base.message.contains('[') && !base.message.contains("(default") {
265 use std::fmt::Write as _;
266 let _ = write!(msg, "[default: {}] ", def);
267 } else {
268 msg.push(' ');
269 }
270 }
271 writer.write_all(msg.as_bytes())?;
272 writer.flush()?;
273 }
274
275 let mut line = String::new();
276 let bytes = reader.read_line(&mut line)?;
277 if bytes == 0 {
278 if let Some(v) = default_val.take() {
279 return Ok(v);
280 } else if let Some(def) = &base.default_str {
281 return parse_as::<T>(def);
282 } else {
283 return Err(Error::EmptyNotAllowed);
284 }
285 }
286
287 let mut s = line;
288 if base.trim_input {
289 s = s.trim().to_string();
290 }
291
292 if s.is_empty() {
293 if let Some(v) = default_val.take() {
294 return Ok(v);
295 } else if let Some(def) = &base.default_str {
296 match parse_as::<T>(def) {
297 Ok(val) => return Ok(val),
298 Err(e) => return Err(e),
299 }
300 } else {
301 if base.retries == 0 {
302 return Err(Error::EmptyNotAllowed);
303 }
304 attempts_left -= 1;
305 if attempts_left == 0 {
306 return Err(Error::RetriesExceeded);
307 }
308 writeln!(writer, "Empty input. Please try again.")?;
309 continue;
310 }
311 }
312
313 let val = match parse_as::<T>(&s) {
315 Ok(v) => v,
316 Err(e) => {
317 attempts_left -= 1;
318 if attempts_left == 0 {
319 return Err(Error::RetriesExceeded);
320 }
321 writeln!(writer, "{e}")?;
322 continue;
323 }
324 };
325
326 if let Some(vf) = &validator {
328 if !vf(&val) {
329 let msg = validation_msg.clone().unwrap_or_else(|| "Invalid value".to_string());
330 attempts_left -= 1;
331 if attempts_left == 0 {
332 return Err(Error::Validation(msg));
333 }
334 writeln!(writer, "{msg}")?;
335 continue;
336 }
337 }
338
339 return Ok(val);
340 }
341 }
342}
343
344fn parse_as<T>(s: &str) -> Result<T, Error>
346where
347 T: FromStr,
348 T::Err: std::fmt::Display,
349{
350 T::from_str(s).map_err(|e| Error::Parse {
351 ty: std::any::type_name::<T>(),
352 cause: e.to_string(),
353 })
354}
355
356pub fn prompt(message: &str) -> Prompt<'_> {
358 Prompt::new(message)
359}