match_template/
lib.rs

1// Copyright 2022 TiKV Project Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This crate provides a macro that can be used to append a match expression
16//! with multiple arms, where the tokens in the first arm, as a template, can be
17//! substituted and the template arm will be expanded into multiple arms.
18//!
19//! For example, the following code
20//!
21//! ```ignore
22//! match_template! {
23//!     T = [Int, Real, Double],
24//!     match Foo {
25//!         EvalType::T => { panic!("{}", EvalType::T); },
26//!         EvalType::Other => unreachable!(),
27//!     }
28//! }
29//! ```
30//!
31//! generates
32//!
33//! ```ignore
34//! match Foo {
35//!     EvalType::Int => { panic!("{}", EvalType::Int); },
36//!     EvalType::Real => { panic!("{}", EvalType::Real); },
37//!     EvalType::Double => { panic!("{}", EvalType::Double); },
38//!     EvalType::Other => unreachable!(),
39//! }
40//! ```
41//!
42//! In addition, substitution can vary on two sides of the arms.
43//!
44//! For example,
45//!
46//! ```ignore
47//! match_template! {
48//!     T = [Foo, Bar => Baz],
49//!     match Foo {
50//!         EvalType::T => { panic!("{}", EvalType::T); },
51//!     }
52//! }
53//! ```
54//!
55//! generates
56//!
57//! ```ignore
58//! match Foo {
59//!     EvalType::Foo => { panic!("{}", EvalType::Foo); },
60//!     EvalType::Bar => { panic!("{}", EvalType::Baz); },
61//! }
62//! ```
63//!
64//! Wildcard match arm is also supported (but there will be no substitution).
65
66use proc_macro2::{Group, Ident, TokenStream, TokenTree};
67use quote::{quote, ToTokens};
68use syn::{
69    bracketed,
70    parse::{Parse, ParseStream},
71    parse_macro_input,
72    punctuated::Punctuated,
73    Arm, Expr, ExprMatch, Pat, Token,
74};
75
76/// A procedural macro that generates repeated match arms by pattern.
77///
78/// See the [module-level documentation](self) for more details.
79#[proc_macro]
80pub fn match_template(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
81    let mt = parse_macro_input!(input as MatchTemplate);
82    mt.expand().into()
83}
84
85struct MatchTemplate {
86    template_ident: Ident,
87    substitutes: Punctuated<Substitution, Token![,]>,
88    match_exp: Box<Expr>,
89    template_arm: Arm,
90    remaining_arms: Vec<Arm>,
91}
92
93impl Parse for MatchTemplate {
94    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
95        let template_ident = input.parse()?;
96        input.parse::<Token![=]>()?;
97        let substitutes_tokens;
98        bracketed!(substitutes_tokens in input);
99        let substitutes =
100            Punctuated::<Substitution, Token![,]>::parse_terminated(&substitutes_tokens)?;
101        input.parse::<Token![,]>()?;
102        let m: ExprMatch = input.parse()?;
103        let mut arms = m.arms;
104        arms.iter_mut().for_each(|arm| arm.comma = None);
105        assert!(!arms.is_empty(), "Expect at least 1 match arm");
106        let template_arm = arms.remove(0);
107        assert!(template_arm.guard.is_none(), "Expect no match arm guard");
108
109        Ok(Self {
110            template_ident,
111            substitutes,
112            match_exp: m.expr,
113            template_arm,
114            remaining_arms: arms,
115        })
116    }
117}
118
119impl MatchTemplate {
120    fn expand(self) -> TokenStream {
121        let Self {
122            template_ident,
123            substitutes,
124            match_exp,
125            template_arm,
126            remaining_arms,
127        } = self;
128        let match_arms = substitutes.into_iter().map(|substitute| {
129            let mut arm = template_arm.clone();
130            let (left_tokens, right_tokens) = match substitute {
131                Substitution::Identical(ident) => {
132                    (ident.clone().into_token_stream(), ident.into_token_stream())
133                }
134                Substitution::Map(left_ident, right_tokens) => {
135                    (left_ident.into_token_stream(), right_tokens)
136                }
137            };
138            arm.pat = replace_in_token_stream(
139                arm.pat,
140                Pat::parse_multi_with_leading_vert,
141                &template_ident,
142                &left_tokens,
143            );
144            arm.body =
145                replace_in_token_stream(arm.body, Parse::parse, &template_ident, &right_tokens);
146            arm
147        });
148        quote! {
149            match #match_exp {
150                #(#match_arms,)*
151                #(#remaining_arms,)*
152            }
153        }
154    }
155}
156
157#[derive(Debug)]
158enum Substitution {
159    Identical(Ident),
160    Map(Ident, TokenStream),
161}
162
163impl Parse for Substitution {
164    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
165        let left_ident = input.parse()?;
166        let fat_arrow: Option<Token![=>]> = input.parse()?;
167        if fat_arrow.is_some() {
168            let mut right_tokens: Vec<TokenTree> = vec![];
169            while !input.peek(Token![,]) && !input.is_empty() {
170                right_tokens.push(input.parse()?);
171            }
172            Ok(Substitution::Map(
173                left_ident,
174                right_tokens.into_iter().collect(),
175            ))
176        } else {
177            Ok(Substitution::Identical(left_ident))
178        }
179    }
180}
181
182fn replace_in_token_stream<T: ToTokens, P: Fn(ParseStream) -> syn::Result<T>>(
183    input: T,
184    parse: P,
185    from_ident: &Ident,
186    to_tokens: &TokenStream,
187) -> T {
188    let mut tokens = TokenStream::new();
189    input.to_tokens(&mut tokens);
190
191    let tokens: TokenStream = tokens
192        .into_iter()
193        .flat_map(|token| match token {
194            TokenTree::Ident(ident) if ident == *from_ident => to_tokens.clone(),
195            TokenTree::Group(group) => Group::new(
196                group.delimiter(),
197                replace_in_token_stream(group.stream(), Parse::parse, from_ident, to_tokens),
198            )
199            .into_token_stream(),
200            other => other.into(),
201        })
202        .collect();
203
204    syn::parse::Parser::parse2(parse, tokens).unwrap()
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_basic() {
213        let input = r#"
214            T = [Int, Real, Double],
215            match foo() {
216                EvalType::T => { panic!("{}", EvalType::T); },
217                EvalType::Other => unreachable!(),
218            }
219        "#;
220
221        let expect_output = r#"
222            match foo() {
223                EvalType::Int => { panic!("{}", EvalType::Int); },
224                EvalType::Real => { panic!("{}", EvalType::Real); },
225                EvalType::Double => { panic!("{}", EvalType::Double); },
226                EvalType::Other => unreachable!(),
227            }
228        "#;
229        let expect_output_stream: TokenStream = expect_output.parse().unwrap();
230
231        let mt: MatchTemplate = syn::parse_str(input).unwrap();
232        let output = mt.expand();
233        assert_eq!(output.to_string(), expect_output_stream.to_string());
234    }
235
236    #[test]
237    fn test_wildcard() {
238        let input = r#"
239            TT = [Foo, Bar],
240            match v {
241                VectorValue::TT => EvalType::TT,
242                _ => unreachable!(),
243            }
244        "#;
245
246        let expect_output = r#"
247            match v {
248                VectorValue::Foo => EvalType::Foo,
249                VectorValue::Bar => EvalType::Bar,
250                _ => unreachable!(),
251            }
252        "#;
253        let expect_output_stream: TokenStream = expect_output.parse().unwrap();
254
255        let mt: MatchTemplate = syn::parse_str(input).unwrap();
256        let output = mt.expand();
257        assert_eq!(output.to_string(), expect_output_stream.to_string());
258    }
259
260    #[test]
261    fn test_map() {
262        let input = r#"
263            TT = [Foo, Bar => Baz, Bark => <&'static Whooh>()],
264            match v {
265                VectorValue::TT => EvalType::TT,
266                EvalType::Other => unreachable!(),
267            }
268        "#;
269
270        let expect_output = r#"
271            match v {
272                VectorValue::Foo => EvalType::Foo,
273                VectorValue::Bar => EvalType::Baz,
274                VectorValue::Bark => EvalType:: < & 'static Whooh>(),
275                EvalType::Other => unreachable!(),
276            }
277        "#;
278        let expect_output_stream: TokenStream = expect_output.parse().unwrap();
279
280        let mt: MatchTemplate = syn::parse_str(input).unwrap();
281        let output = mt.expand();
282        assert_eq!(output.to_string(), expect_output_stream.to_string());
283    }
284}