cfg_or_panic/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::{
5    parse_macro_input, parse_quote, Attribute, Block, Error, Expr, ExprLit, ImplItem, ImplItemFn,
6    Item, ItemFn, ItemImpl, ItemMod, Lit, Meta, Result, Signature, Type,
7};
8
9/// Keep the function body under `#[cfg(..)]`, or replace it with `unimplemented!()` under `#[cfg(not(..))]`.
10///
11/// # Examples
12///
13/// `#[cfg_or_panic]` can be used on functions, `mod`, and `impl` blocks.
14///
15/// ## Function
16/// ```should_panic
17/// # use cfg_or_panic::cfg_or_panic;
18/// #[cfg_or_panic(foo)]
19/// fn foo() {
20///   println!("foo");
21/// }
22/// # fn main() { foo(); }
23/// ```
24///
25/// ## `mod`
26/// ```should_panic
27/// # use cfg_or_panic::cfg_or_panic;
28/// #[cfg_or_panic(foo)]
29/// mod foo {
30///   pub fn foo() {
31///     println!("foo");
32///   }
33/// }
34/// # fn main() { foo::foo(); }
35/// ```
36///
37/// ## `impl`
38/// ```should_panic
39/// # use cfg_or_panic::cfg_or_panic;
40/// struct Foo(String);
41///
42/// #[cfg_or_panic(foo)]
43/// impl Foo {
44///   fn foo(&self) {
45///     println!("foo: {}", self.0);
46///   }
47/// }
48/// # fn main() { Foo("bar".to_owned()).foo(); }
49/// ```
50///
51/// ## Dummy return type
52/// For the functions returning an `impl Trait`, you may have to specify a dummy return type for the panic branch.
53/// This can be done by adding `#[panic_return = "dummy::return::Type"]` to the function.
54/// ```should_panic
55/// # use cfg_or_panic::cfg_or_panic;
56/// #[cfg_or_panic(foo)]
57/// #[panic_return = "std::iter::Empty<_>"]
58/// fn my_iter() -> impl Iterator<Item = i32> {
59///   (0..10).into_iter()
60/// }
61/// # fn main() { my_iter().count(); }
62/// ```
63#[proc_macro_attribute]
64pub fn cfg_or_panic(args: TokenStream, input: TokenStream) -> TokenStream {
65    let expander = Expander::new(args);
66    let mut item = parse_macro_input!(input as Item);
67
68    if let Err(e) = expander.expand_item(&mut item) {
69        return e.to_compile_error().into();
70    }
71    item.into_token_stream().into()
72}
73
74struct Expander {
75    args: TokenStream2,
76}
77
78impl Expander {
79    fn new(args: impl Into<TokenStream2>) -> Self {
80        Self { args: args.into() }
81    }
82
83    fn expand_item(&self, item: &mut Item) -> Result<()> {
84        match item {
85            Item::Fn(item_fn) => self.expand_fn(item_fn),
86            Item::Impl(item_impl) => self.expand_impl(item_impl),
87            Item::Mod(item_mod) => self.expand_mod(item_mod),
88            _ => Err(Error::new_spanned(
89                item,
90                "`#[cfg_or_panic]` can only be used on functions, `mod`, and `impl` blocks",
91            )),
92        }
93    }
94
95    fn expand_mod(&self, item_mod: &mut ItemMod) -> Result<()> {
96        let Some((_, content)) = &mut item_mod.content else {
97            return Ok(());
98        };
99
100        for item in content {
101            self.expand_item(item).ok();
102        }
103
104        Ok(())
105    }
106
107    fn expand_impl(&self, item_impl: &mut ItemImpl) -> Result<()> {
108        for item in &mut item_impl.items {
109            #[allow(clippy::single_match)]
110            match item {
111                ImplItem::Fn(impl_item_fn) => self.expand_impl_fn(impl_item_fn)?,
112                _ => {}
113            }
114        }
115
116        Ok(())
117    }
118
119    fn expand_fn(&self, f: &mut ItemFn) -> Result<()> {
120        self.expand_fn_inner(&f.sig, &mut f.block, &mut f.attrs)
121    }
122
123    fn expand_impl_fn(&self, f: &mut ImplItemFn) -> Result<()> {
124        self.expand_fn_inner(&f.sig, &mut f.block, &mut f.attrs)
125    }
126
127    fn expand_fn_inner(
128        &self,
129        sig: &Signature,
130        fn_block: &mut Block,
131        fn_attrs: &mut Vec<Attribute>,
132    ) -> Result<()> {
133        let name = &sig.ident;
134        let args = &self.args;
135
136        let return_ty = {
137            let mut return_ty = None;
138            let mut new_fn_attrs = Vec::new();
139
140            // TODO: use `extract_if` when stable
141            for fn_attr in fn_attrs.drain(..) {
142                if let Some(ty) = extract_panic_return_attr(&fn_attr) {
143                    return_ty = Some(ty?);
144                } else {
145                    new_fn_attrs.push(fn_attr);
146                }
147            }
148
149            *fn_attrs = new_fn_attrs;
150            return_ty
151        };
152
153        let msg = format!(
154            "function `{}` unimplemented unless `#[cfg({})]` is activated",
155            name, args
156        );
157        let unimplemented = quote!(
158            panic!(#msg);
159        );
160
161        let may_with_ret_ty = if let Some(ty) = return_ty {
162            quote!(
163                #[allow(unreachable_code, clippy::diverging_sub_expression)]
164                {
165                    let __ret: #ty = #unimplemented;
166                    return __ret;
167                }
168            )
169        } else {
170            unimplemented
171        };
172
173        let block = std::mem::replace(fn_block, parse_quote!({}));
174        *fn_block = parse_quote!({
175            #[cfg(not(#args))]
176            #may_with_ret_ty
177            #[cfg(#args)]
178            #block
179        });
180
181        let attr = parse_quote!(
182            #[cfg_attr(not(#args), allow(unused_variables))]
183        );
184        fn_attrs.push(attr);
185
186        Ok(())
187    }
188}
189
190fn extract_panic_return_attr(attr: &Attribute) -> Option<Result<Type>> {
191    let Meta::NameValue(name_value) = &attr.meta else {
192        return None;
193    };
194    if name_value.path.get_ident()? != "panic_return" {
195        return None;
196    }
197
198    Some(parse_panic_return_attr(name_value.value.clone()))
199}
200
201fn parse_panic_return_attr(value_expr: Expr) -> Result<Type> {
202    let Expr::Lit(ExprLit {
203        lit: Lit::Str(lit_str),
204        ..
205    }) = value_expr
206    else {
207        return Err(Error::new_spanned(value_expr, "expected a string literal"));
208    };
209
210    syn::parse_str(&lit_str.value())
211}