assert_tokenstreams_eq/
lib.rs

1use std::{
2    io::Write,
3    process::{Command, Stdio},
4    string::FromUtf8Error,
5};
6
7use pretty_assertions::assert_eq;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11enum RustfmtError {
12    #[error("Could not create child process for rustfmt.")]
13    Io(#[from] std::io::Error),
14    #[error("Could not convert output bytes to UTF-8.")]
15    Utf8(#[from] FromUtf8Error),
16    #[error("Could not read from rustfmt child process.")]
17    Output,
18    #[error("Could not access stdin of rustfmt child process.")]
19    Stdin,
20    #[error("Input is no valid rust code.\n{0}")]
21    InvalidRust(String),
22}
23
24/// # Panics
25///
26/// This function will panic for different reasons:
27/// - The tokenstreams are not equal.
28/// - Any tokenstream is no valid rust code.
29/// - rustfmt is not installed or configured wrong.
30/// - It was not possible to create a child process to run rustfmt.
31/// - The output of rustfmt could not be converted to Utf-8.
32pub fn compare_tokenstreams(first_tokenstream: &impl ToString, second_tokenstream: &impl ToString) {
33    let first_formatted = match apply_rustfmt(first_tokenstream) {
34        Ok(tokens) => tokens,
35        Err(e) => panic!("{}", e),
36    };
37    let second_formatted = match apply_rustfmt(second_tokenstream) {
38        Ok(tokens) => tokens,
39        Err(e) => panic!("{}", e),
40    };
41    assert_eq!(first_formatted, second_formatted);
42}
43
44fn apply_rustfmt(tokens: &impl ToString) -> Result<String, RustfmtError> {
45    let mut process = Command::new("rustfmt")
46        .arg("--")
47        .stdin(Stdio::piped())
48        .stdout(Stdio::piped())
49        .stderr(Stdio::piped())
50        .spawn()?;
51
52    let Some(stdin) = process.stdin.as_mut() else {
53        return Err(RustfmtError::Stdin);
54    };
55
56    stdin.write_all(tokens.to_string().as_bytes())?;
57
58    let output = process
59        .wait_with_output()
60        .map_err(|_| RustfmtError::Output)?;
61
62    if output.status.success() {
63        let fmt_tokens = String::from_utf8(output.stdout)?;
64        Ok(fmt_tokens)
65    } else {
66        let err = String::from_utf8(output.stderr)?;
67        Err(RustfmtError::InvalidRust(err))
68    }
69}
70
71#[macro_export]
72macro_rules! assert_tokenstreams_eq {
73    ($left:expr, $right:expr) => {{
74        $crate::compare_tokenstreams($left, $right);
75    }};
76}
77
78#[cfg(test)]
79mod tests {
80    use quote::quote;
81
82    use crate::assert_tokenstreams_eq;
83
84    #[test]
85    fn test_compare_tokenstreams_equal() {
86        let first = quote! {
87            fn test(a: String, b: String) {
88                return a;
89            }
90        };
91        let second = quote! {
92            fn test(a: String, b: String) {
93                return a;
94            }
95        };
96        assert_tokenstreams_eq!(&first, &second);
97    }
98
99    #[test]
100    #[should_panic]
101    fn test_compare_tokenstreams_unequal() {
102        let first = quote! {
103            fn test(a: String, b: String) {
104                let test = 2;
105                return a;
106            }
107        };
108        let second = quote! {
109            fn test2(a: String, b: String) {
110                let test = 2;
111                return b;
112            }
113        };
114        assert_tokenstreams_eq!(&first, &second);
115    }
116
117    #[test]
118    #[should_panic]
119    fn test_compare_tokenstreams_invalid_rust_code() {
120        let first = quote! {
121            async fn test(a: String, b: String) {
122                let test = async {
123                    a + b
124                }
125                return test;
126            }
127        };
128        let second = quote! {
129            async fn test(a: String, b: String) {
130                return a;
131            }
132        };
133        assert_tokenstreams_eq!(&first, &second);
134    }
135
136    #[test]
137    #[should_panic]
138    fn test_compare_tokenstreams_unequal_async() {
139        let first = quote! {
140            async fn test(a: String, b: String) {
141                let test = async {
142                    a + b
143                }.await
144                return test;
145            }
146        };
147        let second = quote! {
148            async fn test(a: String, b: String) {
149                return a;
150            }
151        };
152        assert_tokenstreams_eq!(&first, &second);
153    }
154}