1#![cfg_attr(
2 rust_comp_feature = "unstable_features",
3 feature(internal_output_capture)
4)]
5
6use std::{
7 io::{Read as _, Write as _},
8 process::Termination,
9};
10
11pub type Result<T, E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
12
13static MARKER: &[u8] = &[
14 255, 77, 65, 82, 75, 69, 82, 95, 70, 79, 82, 95, 82, 69, 68, 73, 82, 69, 67, 84, 95, 83, 84,
15 68, 79, 85, 84,
16];
17pub fn redirect_stdout(f: impl FnOnce()) -> Result<Vec<u8>> {
18 let mut lock = std::io::stdout().lock();
19 lock.flush()?;
20
21 let stdout = tempfile::NamedTempFile::new()?;
22
23 let guard = stdio_override::StdoutOverride::override_file(&stdout)?;
24 #[cfg(rust_comp_feature = "unstable_features")]
25 let old_capture = std::io::set_output_capture(None);
26 std::io::stdout().write_all(MARKER)?;
27 std::io::stdout().flush()?;
28 f();
29 std::io::stdout().flush()?;
30 #[cfg(rust_comp_feature = "unstable_features")]
31 std::io::set_output_capture(old_capture);
32 drop(guard);
33
34 let mut res = Vec::new();
35 stdout.as_file().read_to_end(&mut res)?;
36 if let Some(ret) = res.strip_prefix(MARKER) {
37 return Ok(ret.to_vec());
38 }
39 eprintln!("WARNING: Couldn't read stdout. Please run tests with --nocapture: cargo test -- --nocapture");
40 Ok(b"Use `cargo test -- --nocapture` if you care the stdout".to_vec())
41}
42pub fn redirect_stdin<T>(input: impl AsRef<[u8]>, f: impl FnOnce() -> T) -> Result<T> {
43 let mut lock = std::io::stdout().lock();
44 lock.flush()?;
45
46 let stdin = tempfile::NamedTempFile::new()?;
47 stdin.as_file().write_all(input.as_ref())?;
48 let _guard = stdio_override::StdinOverride::override_file(&stdin)?;
49 #[cfg(feature = "proconio")]
50 {
51 proconio::input_interactive! {};
53 }
54 Ok(f())
55}
56
57pub fn redirect_stdio(input: impl AsRef<[u8]>, f: impl FnOnce()) -> Result<Vec<u8>> {
58 redirect_stdin(input.as_ref(), || redirect_stdout(f))?
59}
60
61pub fn redirect_stdio_utf8(input: impl AsRef<[u8]>, f: impl FnOnce()) -> Result<String> {
62 Ok(String::from_utf8(redirect_stdio(input.as_ref(), f)?)?)
63}
64
65pub fn wrap_assert_success<T: Termination>(f: impl FnOnce() -> T) -> impl FnOnce() {
66 || {
67 let exit_code = f().report();
68 let exit_code: u8 = unsafe { std::mem::transmute(exit_code) };
69 assert_eq!(exit_code, 0);
70 }
71}
72
73#[macro_export]
74macro_rules! assert_success_with_input {
75 ($input: expr, $main: expr) => {{
76 $crate::redirect_stdio($input, $crate::wrap_assert_success($main)).unwrap();
77 }};
78}
79#[macro_export]
80macro_rules! assert_eq_output_for_input {
81 ($input: expr, $left: expr, $right: expr) => {
82 let left = redirect_stdio(input, wrap_assert_success(left)).unwrap();
83 let right = redirect_stdio(input, wrap_assert_success(right)).unwrap();
84 if let (Ok(left), Ok(right)) = (std::str::from_utf8(&left), std::str::from_utf8(&right)) {
85 assert_eq!(left, right)
86 } else {
87 assert_eq!(left, right)
88 }
89 };
90}
91
92#[cfg(feature = "proptest")]
93#[macro_export]
94macro_rules! prop_assert_success_with_input {
95 ($input: expr, $main: expr) => {{
96 let ret = $crate::redirect_stdio($input, $crate::wrap_assert_success($main));
97 proptest::prop_assert!(ret.is_ok())
98 }};
99}
100#[cfg(feature = "proptest")]
101#[macro_export]
102macro_rules! prop_assert_eq_output_for_input {
103 ($input: expr, $left: expr, $right: expr) => {
104 let left = $crate::redirect_stdio($input, $crate::wrap_assert_success($left)).unwrap();
105 let right = $crate::redirect_stdio($input, $crate::wrap_assert_success($right)).unwrap();
106 if let (Ok(left), Ok(right)) = (std::str::from_utf8(&left), std::str::from_utf8(&right)) {
107 proptest::prop_assert_eq!(left, right)
108 } else {
109 proptest::prop_assert_eq!(left, right)
110 }
111 };
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_stdin() {
120 let s = redirect_stdin("This is input", || {
121 std::io::stdin().lines().next().unwrap().unwrap()
122 })
123 .unwrap();
124 assert_eq!(s, "This is input");
125 }
126
127 #[test]
128 fn test_stdout() {
129 let s = redirect_stdout(|| {
130 println!("This is output");
131 })
132 .unwrap();
133 assert_eq!(s, "This is output\n".as_bytes());
134 }
135
136 #[test]
137 fn test_stdio() {
138 let s = redirect_stdio_utf8("This is input", || {
139 for line in std::io::stdin().lines() {
140 println!("{}", line.unwrap());
141 }
142 })
143 .unwrap();
144 assert_eq!(s, "This is input\n");
145 }
146
147 #[test]
148 fn test_success() {
149 assert_success_with_input!("foo", || -> Result<(), ()> {
150 println!("aaa");
151 if std::io::read_to_string(std::io::stdin()).unwrap() == "foo" {
152 Ok(())
153 } else {
154 Err(())
155 }
156 })
157 }
158
159 #[test]
160 fn test_runtime_error() {
161 let mut lock = std::io::stderr().lock();
162 lock.flush().unwrap();
163
164 let stderr = tempfile::NamedTempFile::new().unwrap();
165 let _guard = stdio_override::StderrOverride::override_file(&stderr).unwrap();
166 let result = std::panic::catch_unwind(|| {
167 assert_success_with_input!("bar", || -> Result<(), ()> {
168 println!("aaa");
169 if std::io::read_to_string(std::io::stdin()).unwrap() == "foo" {
170 Ok(())
171 } else {
172 Err(())
173 }
174 })
175 });
176 assert!(result.is_err());
177 }
178
179 #[cfg(feature = "proptest")]
180 mod proptest_features {
181 use proptest::prelude::*;
182 proptest! {
183 #[test]
184 fn prop_assert(s: String) {
185 prop_assert_success_with_input!(&s, || -> Result<(), ()> {
186 println!("aaa");
187 let input = std::io::read_to_string(std::io::stdin()).unwrap();
188 if input == s {
189 Ok(())
190 } else {
191 Err(())
192 }
193 })
194 }
195 }
196
197 fn wa() {
198 let input = std::io::read_to_string(std::io::stdin()).unwrap();
199 let mut input = input.split_whitespace();
200 let a = input.next().unwrap().parse::<i32>().unwrap();
201 let b = input.next().unwrap().parse::<i32>().unwrap();
202 println!("{}", a.wrapping_add(b));
203 }
204
205 fn ac() {
206 let input = std::io::read_to_string(std::io::stdin()).unwrap();
207 let mut input = input.split_whitespace();
208 let a = input.next().unwrap().parse::<i64>().unwrap();
209 let b = input.next().unwrap().parse::<i64>().unwrap();
210 println!("{}", a + b);
211 }
212
213 #[test]
214 #[should_panic]
215 fn test_failure() {
216 std::env::set_var("PROPTEST_DISABLE_FAILURE_PERSISTENCE", "true");
217 proptest!(move |(a: i32, b: i32)| {
218 prop_assert_eq_output_for_input!(format!("{} {}", a, b), wa, ac);
219 });
220 }
221 }
222}