1use proc_macro2::{Group, Ident, TokenStream, TokenTree};
67use quote::{quote, ToTokens};
68use syn::{
69 bracketed,
70 parse::{Parse, ParseStream},
71 parse_macro_input,
72 punctuated::Punctuated,
73 Arm, Expr, ExprMatch, Pat, Token,
74};
75
76#[proc_macro]
80pub fn match_template(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
81 let mt = parse_macro_input!(input as MatchTemplate);
82 mt.expand().into()
83}
84
85struct MatchTemplate {
86 template_ident: Ident,
87 substitutes: Punctuated<Substitution, Token![,]>,
88 match_exp: Box<Expr>,
89 template_arm: Arm,
90 remaining_arms: Vec<Arm>,
91}
92
93impl Parse for MatchTemplate {
94 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
95 let template_ident = input.parse()?;
96 input.parse::<Token![=]>()?;
97 let substitutes_tokens;
98 bracketed!(substitutes_tokens in input);
99 let substitutes =
100 Punctuated::<Substitution, Token![,]>::parse_terminated(&substitutes_tokens)?;
101 input.parse::<Token![,]>()?;
102 let m: ExprMatch = input.parse()?;
103 let mut arms = m.arms;
104 arms.iter_mut().for_each(|arm| arm.comma = None);
105 assert!(!arms.is_empty(), "Expect at least 1 match arm");
106 let template_arm = arms.remove(0);
107 assert!(template_arm.guard.is_none(), "Expect no match arm guard");
108
109 Ok(Self {
110 template_ident,
111 substitutes,
112 match_exp: m.expr,
113 template_arm,
114 remaining_arms: arms,
115 })
116 }
117}
118
119impl MatchTemplate {
120 fn expand(self) -> TokenStream {
121 let Self {
122 template_ident,
123 substitutes,
124 match_exp,
125 template_arm,
126 remaining_arms,
127 } = self;
128 let match_arms = substitutes.into_iter().map(|substitute| {
129 let mut arm = template_arm.clone();
130 let (left_tokens, right_tokens) = match substitute {
131 Substitution::Identical(ident) => {
132 (ident.clone().into_token_stream(), ident.into_token_stream())
133 }
134 Substitution::Map(left_ident, right_tokens) => {
135 (left_ident.into_token_stream(), right_tokens)
136 }
137 };
138 arm.pat = replace_in_token_stream(
139 arm.pat,
140 Pat::parse_multi_with_leading_vert,
141 &template_ident,
142 &left_tokens,
143 );
144 arm.body =
145 replace_in_token_stream(arm.body, Parse::parse, &template_ident, &right_tokens);
146 arm
147 });
148 quote! {
149 match #match_exp {
150 #(#match_arms,)*
151 #(#remaining_arms,)*
152 }
153 }
154 }
155}
156
157#[derive(Debug)]
158enum Substitution {
159 Identical(Ident),
160 Map(Ident, TokenStream),
161}
162
163impl Parse for Substitution {
164 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
165 let left_ident = input.parse()?;
166 let fat_arrow: Option<Token![=>]> = input.parse()?;
167 if fat_arrow.is_some() {
168 let mut right_tokens: Vec<TokenTree> = vec![];
169 while !input.peek(Token![,]) && !input.is_empty() {
170 right_tokens.push(input.parse()?);
171 }
172 Ok(Substitution::Map(
173 left_ident,
174 right_tokens.into_iter().collect(),
175 ))
176 } else {
177 Ok(Substitution::Identical(left_ident))
178 }
179 }
180}
181
182fn replace_in_token_stream<T: ToTokens, P: Fn(ParseStream) -> syn::Result<T>>(
183 input: T,
184 parse: P,
185 from_ident: &Ident,
186 to_tokens: &TokenStream,
187) -> T {
188 let mut tokens = TokenStream::new();
189 input.to_tokens(&mut tokens);
190
191 let tokens: TokenStream = tokens
192 .into_iter()
193 .flat_map(|token| match token {
194 TokenTree::Ident(ident) if ident == *from_ident => to_tokens.clone(),
195 TokenTree::Group(group) => Group::new(
196 group.delimiter(),
197 replace_in_token_stream(group.stream(), Parse::parse, from_ident, to_tokens),
198 )
199 .into_token_stream(),
200 other => other.into(),
201 })
202 .collect();
203
204 syn::parse::Parser::parse2(parse, tokens).unwrap()
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_basic() {
213 let input = r#"
214 T = [Int, Real, Double],
215 match foo() {
216 EvalType::T => { panic!("{}", EvalType::T); },
217 EvalType::Other => unreachable!(),
218 }
219 "#;
220
221 let expect_output = r#"
222 match foo() {
223 EvalType::Int => { panic!("{}", EvalType::Int); },
224 EvalType::Real => { panic!("{}", EvalType::Real); },
225 EvalType::Double => { panic!("{}", EvalType::Double); },
226 EvalType::Other => unreachable!(),
227 }
228 "#;
229 let expect_output_stream: TokenStream = expect_output.parse().unwrap();
230
231 let mt: MatchTemplate = syn::parse_str(input).unwrap();
232 let output = mt.expand();
233 assert_eq!(output.to_string(), expect_output_stream.to_string());
234 }
235
236 #[test]
237 fn test_wildcard() {
238 let input = r#"
239 TT = [Foo, Bar],
240 match v {
241 VectorValue::TT => EvalType::TT,
242 _ => unreachable!(),
243 }
244 "#;
245
246 let expect_output = r#"
247 match v {
248 VectorValue::Foo => EvalType::Foo,
249 VectorValue::Bar => EvalType::Bar,
250 _ => unreachable!(),
251 }
252 "#;
253 let expect_output_stream: TokenStream = expect_output.parse().unwrap();
254
255 let mt: MatchTemplate = syn::parse_str(input).unwrap();
256 let output = mt.expand();
257 assert_eq!(output.to_string(), expect_output_stream.to_string());
258 }
259
260 #[test]
261 fn test_map() {
262 let input = r#"
263 TT = [Foo, Bar => Baz, Bark => <&'static Whooh>()],
264 match v {
265 VectorValue::TT => EvalType::TT,
266 EvalType::Other => unreachable!(),
267 }
268 "#;
269
270 let expect_output = r#"
271 match v {
272 VectorValue::Foo => EvalType::Foo,
273 VectorValue::Bar => EvalType::Baz,
274 VectorValue::Bark => EvalType:: < & 'static Whooh>(),
275 EvalType::Other => unreachable!(),
276 }
277 "#;
278 let expect_output_stream: TokenStream = expect_output.parse().unwrap();
279
280 let mt: MatchTemplate = syn::parse_str(input).unwrap();
281 let output = mt.expand();
282 assert_eq!(output.to_string(), expect_output_stream.to_string());
283 }
284}