use std::fmt::Arguments;
use std::io::{self, BufRead, Write};
use std::str::FromStr;
#[derive(Debug)]
pub enum InputError<E> {
Io(io::Error),
Parse(E),
Eof,
}
impl<E: std::fmt::Display + std::fmt::Debug> std::fmt::Display for InputError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InputError::Io(e) => write!(f, "I/O error: {}", e),
InputError::Parse(e) => write!(f, "Parse error: {}", e),
InputError::Eof => write!(f, "EOF encountered"),
}
}
}
impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for InputError<E> {}
pub fn read_input_from<R, T>(
reader: &mut R,
prompt: Option<Arguments<'_>>,
) -> Result<T, InputError<T::Err>>
where
R: BufRead,
T: FromStr,
T::Err: std::fmt::Display + std::fmt::Debug,
{
if let Some(prompt_args) = prompt {
print!("{}", prompt_args);
io::stdout().flush().map_err(InputError::Io)?;
}
let mut input = String::new();
let bytes_read = reader.read_line(&mut input).map_err(InputError::Io)?;
if bytes_read == 0 {
return Err(InputError::Eof);
}
let trimmed = input.trim_end_matches(['\r', '\n'].as_ref());
trimmed.parse::<T>().map_err(InputError::Parse)
}
pub fn read_input<T>() -> Result<T, InputError<T::Err>>
where
T: FromStr,
T::Err: std::fmt::Display + std::fmt::Debug,
{
let stdin = io::stdin();
let mut locked = stdin.lock();
read_input_from(&mut locked, None)
}
pub fn read_input_with_prompt<T>(prompt: Arguments<'_>) -> Result<T, InputError<T::Err>>
where
T: FromStr,
T::Err: std::fmt::Display + std::fmt::Debug,
{
let stdin = io::stdin();
let mut locked = stdin.lock();
read_input_from(&mut locked, Some(prompt))
}
#[macro_export]
macro_rules! input {
() => {{
match $crate::read_input_from(&mut ::std::io::stdin().lock(), None) {
Ok(val) => Ok(Some(val)),
Err($crate::InputError::Eof) => Ok(None),
Err(err) => Err(err),
}
}};
($($arg:tt)*) => {{
match $crate::read_input_from(
&mut ::std::io::stdin().lock(),
Some(format_args!($($arg)*))
) {
Ok(val) => Ok(Some(val)),
Err($crate::InputError::Eof) => Ok(None),
Err(err) => Err(err),
}
}};
}
#[macro_export]
macro_rules! inputln {
() => {{
match $crate::read_input_from(&mut ::std::io::stdin().lock(), None) {
Ok(val) => Ok(Some(val)),
Err($crate::InputError::Eof) => Ok(None),
Err(err) => Err(err),
}
}};
($($arg:tt)*) => {{
println!("{}", format_args!($($arg)*));
::std::io::Write::flush(&mut ::std::io::stdout()).unwrap();
match $crate::read_input_from(&mut ::std::io::stdin().lock(), None) {
Ok(val) => Ok(Some(val)),
Err($crate::InputError::Eof) => Ok(None),
Err(err) => Err(err),
}
}};
}
#[macro_export]
macro_rules! input_no_eof {
() => {{
$crate::read_input_from(&mut ::std::io::stdin().lock(), None)
}};
($($arg:tt)*) => {{
$crate::read_input_from(
&mut ::std::io::stdin().lock(),
Some(format_args!($($arg)*))
)
}};
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Cursor, Error, ErrorKind};
#[test]
fn test_read_input_integer() {
let mut reader = Cursor::new("42\n");
let res: Result<i32, _> = read_input_from(&mut reader, None);
assert_eq!(res.unwrap(), 42);
}
#[test]
fn test_read_input_float() {
let mut reader = Cursor::new("3.14159\n");
let res: Result<f64, _> = read_input_from(&mut reader, None);
assert!((res.unwrap() - 3.14159).abs() < f64::EPSILON);
}
#[test]
fn test_read_input_unsigned() {
let mut reader = Cursor::new("255\n");
let res: Result<u32, _> = read_input_from(&mut reader, None);
assert_eq!(res.unwrap(), 255);
}
#[test]
fn test_read_input_eof() {
let mut reader = Cursor::new("");
let res: Result<i32, _> = read_input_from(&mut reader, None);
assert!(matches!(res, Err(InputError::Eof)));
}
#[test]
fn test_read_input_parse_error() {
let mut reader = Cursor::new("not an int\n");
let res: Result<i32, _> = read_input_from(&mut reader, None);
assert!(matches!(res, Err(InputError::Parse(_))));
}
#[test]
fn test_read_input_string() {
let mut reader = Cursor::new("hello world\r\n");
let res: Result<String, _> = read_input_from(&mut reader, None);
assert_eq!(res.unwrap(), "hello world");
}
#[test]
fn test_with_prompt() {
let mut reader = Cursor::new("100\n");
let prompt = format_args!("Enter: ");
let res: Result<i32, _> = read_input_from(&mut reader, Some(prompt));
assert_eq!(res.unwrap(), 100);
}
#[test]
fn test_multiple_lines_valid() {
let mut reader = Cursor::new("123\n456\n");
let first: i32 = read_input_from(&mut reader, None).unwrap();
assert_eq!(first, 123);
let second: i32 = read_input_from(&mut reader, None).unwrap();
assert_eq!(second, 456);
}
#[test]
fn test_multiple_lines_parse_error_then_eof() {
let mut reader = Cursor::new("42\nnotanint\n");
let first: i32 = read_input_from(&mut reader, None).unwrap();
assert_eq!(first, 42);
let second = read_input_from::<_, i32>(&mut reader, None);
assert!(matches!(second, Err(InputError::Parse(_))));
let third = read_input_from::<_, i32>(&mut reader, None);
assert!(matches!(third, Err(InputError::Eof)));
}
#[test]
fn test_empty_line_behavior() {
let mut reader = Cursor::new("\n");
let res: Result<i32, _> = read_input_from(&mut reader, None);
assert!(matches!(res, Err(InputError::Parse(_))));
}
#[test]
fn test_input_macro() {
let mut reader = Cursor::new("HelloFromMacro\n");
let result: Result<String, _> = read_input_from(&mut reader, None);
assert_eq!(result.unwrap(), "HelloFromMacro");
}
#[test]
fn test_io_error() {
struct ErrorReader;
impl BufRead for ErrorReader {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
Err(Error::new(ErrorKind::Other, "Simulated I/O failure"))
}
fn consume(&mut self, _amt: usize) {}
}
impl std::io::Read for ErrorReader {
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
Err(Error::new(ErrorKind::Other, "Simulated I/O failure"))
}
}
let mut reader = ErrorReader;
let res: Result<String, _> = read_input_from(&mut reader, None);
assert!(matches!(res, Err(InputError::Io(_))));
}
}