is_macro/
lib.rs

1extern crate proc_macro;
2
3use heck::ToSnakeCase;
4use proc_macro2::Span;
5use quote::{quote, ToTokens};
6use syn::{
7    parse,
8    parse::Parse,
9    parse2, parse_quote,
10    punctuated::{Pair, Punctuated},
11    spanned::Spanned,
12    Data, DataEnum, DeriveInput, Expr, ExprLit, Field, Fields, Generics, Ident, ImplItem, ItemImpl,
13    Lit, Meta, MetaNameValue, Path, Token, Type, TypePath, TypeReference, TypeTuple, WhereClause,
14};
15
16/// A proc macro to generate methods like is_variant / expect_variant.
17///
18///
19/// # Example
20///
21/// ```rust
22/// 
23/// use is_macro::Is;
24/// #[derive(Debug, Is)]
25/// pub enum Enum<T> {
26///     A,
27///     B(T),
28///     C(Option<T>),
29/// }
30///
31/// // Rust's type inference cannot handle this.
32/// assert!(Enum::<()>::A.is_a());
33///
34/// assert_eq!(Enum::B(String::from("foo")).b(), Some(String::from("foo")));
35///
36/// assert_eq!(Enum::B(String::from("foo")).expect_b(), String::from("foo"));
37/// ```
38///
39/// # Renaming
40///
41/// ```rust
42/// 
43/// use is_macro::Is;
44/// #[derive(Debug, Is)]
45/// pub enum Enum {
46///     #[is(name = "video_mp4")]
47///     VideoMp4,
48/// }
49///
50/// assert!(Enum::VideoMp4.is_video_mp4());
51/// ```
52#[proc_macro_derive(Is, attributes(is))]
53pub fn is(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54    let input: DeriveInput = syn::parse(input).expect("failed to parse derive input");
55    let generics: Generics = input.generics.clone();
56
57    let items = match input.data {
58        Data::Enum(e) => expand(e),
59        _ => panic!("`Is` can be applied only on enums"),
60    };
61
62    ItemImpl {
63        attrs: vec![],
64        defaultness: None,
65        unsafety: None,
66        impl_token: Default::default(),
67        generics: Default::default(),
68        trait_: None,
69        self_ty: Box::new(Type::Path(TypePath {
70            qself: None,
71            path: Path::from(input.ident),
72        })),
73        brace_token: Default::default(),
74        items,
75    }
76    .with_generics(generics)
77    .into_token_stream()
78    .into()
79}
80
81#[derive(Debug)]
82struct Input {
83    name: String,
84}
85
86impl Parse for Input {
87    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
88        let _: Ident = input.parse()?;
89        let _: Token![=] = input.parse()?;
90
91        let name = input.parse::<ExprLit>()?;
92
93        Ok(Input {
94            name: match name.lit {
95                Lit::Str(s) => s.value(),
96                _ => panic!("is(name = ...) expects a string literal"),
97            },
98        })
99    }
100}
101
102fn expand(input: DataEnum) -> Vec<ImplItem> {
103    let mut items = vec![];
104
105    for v in &input.variants {
106        let attrs = v
107            .attrs
108            .iter()
109            .filter(|attr| attr.path().is_ident("is"))
110            .collect::<Vec<_>>();
111        if attrs.len() >= 2 {
112            panic!("derive(Is) expects no attribute or one attribute")
113        }
114        let i = match attrs.into_iter().next() {
115            None => Input {
116                name: {
117                    v.ident.to_string().to_snake_case()
118                    //
119                },
120            },
121            Some(attr) => {
122                //
123
124                let mut input = Input {
125                    name: Default::default(),
126                };
127
128                let mut apply = |v: &MetaNameValue| {
129                    assert!(
130                        v.path.is_ident("name"),
131                        "Currently, is() only supports `is(name = 'foo')`"
132                    );
133
134                    input.name = match &v.value {
135                        Expr::Lit(ExprLit {
136                            lit: Lit::Str(s), ..
137                        }) => s.value(),
138                        _ => unimplemented!(
139                            "is(): name must be a string literal but {:?} is provided",
140                            v.value
141                        ),
142                    };
143                };
144
145                match &attr.meta {
146                    Meta::NameValue(v) => {
147                        //
148                        apply(v)
149                    }
150                    Meta::List(l) => {
151                        // Handle is(name = "foo")
152                        input = parse2(l.tokens.clone()).expect("failed to parse input");
153                    }
154                    _ => unimplemented!("is({:?})", attr.meta),
155                }
156
157                input
158            }
159        };
160
161        let name = &*i.name;
162        {
163            let name_of_is = Ident::new(&format!("is_{name}"), v.ident.span());
164            let docs_of_is = format!(
165                "Returns `true` if `self` is of variant [`{variant}`].\n\n[`{variant}`]: \
166                 #variant.{variant}",
167                variant = v.ident,
168            );
169
170            let variant = &v.ident;
171
172            let item_impl: ItemImpl = parse_quote!(
173                impl Type {
174                    #[doc = #docs_of_is]
175                    #[inline]
176                    pub const fn #name_of_is(&self) -> bool {
177                        match *self {
178                            Self::#variant { .. } => true,
179                            _ => false,
180                        }
181                    }
182                }
183            );
184
185            items.extend(item_impl.items);
186        }
187
188        {
189            let name_of_cast = Ident::new(&format!("as_{name}"), v.ident.span());
190            let name_of_cast_mut = Ident::new(&format!("as_mut_{name}"), v.ident.span());
191            let name_of_expect = Ident::new(&format!("expect_{name}"), v.ident.span());
192            let name_of_take = Ident::new(name, v.ident.span());
193
194            let docs_of_cast = format!(
195                "Returns `Some` if `self` is a reference of variant [`{variant}`], and `None` \
196                 otherwise.\n\n[`{variant}`]: #variant.{variant}",
197                variant = v.ident,
198            );
199            let docs_of_cast_mut = format!(
200                "Returns `Some` if `self` is a mutable reference of variant [`{variant}`], and \
201                 `None` otherwise.\n\n[`{variant}`]: #variant.{variant}",
202                variant = v.ident,
203            );
204            let docs_of_expect = format!(
205                "Unwraps the value, yielding the content of [`{variant}`].\n\n# Panics\n\nPanics \
206                 if the value is not [`{variant}`], with a panic message including the content of \
207                 `self`.\n\n[`{variant}`]: #variant.{variant}",
208                variant = v.ident,
209            );
210            let docs_of_take = format!(
211                "Returns `Some` if `self` is of variant [`{variant}`], and `None` \
212                 otherwise.\n\n[`{variant}`]: #variant.{variant}",
213                variant = v.ident,
214            );
215
216            if let Fields::Unnamed(fields) = &v.fields {
217                let types = fields.unnamed.iter().map(|f| f.ty.clone());
218                let cast_ty = types_to_type(types.clone().map(|ty| add_ref(false, ty)));
219                let cast_ty_mut = types_to_type(types.clone().map(|ty| add_ref(true, ty)));
220                let ty = types_to_type(types);
221
222                let mut fields: Punctuated<Ident, Token![,]> = fields
223                    .unnamed
224                    .clone()
225                    .into_pairs()
226                    .enumerate()
227                    .map(|(i, pair)| {
228                        let handle = |f: Field| {
229                            //
230                            Ident::new(&format!("v{i}"), f.span())
231                        };
232                        match pair {
233                            Pair::Punctuated(v, p) => Pair::Punctuated(handle(v), p),
234                            Pair::End(v) => Pair::End(handle(v)),
235                        }
236                    })
237                    .collect();
238
239                // Make sure that we don't have any trailing punctuation
240                // This ensure that if we have a single unnamed field,
241                // we will produce a value of the form `(v)`,
242                // not a single-element tuple `(v,)`
243                if let Some(mut pair) = fields.pop() {
244                    if let Pair::Punctuated(v, _) = pair {
245                        pair = Pair::End(v);
246                    }
247                    fields.extend(std::iter::once(pair));
248                }
249
250                let variant = &v.ident;
251
252                let item_impl: ItemImpl = parse_quote!(
253                    impl #ty {
254                        #[doc = #docs_of_cast]
255                        #[inline]
256                        pub fn #name_of_cast(&self) -> Option<#cast_ty> {
257                            match self {
258                                Self::#variant(#fields) => Some((#fields)),
259                                _ => None,
260                            }
261                        }
262
263                        #[doc = #docs_of_cast_mut]
264                        #[inline]
265                        pub fn #name_of_cast_mut(&mut self) -> Option<#cast_ty_mut> {
266                            match self {
267                                Self::#variant(#fields) => Some((#fields)),
268                                _ => None,
269                            }
270                        }
271
272                        #[doc = #docs_of_expect]
273                        #[inline]
274                        pub fn #name_of_expect(self) -> #ty
275                        where
276                            Self: ::std::fmt::Debug,
277                        {
278                            match self {
279                                Self::#variant(#fields) => (#fields),
280                                _ => panic!("called expect on {:?}", self),
281                            }
282                        }
283
284                        #[doc = #docs_of_take]
285                        #[inline]
286                        pub fn #name_of_take(self) -> Option<#ty> {
287                            match self {
288                                Self::#variant(#fields) => Some((#fields)),
289                                _ => None,
290                            }
291                        }
292                    }
293                );
294
295                items.extend(item_impl.items);
296            }
297        }
298    }
299
300    items
301}
302
303fn types_to_type(types: impl Iterator<Item = Type>) -> Type {
304    let mut types: Punctuated<_, _> = types.collect();
305    if types.len() == 1 {
306        types.pop().expect("len is 1").into_value()
307    } else {
308        TypeTuple {
309            paren_token: Default::default(),
310            elems: types,
311        }
312        .into()
313    }
314}
315
316fn add_ref(mutable: bool, ty: Type) -> Type {
317    Type::Reference(TypeReference {
318        and_token: Default::default(),
319        lifetime: None,
320        mutability: if mutable {
321            Some(Default::default())
322        } else {
323            None
324        },
325        elem: Box::new(ty),
326    })
327}
328
329/// Extension trait for `ItemImpl` (impl block).
330trait ItemImplExt {
331    /// Instead of
332    ///
333    /// ```rust,ignore
334    /// let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
335    ///
336    /// let item: Item = Quote::new(def_site::<Span>())
337    ///     .quote_with(smart_quote!(
338    /// Vars {
339    /// Type: type_name,
340    /// impl_generics,
341    /// ty_generics,
342    /// where_clause,
343    /// },
344    /// {
345    /// impl impl_generics ::swc_common::AstNode for Type ty_generics
346    /// where_clause {}
347    /// }
348    /// )).parse();
349    /// ```
350    ///
351    /// You can use this like
352    ///
353    /// ```rust,ignore
354    // let item = Quote::new(def_site::<Span>())
355    ///     .quote_with(smart_quote!(Vars { Type: type_name }, {
356    ///         impl ::swc_common::AstNode for Type {}
357    ///     }))
358    ///     .parse::<ItemImpl>()
359    ///     .with_generics(input.generics);
360    /// ```
361    fn with_generics(self, generics: Generics) -> Self;
362}
363
364impl ItemImplExt for ItemImpl {
365    fn with_generics(mut self, mut generics: Generics) -> Self {
366        // TODO: Check conflicting name
367
368        let need_new_punct = !generics.params.empty_or_trailing();
369        if need_new_punct {
370            generics
371                .params
372                .push_punct(syn::token::Comma(Span::call_site()));
373        }
374
375        // Respan
376        if let Some(t) = generics.lt_token {
377            self.generics.lt_token = Some(t)
378        }
379        if let Some(t) = generics.gt_token {
380            self.generics.gt_token = Some(t)
381        }
382
383        let ty = self.self_ty;
384
385        // Handle generics defined on struct, enum, or union.
386        let mut item: ItemImpl = {
387            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
388            let item = if let Some((ref polarity, ref path, ref for_token)) = self.trait_ {
389                quote! {
390                    impl #impl_generics #polarity #path #for_token #ty #ty_generics #where_clause {}
391                }
392            } else {
393                quote! {
394                    impl #impl_generics #ty #ty_generics #where_clause {}
395
396                }
397            };
398            parse2(item.into_token_stream())
399                .unwrap_or_else(|err| panic!("with_generics failed: {}", err))
400        };
401
402        // Handle generics added by proc-macro.
403        item.generics
404            .params
405            .extend(self.generics.params.into_pairs());
406        match self.generics.where_clause {
407            Some(WhereClause {
408                ref mut predicates, ..
409            }) => predicates.extend(
410                generics
411                    .where_clause
412                    .into_iter()
413                    .flat_map(|wc| wc.predicates.into_pairs()),
414            ),
415            ref mut opt @ None => *opt = generics.where_clause,
416        }
417
418        ItemImpl {
419            attrs: self.attrs,
420            defaultness: self.defaultness,
421            unsafety: self.unsafety,
422            impl_token: self.impl_token,
423            brace_token: self.brace_token,
424            items: self.items,
425            ..item
426        }
427    }
428}