1extern crate proc_macro;
2
3use heck::ToSnakeCase;
4use proc_macro2::Span;
5use quote::{quote, ToTokens};
6use syn::{
7 parse,
8 parse::Parse,
9 parse2, parse_quote,
10 punctuated::{Pair, Punctuated},
11 spanned::Spanned,
12 Data, DataEnum, DeriveInput, Expr, ExprLit, Field, Fields, Generics, Ident, ImplItem, ItemImpl,
13 Lit, Meta, MetaNameValue, Path, Token, Type, TypePath, TypeReference, TypeTuple, WhereClause,
14};
15
16#[proc_macro_derive(Is, attributes(is))]
53pub fn is(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54 let input: DeriveInput = syn::parse(input).expect("failed to parse derive input");
55 let generics: Generics = input.generics.clone();
56
57 let items = match input.data {
58 Data::Enum(e) => expand(e),
59 _ => panic!("`Is` can be applied only on enums"),
60 };
61
62 ItemImpl {
63 attrs: vec![],
64 defaultness: None,
65 unsafety: None,
66 impl_token: Default::default(),
67 generics: Default::default(),
68 trait_: None,
69 self_ty: Box::new(Type::Path(TypePath {
70 qself: None,
71 path: Path::from(input.ident),
72 })),
73 brace_token: Default::default(),
74 items,
75 }
76 .with_generics(generics)
77 .into_token_stream()
78 .into()
79}
80
81#[derive(Debug)]
82struct Input {
83 name: String,
84}
85
86impl Parse for Input {
87 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
88 let _: Ident = input.parse()?;
89 let _: Token![=] = input.parse()?;
90
91 let name = input.parse::<ExprLit>()?;
92
93 Ok(Input {
94 name: match name.lit {
95 Lit::Str(s) => s.value(),
96 _ => panic!("is(name = ...) expects a string literal"),
97 },
98 })
99 }
100}
101
102fn expand(input: DataEnum) -> Vec<ImplItem> {
103 let mut items = vec![];
104
105 for v in &input.variants {
106 let attrs = v
107 .attrs
108 .iter()
109 .filter(|attr| attr.path().is_ident("is"))
110 .collect::<Vec<_>>();
111 if attrs.len() >= 2 {
112 panic!("derive(Is) expects no attribute or one attribute")
113 }
114 let i = match attrs.into_iter().next() {
115 None => Input {
116 name: {
117 v.ident.to_string().to_snake_case()
118 },
120 },
121 Some(attr) => {
122 let mut input = Input {
125 name: Default::default(),
126 };
127
128 let mut apply = |v: &MetaNameValue| {
129 assert!(
130 v.path.is_ident("name"),
131 "Currently, is() only supports `is(name = 'foo')`"
132 );
133
134 input.name = match &v.value {
135 Expr::Lit(ExprLit {
136 lit: Lit::Str(s), ..
137 }) => s.value(),
138 _ => unimplemented!(
139 "is(): name must be a string literal but {:?} is provided",
140 v.value
141 ),
142 };
143 };
144
145 match &attr.meta {
146 Meta::NameValue(v) => {
147 apply(v)
149 }
150 Meta::List(l) => {
151 input = parse2(l.tokens.clone()).expect("failed to parse input");
153 }
154 _ => unimplemented!("is({:?})", attr.meta),
155 }
156
157 input
158 }
159 };
160
161 let name = &*i.name;
162 {
163 let name_of_is = Ident::new(&format!("is_{name}"), v.ident.span());
164 let docs_of_is = format!(
165 "Returns `true` if `self` is of variant [`{variant}`].\n\n[`{variant}`]: \
166 #variant.{variant}",
167 variant = v.ident,
168 );
169
170 let variant = &v.ident;
171
172 let item_impl: ItemImpl = parse_quote!(
173 impl Type {
174 #[doc = #docs_of_is]
175 #[inline]
176 pub const fn #name_of_is(&self) -> bool {
177 match *self {
178 Self::#variant { .. } => true,
179 _ => false,
180 }
181 }
182 }
183 );
184
185 items.extend(item_impl.items);
186 }
187
188 {
189 let name_of_cast = Ident::new(&format!("as_{name}"), v.ident.span());
190 let name_of_cast_mut = Ident::new(&format!("as_mut_{name}"), v.ident.span());
191 let name_of_expect = Ident::new(&format!("expect_{name}"), v.ident.span());
192 let name_of_take = Ident::new(name, v.ident.span());
193
194 let docs_of_cast = format!(
195 "Returns `Some` if `self` is a reference of variant [`{variant}`], and `None` \
196 otherwise.\n\n[`{variant}`]: #variant.{variant}",
197 variant = v.ident,
198 );
199 let docs_of_cast_mut = format!(
200 "Returns `Some` if `self` is a mutable reference of variant [`{variant}`], and \
201 `None` otherwise.\n\n[`{variant}`]: #variant.{variant}",
202 variant = v.ident,
203 );
204 let docs_of_expect = format!(
205 "Unwraps the value, yielding the content of [`{variant}`].\n\n# Panics\n\nPanics \
206 if the value is not [`{variant}`], with a panic message including the content of \
207 `self`.\n\n[`{variant}`]: #variant.{variant}",
208 variant = v.ident,
209 );
210 let docs_of_take = format!(
211 "Returns `Some` if `self` is of variant [`{variant}`], and `None` \
212 otherwise.\n\n[`{variant}`]: #variant.{variant}",
213 variant = v.ident,
214 );
215
216 if let Fields::Unnamed(fields) = &v.fields {
217 let types = fields.unnamed.iter().map(|f| f.ty.clone());
218 let cast_ty = types_to_type(types.clone().map(|ty| add_ref(false, ty)));
219 let cast_ty_mut = types_to_type(types.clone().map(|ty| add_ref(true, ty)));
220 let ty = types_to_type(types);
221
222 let mut fields: Punctuated<Ident, Token![,]> = fields
223 .unnamed
224 .clone()
225 .into_pairs()
226 .enumerate()
227 .map(|(i, pair)| {
228 let handle = |f: Field| {
229 Ident::new(&format!("v{i}"), f.span())
231 };
232 match pair {
233 Pair::Punctuated(v, p) => Pair::Punctuated(handle(v), p),
234 Pair::End(v) => Pair::End(handle(v)),
235 }
236 })
237 .collect();
238
239 if let Some(mut pair) = fields.pop() {
244 if let Pair::Punctuated(v, _) = pair {
245 pair = Pair::End(v);
246 }
247 fields.extend(std::iter::once(pair));
248 }
249
250 let variant = &v.ident;
251
252 let item_impl: ItemImpl = parse_quote!(
253 impl #ty {
254 #[doc = #docs_of_cast]
255 #[inline]
256 pub fn #name_of_cast(&self) -> Option<#cast_ty> {
257 match self {
258 Self::#variant(#fields) => Some((#fields)),
259 _ => None,
260 }
261 }
262
263 #[doc = #docs_of_cast_mut]
264 #[inline]
265 pub fn #name_of_cast_mut(&mut self) -> Option<#cast_ty_mut> {
266 match self {
267 Self::#variant(#fields) => Some((#fields)),
268 _ => None,
269 }
270 }
271
272 #[doc = #docs_of_expect]
273 #[inline]
274 pub fn #name_of_expect(self) -> #ty
275 where
276 Self: ::std::fmt::Debug,
277 {
278 match self {
279 Self::#variant(#fields) => (#fields),
280 _ => panic!("called expect on {:?}", self),
281 }
282 }
283
284 #[doc = #docs_of_take]
285 #[inline]
286 pub fn #name_of_take(self) -> Option<#ty> {
287 match self {
288 Self::#variant(#fields) => Some((#fields)),
289 _ => None,
290 }
291 }
292 }
293 );
294
295 items.extend(item_impl.items);
296 }
297 }
298 }
299
300 items
301}
302
303fn types_to_type(types: impl Iterator<Item = Type>) -> Type {
304 let mut types: Punctuated<_, _> = types.collect();
305 if types.len() == 1 {
306 types.pop().expect("len is 1").into_value()
307 } else {
308 TypeTuple {
309 paren_token: Default::default(),
310 elems: types,
311 }
312 .into()
313 }
314}
315
316fn add_ref(mutable: bool, ty: Type) -> Type {
317 Type::Reference(TypeReference {
318 and_token: Default::default(),
319 lifetime: None,
320 mutability: if mutable {
321 Some(Default::default())
322 } else {
323 None
324 },
325 elem: Box::new(ty),
326 })
327}
328
329trait ItemImplExt {
331 fn with_generics(self, generics: Generics) -> Self;
362}
363
364impl ItemImplExt for ItemImpl {
365 fn with_generics(mut self, mut generics: Generics) -> Self {
366 let need_new_punct = !generics.params.empty_or_trailing();
369 if need_new_punct {
370 generics
371 .params
372 .push_punct(syn::token::Comma(Span::call_site()));
373 }
374
375 if let Some(t) = generics.lt_token {
377 self.generics.lt_token = Some(t)
378 }
379 if let Some(t) = generics.gt_token {
380 self.generics.gt_token = Some(t)
381 }
382
383 let ty = self.self_ty;
384
385 let mut item: ItemImpl = {
387 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
388 let item = if let Some((ref polarity, ref path, ref for_token)) = self.trait_ {
389 quote! {
390 impl #impl_generics #polarity #path #for_token #ty #ty_generics #where_clause {}
391 }
392 } else {
393 quote! {
394 impl #impl_generics #ty #ty_generics #where_clause {}
395
396 }
397 };
398 parse2(item.into_token_stream())
399 .unwrap_or_else(|err| panic!("with_generics failed: {}", err))
400 };
401
402 item.generics
404 .params
405 .extend(self.generics.params.into_pairs());
406 match self.generics.where_clause {
407 Some(WhereClause {
408 ref mut predicates, ..
409 }) => predicates.extend(
410 generics
411 .where_clause
412 .into_iter()
413 .flat_map(|wc| wc.predicates.into_pairs()),
414 ),
415 ref mut opt @ None => *opt = generics.where_clause,
416 }
417
418 ItemImpl {
419 attrs: self.attrs,
420 defaultness: self.defaultness,
421 unsafety: self.unsafety,
422 impl_token: self.impl_token,
423 brace_token: self.brace_token,
424 items: self.items,
425 ..item
426 }
427 }
428}