assert_tokenstreams_eq/
lib.rs

1use std::{
2    io::Write,
3    process::{Command, Stdio},
4    string::FromUtf8Error,
5};
6
7use pretty_assertions::assert_eq;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11enum RustfmtError {
12    #[error("Could not create child process for rustfmt.")]
13    Io(#[from] std::io::Error),
14    #[error("Could not convert output bytes to UTF-8.")]
15    Utf8(#[from] FromUtf8Error),
16    #[error("Could not read from rustfmt child process.")]
17    Output,
18    #[error("Could not access stdin of rustfmt child process.")]
19    Stdin,
20    #[error("Input is not valid Rust code.\n{0}")]
21    InvalidRust(String),
22}
23
24/// # Panics
25///
26/// This function will panic if:
27/// - The token streams are not equal.
28/// - Either token stream is not valid Rust code.
29/// - `rustfmt` is not installed or is configured incorrectly.
30/// - It was not possible to create a child process to run `rustfmt`.
31/// - The output of `rustfmt` could not be converted to UTF-8.
32pub fn compare_tokenstreams(
33    first_tokenstream: &impl std::fmt::Display,
34    second_tokenstream: &impl std::fmt::Display,
35) {
36    let first_formatted = match apply_rustfmt(first_tokenstream) {
37        Ok(tokens) => tokens,
38        Err(e) => panic!("Error formatting first token stream: {e}"),
39    };
40    let second_formatted = match apply_rustfmt(second_tokenstream) {
41        Ok(tokens) => tokens,
42        Err(e) => panic!("Error formatting second token stream: {e}"),
43    };
44    assert_eq!(first_formatted, second_formatted);
45}
46
47fn apply_rustfmt(tokens: &impl std::fmt::Display) -> Result<String, RustfmtError> {
48    let mut process = Command::new("rustfmt")
49        .arg("--")
50        .stdin(Stdio::piped())
51        .stdout(Stdio::piped())
52        .stderr(Stdio::piped())
53        .spawn()?;
54
55    let Some(stdin) = process.stdin.as_mut() else {
56        return Err(RustfmtError::Stdin);
57    };
58
59    write!(stdin, "{tokens}")?;
60
61    let output = process
62        .wait_with_output()
63        .map_err(|_| RustfmtError::Output)?;
64
65    if output.status.success() {
66        let fmt_tokens = String::from_utf8(output.stdout)?;
67        Ok(fmt_tokens)
68    } else {
69        let err = String::from_utf8(output.stderr)?;
70        Err(RustfmtError::InvalidRust(err))
71    }
72}
73
74#[macro_export]
75macro_rules! assert_tokenstreams_eq {
76    ($left:expr, $right:expr) => {{
77        $crate::compare_tokenstreams($left, $right);
78    }};
79}
80
81#[cfg(test)]
82mod tests {
83    use quote::quote;
84
85    use crate::assert_tokenstreams_eq;
86
87    #[test]
88    fn test_compare_tokenstreams_equal() {
89        let first = quote! {
90            fn test(a: String, b: String) {
91                return a;
92            }
93        };
94        let second = quote! {
95            fn test(a: String, b: String) {
96                return a;
97            }
98        };
99        assert_tokenstreams_eq!(&first, &second);
100    }
101
102    #[test]
103    #[should_panic(expected = "(left == right)")]
104    fn test_compare_tokenstreams_unequal() {
105        let first = quote! {
106            fn test(a: String, b: String) {
107                let test = 2;
108                return a;
109            }
110        };
111        let second = quote! {
112            fn test2(a: String, b: String) {
113                let test = 2;
114                return b;
115            }
116        };
117        assert_tokenstreams_eq!(&first, &second);
118    }
119
120    #[test]
121    #[should_panic(expected = "Input is not valid Rust code.")]
122    fn test_compare_tokenstreams_invalid_rust_code() {
123        let first = quote! {
124            async fn test(a: String, b: String) {
125                let test = async {
126                    a + b
127                }
128                return test;
129            }
130        };
131        let second = quote! {
132            async fn test(a: String, b: String) {
133                return a;
134            }
135        };
136        assert_tokenstreams_eq!(&first, &second);
137    }
138
139    #[test]
140    #[should_panic(expected = "(left == right)")]
141    fn test_compare_tokenstreams_unequal_async() {
142        let first = quote! {
143            async fn test(a: String, b: String) {
144                let test = async {
145                    a + b
146                }.await;
147                return test;
148            }
149        };
150        let second = quote! {
151            async fn test(a: String, b: String) {
152                return a;
153            }
154        };
155        assert_tokenstreams_eq!(&first, &second);
156    }
157}