1#![forbid(unsafe_code)]
16#![deny(elided_lifetimes_in_paths)]
17#![deny(unreachable_pub)]
18
19use std::iter::FromIterator;
78
79use proc_macro::TokenStream;
80use quote::{quote, ToTokens};
81use syn::punctuated::Punctuated;
82use syn::spanned::Spanned;
83use syn::{parse_quote, DeriveInput, Token};
84
85#[proc_macro_derive(EnumTemplate, attributes(template))]
89pub fn derive_enum_template(input: TokenStream) -> TokenStream {
90 let ast: syn::DeriveInput = syn::parse(input).unwrap();
91
92 let data = match &ast.data {
93 syn::Data::Enum(data) => data,
94 syn::Data::Struct(data) => {
95 return fail_at(
96 data.struct_token,
97 "#[derive(EnumTemplate)] can only be used with enums",
98 );
99 }
100 syn::Data::Union(data) => {
101 return fail_at(
102 data.union_token,
103 "#[derive(EnumTemplate)] can only be used with enums",
104 );
105 }
106 };
107
108 let mut global_meta = None;
109 for attr in &ast.attrs {
110 let meta_list = match attr.parse_meta() {
111 Ok(syn::Meta::List(attr)) => attr,
112 _ => continue,
113 };
114 if meta_list.path.is_ident("template") {
115 if global_meta.is_some() {
116 return fail_at(
117 meta_list.path,
118 "cannot have more than one #[template] attribute for a type",
119 );
120 }
121 global_meta = Some(attr);
122 }
123 }
124
125 let mut default_variant_name = None;
126 let variant_definitions =
127 make_variant_definitions(global_meta, &ast, data, &mut default_variant_name);
128 let variant_definitions = match variant_definitions {
129 Ok(variant_definitions) => variant_definitions,
130 Err(err) => return err,
131 };
132 let match_render_impl = make_render_impl(&ast, data, "render", Punctuated::new());
133 let match_render_into_impl = make_render_impl(
134 &ast,
135 data,
136 "render_into",
137 Punctuated::from_iter([syn::Expr::Path(parse_quote!(writer))]),
138 );
139 let dflt_or_fst_variant_name =
140 default_variant_name.unwrap_or_else(|| variant_definitions[0].ident.clone());
141
142 let mut static_ty_generics = quote!(::<);
143 for g in ast.generics.params.iter() {
144 match g {
145 syn::GenericParam::Type(param) => {
146 param.ident.to_tokens(&mut static_ty_generics);
147 }
148 syn::GenericParam::Const(param) => {
149 param.ident.to_tokens(&mut static_ty_generics);
150 }
151 _ => (),
152 }
153 }
154 static_ty_generics.extend(quote!(>));
155
156 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
157 let enum_name = &ast.ident;
158 let mut result = quote! {
159 impl #impl_generics askama::Template for #enum_name #ty_generics #where_clause {
160 fn render(&self) -> askama::Result<::std::string::String> {
161 #match_render_impl
162 }
163
164 fn render_into(
165 &self,
166 writer: &mut (impl ::std::fmt::Write + ?::std::marker::Sized),
167 ) -> askama::Result<()> {
168 #match_render_into_impl
169 }
170
171 const EXTENSION: ::std::option::Option<&'static str> =
172 <#dflt_or_fst_variant_name #static_ty_generics as askama::Template>::EXTENSION;
173 const SIZE_HINT: ::std::primitive::usize =
174 <#dflt_or_fst_variant_name #static_ty_generics as askama::Template>::SIZE_HINT;
175 const MIME_TYPE: &'static ::std::primitive::str =
176 <#dflt_or_fst_variant_name #static_ty_generics as askama::Template>::MIME_TYPE;
177 }
178 };
179 for variant_definition in variant_definitions {
180 variant_definition.to_tokens(&mut result);
181 }
182 let result = quote! {
183 #[allow(non_camel_case_types, non_snake_case, unused_qualifications)]
184 const _: () = {
185 #result
186
187 impl #impl_generics ::std::fmt::Display for #enum_name #ty_generics #where_clause {
188 #[inline]
189 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
190 askama::Template::render_into(self, f).map_err(|_| ::std::fmt::Error {})
191 }
192 }
193 };
194 };
195 result.into()
196}
197
198fn make_render_impl(
199 ast: &DeriveInput,
200 data: &syn::DataEnum,
201 meth_name: &'static str,
202 args: Punctuated<syn::Expr, syn::token::Comma>,
203) -> syn::ExprMatch {
204 let mut generics = ast.generics.clone();
205 generics.params.push(parse_quote!('_));
206 let (_, inst_ty_generics, _) = generics.split_for_impl();
207 let inst_ty_generics = inst_ty_generics.as_turbofish();
208
209 let match_render_impl = data
210 .variants
211 .iter()
212 .enumerate()
213 .map(|(index, variant)| {
214 let self_variant_name = &variant.ident;
215
216 let variant_name = &format!("_{}_{}_{}", &ast.ident, index, variant.ident);
217 let variant_span = variant.ident.span();
218 let variant_name = syn::Ident::new(variant_name, variant_span);
219
220 let (pat, base) = match &variant.fields {
221 syn::Fields::Named(fields) => {
222 let tmp_names = fields
223 .named
224 .iter()
225 .enumerate()
226 .map(|(index, field)| syn::Ident::new(&format!("_{}", index), field.span()))
227 .collect::<Vec<_>>();
228
229 let source_elems = tmp_names
230 .iter()
231 .zip(fields.named.iter())
232 .map(|(dest, source)| syn::FieldPat {
233 attrs: vec![],
234 member: syn::Member::Named(source.ident.clone().unwrap()),
235 colon_token: Some(Token),
236 pat: parse_quote!(#dest),
237 })
238 .collect();
239 let pat = syn::Pat::Struct(syn::PatStruct {
240 attrs: vec![],
241 path: parse_quote!(Self::#self_variant_name),
242 brace_token: syn::token::Brace(variant_span),
243 fields: source_elems,
244 dot2_token: None,
245 });
246
247 let mut fields = tmp_names
248 .iter()
249 .zip(fields.named.iter())
250 .map(|(tmp, source)| syn::FieldValue {
251 attrs: vec![],
252 member: syn::Member::Named(source.ident.clone().unwrap()),
253 colon_token: Some(Token),
254 expr: parse_quote!(#tmp),
255 })
256 .collect::<Punctuated<syn::FieldValue, Token![,]>>();
257 fields.push(parse_quote!(#variant_name: ::std::marker::PhantomData));
258 let base = syn::Expr::Struct(syn::ExprStruct {
259 attrs: vec![],
260 path: parse_quote!(#variant_name #inst_ty_generics),
261 brace_token: syn::token::Brace(variant_span),
262 fields,
263 dot2_token: None,
264 rest: None,
265 });
266
267 (pat, base)
268 }
269 syn::Fields::Unnamed(fields) => {
270 let tmp_names = fields
271 .unnamed
272 .iter()
273 .enumerate()
274 .map(|(index, field)| syn::Ident::new(&format!("_{}", index), field.span()))
275 .collect::<Vec<_>>();
276
277 let source_elems = tmp_names
278 .iter()
279 .map(|ident| {
280 syn::Pat::Ident(syn::PatIdent {
281 attrs: vec![],
282 by_ref: None,
283 mutability: None,
284 ident: ident.clone(),
285 subpat: None,
286 })
287 })
288 .collect();
289 let pat = syn::Pat::TupleStruct(syn::PatTupleStruct {
290 attrs: vec![],
291 path: parse_quote!(Self::#self_variant_name),
292 pat: syn::PatTuple {
293 attrs: vec![],
294 paren_token: syn::token::Paren(variant_span),
295 elems: source_elems,
296 },
297 });
298
299 let mut args = tmp_names
300 .iter()
301 .map(|field_name| {
302 let expr: syn::Expr = parse_quote!(#field_name);
303 expr
304 })
305 .collect::<Punctuated<syn::Expr, Token![,]>>();
306 args.push(parse_quote!(::std::marker::PhantomData));
307 let base = syn::Expr::Call(syn::ExprCall {
308 attrs: vec![],
309 func: parse_quote!(#variant_name #inst_ty_generics),
310 paren_token: syn::token::Paren(variant_span),
311 args,
312 });
313
314 (pat, base)
315 }
316 syn::Fields::Unit => {
317 let pat = parse_quote!(Self :: #self_variant_name);
318 let base =
319 parse_quote!(#variant_name #inst_ty_generics(::std::marker::PhantomData));
320 (pat, base)
321 }
322 };
323 let field = syn::Expr::Field(syn::ExprField {
324 attrs: vec![],
325 base: Box::new(base),
326 dot_token: Token,
327 member: syn::Member::Named(syn::Ident::new(meth_name, variant_span)),
328 });
329 let call = syn::Expr::Call(syn::ExprCall {
330 attrs: vec![],
331 func: field.into(),
332 paren_token: syn::token::Paren(variant_span),
333 args: args.clone(),
334 });
335 syn::Arm {
336 attrs: vec![],
337 pat,
338 guard: None,
339 fat_arrow_token: Token,
340 body: call.into(),
341 comma: Some(Token),
342 }
343 })
344 .collect();
345 syn::ExprMatch {
346 attrs: vec![],
347 match_token: Token,
348 expr: parse_quote!(self),
349 brace_token: syn::token::Brace(data.brace_token.span),
350 arms: match_render_impl,
351 }
352}
353
354fn make_variant_definitions(
355 global_meta: Option<&syn::Attribute>,
356 ast: &DeriveInput,
357 data: &syn::DataEnum,
358 default_variant_name: &mut Option<syn::Ident>,
359) -> Result<Vec<syn::DeriveInput>, TokenStream> {
360 data.variants
361 .iter()
362 .enumerate()
363 .map(|(index, variant)| {
364 let variant_name = &format!("_{}_{}_{}", &ast.ident, index, variant.ident);
365 let variant_span = variant.ident.span();
366 let variant_lifetime = syn::Lifetime::new(&format!("'{}", variant_name), variant_span);
367 let variant_name = syn::Ident::new(variant_name, variant_span);
368
369 let mut local_meta = None;
370 for attr in &variant.attrs {
371 let meta_list = match attr.parse_meta() {
372 Ok(syn::Meta::List(attr)) => attr,
373 _ => continue,
374 };
375 if meta_list.path.is_ident("template") {
376 if local_meta.is_some() {
377 return Err(fail_at(
378 meta_list.path,
379 "cannot have more than one #[template] attribute for a variant",
380 ));
381 }
382 local_meta = Some(attr);
383 }
384 }
385 if local_meta.is_none() && default_variant_name.is_none() {
386 *default_variant_name = Some(variant_name.clone());
387 }
388 let meta = match local_meta.or(global_meta) {
389 Some(meta) => meta,
390 None => return Err(fail_at(&variant.ident, "need a #[template] attribute")),
391 };
392
393 let (_, ty_generics, _) = ast.generics.split_for_impl();
394 let enum_name = &ast.ident;
395 let phantom_type = parse_quote!(::std::marker::PhantomData::<
396 & #variant_lifetime #enum_name #ty_generics,
397 >);
398 let fields = match &variant.fields {
399 syn::Fields::Named(fields) => {
400 let mut fields = fields
401 .named
402 .iter()
403 .map(|field| {
404 let mut field = field.clone();
405 field.ty = syn::Type::Reference(syn::TypeReference {
406 and_token: Token),
407 lifetime: Some(variant_lifetime.clone()),
408 mutability: None,
409 elem: field.ty.into(),
410 });
411 field
412 })
413 .collect::<Vec<syn::Field>>();
414 fields.push(syn::Field {
415 attrs: vec![],
416 vis: syn::Visibility::Inherited,
417 ident: Some(variant_name.clone()),
418 colon_token: Some(Token),
419 ty: phantom_type,
420 });
421 syn::Fields::Named(syn::FieldsNamed {
422 brace_token: syn::token::Brace(variant_span),
423 named: Punctuated::from_iter(fields),
424 })
425 }
426 syn::Fields::Unnamed(fields) => {
427 let mut fields = fields
428 .unnamed
429 .iter()
430 .map(|field| {
431 let mut field = field.clone();
432 field.ty = syn::Type::Reference(syn::TypeReference {
433 and_token: Token),
434 lifetime: Some(variant_lifetime.clone()),
435 mutability: None,
436 elem: field.ty.into(),
437 });
438 field
439 })
440 .collect::<Vec<syn::Field>>();
441 fields.push(syn::Field {
442 attrs: vec![],
443 vis: syn::Visibility::Inherited,
444 ident: None,
445 colon_token: None,
446 ty: phantom_type,
447 });
448 syn::Fields::Unnamed(syn::FieldsUnnamed {
449 paren_token: syn::token::Paren(variant_span),
450 unnamed: Punctuated::from_iter(fields),
451 })
452 }
453 syn::Fields::Unit => syn::Fields::Unnamed(syn::FieldsUnnamed {
454 paren_token: syn::token::Paren(variant_span),
455 unnamed: Punctuated::from_iter([syn::Field {
456 attrs: vec![],
457 vis: syn::Visibility::Inherited,
458 ident: None,
459 colon_token: None,
460 ty: phantom_type,
461 }]),
462 }),
463 };
464
465 let mut generics = ast.generics.clone();
466 generics.params.push(parse_quote!(#variant_lifetime));
467 Ok(syn::DeriveInput {
468 attrs: vec![
469 parse_quote!(#[::std::prelude::v1::derive(
470 askama::Template,
471 ::std::prelude::v1::Clone,
472 ::std::prelude::v1::Copy,
473 ::std::prelude::v1::Debug,
474 )]),
475 meta.clone(),
476 ],
477 vis: syn::Visibility::Inherited,
478 ident: variant_name,
479 generics,
480 data: syn::Data::Struct(syn::DataStruct {
481 struct_token: Token,
482 fields,
483 semi_token: None,
484 }),
485 })
486 })
487 .collect()
488}
489
490fn fail_at(spanned: impl Spanned, msg: &str) -> TokenStream {
491 syn::Error::new(spanned.span(), msg)
492 .into_compile_error()
493 .into()
494}