1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::{
5 parse_macro_input, parse_quote, Attribute, Block, Error, Expr, ExprLit, ImplItem, ImplItemFn,
6 Item, ItemFn, ItemImpl, ItemMod, Lit, Meta, Result, Signature, Type,
7};
8
9#[proc_macro_attribute]
64pub fn cfg_or_panic(args: TokenStream, input: TokenStream) -> TokenStream {
65 let expander = Expander::new(args);
66 let mut item = parse_macro_input!(input as Item);
67
68 if let Err(e) = expander.expand_item(&mut item) {
69 return e.to_compile_error().into();
70 }
71 item.into_token_stream().into()
72}
73
74struct Expander {
75 args: TokenStream2,
76}
77
78impl Expander {
79 fn new(args: impl Into<TokenStream2>) -> Self {
80 Self { args: args.into() }
81 }
82
83 fn expand_item(&self, item: &mut Item) -> Result<()> {
84 match item {
85 Item::Fn(item_fn) => self.expand_fn(item_fn),
86 Item::Impl(item_impl) => self.expand_impl(item_impl),
87 Item::Mod(item_mod) => self.expand_mod(item_mod),
88 _ => Err(Error::new_spanned(
89 item,
90 "`#[cfg_or_panic]` can only be used on functions, `mod`, and `impl` blocks",
91 )),
92 }
93 }
94
95 fn expand_mod(&self, item_mod: &mut ItemMod) -> Result<()> {
96 let Some((_, content)) = &mut item_mod.content else {
97 return Ok(());
98 };
99
100 for item in content {
101 self.expand_item(item).ok();
102 }
103
104 Ok(())
105 }
106
107 fn expand_impl(&self, item_impl: &mut ItemImpl) -> Result<()> {
108 for item in &mut item_impl.items {
109 #[allow(clippy::single_match)]
110 match item {
111 ImplItem::Fn(impl_item_fn) => self.expand_impl_fn(impl_item_fn)?,
112 _ => {}
113 }
114 }
115
116 Ok(())
117 }
118
119 fn expand_fn(&self, f: &mut ItemFn) -> Result<()> {
120 self.expand_fn_inner(&f.sig, &mut f.block, &mut f.attrs)
121 }
122
123 fn expand_impl_fn(&self, f: &mut ImplItemFn) -> Result<()> {
124 self.expand_fn_inner(&f.sig, &mut f.block, &mut f.attrs)
125 }
126
127 fn expand_fn_inner(
128 &self,
129 sig: &Signature,
130 fn_block: &mut Block,
131 fn_attrs: &mut Vec<Attribute>,
132 ) -> Result<()> {
133 let name = &sig.ident;
134 let args = &self.args;
135
136 let return_ty = {
137 let mut return_ty = None;
138 let mut new_fn_attrs = Vec::new();
139
140 for fn_attr in fn_attrs.drain(..) {
142 if let Some(ty) = extract_panic_return_attr(&fn_attr) {
143 return_ty = Some(ty?);
144 } else {
145 new_fn_attrs.push(fn_attr);
146 }
147 }
148
149 *fn_attrs = new_fn_attrs;
150 return_ty
151 };
152
153 let msg = format!(
154 "function `{}` unimplemented unless `#[cfg({})]` is activated",
155 name, args
156 );
157 let unimplemented = quote!(
158 panic!(#msg);
159 );
160
161 let may_with_ret_ty = if let Some(ty) = return_ty {
162 quote!(
163 #[allow(unreachable_code, clippy::diverging_sub_expression)]
164 {
165 let __ret: #ty = #unimplemented;
166 return __ret;
167 }
168 )
169 } else {
170 unimplemented
171 };
172
173 let block = std::mem::replace(fn_block, parse_quote!({}));
174 *fn_block = parse_quote!({
175 #[cfg(not(#args))]
176 #may_with_ret_ty
177 #[cfg(#args)]
178 #block
179 });
180
181 let attr = parse_quote!(
182 #[cfg_attr(not(#args), allow(unused_variables))]
183 );
184 fn_attrs.push(attr);
185
186 Ok(())
187 }
188}
189
190fn extract_panic_return_attr(attr: &Attribute) -> Option<Result<Type>> {
191 let Meta::NameValue(name_value) = &attr.meta else {
192 return None;
193 };
194 if name_value.path.get_ident()? != "panic_return" {
195 return None;
196 }
197
198 Some(parse_panic_return_attr(name_value.value.clone()))
199}
200
201fn parse_panic_return_attr(value_expr: Expr) -> Result<Type> {
202 let Expr::Lit(ExprLit {
203 lit: Lit::Str(lit_str),
204 ..
205 }) = value_expr
206 else {
207 return Err(Error::new_spanned(value_expr, "expected a string literal"));
208 };
209
210 syn::parse_str(&lit_str.value())
211}