codama_syn_helpers/extensions/
parse_buffer.rs

1use proc_macro2::{TokenStream, TokenTree};
2use syn::{parse::ParseBuffer, Token};
3
4pub trait ParseBufferExtension<'a> {
5    fn get_self(&self) -> &ParseBuffer<'a>;
6
7    /// Advance the buffer until we reach a comma or the end of the buffer.
8    /// Returns the consumed tokens as a `TokenStream`.
9    fn parse_arg(&self) -> syn::Result<TokenStream> {
10        self.get_self().step(|cursor| {
11            let mut tts = Vec::new();
12            let mut rest = *cursor;
13            while let Some((tt, next)) = rest.token_tree() {
14                match &tt {
15                    TokenTree::Punct(punct) if punct.as_char() == ',' => {
16                        return Ok((tts.into_iter().collect(), rest));
17                    }
18                    _ => {
19                        tts.push(tt);
20                        rest = next
21                    }
22                }
23            }
24            Ok((tts.into_iter().collect(), rest))
25        })
26    }
27
28    /// Fork the current buffer and move the original buffer to the end of the argument.
29    fn fork_arg(&self) -> syn::Result<ParseBuffer<'a>> {
30        let this = self.get_self();
31        let fork = this.fork();
32        this.parse_arg()?;
33        Ok(fork)
34    }
35
36    // Check if the buffer is empty or the next token is a comma.
37    fn is_end_of_arg(&self) -> bool {
38        let this = self.get_self();
39        this.is_empty() || this.peek(Token![,])
40    }
41
42    /// Check if the next token is an empty group.
43    fn is_empty_group(&self) -> bool {
44        match self.get_self().fork().parse::<proc_macro2::Group>() {
45            Ok(group) => group.stream().is_empty(),
46            Err(_) => false,
47        }
48    }
49}
50
51impl<'a> ParseBufferExtension<'a> for ParseBuffer<'a> {
52    fn get_self(&self) -> &ParseBuffer<'a> {
53        self
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    macro_rules! test_buffer {
62        ($ty:ty, $callback:expr) => {{
63            struct TestBuffer($ty);
64            impl syn::parse::Parse for TestBuffer {
65                fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
66                    let result = ($callback)(input)?; // Call the provided callback
67                    input.parse::<proc_macro2::TokenStream>()?; // Consume the rest of the input
68                    Ok(Self(result))
69                }
70            }
71            move |input: &str| -> syn::Result<$ty> {
72                syn::parse_str::<TestBuffer>(input).map(|x| x.0)
73            }
74        }};
75    }
76
77    #[test]
78    fn parse_arg() {
79        let test = test_buffer!(
80            (String, String),
81            |input: syn::parse::ParseStream| -> syn::Result<(String, String)> {
82                let arg = input.parse_arg()?;
83                Ok((arg.to_string(), input.to_string()))
84            }
85        );
86
87        assert_eq!(
88            test("foo , bar , baz").unwrap(),
89            ("foo".into(), ", bar , baz".into())
90        );
91        assert_eq!(
92            test(", bar , baz").unwrap(),
93            ("".into(), ", bar , baz".into())
94        );
95        assert_eq!(
96            test("(foo , bar), baz").unwrap(),
97            ("(foo , bar)".into(), ", baz".into())
98        );
99        assert_eq!(
100            test("foo bar baz").unwrap(),
101            ("foo bar baz".into(), "".into())
102        );
103        assert_eq!(test("").unwrap(), ("".into(), "".into()));
104    }
105
106    #[test]
107    fn fork_arg() {
108        let test = test_buffer!(
109            (String, String),
110            |input: syn::parse::ParseStream| -> syn::Result<(String, String)> {
111                let arg = input.fork_arg()?;
112                Ok((arg.to_string(), input.to_string()))
113            }
114        );
115
116        assert_eq!(
117            test("foo , bar , baz").unwrap(),
118            ("foo , bar , baz".into(), ", bar , baz".into())
119        );
120        assert_eq!(
121            test("foo bar baz").unwrap(),
122            ("foo bar baz".into(), "".into())
123        );
124    }
125
126    #[test]
127    fn is_end_of_arg() {
128        let test = test_buffer!(
129            bool,
130            |input: syn::parse::ParseStream| -> syn::Result<bool> { Ok(input.is_end_of_arg()) }
131        );
132
133        assert!(test("").unwrap());
134        assert!(test(", bar").unwrap());
135        assert!(!test("foo").unwrap());
136    }
137
138    #[test]
139    fn is_empty_group() {
140        let test = test_buffer!(
141            bool,
142            |input: syn::parse::ParseStream| -> syn::Result<bool> { Ok(input.is_empty_group()) }
143        );
144
145        assert!(test("()").unwrap());
146        assert!(test("[]").unwrap());
147        assert!(test("{}").unwrap());
148        assert!(!test("").unwrap());
149        assert!(!test("foo").unwrap());
150        assert!(!test("(foo)").unwrap());
151    }
152}