ralgo_test_util/
lib.rs

1#![cfg_attr(
2    rust_comp_feature = "unstable_features",
3    feature(internal_output_capture)
4)]
5
6use std::{
7    io::{Read as _, Write as _},
8    process::Termination,
9};
10
11pub type Result<T, E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
12
13static MARKER: &[u8] = &[
14    255, 77, 65, 82, 75, 69, 82, 95, 70, 79, 82, 95, 82, 69, 68, 73, 82, 69, 67, 84, 95, 83, 84,
15    68, 79, 85, 84,
16];
17pub fn redirect_stdout(f: impl FnOnce()) -> Result<Vec<u8>> {
18    let mut lock = std::io::stdout().lock();
19    lock.flush()?;
20
21    let stdout = tempfile::NamedTempFile::new()?;
22
23    let guard = stdio_override::StdoutOverride::override_file(&stdout)?;
24    #[cfg(rust_comp_feature = "unstable_features")]
25    let old_capture = std::io::set_output_capture(None);
26    std::io::stdout().write_all(MARKER)?;
27    std::io::stdout().flush()?;
28    f();
29    std::io::stdout().flush()?;
30    #[cfg(rust_comp_feature = "unstable_features")]
31    std::io::set_output_capture(old_capture);
32    drop(guard);
33
34    let mut res = Vec::new();
35    stdout.as_file().read_to_end(&mut res)?;
36    if let Some(ret) = res.strip_prefix(MARKER) {
37        return Ok(ret.to_vec());
38    }
39    eprintln!("WARNING: Couldn't read stdout. Please run tests with --nocapture: cargo test -- --nocapture");
40    Ok(b"Use `cargo test -- --nocapture` if you care the stdout".to_vec())
41}
42pub fn redirect_stdin<T>(input: impl AsRef<[u8]>, f: impl FnOnce() -> T) -> Result<T> {
43    let mut lock = std::io::stdout().lock();
44    lock.flush()?;
45
46    let stdin = tempfile::NamedTempFile::new()?;
47    stdin.as_file().write_all(input.as_ref())?;
48    let _guard = stdio_override::StdinOverride::override_file(&stdin)?;
49    #[cfg(feature = "proconio")]
50    {
51        // let proconio use line source
52        proconio::input_interactive! {};
53    }
54    Ok(f())
55}
56
57pub fn redirect_stdio(input: impl AsRef<[u8]>, f: impl FnOnce()) -> Result<Vec<u8>> {
58    redirect_stdin(input.as_ref(), || redirect_stdout(f))?
59}
60
61pub fn redirect_stdio_utf8(input: impl AsRef<[u8]>, f: impl FnOnce()) -> Result<String> {
62    Ok(String::from_utf8(redirect_stdio(input.as_ref(), f)?)?)
63}
64
65pub fn wrap_assert_success<T: Termination>(f: impl FnOnce() -> T) -> impl FnOnce() {
66    || {
67        let exit_code = f().report();
68        let exit_code: u8 = unsafe { std::mem::transmute(exit_code) };
69        assert_eq!(exit_code, 0);
70    }
71}
72
73#[macro_export]
74macro_rules! assert_success_with_input {
75    ($input: expr, $main: expr) => {{
76        $crate::redirect_stdio($input, $crate::wrap_assert_success($main)).unwrap();
77    }};
78}
79#[macro_export]
80macro_rules! assert_eq_output_for_input {
81    ($input: expr, $left: expr, $right: expr) => {
82        let left = redirect_stdio(input, wrap_assert_success(left)).unwrap();
83        let right = redirect_stdio(input, wrap_assert_success(right)).unwrap();
84        if let (Ok(left), Ok(right)) = (std::str::from_utf8(&left), std::str::from_utf8(&right)) {
85            assert_eq!(left, right)
86        } else {
87            assert_eq!(left, right)
88        }
89    };
90}
91
92#[cfg(feature = "proptest")]
93#[macro_export]
94macro_rules! prop_assert_success_with_input {
95    ($input: expr, $main: expr) => {{
96        let ret = $crate::redirect_stdio($input, $crate::wrap_assert_success($main));
97        proptest::prop_assert!(ret.is_ok())
98    }};
99}
100#[cfg(feature = "proptest")]
101#[macro_export]
102macro_rules! prop_assert_eq_output_for_input {
103    ($input: expr, $left: expr, $right: expr) => {
104        let left = $crate::redirect_stdio($input, $crate::wrap_assert_success($left)).unwrap();
105        let right = $crate::redirect_stdio($input, $crate::wrap_assert_success($right)).unwrap();
106        if let (Ok(left), Ok(right)) = (std::str::from_utf8(&left), std::str::from_utf8(&right)) {
107            proptest::prop_assert_eq!(left, right)
108        } else {
109            proptest::prop_assert_eq!(left, right)
110        }
111    };
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_stdin() {
120        let s = redirect_stdin("This is input", || {
121            std::io::stdin().lines().next().unwrap().unwrap()
122        })
123        .unwrap();
124        assert_eq!(s, "This is input");
125    }
126
127    #[test]
128    fn test_stdout() {
129        let s = redirect_stdout(|| {
130            println!("This is output");
131        })
132        .unwrap();
133        assert_eq!(s, "This is output\n".as_bytes());
134    }
135
136    #[test]
137    fn test_stdio() {
138        let s = redirect_stdio_utf8("This is input", || {
139            for line in std::io::stdin().lines() {
140                println!("{}", line.unwrap());
141            }
142        })
143        .unwrap();
144        assert_eq!(s, "This is input\n");
145    }
146
147    #[test]
148    fn test_success() {
149        assert_success_with_input!("foo", || -> Result<(), ()> {
150            println!("aaa");
151            if std::io::read_to_string(std::io::stdin()).unwrap() == "foo" {
152                Ok(())
153            } else {
154                Err(())
155            }
156        })
157    }
158
159    #[test]
160    fn test_runtime_error() {
161        let mut lock = std::io::stderr().lock();
162        lock.flush().unwrap();
163
164        let stderr = tempfile::NamedTempFile::new().unwrap();
165        let _guard = stdio_override::StderrOverride::override_file(&stderr).unwrap();
166        let result = std::panic::catch_unwind(|| {
167            assert_success_with_input!("bar", || -> Result<(), ()> {
168                println!("aaa");
169                if std::io::read_to_string(std::io::stdin()).unwrap() == "foo" {
170                    Ok(())
171                } else {
172                    Err(())
173                }
174            })
175        });
176        assert!(result.is_err());
177    }
178
179    #[cfg(feature = "proptest")]
180    mod proptest_features {
181        use proptest::prelude::*;
182        proptest! {
183            #[test]
184            fn prop_assert(s: String) {
185                prop_assert_success_with_input!(&s, || -> Result<(), ()> {
186                    println!("aaa");
187                    let input = std::io::read_to_string(std::io::stdin()).unwrap();
188                    if input == s {
189                        Ok(())
190                    } else {
191                        Err(())
192                    }
193                })
194            }
195        }
196
197        fn wa() {
198            let input = std::io::read_to_string(std::io::stdin()).unwrap();
199            let mut input = input.split_whitespace();
200            let a = input.next().unwrap().parse::<i32>().unwrap();
201            let b = input.next().unwrap().parse::<i32>().unwrap();
202            println!("{}", a.wrapping_add(b));
203        }
204
205        fn ac() {
206            let input = std::io::read_to_string(std::io::stdin()).unwrap();
207            let mut input = input.split_whitespace();
208            let a = input.next().unwrap().parse::<i64>().unwrap();
209            let b = input.next().unwrap().parse::<i64>().unwrap();
210            println!("{}", a + b);
211        }
212
213        #[test]
214        #[should_panic]
215        fn test_failure() {
216            std::env::set_var("PROPTEST_DISABLE_FAILURE_PERSISTENCE", "true");
217            proptest!(move |(a: i32, b: i32)| {
218                prop_assert_eq_output_for_input!(format!("{} {}", a, b), wa, ac);
219            });
220        }
221    }
222}