askama_enum/
lib.rs

1// Copyright © 2022 René Kijewski <crates.io@k6i.de>
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
8// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
9// AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
10// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
11// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
12// OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
13// PERFORMANCE OF THIS SOFTWARE.
14
15#![forbid(unsafe_code)]
16#![deny(elided_lifetimes_in_paths)]
17#![deny(unreachable_pub)]
18
19//! ## askama-enum
20//!
21//! [![GitHub Workflow Status](https://img.shields.io/github/workflow/status/Kijewski/askama-enum/CI?logo=github)](https://github.com/Kijewski/askama-enum/actions/workflows/ci.yml)
22//! [![Crates.io](https://img.shields.io/crates/v/askama-enum?logo=rust)](https://crates.io/crates/askama-enum)
23//! ![Minimum supported Rust version](https://img.shields.io/badge/rustc-1.53+-important?logo=rust "Minimum Supported Rust Version")
24//! ![License](https://img.shields.io/badge/license-ISC%2FMIT%2FApache--2.0%20WITH%20LLVM--exception-informational?logo=apache)
25//!
26//! Implement different [Askama](https://crates.io/crates/askama) templates for different enum variants.
27//!
28//! You can add a default `#[template]` for variants that don't have a specific `#[template]` attribute.
29//! If omitted, then every variant needs its own `#[template]` attribute.
30//! The `#[template]` attribute is not interpreted, but simply copied to be used by askama.
31//!
32//! ```rust
33//! # #[cfg(feature = "askama")] fn main() {
34//! # use askama_enum::EnumTemplate;
35//! #[derive(EnumTemplate)]
36//! #[template(ext = "html", source = "default")] // default, optional
37//! enum MyEnum<'a, T: std::fmt::Display> {
38//!     // uses the default `#[template]`
39//!     A,
40//!
41//!     // uses specific `#[template]`
42//!     #[template(ext = "html", source = "B")]
43//!     B,
44//!
45//!     // you can use tuple structs
46//!     #[template(
47//!         ext = "html",
48//!         source = "{{self.0}} {{self.1}} {{self.2}} {{self.3}}",
49//!     )]
50//!     C(u8, &'a u16, u32, &'a u64),
51//!
52//!     // and named fields, too
53//!     #[template(ext = "html", source = "{{some}} {{fields}}")]
54//!     D { some: T, fields: T },
55//! }
56//!
57//! assert_eq!(
58//!     MyEnum::A::<&str>.to_string(),
59//!     "default",
60//! );
61//! assert_eq!(
62//!     MyEnum::B::<&str>.to_string(),
63//!     "B",
64//! );
65//! assert_eq!(
66//!     MyEnum::C::<&str>(1, &2, 3, &4).to_string(),
67//!     "1 2 3 4",
68//! );
69//! assert_eq!(
70//!     MyEnum::D { some: "some", fields: "fields" }.to_string(),
71//!     "some fields",
72//! );
73//! # }
74//! ```
75//!
76
77use std::iter::FromIterator;
78
79use proc_macro::TokenStream;
80use quote::{quote, ToTokens};
81use syn::punctuated::Punctuated;
82use syn::spanned::Spanned;
83use syn::{parse_quote, DeriveInput, Token};
84
85/// Implement different Askama templates for different enum variants
86///
87/// Please see the [crate] documentation for more examples.
88#[proc_macro_derive(EnumTemplate, attributes(template))]
89pub fn derive_enum_template(input: TokenStream) -> TokenStream {
90    let ast: syn::DeriveInput = syn::parse(input).unwrap();
91
92    let data = match &ast.data {
93        syn::Data::Enum(data) => data,
94        syn::Data::Struct(data) => {
95            return fail_at(
96                data.struct_token,
97                "#[derive(EnumTemplate)] can only be used with enums",
98            );
99        }
100        syn::Data::Union(data) => {
101            return fail_at(
102                data.union_token,
103                "#[derive(EnumTemplate)] can only be used with enums",
104            );
105        }
106    };
107
108    let mut global_meta = None;
109    for attr in &ast.attrs {
110        let meta_list = match attr.parse_meta() {
111            Ok(syn::Meta::List(attr)) => attr,
112            _ => continue,
113        };
114        if meta_list.path.is_ident("template") {
115            if global_meta.is_some() {
116                return fail_at(
117                    meta_list.path,
118                    "cannot have more than one #[template] attribute for a type",
119                );
120            }
121            global_meta = Some(attr);
122        }
123    }
124
125    let mut default_variant_name = None;
126    let variant_definitions =
127        make_variant_definitions(global_meta, &ast, data, &mut default_variant_name);
128    let variant_definitions = match variant_definitions {
129        Ok(variant_definitions) => variant_definitions,
130        Err(err) => return err,
131    };
132    let match_render_impl = make_render_impl(&ast, data, "render", Punctuated::new());
133    let match_render_into_impl = make_render_impl(
134        &ast,
135        data,
136        "render_into",
137        Punctuated::from_iter([syn::Expr::Path(parse_quote!(writer))]),
138    );
139    let dflt_or_fst_variant_name =
140        default_variant_name.unwrap_or_else(|| variant_definitions[0].ident.clone());
141
142    let mut static_ty_generics = quote!(::<);
143    for g in ast.generics.params.iter() {
144        match g {
145            syn::GenericParam::Type(param) => {
146                param.ident.to_tokens(&mut static_ty_generics);
147            }
148            syn::GenericParam::Const(param) => {
149                param.ident.to_tokens(&mut static_ty_generics);
150            }
151            _ => (),
152        }
153    }
154    static_ty_generics.extend(quote!(>));
155
156    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
157    let enum_name = &ast.ident;
158    let mut result = quote! {
159        impl #impl_generics askama::Template for #enum_name #ty_generics #where_clause {
160            fn render(&self) -> askama::Result<::std::string::String> {
161                #match_render_impl
162            }
163
164            fn render_into(
165                &self,
166                writer: &mut (impl ::std::fmt::Write + ?::std::marker::Sized),
167            ) -> askama::Result<()> {
168                #match_render_into_impl
169            }
170
171            const EXTENSION: ::std::option::Option<&'static str> =
172                <#dflt_or_fst_variant_name #static_ty_generics as askama::Template>::EXTENSION;
173            const SIZE_HINT: ::std::primitive::usize =
174                <#dflt_or_fst_variant_name #static_ty_generics as askama::Template>::SIZE_HINT;
175            const MIME_TYPE: &'static ::std::primitive::str =
176                <#dflt_or_fst_variant_name #static_ty_generics as askama::Template>::MIME_TYPE;
177        }
178    };
179    for variant_definition in variant_definitions {
180        variant_definition.to_tokens(&mut result);
181    }
182    let result = quote! {
183        #[allow(non_camel_case_types, non_snake_case, unused_qualifications)]
184        const _: () = {
185            #result
186
187            impl #impl_generics ::std::fmt::Display for #enum_name #ty_generics #where_clause {
188                #[inline]
189                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
190                    askama::Template::render_into(self, f).map_err(|_| ::std::fmt::Error {})
191                }
192            }
193        };
194    };
195    result.into()
196}
197
198fn make_render_impl(
199    ast: &DeriveInput,
200    data: &syn::DataEnum,
201    meth_name: &'static str,
202    args: Punctuated<syn::Expr, syn::token::Comma>,
203) -> syn::ExprMatch {
204    let mut generics = ast.generics.clone();
205    generics.params.push(parse_quote!('_));
206    let (_, inst_ty_generics, _) = generics.split_for_impl();
207    let inst_ty_generics = inst_ty_generics.as_turbofish();
208
209    let match_render_impl = data
210        .variants
211        .iter()
212        .enumerate()
213        .map(|(index, variant)| {
214            let self_variant_name = &variant.ident;
215
216            let variant_name = &format!("_{}_{}_{}", &ast.ident, index, variant.ident);
217            let variant_span = variant.ident.span();
218            let variant_name = syn::Ident::new(variant_name, variant_span);
219
220            let (pat, base) = match &variant.fields {
221                syn::Fields::Named(fields) => {
222                    let tmp_names = fields
223                        .named
224                        .iter()
225                        .enumerate()
226                        .map(|(index, field)| syn::Ident::new(&format!("_{}", index), field.span()))
227                        .collect::<Vec<_>>();
228
229                    let source_elems = tmp_names
230                        .iter()
231                        .zip(fields.named.iter())
232                        .map(|(dest, source)| syn::FieldPat {
233                            attrs: vec![],
234                            member: syn::Member::Named(source.ident.clone().unwrap()),
235                            colon_token: Some(Token![:](variant_span)),
236                            pat: parse_quote!(#dest),
237                        })
238                        .collect();
239                    let pat = syn::Pat::Struct(syn::PatStruct {
240                        attrs: vec![],
241                        path: parse_quote!(Self::#self_variant_name),
242                        brace_token: syn::token::Brace(variant_span),
243                        fields: source_elems,
244                        dot2_token: None,
245                    });
246
247                    let mut fields = tmp_names
248                        .iter()
249                        .zip(fields.named.iter())
250                        .map(|(tmp, source)| syn::FieldValue {
251                            attrs: vec![],
252                            member: syn::Member::Named(source.ident.clone().unwrap()),
253                            colon_token: Some(Token![:](variant_span)),
254                            expr: parse_quote!(#tmp),
255                        })
256                        .collect::<Punctuated<syn::FieldValue, Token![,]>>();
257                    fields.push(parse_quote!(#variant_name: ::std::marker::PhantomData));
258                    let base = syn::Expr::Struct(syn::ExprStruct {
259                        attrs: vec![],
260                        path: parse_quote!(#variant_name #inst_ty_generics),
261                        brace_token: syn::token::Brace(variant_span),
262                        fields,
263                        dot2_token: None,
264                        rest: None,
265                    });
266
267                    (pat, base)
268                }
269                syn::Fields::Unnamed(fields) => {
270                    let tmp_names = fields
271                        .unnamed
272                        .iter()
273                        .enumerate()
274                        .map(|(index, field)| syn::Ident::new(&format!("_{}", index), field.span()))
275                        .collect::<Vec<_>>();
276
277                    let source_elems = tmp_names
278                        .iter()
279                        .map(|ident| {
280                            syn::Pat::Ident(syn::PatIdent {
281                                attrs: vec![],
282                                by_ref: None,
283                                mutability: None,
284                                ident: ident.clone(),
285                                subpat: None,
286                            })
287                        })
288                        .collect();
289                    let pat = syn::Pat::TupleStruct(syn::PatTupleStruct {
290                        attrs: vec![],
291                        path: parse_quote!(Self::#self_variant_name),
292                        pat: syn::PatTuple {
293                            attrs: vec![],
294                            paren_token: syn::token::Paren(variant_span),
295                            elems: source_elems,
296                        },
297                    });
298
299                    let mut args = tmp_names
300                        .iter()
301                        .map(|field_name| {
302                            let expr: syn::Expr = parse_quote!(#field_name);
303                            expr
304                        })
305                        .collect::<Punctuated<syn::Expr, Token![,]>>();
306                    args.push(parse_quote!(::std::marker::PhantomData));
307                    let base = syn::Expr::Call(syn::ExprCall {
308                        attrs: vec![],
309                        func: parse_quote!(#variant_name #inst_ty_generics),
310                        paren_token: syn::token::Paren(variant_span),
311                        args,
312                    });
313
314                    (pat, base)
315                }
316                syn::Fields::Unit => {
317                    let pat = parse_quote!(Self :: #self_variant_name);
318                    let base =
319                        parse_quote!(#variant_name #inst_ty_generics(::std::marker::PhantomData));
320                    (pat, base)
321                }
322            };
323            let field = syn::Expr::Field(syn::ExprField {
324                attrs: vec![],
325                base: Box::new(base),
326                dot_token: Token![.](variant_span),
327                member: syn::Member::Named(syn::Ident::new(meth_name, variant_span)),
328            });
329            let call = syn::Expr::Call(syn::ExprCall {
330                attrs: vec![],
331                func: field.into(),
332                paren_token: syn::token::Paren(variant_span),
333                args: args.clone(),
334            });
335            syn::Arm {
336                attrs: vec![],
337                pat,
338                guard: None,
339                fat_arrow_token: Token![=>](variant_span),
340                body: call.into(),
341                comma: Some(Token![,](variant_span)),
342            }
343        })
344        .collect();
345    syn::ExprMatch {
346        attrs: vec![],
347        match_token: Token![match](data.brace_token.span),
348        expr: parse_quote!(self),
349        brace_token: syn::token::Brace(data.brace_token.span),
350        arms: match_render_impl,
351    }
352}
353
354fn make_variant_definitions(
355    global_meta: Option<&syn::Attribute>,
356    ast: &DeriveInput,
357    data: &syn::DataEnum,
358    default_variant_name: &mut Option<syn::Ident>,
359) -> Result<Vec<syn::DeriveInput>, TokenStream> {
360    data.variants
361        .iter()
362        .enumerate()
363        .map(|(index, variant)| {
364            let variant_name = &format!("_{}_{}_{}", &ast.ident, index, variant.ident);
365            let variant_span = variant.ident.span();
366            let variant_lifetime = syn::Lifetime::new(&format!("'{}", variant_name), variant_span);
367            let variant_name = syn::Ident::new(variant_name, variant_span);
368
369            let mut local_meta = None;
370            for attr in &variant.attrs {
371                let meta_list = match attr.parse_meta() {
372                    Ok(syn::Meta::List(attr)) => attr,
373                    _ => continue,
374                };
375                if meta_list.path.is_ident("template") {
376                    if local_meta.is_some() {
377                        return Err(fail_at(
378                            meta_list.path,
379                            "cannot have more than one #[template] attribute for a variant",
380                        ));
381                    }
382                    local_meta = Some(attr);
383                }
384            }
385            if local_meta.is_none() && default_variant_name.is_none() {
386                *default_variant_name = Some(variant_name.clone());
387            }
388            let meta = match local_meta.or(global_meta) {
389                Some(meta) => meta,
390                None => return Err(fail_at(&variant.ident, "need a #[template] attribute")),
391            };
392
393            let (_, ty_generics, _) = ast.generics.split_for_impl();
394            let enum_name = &ast.ident;
395            let phantom_type = parse_quote!(::std::marker::PhantomData::<
396                & #variant_lifetime #enum_name #ty_generics,
397            >);
398            let fields = match &variant.fields {
399                syn::Fields::Named(fields) => {
400                    let mut fields = fields
401                        .named
402                        .iter()
403                        .map(|field| {
404                            let mut field = field.clone();
405                            field.ty = syn::Type::Reference(syn::TypeReference {
406                                and_token: Token![&](field.span()),
407                                lifetime: Some(variant_lifetime.clone()),
408                                mutability: None,
409                                elem: field.ty.into(),
410                            });
411                            field
412                        })
413                        .collect::<Vec<syn::Field>>();
414                    fields.push(syn::Field {
415                        attrs: vec![],
416                        vis: syn::Visibility::Inherited,
417                        ident: Some(variant_name.clone()),
418                        colon_token: Some(Token![:](variant_span)),
419                        ty: phantom_type,
420                    });
421                    syn::Fields::Named(syn::FieldsNamed {
422                        brace_token: syn::token::Brace(variant_span),
423                        named: Punctuated::from_iter(fields),
424                    })
425                }
426                syn::Fields::Unnamed(fields) => {
427                    let mut fields = fields
428                        .unnamed
429                        .iter()
430                        .map(|field| {
431                            let mut field = field.clone();
432                            field.ty = syn::Type::Reference(syn::TypeReference {
433                                and_token: Token![&](field.span()),
434                                lifetime: Some(variant_lifetime.clone()),
435                                mutability: None,
436                                elem: field.ty.into(),
437                            });
438                            field
439                        })
440                        .collect::<Vec<syn::Field>>();
441                    fields.push(syn::Field {
442                        attrs: vec![],
443                        vis: syn::Visibility::Inherited,
444                        ident: None,
445                        colon_token: None,
446                        ty: phantom_type,
447                    });
448                    syn::Fields::Unnamed(syn::FieldsUnnamed {
449                        paren_token: syn::token::Paren(variant_span),
450                        unnamed: Punctuated::from_iter(fields),
451                    })
452                }
453                syn::Fields::Unit => syn::Fields::Unnamed(syn::FieldsUnnamed {
454                    paren_token: syn::token::Paren(variant_span),
455                    unnamed: Punctuated::from_iter([syn::Field {
456                        attrs: vec![],
457                        vis: syn::Visibility::Inherited,
458                        ident: None,
459                        colon_token: None,
460                        ty: phantom_type,
461                    }]),
462                }),
463            };
464
465            let mut generics = ast.generics.clone();
466            generics.params.push(parse_quote!(#variant_lifetime));
467            Ok(syn::DeriveInput {
468                attrs: vec![
469                    parse_quote!(#[::std::prelude::v1::derive(
470                        askama::Template,
471                        ::std::prelude::v1::Clone,
472                        ::std::prelude::v1::Copy,
473                        ::std::prelude::v1::Debug,
474                    )]),
475                    meta.clone(),
476                ],
477                vis: syn::Visibility::Inherited,
478                ident: variant_name,
479                generics,
480                data: syn::Data::Struct(syn::DataStruct {
481                    struct_token: Token![struct](variant_span),
482                    fields,
483                    semi_token: None,
484                }),
485            })
486        })
487        .collect()
488}
489
490fn fail_at(spanned: impl Spanned, msg: &str) -> TokenStream {
491    syn::Error::new(spanned.span(), msg)
492        .into_compile_error()
493        .into()
494}