assert_tokenstreams_eq/
lib.rs1use std::{
2 io::Write,
3 process::{Command, Stdio},
4 string::FromUtf8Error,
5};
6
7use pretty_assertions::assert_eq;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11enum RustfmtError {
12 #[error("Could not create child process for rustfmt.")]
13 Io(#[from] std::io::Error),
14 #[error("Could not convert output bytes to UTF-8.")]
15 Utf8(#[from] FromUtf8Error),
16 #[error("Could not read from rustfmt child process.")]
17 Output,
18 #[error("Could not access stdin of rustfmt child process.")]
19 Stdin,
20 #[error("Input is not valid Rust code.\n{0}")]
21 InvalidRust(String),
22}
23
24pub fn compare_tokenstreams(
33 first_tokenstream: &impl std::fmt::Display,
34 second_tokenstream: &impl std::fmt::Display,
35) {
36 let first_formatted = match apply_rustfmt(first_tokenstream) {
37 Ok(tokens) => tokens,
38 Err(e) => panic!("Error formatting first token stream: {e}"),
39 };
40 let second_formatted = match apply_rustfmt(second_tokenstream) {
41 Ok(tokens) => tokens,
42 Err(e) => panic!("Error formatting second token stream: {e}"),
43 };
44 assert_eq!(first_formatted, second_formatted);
45}
46
47fn apply_rustfmt(tokens: &impl std::fmt::Display) -> Result<String, RustfmtError> {
48 let mut process = Command::new("rustfmt")
49 .arg("--")
50 .stdin(Stdio::piped())
51 .stdout(Stdio::piped())
52 .stderr(Stdio::piped())
53 .spawn()?;
54
55 let Some(stdin) = process.stdin.as_mut() else {
56 return Err(RustfmtError::Stdin);
57 };
58
59 write!(stdin, "{tokens}")?;
60
61 let output = process
62 .wait_with_output()
63 .map_err(|_| RustfmtError::Output)?;
64
65 if output.status.success() {
66 let fmt_tokens = String::from_utf8(output.stdout)?;
67 Ok(fmt_tokens)
68 } else {
69 let err = String::from_utf8(output.stderr)?;
70 Err(RustfmtError::InvalidRust(err))
71 }
72}
73
74#[macro_export]
75macro_rules! assert_tokenstreams_eq {
76 ($left:expr, $right:expr) => {{
77 $crate::compare_tokenstreams($left, $right);
78 }};
79}
80
81#[cfg(test)]
82mod tests {
83 use quote::quote;
84
85 use crate::assert_tokenstreams_eq;
86
87 #[test]
88 fn test_compare_tokenstreams_equal() {
89 let first = quote! {
90 fn test(a: String, b: String) {
91 return a;
92 }
93 };
94 let second = quote! {
95 fn test(a: String, b: String) {
96 return a;
97 }
98 };
99 assert_tokenstreams_eq!(&first, &second);
100 }
101
102 #[test]
103 #[should_panic(expected = "(left == right)")]
104 fn test_compare_tokenstreams_unequal() {
105 let first = quote! {
106 fn test(a: String, b: String) {
107 let test = 2;
108 return a;
109 }
110 };
111 let second = quote! {
112 fn test2(a: String, b: String) {
113 let test = 2;
114 return b;
115 }
116 };
117 assert_tokenstreams_eq!(&first, &second);
118 }
119
120 #[test]
121 #[should_panic(expected = "Input is not valid Rust code.")]
122 fn test_compare_tokenstreams_invalid_rust_code() {
123 let first = quote! {
124 async fn test(a: String, b: String) {
125 let test = async {
126 a + b
127 }
128 return test;
129 }
130 };
131 let second = quote! {
132 async fn test(a: String, b: String) {
133 return a;
134 }
135 };
136 assert_tokenstreams_eq!(&first, &second);
137 }
138
139 #[test]
140 #[should_panic(expected = "(left == right)")]
141 fn test_compare_tokenstreams_unequal_async() {
142 let first = quote! {
143 async fn test(a: String, b: String) {
144 let test = async {
145 a + b
146 }.await;
147 return test;
148 }
149 };
150 let second = quote! {
151 async fn test(a: String, b: String) {
152 return a;
153 }
154 };
155 assert_tokenstreams_eq!(&first, &second);
156 }
157}