sqlx_conditional_queries_core/
parse.rs

1use syn::{parenthesized, parse::Parse};
2
3#[derive(Clone, Debug)]
4pub(crate) struct ParsedConditionalQueryAs {
5    /// This is the equivalent of sqlx's output type in a `query_as!` macro.
6    pub(crate) output_type: syn::Ident,
7    /// The actual string of the query.
8    pub(crate) query_string: syn::LitStr,
9    /// All compile time bindings, each with its variables and associated `match` statement.
10    pub(crate) compile_time_bindings: Vec<(
11        OneOrPunctuated<syn::Ident, syn::token::Comma>,
12        syn::ExprMatch,
13    )>,
14}
15
16/// This enum represents the identifier (`#foo`, `#(foo, bar)`) of single binding expression
17/// inside a query.
18///
19/// Normal statements such as `#foo = match something {...}` are represented by the `One(T)`
20/// variant.
21///
22/// It's also possible to match tuples such as:
23/// `#(order_dir, order_dir_rev) = match order_dir {...}`
24/// These are represented by the `Punctuated(...)` variant.
25#[derive(Clone, Debug)]
26pub(crate) enum OneOrPunctuated<T, P> {
27    One(T),
28    Punctuated(syn::punctuated::Punctuated<T, P>, proc_macro2::Span),
29}
30
31impl<T: syn::spanned::Spanned, P> OneOrPunctuated<T, P> {
32    pub(crate) fn span(&self) -> proc_macro2::Span {
33        match self {
34            OneOrPunctuated::One(t) => t.span(),
35            OneOrPunctuated::Punctuated(_, span) => *span,
36        }
37    }
38}
39
40impl<T, P> IntoIterator for OneOrPunctuated<T, P> {
41    type Item = T;
42
43    type IntoIter = std::vec::IntoIter<T>;
44
45    fn into_iter(self) -> Self::IntoIter {
46        match self {
47            OneOrPunctuated::One(item) => vec![item].into_iter(),
48            OneOrPunctuated::Punctuated(punctuated, _) => {
49                punctuated.into_iter().collect::<Vec<_>>().into_iter()
50            }
51        }
52    }
53}
54
55impl Parse for ParsedConditionalQueryAs {
56    /// Take a given raw token stream from a macro invocation and parse it into our own
57    /// representation for further processing.
58    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
59        // Parse the ident of the output type that we're going to pass to `query_as!`.
60        let output_type = input.parse::<syn::Ident>()?;
61        input.parse::<syn::token::Comma>()?;
62
63        // Parse the actual query string literal.
64        let query_string = input.parse::<syn::LitStr>()?;
65
66        // The rest of the input has to be an optional sequence of compile-time binding
67        // expressions.
68        let mut compile_time_bindings = Vec::new();
69        while !input.is_empty() {
70            // Every binding expression has to be preceded by a comma, and we also allow the final
71            // comma to be optional.
72            input.parse::<syn::token::Comma>()?;
73            if input.is_empty() {
74                break;
75            }
76
77            // Every binding expression starts with a #.
78            input.parse::<syn::token::Pound>()?;
79
80            // Then we parse the binding names.
81            let binding_names = if input.peek(syn::token::Paren) {
82                // If the binding names start with parens we're parsing a tuple of binding names.
83                let content;
84                let paren_token = parenthesized!(content in input);
85                OneOrPunctuated::Punctuated(
86                    content.parse_terminated(syn::Ident::parse, syn::token::Comma)?,
87                    paren_token.span.join(),
88                )
89            } else {
90                // Otherwise we only parse a single ident.
91                let name = input.parse::<syn::Ident>()?;
92                OneOrPunctuated::One(name)
93            };
94
95            // Binding names and match is delimited by equals sign.
96            input.parse::<syn::token::Eq>()?;
97
98            // And finally we parse a match expression.
99            let match_expression = input.parse::<syn::ExprMatch>()?;
100
101            compile_time_bindings.push((binding_names, match_expression));
102        }
103
104        Ok(ParsedConditionalQueryAs {
105            output_type,
106            query_string,
107            compile_time_bindings,
108        })
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn valid_syntax() {
118        let mut parsed = syn::parse_str::<ParsedConditionalQueryAs>(
119            r#"
120                SomeType,
121                "some SQL query",
122                #binding = match foo {
123                    bar => "baz",
124                },
125                #(a, b) = match c {
126                    d => ("e", "f"),
127                },
128            "#,
129        )
130        .unwrap();
131
132        assert_eq!(
133            parsed.output_type,
134            syn::Ident::new("SomeType", proc_macro2::Span::call_site()),
135        );
136
137        assert_eq!(
138            parsed.query_string,
139            syn::LitStr::new("some SQL query", proc_macro2::Span::call_site()),
140        );
141
142        assert_eq!(parsed.compile_time_bindings.len(), 2);
143
144        {
145            let (names, _) = parsed.compile_time_bindings.remove(0);
146
147            assert_eq!(
148                names.into_iter().collect::<Vec<_>>(),
149                [syn::Ident::new("binding", proc_macro2::Span::call_site())]
150            );
151        }
152
153        {
154            let (names, _) = parsed.compile_time_bindings.remove(0);
155
156            assert_eq!(
157                names.into_iter().collect::<Vec<_>>(),
158                [
159                    syn::Ident::new("a", proc_macro2::Span::call_site()),
160                    syn::Ident::new("b", proc_macro2::Span::call_site()),
161                ]
162            );
163        }
164    }
165}